Use default hashmap to store streams. Avoid allocating a stream object on receiving every single frame

This commit is contained in:
Andy Wang 2020-12-22 20:07:17 +00:00
parent fd5005db0a
commit 35f41424c9
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
3 changed files with 55 additions and 28 deletions

View File

@ -63,7 +63,9 @@ type Session struct {
// atomic // atomic
activeStreamCount uint32 activeStreamCount uint32
streams sync.Map
streamsM sync.Mutex
streams map[uint32]*Stream
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool recvFramePool sync.Pool
@ -93,6 +95,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
nextStreamID: 1, nextStreamID: 1,
acceptCh: make(chan *Stream, acceptBacklog), acceptCh: make(chan *Stream, acceptBacklog),
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }}, recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
streams: map[uint32]*Stream{},
} }
sesh.addrs.Store([]net.Addr{nil, nil}) sesh.addrs.Store([]net.Addr{nil, nil})
@ -149,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) {
return nil, errNoMultiplex return nil, errNoMultiplex
} }
stream := makeStream(sesh, id) stream := makeStream(sesh, id)
sesh.streams.Store(id, stream) sesh.streamsM.Lock()
sesh.streams[id] = stream
sesh.streamsM.Unlock()
sesh.streamCountIncr() sesh.streamCountIncr()
log.Tracef("stream %v of session %v opened", id, sesh.id) log.Tracef("stream %v of session %v opened", id, sesh.id)
return stream, nil return stream, nil
@ -200,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
// We set it as nil to signify that the stream id had existed before. // We set it as nil to signify that the stream id had existed before.
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell // If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
// if the frame it received was from a new stream or a dying stream whose frame arrived late // if the frame it received was from a new stream or a dying stream whose frame arrived late
sesh.streams.Store(s.id, nil) sesh.streamsM.Lock()
sesh.streams[s.id] = nil
sesh.streamsM.Unlock()
if sesh.streamCountDecr() == 0 { if sesh.streamCountDecr() == 0 {
if sesh.Singleplex { if sesh.Singleplex {
return sesh.Close() return sesh.Close()
@ -229,15 +236,19 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
return sesh.passiveClose() return sesh.passiveClose()
} }
newStream := makeStream(sesh, frame.StreamID) sesh.streamsM.Lock()
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream) existingStream, existing := sesh.streams[frame.StreamID]
if existing { if existing {
if existingStreamI == nil { sesh.streamsM.Unlock()
if existingStream == nil {
// this is when the stream existed before but has since been closed. We do nothing // this is when the stream existed before but has since been closed. We do nothing
return nil return nil
} }
return existingStreamI.(*Stream).recvFrame(frame) return existingStream.recvFrame(frame)
} else { } else {
newStream := makeStream(sesh, frame.StreamID)
sesh.streams[frame.StreamID] = newStream
sesh.streamsM.Unlock()
// new stream // new stream
sesh.streamCountIncr() sesh.streamCountIncr()
sesh.acceptCh <- newStream sesh.acceptCh <- newStream
@ -265,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
} }
sesh.acceptCh <- nil sesh.acceptCh <- nil
sesh.streams.Range(func(key, streamI interface{}) bool { sesh.streamsM.Lock()
if streamI == nil { for id, stream := range sesh.streams {
return true if stream == nil {
continue
} }
stream := streamI.(*Stream)
atomic.StoreUint32(&stream.closed, 1) atomic.StoreUint32(&stream.closed, 1)
_ = stream.getRecvBuf().Close() // will not block _ = stream.getRecvBuf().Close() // will not block
sesh.streams.Delete(key) delete(sesh.streams, id)
sesh.streamCountDecr() sesh.streamCountDecr()
return true }
}) sesh.streamsM.Unlock()
if closeSwitchboard { if closeSwitchboard {
sesh.sb.closeAll() sesh.sb.closeAll()

View File

@ -112,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err) t.Fatalf("receiving normal frame for stream 1: %v", err)
} }
_, ok := sesh.streams.Load(f1.StreamID) sesh.streamsM.Lock()
_, ok := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if !ok { if !ok {
t.Fatal("failed to fetch stream 1 after receiving it") t.Fatal("failed to fetch stream 1 after receiving it")
} }
@ -132,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving normal frame for stream 2: %v", err) t.Fatalf("receiving normal frame for stream 2: %v", err)
} }
s2I, ok := sesh.streams.Load(f2.StreamID) sesh.streamsM.Lock()
if s2I == nil || !ok { s2M, ok := sesh.streams[f2.StreamID]
sesh.streamsM.Unlock()
if s2M == nil || !ok {
t.Fatal("failed to fetch stream 2 after receiving it") t.Fatal("failed to fetch stream 2 after receiving it")
} }
if sesh.streamCount() != 2 { if sesh.streamCount() != 2 {
@ -152,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err) t.Fatalf("receiving stream closing frame for stream 1: %v", err)
} }
s1I, _ := sesh.streams.Load(f1.StreamID) sesh.streamsM.Lock()
if s1I != nil { s1M, _ := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Fatal("stream 1 still exist after receiving stream close") t.Fatal("stream 1 still exist after receiving stream close")
} }
s1, _ := sesh.Accept() s1, _ := sesh.Accept()
@ -179,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving stream closing frame for stream 1 %v", err) t.Fatalf("receiving stream closing frame for stream 1 %v", err)
} }
s1I, _ = sesh.streams.Load(f1.StreamID) sesh.streamsM.Lock()
if s1I != nil { s1M, _ = sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Error("stream 1 exists after receiving stream close for the second time") t.Error("stream 1 exists after receiving stream close for the second time")
} }
streamCount := sesh.streamCount() streamCount := sesh.streamCount()
@ -243,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
} }
_, ok := sesh.streams.Load(f1CloseStream.StreamID) sesh.streamsM.Lock()
_, ok := sesh.streams[f1CloseStream.StreamID]
sesh.streamsM.Unlock()
if !ok { if !ok {
t.Fatal("stream 1 doesn't exist") t.Fatal("stream 1 doesn't exist")
} }
@ -334,12 +344,13 @@ func TestParallelStreams(t *testing.T) {
wg.Wait() wg.Wait()
sc := int(sesh.streamCount()) sc := int(sesh.streamCount())
var count int var count int
sesh.streams.Range(func(_, s interface{}) bool { sesh.streamsM.Lock()
for _, s := range sesh.streams {
if s != nil { if s != nil {
count++ count++
} }
return true }
}) sesh.streamsM.Unlock()
if sc != count { if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
} }

View File

@ -167,10 +167,13 @@ func TestStream_Close(t *testing.T) {
return return
} }
if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil { sesh.streamsM.Lock()
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
sesh.streamsM.Unlock()
t.Error("stream still exists") t.Error("stream still exists")
return return
} }
sesh.streamsM.Unlock()
_, err = io.ReadFull(stream, readBuf[1:]) _, err = io.ReadFull(stream, readBuf[1:])
if err != nil { if err != nil {
@ -242,8 +245,10 @@ func TestStream_Close(t *testing.T) {
} }
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
sI, _ := sesh.streams.Load(stream.(*Stream).id) sesh.streamsM.Lock()
return sI == nil s, _ := sesh.streams[stream.(*Stream).id]
sesh.streamsM.Unlock()
return s == nil
}, time.Second, 10*time.Millisecond, "streams still exists") }, time.Second, 10*time.Millisecond, "streams still exists")
}) })