diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index fc21c00..504298d 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -179,7 +179,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } - _ = s.getRecvBuf().Close() // recvBuf.Close should not return error + _ = s.recvBuf.Close() // recvBuf.Close should not return error if active { tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte) @@ -285,7 +285,7 @@ func (sesh *Session) closeSession() error { close(sesh.acceptCh) for id, stream := range sesh.streams { if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) { - _ = stream.getRecvBuf().Close() // will not block + _ = stream.recvBuf.Close() // will not block delete(sesh.streams, id) sesh.streamCountDecr() } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 9141e59..8c2ea15 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -23,9 +23,8 @@ type Stream struct { session *Session - allocIdempot sync.Once // a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't - // been read by the consumer through Read or WriteTo. Lazily allocated + // been read by the consumer through Read or WriteTo. recvBuf recvBuffer writingM sync.Mutex @@ -56,25 +55,20 @@ func makeStream(sesh *Session, id uint32) *Stream { }, } + if sesh.Unordered { + stream.recvBuf = NewDatagramBufferedPipe() + } else { + stream.recvBuf = NewStreamBuffer() + } + return stream } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } -func (s *Stream) getRecvBuf() recvBuffer { - s.allocIdempot.Do(func() { - if s.session.Unordered { - s.recvBuf = NewDatagramBufferedPipe() - } else { - s.recvBuf = NewStreamBuffer() - } - }) - return s.recvBuf -} - // receive a readily deobfuscated Frame so its payload can later be Read func (s *Stream) recvFrame(frame *Frame) error { - toBeClosed, err := s.getRecvBuf().Write(frame) + toBeClosed, err := s.recvBuf.Write(frame) if toBeClosed { err = s.passiveClose() if errors.Is(err, errRepeatStreamClosing) { @@ -93,7 +87,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { return 0, nil } - n, err = s.getRecvBuf().Read(buf) + n, err = s.recvBuf.Read(buf) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -104,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { // WriteTo continuously write data Stream has received into the writer w. func (s *Stream) WriteTo(w io.Writer) (int64, error) { // will keep writing until the underlying buffer is closed - n, err := s.getRecvBuf().WriteTo(w) + n, err := s.recvBuf.WriteTo(w) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -215,8 +209,8 @@ func (s *Stream) Close() error { func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } -func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) } -func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil } +func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } +func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d } var errNotImplemented = errors.New("Not implemented")