diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 3a6df78..49572cc 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -154,10 +154,11 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { padding := genRandomPadding() f := &Frame{ StreamID: s.id, - Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, + Seq: s.nextSendSeq, Closing: C_STREAM, Payload: padding, } + s.nextSendSeq++ obfsBuf := make([]byte, len(padding)+64) i, err := sesh.Obfs(f, obfsBuf) @@ -168,7 +169,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if err != nil { return err } - log.Tracef("stream %v actively closed", s.id) + log.Tracef("stream %v actively closed. seq %v", s.id, f.Seq) } else { log.Tracef("stream %v passively closed", s.id) } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 7d2c908..0b6974e 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -20,7 +20,6 @@ type Stream struct { recvBuf recvBuffer - // atomic nextSendSeq uint64 writingM sync.Mutex @@ -74,10 +73,10 @@ func (s *Stream) Read(buf []byte) (n int, err error) { } n, err = s.recvBuf.Read(buf) + log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream } - log.Tracef("%v read from stream %v with err %v", n, s.id, err) return } @@ -91,14 +90,7 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { return n, nil } -func (s *Stream) writePayload(seq uint64, payload []byte) error { - f := &Frame{ - StreamID: s.id, - Seq: seq, - Closing: C_NOOP, - Payload: payload, - } - +func (s *Stream) sendFrame(f *Frame) error { var cipherTextLen int cipherTextLen, err := s.session.Obfs(f, s.obfsBuf) if err != nil { @@ -106,7 +98,7 @@ func (s *Stream) writePayload(seq uint64, payload []byte) error { } _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) - log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err) + log.Tracef("%v sent to remote through stream %v with err %v. seq: %v", len(f.Payload), s.id, err, f.Seq) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) @@ -139,7 +131,14 @@ func (s *Stream) Write(in []byte) (n int, err error) { } framePayload = in[n : s.session.maxStreamUnitWrite+n] } - err = s.writePayload(atomic.AddUint64(&s.nextSendSeq, 1)-1, framePayload) + f := &Frame{ + StreamID: s.id, + Seq: s.nextSendSeq, + Closing: C_NOOP, + Payload: framePayload, + } + s.nextSendSeq++ + err = s.sendFrame(f) if err != nil { return } @@ -149,8 +148,6 @@ func (s *Stream) Write(in []byte) (n int, err error) { } func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { - s.writingM.Lock() - defer s.writingM.Unlock() if s.obfsBuf == nil { s.obfsBuf = make([]byte, s.session.SendBufferSize) } @@ -160,10 +157,20 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { return n, er } if s.isClosed() { - return 0, ErrBrokenStream + return n, ErrBrokenStream } - seq := atomic.AddUint64(&s.nextSendSeq, 1) - 1 - err = s.writePayload(seq, s.obfsBuf[HEADER_LEN:HEADER_LEN+read]) + + s.writingM.Lock() + f := &Frame{ + StreamID: s.id, + Seq: s.nextSendSeq, + Closing: C_NOOP, + Payload: s.obfsBuf[HEADER_LEN : HEADER_LEN+read], + } + s.nextSendSeq++ + err = s.sendFrame(f) + s.writingM.Unlock() + if err != nil { return } @@ -177,6 +184,9 @@ func (s *Stream) passiveClose() error { // active close. Close locally and tell the remote that this stream is being closed func (s *Stream) Close() error { + s.writingM.Lock() + defer s.writingM.Unlock() + return s.session.closeStream(s, true) }