diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 47434ea..0c8b9a7 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -51,7 +51,9 @@ type Session struct { // atomic nextStreamID uint32 - streams sync.Map + // atomic + activeStreamCount uint32 + streams sync.Map // Switchboard manages all connections to remote sb *switchboard @@ -94,6 +96,16 @@ func MakeSession(id uint32, config *SessionConfig) *Session { return sesh } +func (sesh *Session) streamCountIncr() uint32 { + return atomic.AddUint32(&sesh.activeStreamCount, 1) +} +func (sesh *Session) streamCountDecr() uint32 { + return atomic.AddUint32(&sesh.activeStreamCount, ^uint32(0)) +} +func (sesh *Session) streamCount() uint32 { + return atomic.LoadUint32(&sesh.activeStreamCount) +} + func (sesh *Session) AddConnection(conn net.Conn) { sesh.sb.addConn(conn) addrs := []net.Addr{conn.LocalAddr(), conn.RemoteAddr()} @@ -112,6 +124,7 @@ func (sesh *Session) OpenStream() (*Stream, error) { } stream := makeStream(sesh, id, connId) sesh.streams.Store(id, stream) + sesh.streamCountIncr() log.Tracef("stream %v of session %v opened", id, sesh.id) return stream, nil } @@ -159,14 +172,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { log.Tracef("stream %v passively closed", s.id) } - sesh.streams.Delete(s.id) - var count int - sesh.streams.Range(func(_, _ interface{}) bool { - count += 1 - return true - }) - if count == 0 { - log.Tracef("session %v has no active stream left", sesh.id) + sesh.streams.Store(s.id, nil) + if sesh.streamCountDecr() == 0 { + log.Debugf("session %v has no active stream left", sesh.id) go sesh.timeoutAfter(30 * time.Second) } return nil @@ -181,32 +189,25 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } - if frame.Closing == C_STREAM { - streamI, existing := sesh.streams.Load(frame.StreamID) - if existing { - // DO NOT close the stream straight away here because the sequence number of this frame - // hasn't been checked. There may be later data frames which haven't arrived - stream := streamI.(*Stream) - return stream.writeFrame(*frame) - } else { - // If the stream has been closed and the current frame is a closing frame, we do noop - return nil - } - } else if frame.Closing == C_SESSION { - // Closing session + if frame.Closing == C_SESSION { sesh.SetTerminalMsg("Received a closing notification frame") return sesh.passiveClose() - } else { - connId, _, _ := sesh.sb.pickRandConn() - // we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write - newStream := makeStream(sesh, frame.StreamID, connId) - existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream) - if existing { - return existingStreamI.(*Stream).writeFrame(*frame) - } else { - sesh.acceptCh <- newStream - return newStream.writeFrame(*frame) + } + + connId, _, _ := sesh.sb.pickRandConn() + // we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write + newStream := makeStream(sesh, frame.StreamID, connId) + existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream) + if existing { + if existingStreamI == nil { + // this is when the stream existed before but has since been closed. We do nothing + return nil } + return existingStreamI.(*Stream).writeFrame(*frame) + } else { + sesh.streamCountIncr() + sesh.acceptCh <- newStream + return newStream.writeFrame(*frame) } } @@ -232,10 +233,14 @@ func (sesh *Session) passiveClose() error { sesh.acceptCh <- nil sesh.streams.Range(func(key, streamI interface{}) bool { + if streamI == nil { + return true + } stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) _ = stream.recvBuf.Close() // will not block sesh.streams.Delete(key) + sesh.streamCountDecr() return true }) @@ -261,10 +266,14 @@ func (sesh *Session) Close() error { sesh.acceptCh <- nil sesh.streams.Range(func(key, streamI interface{}) bool { + if streamI == nil { + return true + } stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) _ = stream.recvBuf.Close() // will not block sesh.streams.Delete(key) + sesh.streamCountDecr() return true }) @@ -296,12 +305,8 @@ func (sesh *Session) IsClosed() bool { func (sesh *Session) timeoutAfter(to time.Duration) { time.Sleep(to) - var count int - sesh.streams.Range(func(_, _ interface{}) bool { - count += 1 - return true - }) - if count == 0 && !sesh.IsClosed() { + + if sesh.streamCount() == 0 && !sesh.IsClosed() { sesh.SetTerminalMsg("timeout") sesh.Close() } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index 595c0d8..7aa6d5d 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -4,6 +4,8 @@ import ( "bytes" "github.com/cbeuw/Cloak/internal/util" "math/rand" + "strconv" + "sync/atomic" "testing" ) @@ -174,6 +176,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if !ok { t.Fatal("failed to fetch stream 1 after receiving it") } + if sesh.streamCount() != 1 { + t.Error("stream count isn't 1") + } // create stream 2 f2 := &Frame{ @@ -188,9 +193,12 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { t.Fatalf("receiving normal frame for stream 2: %v", err) } s2I, ok := sesh.streams.Load(f2.StreamID) - if !ok { + if s2I == nil || !ok { t.Fatal("failed to fetch stream 2 after receiving it") } + if sesh.streamCount() != 2 { + t.Error("stream count isn't 2") + } // close stream 1 f1CloseStream := &Frame{ @@ -204,33 +212,189 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) } - _, ok = sesh.streams.Load(f1.StreamID) - if ok { + s1I, _ = sesh.streams.Load(f1.StreamID) + if s1I != nil { t.Fatal("stream 1 still exist after receiving stream close") } - s1 := s1I.(*Stream) - if !s1.isClosed() { + s1, _ := sesh.Accept() + if !s1.(*Stream).isClosed() { t.Fatal("stream 1 not marked as closed") } payloadBuf := make([]byte, testPayloadLen) - _, err = s1.recvBuf.Read(payloadBuf) + _, err = s1.Read(payloadBuf) if err != nil || !bytes.Equal(payloadBuf, testPayload) { t.Fatalf("failed to read from stream 1 after closing: %v", err) } - s2 := s2I.(*Stream) - if s2.isClosed() { + s2, _ := sesh.Accept() + if s2.(*Stream).isClosed() { t.Fatal("stream 2 shouldn't be closed") } + if sesh.streamCount() != 1 { + t.Error("stream count isn't 1 after stream 1 closed") + } // close stream 1 again n, _ = sesh.Obfs(f1CloseStream, obfsBuf) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { - t.Fatalf("receiving stream closing frame for stream 1: %v", err) + t.Fatalf("receiving stream closing frame for stream 1 %v", err) } - _, ok = sesh.streams.Load(f1.StreamID) - if ok { - t.Fatal("stream 1 exists after receiving stream close for the second time") + s1I, _ = sesh.streams.Load(f1.StreamID) + if s1I != nil { + t.Error("stream 1 exists after receiving stream close for the second time") + } + if sesh.streamCount() != 1 { + t.Error("stream count isn't 1 after stream 1 closed twice") + } + + // close session + fCloseSession := &Frame{ + StreamID: 0xffffffff, + Seq: 0, + Closing: C_SESSION, + Payload: testPayload, + } + n, _ = sesh.Obfs(fCloseSession, obfsBuf) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Fatalf("receiving session closing frame: %v", err) + } + if !sesh.IsClosed() { + t.Error("session not closed after receiving signal") + } + if !s2.(*Stream).isClosed() { + t.Error("stream 2 isn't closed after session closed") + } + if _, err := s2.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) { + t.Error("failed to read from stream 2 after session closed") + } + if _, err := s2.Write(testPayload); err == nil { + t.Error("can still write to stream 2 after session closed") + } + if sesh.streamCount() != 0 { + t.Error("stream count isn't 0 after session closed") + } +} + +func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { + // Tests for when the closing frame of a stream is received first before any data frame + testPayloadLen := 1024 + testPayload := make([]byte, testPayloadLen) + rand.Read(testPayload) + obfsBuf := make([]byte, 17000) + + sessionKey := make([]byte, 32) + obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true) + seshConfigOrdered.Obfuscator = obfuscator + + rand.Read(sessionKey) + sesh := MakeSession(0, seshConfigOrdered) + + // receive stream 1 closing first + f1CloseStream := &Frame{ + 1, + 1, + C_STREAM, + testPayload, + } + n, _ := sesh.Obfs(f1CloseStream, obfsBuf) + err := sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) + } + _, ok := sesh.streams.Load(f1CloseStream.StreamID) + if !ok { + t.Fatal("stream 1 doesn't exist") + } + if sesh.streamCount() != 1 { + t.Error("stream count isn't 1 after stream 1 received") + } + + // receive data frame of stream 1 after receiving the closing frame + f1 := &Frame{ + 1, + 0, + C_NOOP, + testPayload, + } + n, _ = sesh.Obfs(f1, obfsBuf) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Fatalf("receiving normal frame for stream 1: %v", err) + } + s1, err := sesh.Accept() + if err != nil { + t.Fatal("failed to accept stream 1 after receiving it") + } + payloadBuf := make([]byte, testPayloadLen) + if _, err := s1.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) { + t.Error("failed to read from steam 1") + } + if !s1.(*Stream).isClosed() { + t.Error("s1 isn't closed") + } + if sesh.streamCount() != 0 { + t.Error("stream count isn't 0 after stream 1 closed") + } +} + +func TestParallel(t *testing.T) { + rand.Seed(0) + + sessionKey := make([]byte, 32) + obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true) + seshConfigOrdered.Obfuscator = obfuscator + rand.Read(sessionKey) + sesh := MakeSession(0, seshConfigOrdered) + + numStreams := 10 + seqs := make([]*uint64, numStreams) + for i, _ := range seqs { + seqs[i] = new(uint64) + } + randFrame := func() *Frame { + id := rand.Intn(numStreams) + return &Frame{ + uint32(id), + atomic.AddUint64(seqs[id], 1) - 1, + uint8(rand.Intn(2)), + []byte{1, 2, 3, 4}, + } + } + + numOfTests := 100 + tests := make([]struct { + name string + frame *Frame + }, numOfTests) + for i, _ := range tests { + tests[i].name = strconv.Itoa(i) + tests[i].frame = randFrame() + } + + for _, tc := range tests { + go func(frame *Frame) { + data := make([]byte, 1000) + n, _ := sesh.Obfs(frame, data) + data = data[0:n] + + err := sesh.recvDataFromRemote(data) + if err != nil { + t.Error(err) + } + }(tc.frame) + } + + var count int + sesh.streams.Range(func(_, s interface{}) bool { + if s != nil { + count++ + } + return true + }) + sc := int(sesh.streamCount()) + if sc != count { + t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) } } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index ee4d939..25ecf8c 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -157,7 +157,7 @@ func TestStream_Close(t *testing.T) { return } - if _, ok := sesh.streams.Load(streamID); ok { + if sI, _ := sesh.streams.Load(streamID); sI != nil { t.Error("stream still exists") return }