mirror of https://github.com/cbeuw/Cloak
Fix race condition in steam closing. Fall back to temp buffer allocation
This commit is contained in:
parent
53f0116c1d
commit
0209bcd977
|
|
@ -180,21 +180,25 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
|||
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
|
||||
|
||||
if active {
|
||||
// must be holding s.wirtingM on entry
|
||||
if len(s.obfsBuf) < 256+frameHeaderLength+sesh.Obfuscator.maxOverhead {
|
||||
s.obfsBuf = make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
|
||||
}
|
||||
tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
|
||||
|
||||
// Notify remote that this stream is closed
|
||||
common.CryptoRandRead(s.obfsBuf[:1])
|
||||
padLen := int(s.obfsBuf[0]) + 1
|
||||
payload := s.obfsBuf[frameHeaderLength : padLen+frameHeaderLength]
|
||||
common.CryptoRandRead(tmpBuf[:1])
|
||||
padLen := int(tmpBuf[0]) + 1
|
||||
payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength]
|
||||
common.CryptoRandRead(payload)
|
||||
|
||||
// must be holding s.wirtingM on entry
|
||||
s.writingFrame.Closing = closingStream
|
||||
s.writingFrame.Payload = payload
|
||||
|
||||
err := s.obfuscateAndSend(frameHeaderLength)
|
||||
cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength)
|
||||
s.writingFrame.Seq++
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -118,7 +118,6 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
|||
}
|
||||
|
||||
func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
|
||||
var cipherTextLen int
|
||||
cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -174,11 +173,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
|||
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
|
||||
// for readFromTimeout amount of time
|
||||
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
s.writingM.Lock()
|
||||
if s.obfsBuf == nil {
|
||||
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
|
||||
}
|
||||
s.writingM.Unlock()
|
||||
for {
|
||||
if s.readFromTimeout != 0 {
|
||||
if rder, ok := r.(net.Conn); !ok {
|
||||
|
|
@ -191,6 +188,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
if er != nil {
|
||||
return n, er
|
||||
}
|
||||
|
||||
// the above read may have been unblocked by another goroutine calling stream.Close(), so we need
|
||||
// to check that here
|
||||
if s.isClosed() {
|
||||
return n, ErrBrokenStream
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue