mirror of https://github.com/cbeuw/Cloak
Improve stream closing logic and add tests
This commit is contained in:
parent
af5c8a381f
commit
d65aee725a
|
|
@ -52,12 +52,12 @@ func (d *datagramBuffer) Read(target []byte) (int, error) {
|
||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *datagramBuffer) Write(f Frame) error {
|
func (d *datagramBuffer) Write(f Frame) (toBeClosed bool, err error) {
|
||||||
d.rwCond.L.Lock()
|
d.rwCond.L.Lock()
|
||||||
defer d.rwCond.L.Unlock()
|
defer d.rwCond.L.Unlock()
|
||||||
for {
|
for {
|
||||||
if atomic.LoadUint32(&d.closed) == 1 {
|
if atomic.LoadUint32(&d.closed) == 1 {
|
||||||
return io.ErrClosedPipe
|
return true, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
if len(d.buf) <= DATAGRAM_NUMBER_LIMIT {
|
if len(d.buf) <= DATAGRAM_NUMBER_LIMIT {
|
||||||
// if d.buf gets too large, write() will panic. We don't want this to happen
|
// if d.buf gets too large, write() will panic. We don't want this to happen
|
||||||
|
|
@ -66,10 +66,10 @@ func (d *datagramBuffer) Write(f Frame) error {
|
||||||
d.rwCond.Wait()
|
d.rwCond.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.Closing == 1 {
|
if f.Closing != C_NOOP {
|
||||||
atomic.StoreUint32(&d.closed, 1)
|
atomic.StoreUint32(&d.closed, 1)
|
||||||
d.rwCond.Broadcast()
|
d.rwCond.Broadcast()
|
||||||
return nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data := make([]byte, len(f.Payload))
|
data := make([]byte, len(f.Payload))
|
||||||
|
|
@ -77,7 +77,7 @@ func (d *datagramBuffer) Write(f Frame) error {
|
||||||
d.buf = append(d.buf, data)
|
d.buf = append(d.buf, data)
|
||||||
// err will always be nil
|
// err will always be nil
|
||||||
d.rwCond.Broadcast()
|
d.rwCond.Broadcast()
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *datagramBuffer) Close() error {
|
func (d *datagramBuffer) Close() error {
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
b := []byte{0x01, 0x02, 0x03}
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
t.Run("simple write", func(t *testing.T) {
|
t.Run("simple write", func(t *testing.T) {
|
||||||
pipe := NewDatagramBuffer()
|
pipe := NewDatagramBuffer()
|
||||||
err := pipe.Write(Frame{Payload: b})
|
_, err := pipe.Write(Frame{Payload: b})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(
|
t.Error(
|
||||||
"expecting", "nil error",
|
"expecting", "nil error",
|
||||||
|
|
@ -23,7 +23,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
|
|
||||||
t.Run("simple read", func(t *testing.T) {
|
t.Run("simple read", func(t *testing.T) {
|
||||||
pipe := NewDatagramBuffer()
|
pipe := NewDatagramBuffer()
|
||||||
_ = pipe.Write(Frame{Payload: b})
|
_, _ = pipe.Write(Frame{Payload: b})
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
if n != len(b) {
|
if n != len(b) {
|
||||||
|
|
@ -56,7 +56,10 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
|
|
||||||
t.Run("writing closing frame", func(t *testing.T) {
|
t.Run("writing closing frame", func(t *testing.T) {
|
||||||
pipe := NewDatagramBuffer()
|
pipe := NewDatagramBuffer()
|
||||||
err := pipe.Write(Frame{Closing: 1})
|
toBeClosed, err := pipe.Write(Frame{Closing: C_STREAM})
|
||||||
|
if !toBeClosed {
|
||||||
|
t.Error("should be to be closed")
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(
|
t.Error(
|
||||||
"expecting", "nil error",
|
"expecting", "nil error",
|
||||||
|
|
|
||||||
|
|
@ -4,5 +4,5 @@ import "io"
|
||||||
|
|
||||||
type recvBuffer interface {
|
type recvBuffer interface {
|
||||||
io.ReadCloser
|
io.ReadCloser
|
||||||
Write(Frame) error
|
Write(Frame) (toBeClosed bool, err error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// recvDataFromRemote deobfuscate the frame and send it to the appropriate stream buffer
|
// recvDataFromRemote deobfuscate the frame and read the Closing field. If it is a closing frame, it writes the frame
|
||||||
|
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
|
||||||
|
// stream and then writes to the stream buffer
|
||||||
func (sesh *Session) recvDataFromRemote(data []byte) error {
|
func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
frame, err := sesh.Deobfs(data)
|
frame, err := sesh.Deobfs(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -182,10 +184,12 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
if frame.Closing == C_STREAM {
|
if frame.Closing == C_STREAM {
|
||||||
streamI, existing := sesh.streams.Load(frame.StreamID)
|
streamI, existing := sesh.streams.Load(frame.StreamID)
|
||||||
if existing {
|
if existing {
|
||||||
// If the stream has been closed and the current frame is a closing frame, we do noop
|
// DO NOT close the stream (or session below) straight away here because the sequence number of this frame
|
||||||
|
// hasn't been checked. There may be later data frames which haven't arrived
|
||||||
stream := streamI.(*Stream)
|
stream := streamI.(*Stream)
|
||||||
return stream.writeFrame(*frame)
|
return stream.writeFrame(*frame)
|
||||||
} else {
|
} else {
|
||||||
|
// If the stream has been closed and the current frame is a closing frame, we do noop
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
} else if frame.Closing == C_SESSION {
|
} else if frame.Closing == C_SESSION {
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,95 @@ func TestRecvDataFromRemote(t *testing.T) {
|
||||||
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
|
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
|
testPayloadLen := 1024
|
||||||
|
testPayload := make([]byte, testPayloadLen)
|
||||||
|
rand.Read(testPayload)
|
||||||
|
obfsBuf := make([]byte, 17000)
|
||||||
|
|
||||||
|
sessionKey := make([]byte, 32)
|
||||||
|
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
|
||||||
|
seshConfigOrdered.Obfuscator = obfuscator
|
||||||
|
|
||||||
|
rand.Read(sessionKey)
|
||||||
|
sesh := MakeSession(0, seshConfigOrdered)
|
||||||
|
|
||||||
|
f1 := &Frame{
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
C_NOOP,
|
||||||
|
testPayload,
|
||||||
|
}
|
||||||
|
// create stream 1
|
||||||
|
n, _ := sesh.Obfs(f1, obfsBuf)
|
||||||
|
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||||
|
}
|
||||||
|
s1I, ok := sesh.streams.Load(f1.StreamID)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("failed to fetch stream 1 after receiving it")
|
||||||
|
}
|
||||||
|
|
||||||
|
// create stream 2
|
||||||
|
f2 := &Frame{
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
C_NOOP,
|
||||||
|
testPayload,
|
||||||
|
}
|
||||||
|
n, _ = sesh.Obfs(f2, obfsBuf)
|
||||||
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||||
|
}
|
||||||
|
s2I, ok := sesh.streams.Load(f2.StreamID)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("failed to fetch stream 2 after receiving it")
|
||||||
|
}
|
||||||
|
|
||||||
|
// close stream 1
|
||||||
|
f1CloseStream := &Frame{
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
C_STREAM,
|
||||||
|
testPayload,
|
||||||
|
}
|
||||||
|
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
|
||||||
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||||
|
}
|
||||||
|
_, ok = sesh.streams.Load(f1.StreamID)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("stream 1 still exist after receiving stream close")
|
||||||
|
}
|
||||||
|
s1 := s1I.(*Stream)
|
||||||
|
if !s1.isClosed() {
|
||||||
|
t.Fatal("stream 1 not marked as closed")
|
||||||
|
}
|
||||||
|
payloadBuf := make([]byte, testPayloadLen)
|
||||||
|
_, err = s1.recvBuf.Read(payloadBuf)
|
||||||
|
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||||
|
t.Fatalf("failed to read from stream 1 after closing: %v", err)
|
||||||
|
}
|
||||||
|
s2 := s2I.(*Stream)
|
||||||
|
if s2.isClosed() {
|
||||||
|
t.Fatal("stream 2 shouldn't be closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// close stream 1 again
|
||||||
|
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
|
||||||
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||||
|
}
|
||||||
|
_, ok = sesh.streams.Load(f1.StreamID)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("stream 1 exists after receiving stream close for the second time")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
||||||
|
|
@ -211,5 +299,4 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,11 @@ func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream {
|
||||||
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
||||||
|
|
||||||
func (s *Stream) writeFrame(frame Frame) error {
|
func (s *Stream) writeFrame(frame Frame) error {
|
||||||
return s.recvBuf.Write(frame)
|
toBeClosed, err := s.recvBuf.Write(frame)
|
||||||
|
if toBeClosed {
|
||||||
|
return s.passiveClose()
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read implements io.Read
|
// Read implements io.Read
|
||||||
|
|
@ -99,7 +103,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
f := &Frame{
|
f := &Frame{
|
||||||
StreamID: s.id,
|
StreamID: s.id,
|
||||||
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
|
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
|
||||||
Closing: 0,
|
Closing: C_NOOP,
|
||||||
Payload: in,
|
Payload: in,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,39 +61,38 @@ func NewStreamBuffer() *streamBuffer {
|
||||||
|
|
||||||
// recvNewFrame is a forever running loop which receives frames unordered,
|
// recvNewFrame is a forever running loop which receives frames unordered,
|
||||||
// cache and order them and send them into sortedBufCh
|
// cache and order them and send them into sortedBufCh
|
||||||
func (sb *streamBuffer) Write(f Frame) error {
|
func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) {
|
||||||
sb.recvM.Lock()
|
sb.recvM.Lock()
|
||||||
defer sb.recvM.Unlock()
|
defer sb.recvM.Unlock()
|
||||||
// when there'fs no ooo packages in heap and we receive the next package in order
|
// when there'fs no ooo packages in heap and we receive the next package in order
|
||||||
if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq {
|
if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq {
|
||||||
if f.Closing == 1 {
|
if f.Closing != C_NOOP {
|
||||||
sb.buf.Close()
|
sb.buf.Close()
|
||||||
return nil
|
return true, nil
|
||||||
} else {
|
} else {
|
||||||
sb.buf.Write(f.Payload)
|
sb.buf.Write(f.Payload)
|
||||||
sb.nextRecvSeq += 1
|
sb.nextRecvSeq += 1
|
||||||
}
|
}
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.Seq < sb.nextRecvSeq {
|
if f.Seq < sb.nextRecvSeq {
|
||||||
return fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
|
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
|
||||||
}
|
}
|
||||||
|
|
||||||
heap.Push(&sb.sh, &f)
|
heap.Push(&sb.sh, &f)
|
||||||
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
||||||
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
|
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
|
||||||
f = *heap.Pop(&sb.sh).(*Frame)
|
f = *heap.Pop(&sb.sh).(*Frame)
|
||||||
if f.Closing == 1 {
|
if f.Closing != C_NOOP {
|
||||||
// empty data indicates closing signal
|
|
||||||
sb.buf.Close()
|
sb.buf.Close()
|
||||||
return nil
|
return true, nil
|
||||||
} else {
|
} else {
|
||||||
sb.buf.Write(f.Payload)
|
sb.buf.Write(f.Payload)
|
||||||
sb.nextRecvSeq += 1
|
sb.nextRecvSeq += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sb *streamBuffer) Read(buf []byte) (int, error) {
|
func (sb *streamBuffer) Read(buf []byte) (int, error) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue