diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index f347ea7..3a6df78 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -150,24 +150,21 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { _ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() if active { - s.writingM.Lock() - defer s.writingM.Unlock() // Notify remote that this stream is closed + padding := genRandomPadding() f := &Frame{ StreamID: s.id, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, Closing: C_STREAM, - Payload: genRandomPadding(), + Payload: padding, } - if s.obfsBuf == nil { - s.obfsBuf = make([]byte, s.session.SendBufferSize) - } - i, err := s.session.Obfs(f, s.obfsBuf) + obfsBuf := make([]byte, len(padding)+64) + i, err := sesh.Obfs(f, obfsBuf) if err != nil { return err } - _, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) + _, err = sesh.sb.send(obfsBuf[:i], &s.assignedConnId) if err != nil { return err } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index ce8c434..7d2c908 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -84,13 +84,39 @@ func (s *Stream) Read(buf []byte) (n int, err error) { func (s *Stream) WriteTo(w io.Writer) (int64, error) { // will keep writing until the underlying buffer is closed n, err := s.recvBuf.WriteTo(w) + log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream } - log.Tracef("%v read from stream %v with err %v", n, s.id, err) return n, nil } +func (s *Stream) writePayload(seq uint64, payload []byte) error { + f := &Frame{ + StreamID: s.id, + Seq: seq, + Closing: C_NOOP, + Payload: payload, + } + + var cipherTextLen int + cipherTextLen, err := s.session.Obfs(f, s.obfsBuf) + if err != nil { + return err + } + + _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) + log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err) + if err != nil { + if err == errBrokenSwitchboard { + s.session.SetTerminalMsg(err.Error()) + s.session.passiveClose() + } + return err + } + return nil +} + // Write implements io.Write func (s *Stream) Write(in []byte) (n int, err error) { s.writingM.Lock() @@ -113,27 +139,8 @@ func (s *Stream) Write(in []byte) (n int, err error) { } framePayload = in[n : s.session.maxStreamUnitWrite+n] } - - f := &Frame{ - StreamID: s.id, - Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, - Closing: C_NOOP, - Payload: framePayload, - } - - var cipherTextLen int - cipherTextLen, err = s.session.Obfs(f, s.obfsBuf) + err = s.writePayload(atomic.AddUint64(&s.nextSendSeq, 1)-1, framePayload) if err != nil { - return 0, err - } - - _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) - log.Tracef("%v sent to remote through stream %v with err %v", len(framePayload), s.id, err) - if err != nil { - if err == errBrokenSwitchboard { - s.session.SetTerminalMsg(err.Error()) - s.session.passiveClose() - } return } n += len(framePayload) @@ -141,6 +148,29 @@ func (s *Stream) Write(in []byte) (n int, err error) { return } +func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { + s.writingM.Lock() + defer s.writingM.Unlock() + if s.obfsBuf == nil { + s.obfsBuf = make([]byte, s.session.SendBufferSize) + } + for { + read, er := r.Read(s.obfsBuf[HEADER_LEN : HEADER_LEN+s.session.maxStreamUnitWrite]) + if er != nil { + return n, er + } + if s.isClosed() { + return 0, ErrBrokenStream + } + seq := atomic.AddUint64(&s.nextSendSeq, 1) - 1 + err = s.writePayload(seq, s.obfsBuf[HEADER_LEN:HEADER_LEN+read]) + if err != nil { + return + } + n += int64(read) + } +} + func (s *Stream) passiveClose() error { return s.session.closeStream(s, false) } diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 258fa13..4b6fdcd 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -187,12 +187,12 @@ func dispatchConnection(conn net.Conn, sta *State) { go func() { if _, err := common.Copy(localConn, newStream, sta.Timeout); err != nil { - log.Tracef("copying stream to proxy client: %v", err) + log.Tracef("copying stream to proxy server: %v", err) } }() go func() { if _, err := common.Copy(newStream, localConn, 0); err != nil { - log.Tracef("copying proxy client to stream: %v", err) + log.Tracef("copying proxy server to stream: %v", err) } }() }