mirror of https://github.com/cbeuw/Cloak
Improve stream closing logic and add tests
This commit is contained in:
parent
af5c8a381f
commit
d65aee725a
|
|
@ -52,12 +52,12 @@ func (d *datagramBuffer) Read(target []byte) (int, error) {
|
|||
return len(data), nil
|
||||
}
|
||||
|
||||
func (d *datagramBuffer) Write(f Frame) error {
|
||||
func (d *datagramBuffer) Write(f Frame) (toBeClosed bool, err error) {
|
||||
d.rwCond.L.Lock()
|
||||
defer d.rwCond.L.Unlock()
|
||||
for {
|
||||
if atomic.LoadUint32(&d.closed) == 1 {
|
||||
return io.ErrClosedPipe
|
||||
return true, io.ErrClosedPipe
|
||||
}
|
||||
if len(d.buf) <= DATAGRAM_NUMBER_LIMIT {
|
||||
// if d.buf gets too large, write() will panic. We don't want this to happen
|
||||
|
|
@ -66,10 +66,10 @@ func (d *datagramBuffer) Write(f Frame) error {
|
|||
d.rwCond.Wait()
|
||||
}
|
||||
|
||||
if f.Closing == 1 {
|
||||
if f.Closing != C_NOOP {
|
||||
atomic.StoreUint32(&d.closed, 1)
|
||||
d.rwCond.Broadcast()
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
data := make([]byte, len(f.Payload))
|
||||
|
|
@ -77,7 +77,7 @@ func (d *datagramBuffer) Write(f Frame) error {
|
|||
d.buf = append(d.buf, data)
|
||||
// err will always be nil
|
||||
d.rwCond.Broadcast()
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (d *datagramBuffer) Close() error {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
|||
b := []byte{0x01, 0x02, 0x03}
|
||||
t.Run("simple write", func(t *testing.T) {
|
||||
pipe := NewDatagramBuffer()
|
||||
err := pipe.Write(Frame{Payload: b})
|
||||
_, err := pipe.Write(Frame{Payload: b})
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"expecting", "nil error",
|
||||
|
|
@ -23,7 +23,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
|||
|
||||
t.Run("simple read", func(t *testing.T) {
|
||||
pipe := NewDatagramBuffer()
|
||||
_ = pipe.Write(Frame{Payload: b})
|
||||
_, _ = pipe.Write(Frame{Payload: b})
|
||||
b2 := make([]byte, len(b))
|
||||
n, err := pipe.Read(b2)
|
||||
if n != len(b) {
|
||||
|
|
@ -56,7 +56,10 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
|||
|
||||
t.Run("writing closing frame", func(t *testing.T) {
|
||||
pipe := NewDatagramBuffer()
|
||||
err := pipe.Write(Frame{Closing: 1})
|
||||
toBeClosed, err := pipe.Write(Frame{Closing: C_STREAM})
|
||||
if !toBeClosed {
|
||||
t.Error("should be to be closed")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"expecting", "nil error",
|
||||
|
|
|
|||
|
|
@ -4,5 +4,5 @@ import "io"
|
|||
|
||||
type recvBuffer interface {
|
||||
io.ReadCloser
|
||||
Write(Frame) error
|
||||
Write(Frame) (toBeClosed bool, err error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,7 +172,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// recvDataFromRemote deobfuscate the frame and send it to the appropriate stream buffer
|
||||
// recvDataFromRemote deobfuscate the frame and read the Closing field. If it is a closing frame, it writes the frame
|
||||
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
|
||||
// stream and then writes to the stream buffer
|
||||
func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||
frame, err := sesh.Deobfs(data)
|
||||
if err != nil {
|
||||
|
|
@ -182,10 +184,12 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
|||
if frame.Closing == C_STREAM {
|
||||
streamI, existing := sesh.streams.Load(frame.StreamID)
|
||||
if existing {
|
||||
// If the stream has been closed and the current frame is a closing frame, we do noop
|
||||
// DO NOT close the stream (or session below) straight away here because the sequence number of this frame
|
||||
// hasn't been checked. There may be later data frames which haven't arrived
|
||||
stream := streamI.(*Stream)
|
||||
return stream.writeFrame(*frame)
|
||||
} else {
|
||||
// If the stream has been closed and the current frame is a closing frame, we do noop
|
||||
return nil
|
||||
}
|
||||
} else if frame.Closing == C_SESSION {
|
||||
|
|
|
|||
|
|
@ -143,7 +143,95 @@ func TestRecvDataFromRemote(t *testing.T) {
|
|||
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||
testPayloadLen := 1024
|
||||
testPayload := make([]byte, testPayloadLen)
|
||||
rand.Read(testPayload)
|
||||
obfsBuf := make([]byte, 17000)
|
||||
|
||||
sessionKey := make([]byte, 32)
|
||||
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
|
||||
seshConfigOrdered.Obfuscator = obfuscator
|
||||
|
||||
rand.Read(sessionKey)
|
||||
sesh := MakeSession(0, seshConfigOrdered)
|
||||
|
||||
f1 := &Frame{
|
||||
1,
|
||||
0,
|
||||
C_NOOP,
|
||||
testPayload,
|
||||
}
|
||||
// create stream 1
|
||||
n, _ := sesh.Obfs(f1, obfsBuf)
|
||||
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||
}
|
||||
s1I, ok := sesh.streams.Load(f1.StreamID)
|
||||
if !ok {
|
||||
t.Fatal("failed to fetch stream 1 after receiving it")
|
||||
}
|
||||
|
||||
// create stream 2
|
||||
f2 := &Frame{
|
||||
2,
|
||||
0,
|
||||
C_NOOP,
|
||||
testPayload,
|
||||
}
|
||||
n, _ = sesh.Obfs(f2, obfsBuf)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||
}
|
||||
s2I, ok := sesh.streams.Load(f2.StreamID)
|
||||
if !ok {
|
||||
t.Fatal("failed to fetch stream 2 after receiving it")
|
||||
}
|
||||
|
||||
// close stream 1
|
||||
f1CloseStream := &Frame{
|
||||
1,
|
||||
1,
|
||||
C_STREAM,
|
||||
testPayload,
|
||||
}
|
||||
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
_, ok = sesh.streams.Load(f1.StreamID)
|
||||
if ok {
|
||||
t.Fatal("stream 1 still exist after receiving stream close")
|
||||
}
|
||||
s1 := s1I.(*Stream)
|
||||
if !s1.isClosed() {
|
||||
t.Fatal("stream 1 not marked as closed")
|
||||
}
|
||||
payloadBuf := make([]byte, testPayloadLen)
|
||||
_, err = s1.recvBuf.Read(payloadBuf)
|
||||
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||
t.Fatalf("failed to read from stream 1 after closing: %v", err)
|
||||
}
|
||||
s2 := s2I.(*Stream)
|
||||
if s2.isClosed() {
|
||||
t.Fatal("stream 2 shouldn't be closed")
|
||||
}
|
||||
|
||||
// close stream 1 again
|
||||
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
_, ok = sesh.streams.Load(f1.StreamID)
|
||||
if ok {
|
||||
t.Fatal("stream 1 exists after receiving stream close for the second time")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
||||
|
|
@ -211,5 +299,4 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
|||
b.SetBytes(int64(n))
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,11 @@ func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream {
|
|||
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
||||
|
||||
func (s *Stream) writeFrame(frame Frame) error {
|
||||
return s.recvBuf.Write(frame)
|
||||
toBeClosed, err := s.recvBuf.Write(frame)
|
||||
if toBeClosed {
|
||||
return s.passiveClose()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Read implements io.Read
|
||||
|
|
@ -99,7 +103,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
|||
f := &Frame{
|
||||
StreamID: s.id,
|
||||
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
|
||||
Closing: 0,
|
||||
Closing: C_NOOP,
|
||||
Payload: in,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,39 +61,38 @@ func NewStreamBuffer() *streamBuffer {
|
|||
|
||||
// recvNewFrame is a forever running loop which receives frames unordered,
|
||||
// cache and order them and send them into sortedBufCh
|
||||
func (sb *streamBuffer) Write(f Frame) error {
|
||||
func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) {
|
||||
sb.recvM.Lock()
|
||||
defer sb.recvM.Unlock()
|
||||
// when there'fs no ooo packages in heap and we receive the next package in order
|
||||
if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq {
|
||||
if f.Closing == 1 {
|
||||
if f.Closing != C_NOOP {
|
||||
sb.buf.Close()
|
||||
return nil
|
||||
return true, nil
|
||||
} else {
|
||||
sb.buf.Write(f.Payload)
|
||||
sb.nextRecvSeq += 1
|
||||
}
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if f.Seq < sb.nextRecvSeq {
|
||||
return fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
|
||||
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
|
||||
}
|
||||
|
||||
heap.Push(&sb.sh, &f)
|
||||
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
||||
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
|
||||
f = *heap.Pop(&sb.sh).(*Frame)
|
||||
if f.Closing == 1 {
|
||||
// empty data indicates closing signal
|
||||
if f.Closing != C_NOOP {
|
||||
sb.buf.Close()
|
||||
return nil
|
||||
return true, nil
|
||||
} else {
|
||||
sb.buf.Write(f.Payload)
|
||||
sb.nextRecvSeq += 1
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (sb *streamBuffer) Read(buf []byte) (int, error) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue