mirror of https://github.com/cbeuw/Cloak
Use default hashmap to store streams. Avoid allocating a stream object on receiving every single frame
This commit is contained in:
parent
fd5005db0a
commit
35f41424c9
|
|
@ -63,7 +63,9 @@ type Session struct {
|
|||
|
||||
// atomic
|
||||
activeStreamCount uint32
|
||||
streams sync.Map
|
||||
|
||||
streamsM sync.Mutex
|
||||
streams map[uint32]*Stream
|
||||
|
||||
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
||||
recvFramePool sync.Pool
|
||||
|
|
@ -93,6 +95,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
|
|||
nextStreamID: 1,
|
||||
acceptCh: make(chan *Stream, acceptBacklog),
|
||||
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
|
||||
streams: map[uint32]*Stream{},
|
||||
}
|
||||
sesh.addrs.Store([]net.Addr{nil, nil})
|
||||
|
||||
|
|
@ -149,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
|||
return nil, errNoMultiplex
|
||||
}
|
||||
stream := makeStream(sesh, id)
|
||||
sesh.streams.Store(id, stream)
|
||||
sesh.streamsM.Lock()
|
||||
sesh.streams[id] = stream
|
||||
sesh.streamsM.Unlock()
|
||||
sesh.streamCountIncr()
|
||||
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
||||
return stream, nil
|
||||
|
|
@ -200,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
|||
// We set it as nil to signify that the stream id had existed before.
|
||||
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
|
||||
// if the frame it received was from a new stream or a dying stream whose frame arrived late
|
||||
sesh.streams.Store(s.id, nil)
|
||||
sesh.streamsM.Lock()
|
||||
sesh.streams[s.id] = nil
|
||||
sesh.streamsM.Unlock()
|
||||
if sesh.streamCountDecr() == 0 {
|
||||
if sesh.Singleplex {
|
||||
return sesh.Close()
|
||||
|
|
@ -229,15 +236,19 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
|||
return sesh.passiveClose()
|
||||
}
|
||||
|
||||
newStream := makeStream(sesh, frame.StreamID)
|
||||
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
|
||||
sesh.streamsM.Lock()
|
||||
existingStream, existing := sesh.streams[frame.StreamID]
|
||||
if existing {
|
||||
if existingStreamI == nil {
|
||||
sesh.streamsM.Unlock()
|
||||
if existingStream == nil {
|
||||
// this is when the stream existed before but has since been closed. We do nothing
|
||||
return nil
|
||||
}
|
||||
return existingStreamI.(*Stream).recvFrame(frame)
|
||||
return existingStream.recvFrame(frame)
|
||||
} else {
|
||||
newStream := makeStream(sesh, frame.StreamID)
|
||||
sesh.streams[frame.StreamID] = newStream
|
||||
sesh.streamsM.Unlock()
|
||||
// new stream
|
||||
sesh.streamCountIncr()
|
||||
sesh.acceptCh <- newStream
|
||||
|
|
@ -265,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
|
|||
}
|
||||
sesh.acceptCh <- nil
|
||||
|
||||
sesh.streams.Range(func(key, streamI interface{}) bool {
|
||||
if streamI == nil {
|
||||
return true
|
||||
sesh.streamsM.Lock()
|
||||
for id, stream := range sesh.streams {
|
||||
if stream == nil {
|
||||
continue
|
||||
}
|
||||
stream := streamI.(*Stream)
|
||||
atomic.StoreUint32(&stream.closed, 1)
|
||||
_ = stream.getRecvBuf().Close() // will not block
|
||||
sesh.streams.Delete(key)
|
||||
delete(sesh.streams, id)
|
||||
sesh.streamCountDecr()
|
||||
return true
|
||||
})
|
||||
}
|
||||
sesh.streamsM.Unlock()
|
||||
|
||||
if closeSwitchboard {
|
||||
sesh.sb.closeAll()
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||
}
|
||||
_, ok := sesh.streams.Load(f1.StreamID)
|
||||
sesh.streamsM.Lock()
|
||||
_, ok := sesh.streams[f1.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if !ok {
|
||||
t.Fatal("failed to fetch stream 1 after receiving it")
|
||||
}
|
||||
|
|
@ -132,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||
}
|
||||
s2I, ok := sesh.streams.Load(f2.StreamID)
|
||||
if s2I == nil || !ok {
|
||||
sesh.streamsM.Lock()
|
||||
s2M, ok := sesh.streams[f2.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if s2M == nil || !ok {
|
||||
t.Fatal("failed to fetch stream 2 after receiving it")
|
||||
}
|
||||
if sesh.streamCount() != 2 {
|
||||
|
|
@ -152,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
s1I, _ := sesh.streams.Load(f1.StreamID)
|
||||
if s1I != nil {
|
||||
sesh.streamsM.Lock()
|
||||
s1M, _ := sesh.streams[f1.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if s1M != nil {
|
||||
t.Fatal("stream 1 still exist after receiving stream close")
|
||||
}
|
||||
s1, _ := sesh.Accept()
|
||||
|
|
@ -179,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
||||
}
|
||||
s1I, _ = sesh.streams.Load(f1.StreamID)
|
||||
if s1I != nil {
|
||||
sesh.streamsM.Lock()
|
||||
s1M, _ = sesh.streams[f1.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if s1M != nil {
|
||||
t.Error("stream 1 exists after receiving stream close for the second time")
|
||||
}
|
||||
streamCount := sesh.streamCount()
|
||||
|
|
@ -243,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
_, ok := sesh.streams.Load(f1CloseStream.StreamID)
|
||||
sesh.streamsM.Lock()
|
||||
_, ok := sesh.streams[f1CloseStream.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if !ok {
|
||||
t.Fatal("stream 1 doesn't exist")
|
||||
}
|
||||
|
|
@ -334,12 +344,13 @@ func TestParallelStreams(t *testing.T) {
|
|||
wg.Wait()
|
||||
sc := int(sesh.streamCount())
|
||||
var count int
|
||||
sesh.streams.Range(func(_, s interface{}) bool {
|
||||
sesh.streamsM.Lock()
|
||||
for _, s := range sesh.streams {
|
||||
if s != nil {
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
sesh.streamsM.Unlock()
|
||||
if sc != count {
|
||||
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -167,10 +167,13 @@ func TestStream_Close(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil {
|
||||
sesh.streamsM.Lock()
|
||||
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
|
||||
sesh.streamsM.Unlock()
|
||||
t.Error("stream still exists")
|
||||
return
|
||||
}
|
||||
sesh.streamsM.Unlock()
|
||||
|
||||
_, err = io.ReadFull(stream, readBuf[1:])
|
||||
if err != nil {
|
||||
|
|
@ -242,8 +245,10 @@ func TestStream_Close(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
sI, _ := sesh.streams.Load(stream.(*Stream).id)
|
||||
return sI == nil
|
||||
sesh.streamsM.Lock()
|
||||
s, _ := sesh.streams[stream.(*Stream).id]
|
||||
sesh.streamsM.Unlock()
|
||||
return s == nil
|
||||
}, time.Second, 10*time.Millisecond, "streams still exists")
|
||||
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in New Issue