From 0209bcd977ce28273cfed0363adee4c88530edab Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 22:33:01 +0000 Subject: [PATCH] Fix race condition in steam closing. Fall back to temp buffer allocation --- internal/multiplex/session.go | 20 ++++++++++++-------- internal/multiplex/stream.go | 6 +++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index fa9fc14..0a961f1 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -180,21 +180,25 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { _ = s.getRecvBuf().Close() // recvBuf.Close should not return error if active { - // must be holding s.wirtingM on entry - if len(s.obfsBuf) < 256+frameHeaderLength+sesh.Obfuscator.maxOverhead { - s.obfsBuf = make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) - } + tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) // Notify remote that this stream is closed - common.CryptoRandRead(s.obfsBuf[:1]) - padLen := int(s.obfsBuf[0]) + 1 - payload := s.obfsBuf[frameHeaderLength : padLen+frameHeaderLength] + common.CryptoRandRead(tmpBuf[:1]) + padLen := int(tmpBuf[0]) + 1 + payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength] common.CryptoRandRead(payload) + // must be holding s.wirtingM on entry s.writingFrame.Closing = closingStream s.writingFrame.Payload = payload - err := s.obfuscateAndSend(frameHeaderLength) + cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength) + s.writingFrame.Seq++ + if err != nil { + return err + } + + _, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId) if err != nil { return err } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 90d5c16..84f106a 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -118,7 +118,6 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { } func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { - var cipherTextLen int cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) if err != nil { return err @@ -174,11 +173,9 @@ func (s *Stream) Write(in []byte) (n int, err error) { // ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // for readFromTimeout amount of time func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { - s.writingM.Lock() if s.obfsBuf == nil { s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) } - s.writingM.Unlock() for { if s.readFromTimeout != 0 { if rder, ok := r.(net.Conn); !ok { @@ -191,6 +188,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { if er != nil { return n, er } + + // the above read may have been unblocked by another goroutine calling stream.Close(), so we need + // to check that here if s.isClosed() { return n, ErrBrokenStream }