diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 92b5cf8..6f165b0 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -63,7 +63,9 @@ type Session struct { // atomic activeStreamCount uint32 - streams sync.Map + + streamsM sync.Mutex + streams map[uint32]*Stream // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame recvFramePool sync.Pool @@ -93,6 +95,7 @@ func MakeSession(id uint32, config SessionConfig) *Session { nextStreamID: 1, acceptCh: make(chan *Stream, acceptBacklog), recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }}, + streams: map[uint32]*Stream{}, } sesh.addrs.Store([]net.Addr{nil, nil}) @@ -149,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) { return nil, errNoMultiplex } stream := makeStream(sesh, id) - sesh.streams.Store(id, stream) + sesh.streamsM.Lock() + sesh.streams[id] = stream + sesh.streamsM.Unlock() sesh.streamCountIncr() log.Tracef("stream %v of session %v opened", id, sesh.id) return stream, nil @@ -200,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { // We set it as nil to signify that the stream id had existed before. // If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell // if the frame it received was from a new stream or a dying stream whose frame arrived late - sesh.streams.Store(s.id, nil) + sesh.streamsM.Lock() + sesh.streams[s.id] = nil + sesh.streamsM.Unlock() if sesh.streamCountDecr() == 0 { if sesh.Singleplex { return sesh.Close() @@ -229,15 +236,19 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { return sesh.passiveClose() } - newStream := makeStream(sesh, frame.StreamID) - existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream) + sesh.streamsM.Lock() + existingStream, existing := sesh.streams[frame.StreamID] if existing { - if existingStreamI == nil { + sesh.streamsM.Unlock() + if existingStream == nil { // this is when the stream existed before but has since been closed. We do nothing return nil } - return existingStreamI.(*Stream).recvFrame(frame) + return existingStream.recvFrame(frame) } else { + newStream := makeStream(sesh, frame.StreamID) + sesh.streams[frame.StreamID] = newStream + sesh.streamsM.Unlock() // new stream sesh.streamCountIncr() sesh.acceptCh <- newStream @@ -265,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error { } sesh.acceptCh <- nil - sesh.streams.Range(func(key, streamI interface{}) bool { - if streamI == nil { - return true + sesh.streamsM.Lock() + for id, stream := range sesh.streams { + if stream == nil { + continue } - stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) _ = stream.getRecvBuf().Close() // will not block - sesh.streams.Delete(key) + delete(sesh.streams, id) sesh.streamCountDecr() - return true - }) + } + sesh.streamsM.Unlock() if closeSwitchboard { sesh.sb.closeAll() diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index 31cee76..f4b32bb 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -112,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) } - _, ok := sesh.streams.Load(f1.StreamID) + sesh.streamsM.Lock() + _, ok := sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() if !ok { t.Fatal("failed to fetch stream 1 after receiving it") } @@ -132,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving normal frame for stream 2: %v", err) } - s2I, ok := sesh.streams.Load(f2.StreamID) - if s2I == nil || !ok { + sesh.streamsM.Lock() + s2M, ok := sesh.streams[f2.StreamID] + sesh.streamsM.Unlock() + if s2M == nil || !ok { t.Fatal("failed to fetch stream 2 after receiving it") } if sesh.streamCount() != 2 { @@ -152,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) } - s1I, _ := sesh.streams.Load(f1.StreamID) - if s1I != nil { + sesh.streamsM.Lock() + s1M, _ := sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() + if s1M != nil { t.Fatal("stream 1 still exist after receiving stream close") } s1, _ := sesh.Accept() @@ -179,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving stream closing frame for stream 1 %v", err) } - s1I, _ = sesh.streams.Load(f1.StreamID) - if s1I != nil { + sesh.streamsM.Lock() + s1M, _ = sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() + if s1M != nil { t.Error("stream 1 exists after receiving stream close for the second time") } streamCount := sesh.streamCount() @@ -243,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { if err != nil { t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) } - _, ok := sesh.streams.Load(f1CloseStream.StreamID) + sesh.streamsM.Lock() + _, ok := sesh.streams[f1CloseStream.StreamID] + sesh.streamsM.Unlock() if !ok { t.Fatal("stream 1 doesn't exist") } @@ -334,12 +344,13 @@ func TestParallelStreams(t *testing.T) { wg.Wait() sc := int(sesh.streamCount()) var count int - sesh.streams.Range(func(_, s interface{}) bool { + sesh.streamsM.Lock() + for _, s := range sesh.streams { if s != nil { count++ } - return true - }) + } + sesh.streamsM.Unlock() if sc != count { t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 84e8982..c0b86fb 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -167,10 +167,13 @@ func TestStream_Close(t *testing.T) { return } - if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil { + sesh.streamsM.Lock() + if s, _ := sesh.streams[stream.(*Stream).id]; s != nil { + sesh.streamsM.Unlock() t.Error("stream still exists") return } + sesh.streamsM.Unlock() _, err = io.ReadFull(stream, readBuf[1:]) if err != nil { @@ -242,8 +245,10 @@ func TestStream_Close(t *testing.T) { } assert.Eventually(t, func() bool { - sI, _ := sesh.streams.Load(stream.(*Stream).id) - return sI == nil + sesh.streamsM.Lock() + s, _ := sesh.streams[stream.(*Stream).id] + sesh.streamsM.Unlock() + return s == nil }, time.Second, 10*time.Millisecond, "streams still exists") })