mirror of https://github.com/cbeuw/Cloak
Use default hashmap to store streams. Avoid allocating a stream object on receiving every single frame
This commit is contained in:
parent
fd5005db0a
commit
35f41424c9
|
|
@ -63,7 +63,9 @@ type Session struct {
|
||||||
|
|
||||||
// atomic
|
// atomic
|
||||||
activeStreamCount uint32
|
activeStreamCount uint32
|
||||||
streams sync.Map
|
|
||||||
|
streamsM sync.Mutex
|
||||||
|
streams map[uint32]*Stream
|
||||||
|
|
||||||
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
||||||
recvFramePool sync.Pool
|
recvFramePool sync.Pool
|
||||||
|
|
@ -93,6 +95,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
|
||||||
nextStreamID: 1,
|
nextStreamID: 1,
|
||||||
acceptCh: make(chan *Stream, acceptBacklog),
|
acceptCh: make(chan *Stream, acceptBacklog),
|
||||||
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
|
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
|
||||||
|
streams: map[uint32]*Stream{},
|
||||||
}
|
}
|
||||||
sesh.addrs.Store([]net.Addr{nil, nil})
|
sesh.addrs.Store([]net.Addr{nil, nil})
|
||||||
|
|
||||||
|
|
@ -149,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
||||||
return nil, errNoMultiplex
|
return nil, errNoMultiplex
|
||||||
}
|
}
|
||||||
stream := makeStream(sesh, id)
|
stream := makeStream(sesh, id)
|
||||||
sesh.streams.Store(id, stream)
|
sesh.streamsM.Lock()
|
||||||
|
sesh.streams[id] = stream
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
sesh.streamCountIncr()
|
sesh.streamCountIncr()
|
||||||
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
||||||
return stream, nil
|
return stream, nil
|
||||||
|
|
@ -200,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
// We set it as nil to signify that the stream id had existed before.
|
// We set it as nil to signify that the stream id had existed before.
|
||||||
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
|
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
|
||||||
// if the frame it received was from a new stream or a dying stream whose frame arrived late
|
// if the frame it received was from a new stream or a dying stream whose frame arrived late
|
||||||
sesh.streams.Store(s.id, nil)
|
sesh.streamsM.Lock()
|
||||||
|
sesh.streams[s.id] = nil
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
if sesh.streamCountDecr() == 0 {
|
if sesh.streamCountDecr() == 0 {
|
||||||
if sesh.Singleplex {
|
if sesh.Singleplex {
|
||||||
return sesh.Close()
|
return sesh.Close()
|
||||||
|
|
@ -229,15 +236,19 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
return sesh.passiveClose()
|
return sesh.passiveClose()
|
||||||
}
|
}
|
||||||
|
|
||||||
newStream := makeStream(sesh, frame.StreamID)
|
sesh.streamsM.Lock()
|
||||||
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
|
existingStream, existing := sesh.streams[frame.StreamID]
|
||||||
if existing {
|
if existing {
|
||||||
if existingStreamI == nil {
|
sesh.streamsM.Unlock()
|
||||||
|
if existingStream == nil {
|
||||||
// this is when the stream existed before but has since been closed. We do nothing
|
// this is when the stream existed before but has since been closed. We do nothing
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return existingStreamI.(*Stream).recvFrame(frame)
|
return existingStream.recvFrame(frame)
|
||||||
} else {
|
} else {
|
||||||
|
newStream := makeStream(sesh, frame.StreamID)
|
||||||
|
sesh.streams[frame.StreamID] = newStream
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
// new stream
|
// new stream
|
||||||
sesh.streamCountIncr()
|
sesh.streamCountIncr()
|
||||||
sesh.acceptCh <- newStream
|
sesh.acceptCh <- newStream
|
||||||
|
|
@ -265,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
|
||||||
}
|
}
|
||||||
sesh.acceptCh <- nil
|
sesh.acceptCh <- nil
|
||||||
|
|
||||||
sesh.streams.Range(func(key, streamI interface{}) bool {
|
sesh.streamsM.Lock()
|
||||||
if streamI == nil {
|
for id, stream := range sesh.streams {
|
||||||
return true
|
if stream == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
stream := streamI.(*Stream)
|
|
||||||
atomic.StoreUint32(&stream.closed, 1)
|
atomic.StoreUint32(&stream.closed, 1)
|
||||||
_ = stream.getRecvBuf().Close() // will not block
|
_ = stream.getRecvBuf().Close() // will not block
|
||||||
sesh.streams.Delete(key)
|
delete(sesh.streams, id)
|
||||||
sesh.streamCountDecr()
|
sesh.streamCountDecr()
|
||||||
return true
|
}
|
||||||
})
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
if closeSwitchboard {
|
if closeSwitchboard {
|
||||||
sesh.sb.closeAll()
|
sesh.sb.closeAll()
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||||
}
|
}
|
||||||
_, ok := sesh.streams.Load(f1.StreamID)
|
sesh.streamsM.Lock()
|
||||||
|
_, ok := sesh.streams[f1.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("failed to fetch stream 1 after receiving it")
|
t.Fatal("failed to fetch stream 1 after receiving it")
|
||||||
}
|
}
|
||||||
|
|
@ -132,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||||
}
|
}
|
||||||
s2I, ok := sesh.streams.Load(f2.StreamID)
|
sesh.streamsM.Lock()
|
||||||
if s2I == nil || !ok {
|
s2M, ok := sesh.streams[f2.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
if s2M == nil || !ok {
|
||||||
t.Fatal("failed to fetch stream 2 after receiving it")
|
t.Fatal("failed to fetch stream 2 after receiving it")
|
||||||
}
|
}
|
||||||
if sesh.streamCount() != 2 {
|
if sesh.streamCount() != 2 {
|
||||||
|
|
@ -152,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||||
}
|
}
|
||||||
s1I, _ := sesh.streams.Load(f1.StreamID)
|
sesh.streamsM.Lock()
|
||||||
if s1I != nil {
|
s1M, _ := sesh.streams[f1.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
if s1M != nil {
|
||||||
t.Fatal("stream 1 still exist after receiving stream close")
|
t.Fatal("stream 1 still exist after receiving stream close")
|
||||||
}
|
}
|
||||||
s1, _ := sesh.Accept()
|
s1, _ := sesh.Accept()
|
||||||
|
|
@ -179,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
||||||
}
|
}
|
||||||
s1I, _ = sesh.streams.Load(f1.StreamID)
|
sesh.streamsM.Lock()
|
||||||
if s1I != nil {
|
s1M, _ = sesh.streams[f1.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
if s1M != nil {
|
||||||
t.Error("stream 1 exists after receiving stream close for the second time")
|
t.Error("stream 1 exists after receiving stream close for the second time")
|
||||||
}
|
}
|
||||||
streamCount := sesh.streamCount()
|
streamCount := sesh.streamCount()
|
||||||
|
|
@ -243,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
||||||
}
|
}
|
||||||
_, ok := sesh.streams.Load(f1CloseStream.StreamID)
|
sesh.streamsM.Lock()
|
||||||
|
_, ok := sesh.streams[f1CloseStream.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("stream 1 doesn't exist")
|
t.Fatal("stream 1 doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
@ -334,12 +344,13 @@ func TestParallelStreams(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
sc := int(sesh.streamCount())
|
sc := int(sesh.streamCount())
|
||||||
var count int
|
var count int
|
||||||
sesh.streams.Range(func(_, s interface{}) bool {
|
sesh.streamsM.Lock()
|
||||||
|
for _, s := range sesh.streams {
|
||||||
if s != nil {
|
if s != nil {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
return true
|
}
|
||||||
})
|
sesh.streamsM.Unlock()
|
||||||
if sc != count {
|
if sc != count {
|
||||||
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -167,10 +167,13 @@ func TestStream_Close(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil {
|
sesh.streamsM.Lock()
|
||||||
|
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
t.Error("stream still exists")
|
t.Error("stream still exists")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
_, err = io.ReadFull(stream, readBuf[1:])
|
_, err = io.ReadFull(stream, readBuf[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -242,8 +245,10 @@ func TestStream_Close(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
sI, _ := sesh.streams.Load(stream.(*Stream).id)
|
sesh.streamsM.Lock()
|
||||||
return sI == nil
|
s, _ := sesh.streams[stream.(*Stream).id]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
return s == nil
|
||||||
}, time.Second, 10*time.Millisecond, "streams still exists")
|
}, time.Second, 10*time.Millisecond, "streams still exists")
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue