diff --git a/internal/multiplex/bufferedPipe.go b/internal/multiplex/bufferedPipe.go new file mode 100644 index 0000000..4b940a4 --- /dev/null +++ b/internal/multiplex/bufferedPipe.go @@ -0,0 +1,78 @@ +// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173 + +package multiplex + +import ( + "bytes" + "io" + "sync" +) + +const BUF_SIZE_LIMIT = 1 << 20 * 500 + +type bufferedPipe struct { + buf *bytes.Buffer + closed bool + rwCond *sync.Cond +} + +func NewBufferedPipe() *bufferedPipe { + p := &bufferedPipe{ + buf: new(bytes.Buffer), + rwCond: sync.NewCond(&sync.Mutex{}), + } + return p +} + +func (p *bufferedPipe) Read(target []byte) (int, error) { + p.rwCond.L.Lock() + defer p.rwCond.L.Unlock() + for { + if p.closed && p.buf.Len() == 0 { + return 0, io.EOF + } + + if p.buf.Len() > 0 { + break + } + p.rwCond.Wait() + } + n, err := p.buf.Read(target) + // err will always be nil because we have already verified that buf.Len() != 0 + p.rwCond.Broadcast() + return n, err +} + +func (p *bufferedPipe) Write(input []byte) (int, error) { + p.rwCond.L.Lock() + defer p.rwCond.L.Unlock() + for { + if p.closed { + return 0, io.ErrClosedPipe + } + if p.buf.Len() <= BUF_SIZE_LIMIT { + // if p.buf gets too large, write() will panic. We don't want this to happen + break + } + p.rwCond.Wait() + } + n, err := p.buf.Write(input) + // err will always be nil + p.rwCond.Broadcast() + return n, err +} + +func (p *bufferedPipe) Close() error { + p.rwCond.L.Lock() + defer p.rwCond.L.Unlock() + + p.closed = true + p.rwCond.Broadcast() + return nil +} + +func (p *bufferedPipe) Len() int { + p.rwCond.L.Lock() + defer p.rwCond.L.Unlock() + return p.buf.Len() +} diff --git a/internal/multiplex/bufferedPipe_test.go b/internal/multiplex/bufferedPipe_test.go new file mode 100644 index 0000000..4c94be2 --- /dev/null +++ b/internal/multiplex/bufferedPipe_test.go @@ -0,0 +1,166 @@ +package multiplex + +import ( + "bytes" + "testing" + "time" +) + +func TestPipeRW(t *testing.T) { + pipe := NewBufferedPipe() + b := []byte{0x01, 0x02, 0x03} + n, err := pipe.Write(b) + if n != len(b) { + t.Error( + "For", "number of bytes written", + "expecting", len(b), + "got", n, + ) + } + if err != nil { + t.Error( + "For", "simple write", + "expecting", "nil error", + "got", err, + ) + } + + b2 := make([]byte, len(b)) + n, err = pipe.Read(b2) + if n != len(b) { + t.Error( + "For", "number of bytes read", + "expecting", len(b), + "got", n, + ) + } + if err != nil { + t.Error( + "For", "simple read", + "expecting", "nil error", + "got", err, + ) + } + if !bytes.Equal(b, b2) { + t.Error( + "For", "simple read", + "expecting", b, + "got", b2, + ) + } + +} + +func TestReadBlock(t *testing.T) { + pipe := NewBufferedPipe() + b := []byte{0x01, 0x02, 0x03} + go func() { + time.Sleep(10 * time.Millisecond) + pipe.Write(b) + }() + b2 := make([]byte, len(b)) + n, err := pipe.Read(b2) + if n != len(b) { + t.Error( + "For", "number of bytes read after block", + "expecting", len(b), + "got", n, + ) + } + if err != nil { + t.Error( + "For", "blocked read", + "expecting", "nil error", + "got", err, + ) + } + if !bytes.Equal(b, b2) { + t.Error( + "For", "blocked read", + "expecting", b, + "got", b2, + ) + } +} + +func TestPartialRead(t *testing.T) { + pipe := NewBufferedPipe() + b := []byte{0x01, 0x02, 0x03} + pipe.Write(b) + b1 := make([]byte, 1) + n, err := pipe.Read(b1) + if n != len(b1) { + t.Error( + "For", "number of bytes in partial read of 1", + "expecting", len(b1), + "got", n, + ) + } + if err != nil { + t.Error( + "For", "partial read of 1", + "expecting", "nil error", + "got", err, + ) + } + if b1[0] != b[0] { + t.Error( + "For", "partial read of 1", + "expecting", b[0], + "got", b1[0], + ) + } + b2 := make([]byte, 2) + n, err = pipe.Read(b2) + if n != len(b2) { + t.Error( + "For", "number of bytes in partial read of 2", + "expecting", len(b2), + "got", n, + ) + } + if err != nil { + t.Error( + "For", "partial read of 2", + "expecting", "nil error", + "got", err, + ) + } + if !bytes.Equal(b[1:], b2) { + t.Error( + "For", "partial read of 2", + "expecting", b[1:], + "got", b2, + ) + } +} + +func TestReadAfterClose(t *testing.T) { + pipe := NewBufferedPipe() + b := []byte{0x01, 0x02, 0x03} + pipe.Write(b) + b2 := make([]byte, len(b)) + pipe.Close() + n, err := pipe.Read(b2) + if n != len(b) { + t.Error( + "For", "number of bytes read", + "expecting", len(b), + "got", n, + ) + } + if err != nil { + t.Error( + "For", "simple read", + "expecting", "nil error", + "got", err, + ) + } + if !bytes.Equal(b, b2) { + t.Error( + "For", "simple read", + "expecting", b, + "got", b2, + ) + } +} diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go index 88ba386..051aa48 100644 --- a/internal/multiplex/frameSorter.go +++ b/internal/multiplex/frameSorter.go @@ -73,10 +73,10 @@ func (s *Stream) recvNewFrame() { if len(s.sh) == 0 && f.Seq == s.nextRecvSeq { if f.Closing == 1 { // empty data indicates closing signal - s.sortedBufCh <- []byte{} + s.passiveClose() return } else { - s.sortedBufCh <- f.Payload + s.sortedBuf.Write(f.Payload) s.nextRecvSeq += 1 if s.nextRecvSeq == 0 { // getting wrapped s.rev += 1 @@ -115,10 +115,10 @@ func (s *Stream) recvNewFrame() { f = heap.Pop(&s.sh).(*frameNode).frame if f.Closing == 1 { // empty data indicates closing signal - s.sortedBufCh <- []byte{} + s.passiveClose() return } else { - s.sortedBufCh <- f.Payload + s.sortedBuf.Write(f.Payload) s.nextRecvSeq += 1 if s.nextRecvSeq == 0 { // getting wrapped s.rev += 1 diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 0f887c1..5a6585a 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -2,8 +2,6 @@ package multiplex import ( "errors" - "io" - "log" "net" "time" @@ -29,10 +27,8 @@ type Stream struct { // New frames are received through newFrameCh by frameSorter newFrameCh chan *Frame - // sortedBufCh are order-sorted data ready to be read raw - sortedBufCh chan []byte - feederR *io.PipeReader - feederW *io.PipeWriter + + sortedBuf *bufferedPipe // atomic nextSendSeq uint32 @@ -45,45 +41,18 @@ type Stream struct { } func makeStream(id uint32, sesh *Session) *Stream { - r, w := io.Pipe() stream := &Stream{ - id: id, - session: sesh, - die: make(chan struct{}), - sh: []*frameNode{}, - newFrameCh: make(chan *Frame, 1024), - sortedBufCh: make(chan []byte, 1024), - feederR: r, - feederW: w, + id: id, + session: sesh, + die: make(chan struct{}), + sh: []*frameNode{}, + newFrameCh: make(chan *Frame, 1024), + sortedBuf: NewBufferedPipe(), } go stream.recvNewFrame() - go stream.feed() return stream } -func (stream *Stream) feed() { - for { - select { - case <-stream.die: - return - case data := <-stream.sortedBufCh: - if len(data) == 0 { - stream.passiveClose() - return - } - _, err := stream.feederW.Write(data) - if err != nil { - if err == io.ErrClosedPipe { - stream.Close() - return - } else { - log.Println(err) - } - } - } - } -} - func (stream *Stream) Read(buf []byte) (n int, err error) { if len(buf) == 0 { select { @@ -95,9 +64,13 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { } select { case <-stream.die: - return 0, ErrBrokenStream + if stream.sortedBuf.Len() == 0 { + return 0, ErrBrokenStream + } else { + return stream.sortedBuf.Read(buf) + } default: - return stream.feederR.Read(buf) + return stream.sortedBuf.Read(buf) } } @@ -168,9 +141,8 @@ func (stream *Stream) Close() error { tlsRecord, _ := stream.session.obfs(f) stream.session.sb.send(tlsRecord) + stream.sortedBuf.Close() stream.session.delStream(stream.id) - stream.feederW.Close() - stream.feederR.Close() //log.Printf("%v actively closed\n", stream.id) stream.writingM.Unlock() return nil