Fix race condition in steam closing. Fall back to temp buffer allocation

This commit is contained in:
Andy Wang 2020-12-23 22:33:01 +00:00
parent 53f0116c1d
commit 0209bcd977
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
2 changed files with 15 additions and 11 deletions

View File

@ -180,21 +180,25 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error _ = s.getRecvBuf().Close() // recvBuf.Close should not return error
if active { if active {
// must be holding s.wirtingM on entry tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
if len(s.obfsBuf) < 256+frameHeaderLength+sesh.Obfuscator.maxOverhead {
s.obfsBuf = make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
}
// Notify remote that this stream is closed // Notify remote that this stream is closed
common.CryptoRandRead(s.obfsBuf[:1]) common.CryptoRandRead(tmpBuf[:1])
padLen := int(s.obfsBuf[0]) + 1 padLen := int(tmpBuf[0]) + 1
payload := s.obfsBuf[frameHeaderLength : padLen+frameHeaderLength] payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload) common.CryptoRandRead(payload)
// must be holding s.wirtingM on entry
s.writingFrame.Closing = closingStream s.writingFrame.Closing = closingStream
s.writingFrame.Payload = payload s.writingFrame.Payload = payload
err := s.obfuscateAndSend(frameHeaderLength) cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength)
s.writingFrame.Seq++
if err != nil {
return err
}
_, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId)
if err != nil { if err != nil {
return err return err
} }

View File

@ -118,7 +118,6 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
} }
func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
var cipherTextLen int
cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
if err != nil { if err != nil {
return err return err
@ -174,11 +173,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
// for readFromTimeout amount of time // for readFromTimeout amount of time
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
s.writingM.Lock()
if s.obfsBuf == nil { if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
} }
s.writingM.Unlock()
for { for {
if s.readFromTimeout != 0 { if s.readFromTimeout != 0 {
if rder, ok := r.(net.Conn); !ok { if rder, ok := r.(net.Conn); !ok {
@ -191,6 +188,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
if er != nil { if er != nil {
return n, er return n, er
} }
// the above read may have been unblocked by another goroutine calling stream.Close(), so we need
// to check that here
if s.isClosed() { if s.isClosed() {
return n, ErrBrokenStream return n, ErrBrokenStream
} }