diff --git a/internal/common/copy.go b/internal/common/copy.go index 1d57b97..bf5ffb7 100644 --- a/internal/common/copy.go +++ b/internal/common/copy.go @@ -41,18 +41,18 @@ import ( // copyBuffer is the actual implementation of Copy and CopyBuffer. // if buf is nil, one is allocated. func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int64, err error) { - /* - // If the reader has a WriteTo method, use it to do the copy. - // Avoids an allocation and a copy. - if wt, ok := src.(WriterTo); ok { - return wt.WriteTo(dst) - } - // Similarly, if the writer has a ReadFrom method, use it to do the copy. - if rt, ok := dst.(ReaderFrom); ok { - return rt.ReadFrom(src) - } + defer func() { src.Close(); dst.Close() }() + + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.(io.WriterTo); ok { + return wt.WriteTo(dst) + } + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(io.ReaderFrom); ok { + return rt.ReadFrom(src) + } - */ //if buf == nil { size := 32 * 1024 /* @@ -97,7 +97,5 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int break } } - src.Close() - dst.Close() return written, err } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 86451af..948d323 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -84,6 +84,7 @@ func MakeSession(id uint32, config SessionConfig) *Session { if config.MaxFrameSize <= 0 { sesh.MaxFrameSize = defaultSendRecvBufSize - 1024 } + // todo: validation. this must be smaller than the buffer sizes sesh.maxStreamUnitWrite = sesh.MaxFrameSize - HEADER_LEN - sesh.Obfuscator.minOverhead sbConfig := switchboardConfig{ @@ -156,12 +157,11 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { s.writingM.Lock() defer s.writingM.Unlock() // Notify remote that this stream is closed - pad := genRandomPadding() f := &Frame{ StreamID: s.id, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, Closing: C_STREAM, - Payload: pad, + Payload: genRandomPadding(), } if s.obfsBuf == nil { diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 3dfb8a2..a4d5dd6 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -80,7 +80,11 @@ func (s *Stream) Read(buf []byte) (n int, err error) { } log.Tracef("%v read from stream %v with err %v", n, s.id, err) return +} +func (s *Stream) WriteTo(w io.Writer) (int64, error) { + // will keep writing until the underlying buffer is closed + return s.recvBuf.WriteTo(w) } // Write implements io.Write @@ -91,7 +95,6 @@ func (s *Stream) Write(in []byte) (n int, err error) { // in the middle of the execution of Write. This may cause the closing frame // to be sent before the data frame and cause loss of packet. //log.Tracef("attempting to write %v bytes to stream %v",len(in),s.id) - // todo: forbid concurrent write s.writingM.Lock() defer s.writingM.Unlock() if s.isClosed() {