Improve stream closing logic and add tests

This commit is contained in:
Andy Wang 2020-01-22 21:12:32 +00:00
parent af5c8a381f
commit d65aee725a
7 changed files with 120 additions and 23 deletions

View File

@ -52,12 +52,12 @@ func (d *datagramBuffer) Read(target []byte) (int, error) {
return len(data), nil return len(data), nil
} }
func (d *datagramBuffer) Write(f Frame) error { func (d *datagramBuffer) Write(f Frame) (toBeClosed bool, err error) {
d.rwCond.L.Lock() d.rwCond.L.Lock()
defer d.rwCond.L.Unlock() defer d.rwCond.L.Unlock()
for { for {
if atomic.LoadUint32(&d.closed) == 1 { if atomic.LoadUint32(&d.closed) == 1 {
return io.ErrClosedPipe return true, io.ErrClosedPipe
} }
if len(d.buf) <= DATAGRAM_NUMBER_LIMIT { if len(d.buf) <= DATAGRAM_NUMBER_LIMIT {
// if d.buf gets too large, write() will panic. We don't want this to happen // if d.buf gets too large, write() will panic. We don't want this to happen
@ -66,10 +66,10 @@ func (d *datagramBuffer) Write(f Frame) error {
d.rwCond.Wait() d.rwCond.Wait()
} }
if f.Closing == 1 { if f.Closing != C_NOOP {
atomic.StoreUint32(&d.closed, 1) atomic.StoreUint32(&d.closed, 1)
d.rwCond.Broadcast() d.rwCond.Broadcast()
return nil return true, nil
} }
data := make([]byte, len(f.Payload)) data := make([]byte, len(f.Payload))
@ -77,7 +77,7 @@ func (d *datagramBuffer) Write(f Frame) error {
d.buf = append(d.buf, data) d.buf = append(d.buf, data)
// err will always be nil // err will always be nil
d.rwCond.Broadcast() d.rwCond.Broadcast()
return nil return false, nil
} }
func (d *datagramBuffer) Close() error { func (d *datagramBuffer) Close() error {

View File

@ -11,7 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
b := []byte{0x01, 0x02, 0x03} b := []byte{0x01, 0x02, 0x03}
t.Run("simple write", func(t *testing.T) { t.Run("simple write", func(t *testing.T) {
pipe := NewDatagramBuffer() pipe := NewDatagramBuffer()
err := pipe.Write(Frame{Payload: b}) _, err := pipe.Write(Frame{Payload: b})
if err != nil { if err != nil {
t.Error( t.Error(
"expecting", "nil error", "expecting", "nil error",
@ -23,7 +23,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
t.Run("simple read", func(t *testing.T) { t.Run("simple read", func(t *testing.T) {
pipe := NewDatagramBuffer() pipe := NewDatagramBuffer()
_ = pipe.Write(Frame{Payload: b}) _, _ = pipe.Write(Frame{Payload: b})
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
if n != len(b) { if n != len(b) {
@ -56,7 +56,10 @@ func TestDatagramBuffer_RW(t *testing.T) {
t.Run("writing closing frame", func(t *testing.T) { t.Run("writing closing frame", func(t *testing.T) {
pipe := NewDatagramBuffer() pipe := NewDatagramBuffer()
err := pipe.Write(Frame{Closing: 1}) toBeClosed, err := pipe.Write(Frame{Closing: C_STREAM})
if !toBeClosed {
t.Error("should be to be closed")
}
if err != nil { if err != nil {
t.Error( t.Error(
"expecting", "nil error", "expecting", "nil error",

View File

@ -4,5 +4,5 @@ import "io"
type recvBuffer interface { type recvBuffer interface {
io.ReadCloser io.ReadCloser
Write(Frame) error Write(Frame) (toBeClosed bool, err error)
} }

View File

@ -172,7 +172,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
return nil return nil
} }
// recvDataFromRemote deobfuscate the frame and send it to the appropriate stream buffer // recvDataFromRemote deobfuscate the frame and read the Closing field. If it is a closing frame, it writes the frame
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
// stream and then writes to the stream buffer
func (sesh *Session) recvDataFromRemote(data []byte) error { func (sesh *Session) recvDataFromRemote(data []byte) error {
frame, err := sesh.Deobfs(data) frame, err := sesh.Deobfs(data)
if err != nil { if err != nil {
@ -182,10 +184,12 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
if frame.Closing == C_STREAM { if frame.Closing == C_STREAM {
streamI, existing := sesh.streams.Load(frame.StreamID) streamI, existing := sesh.streams.Load(frame.StreamID)
if existing { if existing {
// If the stream has been closed and the current frame is a closing frame, we do noop // DO NOT close the stream (or session below) straight away here because the sequence number of this frame
// hasn't been checked. There may be later data frames which haven't arrived
stream := streamI.(*Stream) stream := streamI.(*Stream)
return stream.writeFrame(*frame) return stream.writeFrame(*frame)
} else { } else {
// If the stream has been closed and the current frame is a closing frame, we do noop
return nil return nil
} }
} else if frame.Closing == C_SESSION { } else if frame.Closing == C_SESSION {

View File

@ -143,7 +143,95 @@ func TestRecvDataFromRemote(t *testing.T) {
t.Errorf("Expecting %x, got %x", testPayload, resultPayload) t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
} }
}) })
}
func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
testPayloadLen := 1024
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
obfsBuf := make([]byte, 17000)
sessionKey := make([]byte, 32)
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
rand.Read(sessionKey)
sesh := MakeSession(0, seshConfigOrdered)
f1 := &Frame{
1,
0,
C_NOOP,
testPayload,
}
// create stream 1
n, _ := sesh.Obfs(f1, obfsBuf)
err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err)
}
s1I, ok := sesh.streams.Load(f1.StreamID)
if !ok {
t.Fatal("failed to fetch stream 1 after receiving it")
}
// create stream 2
f2 := &Frame{
2,
0,
C_NOOP,
testPayload,
}
n, _ = sesh.Obfs(f2, obfsBuf)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 2: %v", err)
}
s2I, ok := sesh.streams.Load(f2.StreamID)
if !ok {
t.Fatal("failed to fetch stream 2 after receiving it")
}
// close stream 1
f1CloseStream := &Frame{
1,
1,
C_STREAM,
testPayload,
}
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
}
_, ok = sesh.streams.Load(f1.StreamID)
if ok {
t.Fatal("stream 1 still exist after receiving stream close")
}
s1 := s1I.(*Stream)
if !s1.isClosed() {
t.Fatal("stream 1 not marked as closed")
}
payloadBuf := make([]byte, testPayloadLen)
_, err = s1.recvBuf.Read(payloadBuf)
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Fatalf("failed to read from stream 1 after closing: %v", err)
}
s2 := s2I.(*Stream)
if s2.isClosed() {
t.Fatal("stream 2 shouldn't be closed")
}
// close stream 1 again
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
}
_, ok = sesh.streams.Load(f1.StreamID)
if ok {
t.Fatal("stream 1 exists after receiving stream close for the second time")
}
} }
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
@ -211,5 +299,4 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
b.SetBytes(int64(n)) b.SetBytes(int64(n))
} }
}) })
} }

View File

@ -59,7 +59,11 @@ func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream {
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
func (s *Stream) writeFrame(frame Frame) error { func (s *Stream) writeFrame(frame Frame) error {
return s.recvBuf.Write(frame) toBeClosed, err := s.recvBuf.Write(frame)
if toBeClosed {
return s.passiveClose()
}
return err
} }
// Read implements io.Read // Read implements io.Read
@ -99,7 +103,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
f := &Frame{ f := &Frame{
StreamID: s.id, StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: 0, Closing: C_NOOP,
Payload: in, Payload: in,
} }

View File

@ -61,39 +61,38 @@ func NewStreamBuffer() *streamBuffer {
// recvNewFrame is a forever running loop which receives frames unordered, // recvNewFrame is a forever running loop which receives frames unordered,
// cache and order them and send them into sortedBufCh // cache and order them and send them into sortedBufCh
func (sb *streamBuffer) Write(f Frame) error { func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) {
sb.recvM.Lock() sb.recvM.Lock()
defer sb.recvM.Unlock() defer sb.recvM.Unlock()
// when there'fs no ooo packages in heap and we receive the next package in order // when there'fs no ooo packages in heap and we receive the next package in order
if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq { if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq {
if f.Closing == 1 { if f.Closing != C_NOOP {
sb.buf.Close() sb.buf.Close()
return nil return true, nil
} else { } else {
sb.buf.Write(f.Payload) sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1 sb.nextRecvSeq += 1
} }
return nil return false, nil
} }
if f.Seq < sb.nextRecvSeq { if f.Seq < sb.nextRecvSeq {
return fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq) return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
} }
heap.Push(&sb.sh, &f) heap.Push(&sb.sh, &f)
// Keep popping from the heap until empty or to the point that the wanted seq was not received // Keep popping from the heap until empty or to the point that the wanted seq was not received
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq { for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
f = *heap.Pop(&sb.sh).(*Frame) f = *heap.Pop(&sb.sh).(*Frame)
if f.Closing == 1 { if f.Closing != C_NOOP {
// empty data indicates closing signal
sb.buf.Close() sb.buf.Close()
return nil return true, nil
} else { } else {
sb.buf.Write(f.Payload) sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1 sb.nextRecvSeq += 1
} }
} }
return nil return false, nil
} }
func (sb *streamBuffer) Read(buf []byte) (int, error) { func (sb *streamBuffer) Read(buf []byte) (int, error) {