From d65aee725a4ec001950abf3c785e86f0774e7b20 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 22 Jan 2020 21:12:32 +0000 Subject: [PATCH] Improve stream closing logic and add tests --- internal/multiplex/datagramBuffer.go | 10 +-- internal/multiplex/datagramBuffer_test.go | 9 ++- internal/multiplex/recvBuffer.go | 2 +- internal/multiplex/session.go | 8 +- internal/multiplex/session_test.go | 89 ++++++++++++++++++++++- internal/multiplex/stream.go | 8 +- internal/multiplex/streamBuffer.go | 17 ++--- 7 files changed, 120 insertions(+), 23 deletions(-) diff --git a/internal/multiplex/datagramBuffer.go b/internal/multiplex/datagramBuffer.go index b37d99f..eebbf48 100644 --- a/internal/multiplex/datagramBuffer.go +++ b/internal/multiplex/datagramBuffer.go @@ -52,12 +52,12 @@ func (d *datagramBuffer) Read(target []byte) (int, error) { return len(data), nil } -func (d *datagramBuffer) Write(f Frame) error { +func (d *datagramBuffer) Write(f Frame) (toBeClosed bool, err error) { d.rwCond.L.Lock() defer d.rwCond.L.Unlock() for { if atomic.LoadUint32(&d.closed) == 1 { - return io.ErrClosedPipe + return true, io.ErrClosedPipe } if len(d.buf) <= DATAGRAM_NUMBER_LIMIT { // if d.buf gets too large, write() will panic. We don't want this to happen @@ -66,10 +66,10 @@ func (d *datagramBuffer) Write(f Frame) error { d.rwCond.Wait() } - if f.Closing == 1 { + if f.Closing != C_NOOP { atomic.StoreUint32(&d.closed, 1) d.rwCond.Broadcast() - return nil + return true, nil } data := make([]byte, len(f.Payload)) @@ -77,7 +77,7 @@ func (d *datagramBuffer) Write(f Frame) error { d.buf = append(d.buf, data) // err will always be nil d.rwCond.Broadcast() - return nil + return false, nil } func (d *datagramBuffer) Close() error { diff --git a/internal/multiplex/datagramBuffer_test.go b/internal/multiplex/datagramBuffer_test.go index 3a7396d..4907866 100644 --- a/internal/multiplex/datagramBuffer_test.go +++ b/internal/multiplex/datagramBuffer_test.go @@ -11,7 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) { b := []byte{0x01, 0x02, 0x03} t.Run("simple write", func(t *testing.T) { pipe := NewDatagramBuffer() - err := pipe.Write(Frame{Payload: b}) + _, err := pipe.Write(Frame{Payload: b}) if err != nil { t.Error( "expecting", "nil error", @@ -23,7 +23,7 @@ func TestDatagramBuffer_RW(t *testing.T) { t.Run("simple read", func(t *testing.T) { pipe := NewDatagramBuffer() - _ = pipe.Write(Frame{Payload: b}) + _, _ = pipe.Write(Frame{Payload: b}) b2 := make([]byte, len(b)) n, err := pipe.Read(b2) if n != len(b) { @@ -56,7 +56,10 @@ func TestDatagramBuffer_RW(t *testing.T) { t.Run("writing closing frame", func(t *testing.T) { pipe := NewDatagramBuffer() - err := pipe.Write(Frame{Closing: 1}) + toBeClosed, err := pipe.Write(Frame{Closing: C_STREAM}) + if !toBeClosed { + t.Error("should be to be closed") + } if err != nil { t.Error( "expecting", "nil error", diff --git a/internal/multiplex/recvBuffer.go b/internal/multiplex/recvBuffer.go index 146224c..db5c844 100644 --- a/internal/multiplex/recvBuffer.go +++ b/internal/multiplex/recvBuffer.go @@ -4,5 +4,5 @@ import "io" type recvBuffer interface { io.ReadCloser - Write(Frame) error + Write(Frame) (toBeClosed bool, err error) } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 321b02d..86e7ff0 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -172,7 +172,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { return nil } -// recvDataFromRemote deobfuscate the frame and send it to the appropriate stream buffer +// recvDataFromRemote deobfuscate the frame and read the Closing field. If it is a closing frame, it writes the frame +// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new +// stream and then writes to the stream buffer func (sesh *Session) recvDataFromRemote(data []byte) error { frame, err := sesh.Deobfs(data) if err != nil { @@ -182,10 +184,12 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { if frame.Closing == C_STREAM { streamI, existing := sesh.streams.Load(frame.StreamID) if existing { - // If the stream has been closed and the current frame is a closing frame, we do noop + // DO NOT close the stream (or session below) straight away here because the sequence number of this frame + // hasn't been checked. There may be later data frames which haven't arrived stream := streamI.(*Stream) return stream.writeFrame(*frame) } else { + // If the stream has been closed and the current frame is a closing frame, we do noop return nil } } else if frame.Closing == C_SESSION { diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index bcb9909..595c0d8 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -143,7 +143,95 @@ func TestRecvDataFromRemote(t *testing.T) { t.Errorf("Expecting %x, got %x", testPayload, resultPayload) } }) +} +func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { + testPayloadLen := 1024 + testPayload := make([]byte, testPayloadLen) + rand.Read(testPayload) + obfsBuf := make([]byte, 17000) + + sessionKey := make([]byte, 32) + obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true) + seshConfigOrdered.Obfuscator = obfuscator + + rand.Read(sessionKey) + sesh := MakeSession(0, seshConfigOrdered) + + f1 := &Frame{ + 1, + 0, + C_NOOP, + testPayload, + } + // create stream 1 + n, _ := sesh.Obfs(f1, obfsBuf) + err := sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Fatalf("receiving normal frame for stream 1: %v", err) + } + s1I, ok := sesh.streams.Load(f1.StreamID) + if !ok { + t.Fatal("failed to fetch stream 1 after receiving it") + } + + // create stream 2 + f2 := &Frame{ + 2, + 0, + C_NOOP, + testPayload, + } + n, _ = sesh.Obfs(f2, obfsBuf) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Fatalf("receiving normal frame for stream 2: %v", err) + } + s2I, ok := sesh.streams.Load(f2.StreamID) + if !ok { + t.Fatal("failed to fetch stream 2 after receiving it") + } + + // close stream 1 + f1CloseStream := &Frame{ + 1, + 1, + C_STREAM, + testPayload, + } + n, _ = sesh.Obfs(f1CloseStream, obfsBuf) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Fatalf("receiving stream closing frame for stream 1: %v", err) + } + _, ok = sesh.streams.Load(f1.StreamID) + if ok { + t.Fatal("stream 1 still exist after receiving stream close") + } + s1 := s1I.(*Stream) + if !s1.isClosed() { + t.Fatal("stream 1 not marked as closed") + } + payloadBuf := make([]byte, testPayloadLen) + _, err = s1.recvBuf.Read(payloadBuf) + if err != nil || !bytes.Equal(payloadBuf, testPayload) { + t.Fatalf("failed to read from stream 1 after closing: %v", err) + } + s2 := s2I.(*Stream) + if s2.isClosed() { + t.Fatal("stream 2 shouldn't be closed") + } + + // close stream 1 again + n, _ = sesh.Obfs(f1CloseStream, obfsBuf) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Fatalf("receiving stream closing frame for stream 1: %v", err) + } + _, ok = sesh.streams.Load(f1.StreamID) + if ok { + t.Fatal("stream 1 exists after receiving stream close for the second time") + } } func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { @@ -211,5 +299,4 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { b.SetBytes(int64(n)) } }) - } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 84301a3..4d1bf39 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -59,7 +59,11 @@ func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream { func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) writeFrame(frame Frame) error { - return s.recvBuf.Write(frame) + toBeClosed, err := s.recvBuf.Write(frame) + if toBeClosed { + return s.passiveClose() + } + return err } // Read implements io.Read @@ -99,7 +103,7 @@ func (s *Stream) Write(in []byte) (n int, err error) { f := &Frame{ StreamID: s.id, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, - Closing: 0, + Closing: C_NOOP, Payload: in, } diff --git a/internal/multiplex/streamBuffer.go b/internal/multiplex/streamBuffer.go index 44cd147..9785a08 100644 --- a/internal/multiplex/streamBuffer.go +++ b/internal/multiplex/streamBuffer.go @@ -61,39 +61,38 @@ func NewStreamBuffer() *streamBuffer { // recvNewFrame is a forever running loop which receives frames unordered, // cache and order them and send them into sortedBufCh -func (sb *streamBuffer) Write(f Frame) error { +func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) { sb.recvM.Lock() defer sb.recvM.Unlock() // when there'fs no ooo packages in heap and we receive the next package in order if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq { - if f.Closing == 1 { + if f.Closing != C_NOOP { sb.buf.Close() - return nil + return true, nil } else { sb.buf.Write(f.Payload) sb.nextRecvSeq += 1 } - return nil + return false, nil } if f.Seq < sb.nextRecvSeq { - return fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq) + return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq) } heap.Push(&sb.sh, &f) // Keep popping from the heap until empty or to the point that the wanted seq was not received for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq { f = *heap.Pop(&sb.sh).(*Frame) - if f.Closing == 1 { - // empty data indicates closing signal + if f.Closing != C_NOOP { sb.buf.Close() - return nil + return true, nil } else { sb.buf.Write(f.Payload) sb.nextRecvSeq += 1 } } - return nil + return false, nil } func (sb *streamBuffer) Read(buf []byte) (int, error) {