Implement stream ReadFrom (flimsy)

This commit is contained in:
Andy Wang 2020-04-12 23:01:30 +01:00
parent 73544c03bb
commit c8368bcc7e
3 changed files with 58 additions and 31 deletions

View File

@ -150,24 +150,21 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
_ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() _ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
if active { if active {
s.writingM.Lock()
defer s.writingM.Unlock()
// Notify remote that this stream is closed // Notify remote that this stream is closed
padding := genRandomPadding()
f := &Frame{ f := &Frame{
StreamID: s.id, StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: C_STREAM, Closing: C_STREAM,
Payload: genRandomPadding(), Payload: padding,
} }
if s.obfsBuf == nil { obfsBuf := make([]byte, len(padding)+64)
s.obfsBuf = make([]byte, s.session.SendBufferSize) i, err := sesh.Obfs(f, obfsBuf)
}
i, err := s.session.Obfs(f, s.obfsBuf)
if err != nil { if err != nil {
return err return err
} }
_, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) _, err = sesh.sb.send(obfsBuf[:i], &s.assignedConnId)
if err != nil { if err != nil {
return err return err
} }

View File

@ -84,13 +84,39 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
func (s *Stream) WriteTo(w io.Writer) (int64, error) { func (s *Stream) WriteTo(w io.Writer) (int64, error) {
// will keep writing until the underlying buffer is closed // will keep writing until the underlying buffer is closed
n, err := s.recvBuf.WriteTo(w) n, err := s.recvBuf.WriteTo(w)
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF { if err == io.EOF {
return n, ErrBrokenStream return n, ErrBrokenStream
} }
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
return n, nil return n, nil
} }
func (s *Stream) writePayload(seq uint64, payload []byte) error {
f := &Frame{
StreamID: s.id,
Seq: seq,
Closing: C_NOOP,
Payload: payload,
}
var cipherTextLen int
cipherTextLen, err := s.session.Obfs(f, s.obfsBuf)
if err != nil {
return err
}
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
s.session.passiveClose()
}
return err
}
return nil
}
// Write implements io.Write // Write implements io.Write
func (s *Stream) Write(in []byte) (n int, err error) { func (s *Stream) Write(in []byte) (n int, err error) {
s.writingM.Lock() s.writingM.Lock()
@ -113,27 +139,8 @@ func (s *Stream) Write(in []byte) (n int, err error) {
} }
framePayload = in[n : s.session.maxStreamUnitWrite+n] framePayload = in[n : s.session.maxStreamUnitWrite+n]
} }
err = s.writePayload(atomic.AddUint64(&s.nextSendSeq, 1)-1, framePayload)
f := &Frame{
StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: C_NOOP,
Payload: framePayload,
}
var cipherTextLen int
cipherTextLen, err = s.session.Obfs(f, s.obfsBuf)
if err != nil { if err != nil {
return 0, err
}
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v", len(framePayload), s.id, err)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
s.session.passiveClose()
}
return return
} }
n += len(framePayload) n += len(framePayload)
@ -141,6 +148,29 @@ func (s *Stream) Write(in []byte) (n int, err error) {
return return
} }
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
s.writingM.Lock()
defer s.writingM.Unlock()
if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.SendBufferSize)
}
for {
read, er := r.Read(s.obfsBuf[HEADER_LEN : HEADER_LEN+s.session.maxStreamUnitWrite])
if er != nil {
return n, er
}
if s.isClosed() {
return 0, ErrBrokenStream
}
seq := atomic.AddUint64(&s.nextSendSeq, 1) - 1
err = s.writePayload(seq, s.obfsBuf[HEADER_LEN:HEADER_LEN+read])
if err != nil {
return
}
n += int64(read)
}
}
func (s *Stream) passiveClose() error { func (s *Stream) passiveClose() error {
return s.session.closeStream(s, false) return s.session.closeStream(s, false)
} }

View File

@ -187,12 +187,12 @@ func dispatchConnection(conn net.Conn, sta *State) {
go func() { go func() {
if _, err := common.Copy(localConn, newStream, sta.Timeout); err != nil { if _, err := common.Copy(localConn, newStream, sta.Timeout); err != nil {
log.Tracef("copying stream to proxy client: %v", err) log.Tracef("copying stream to proxy server: %v", err)
} }
}() }()
go func() { go func() {
if _, err := common.Copy(newStream, localConn, 0); err != nil { if _, err := common.Copy(newStream, localConn, 0); err != nil {
log.Tracef("copying proxy client to stream: %v", err) log.Tracef("copying proxy server to stream: %v", err)
} }
}() }()
} }