Make and add stream upon reception of all new streamIDs even if they are closing

This commit is contained in:
Andy Wang 2020-01-23 20:30:31 +00:00
parent 39e54bae6c
commit 6f34229aa0
3 changed files with 220 additions and 51 deletions

View File

@ -51,7 +51,9 @@ type Session struct {
// atomic
nextStreamID uint32
streams sync.Map
// atomic
activeStreamCount uint32
streams sync.Map
// Switchboard manages all connections to remote
sb *switchboard
@ -94,6 +96,16 @@ func MakeSession(id uint32, config *SessionConfig) *Session {
return sesh
}
func (sesh *Session) streamCountIncr() uint32 {
return atomic.AddUint32(&sesh.activeStreamCount, 1)
}
func (sesh *Session) streamCountDecr() uint32 {
return atomic.AddUint32(&sesh.activeStreamCount, ^uint32(0))
}
func (sesh *Session) streamCount() uint32 {
return atomic.LoadUint32(&sesh.activeStreamCount)
}
func (sesh *Session) AddConnection(conn net.Conn) {
sesh.sb.addConn(conn)
addrs := []net.Addr{conn.LocalAddr(), conn.RemoteAddr()}
@ -112,6 +124,7 @@ func (sesh *Session) OpenStream() (*Stream, error) {
}
stream := makeStream(sesh, id, connId)
sesh.streams.Store(id, stream)
sesh.streamCountIncr()
log.Tracef("stream %v of session %v opened", id, sesh.id)
return stream, nil
}
@ -159,14 +172,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
log.Tracef("stream %v passively closed", s.id)
}
sesh.streams.Delete(s.id)
var count int
sesh.streams.Range(func(_, _ interface{}) bool {
count += 1
return true
})
if count == 0 {
log.Tracef("session %v has no active stream left", sesh.id)
sesh.streams.Store(s.id, nil)
if sesh.streamCountDecr() == 0 {
log.Debugf("session %v has no active stream left", sesh.id)
go sesh.timeoutAfter(30 * time.Second)
}
return nil
@ -181,32 +189,25 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
}
if frame.Closing == C_STREAM {
streamI, existing := sesh.streams.Load(frame.StreamID)
if existing {
// DO NOT close the stream straight away here because the sequence number of this frame
// hasn't been checked. There may be later data frames which haven't arrived
stream := streamI.(*Stream)
return stream.writeFrame(*frame)
} else {
// If the stream has been closed and the current frame is a closing frame, we do noop
return nil
}
} else if frame.Closing == C_SESSION {
// Closing session
if frame.Closing == C_SESSION {
sesh.SetTerminalMsg("Received a closing notification frame")
return sesh.passiveClose()
} else {
connId, _, _ := sesh.sb.pickRandConn()
// we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write
newStream := makeStream(sesh, frame.StreamID, connId)
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
if existing {
return existingStreamI.(*Stream).writeFrame(*frame)
} else {
sesh.acceptCh <- newStream
return newStream.writeFrame(*frame)
}
connId, _, _ := sesh.sb.pickRandConn()
// we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write
newStream := makeStream(sesh, frame.StreamID, connId)
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
if existing {
if existingStreamI == nil {
// this is when the stream existed before but has since been closed. We do nothing
return nil
}
return existingStreamI.(*Stream).writeFrame(*frame)
} else {
sesh.streamCountIncr()
sesh.acceptCh <- newStream
return newStream.writeFrame(*frame)
}
}
@ -232,10 +233,14 @@ func (sesh *Session) passiveClose() error {
sesh.acceptCh <- nil
sesh.streams.Range(func(key, streamI interface{}) bool {
if streamI == nil {
return true
}
stream := streamI.(*Stream)
atomic.StoreUint32(&stream.closed, 1)
_ = stream.recvBuf.Close() // will not block
sesh.streams.Delete(key)
sesh.streamCountDecr()
return true
})
@ -261,10 +266,14 @@ func (sesh *Session) Close() error {
sesh.acceptCh <- nil
sesh.streams.Range(func(key, streamI interface{}) bool {
if streamI == nil {
return true
}
stream := streamI.(*Stream)
atomic.StoreUint32(&stream.closed, 1)
_ = stream.recvBuf.Close() // will not block
sesh.streams.Delete(key)
sesh.streamCountDecr()
return true
})
@ -296,12 +305,8 @@ func (sesh *Session) IsClosed() bool {
func (sesh *Session) timeoutAfter(to time.Duration) {
time.Sleep(to)
var count int
sesh.streams.Range(func(_, _ interface{}) bool {
count += 1
return true
})
if count == 0 && !sesh.IsClosed() {
if sesh.streamCount() == 0 && !sesh.IsClosed() {
sesh.SetTerminalMsg("timeout")
sesh.Close()
}

View File

@ -4,6 +4,8 @@ import (
"bytes"
"github.com/cbeuw/Cloak/internal/util"
"math/rand"
"strconv"
"sync/atomic"
"testing"
)
@ -174,6 +176,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if !ok {
t.Fatal("failed to fetch stream 1 after receiving it")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1")
}
// create stream 2
f2 := &Frame{
@ -188,9 +193,12 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
t.Fatalf("receiving normal frame for stream 2: %v", err)
}
s2I, ok := sesh.streams.Load(f2.StreamID)
if !ok {
if s2I == nil || !ok {
t.Fatal("failed to fetch stream 2 after receiving it")
}
if sesh.streamCount() != 2 {
t.Error("stream count isn't 2")
}
// close stream 1
f1CloseStream := &Frame{
@ -204,33 +212,189 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
}
_, ok = sesh.streams.Load(f1.StreamID)
if ok {
s1I, _ = sesh.streams.Load(f1.StreamID)
if s1I != nil {
t.Fatal("stream 1 still exist after receiving stream close")
}
s1 := s1I.(*Stream)
if !s1.isClosed() {
s1, _ := sesh.Accept()
if !s1.(*Stream).isClosed() {
t.Fatal("stream 1 not marked as closed")
}
payloadBuf := make([]byte, testPayloadLen)
_, err = s1.recvBuf.Read(payloadBuf)
_, err = s1.Read(payloadBuf)
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Fatalf("failed to read from stream 1 after closing: %v", err)
}
s2 := s2I.(*Stream)
if s2.isClosed() {
s2, _ := sesh.Accept()
if s2.(*Stream).isClosed() {
t.Fatal("stream 2 shouldn't be closed")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1 after stream 1 closed")
}
// close stream 1 again
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
}
_, ok = sesh.streams.Load(f1.StreamID)
if ok {
t.Fatal("stream 1 exists after receiving stream close for the second time")
s1I, _ = sesh.streams.Load(f1.StreamID)
if s1I != nil {
t.Error("stream 1 exists after receiving stream close for the second time")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1 after stream 1 closed twice")
}
// close session
fCloseSession := &Frame{
StreamID: 0xffffffff,
Seq: 0,
Closing: C_SESSION,
Payload: testPayload,
}
n, _ = sesh.Obfs(fCloseSession, obfsBuf)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving session closing frame: %v", err)
}
if !sesh.IsClosed() {
t.Error("session not closed after receiving signal")
}
if !s2.(*Stream).isClosed() {
t.Error("stream 2 isn't closed after session closed")
}
if _, err := s2.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Error("failed to read from stream 2 after session closed")
}
if _, err := s2.Write(testPayload); err == nil {
t.Error("can still write to stream 2 after session closed")
}
if sesh.streamCount() != 0 {
t.Error("stream count isn't 0 after session closed")
}
}
func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
// Tests for when the closing frame of a stream is received first before any data frame
testPayloadLen := 1024
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
obfsBuf := make([]byte, 17000)
sessionKey := make([]byte, 32)
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
rand.Read(sessionKey)
sesh := MakeSession(0, seshConfigOrdered)
// receive stream 1 closing first
f1CloseStream := &Frame{
1,
1,
C_STREAM,
testPayload,
}
n, _ := sesh.Obfs(f1CloseStream, obfsBuf)
err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
}
_, ok := sesh.streams.Load(f1CloseStream.StreamID)
if !ok {
t.Fatal("stream 1 doesn't exist")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1 after stream 1 received")
}
// receive data frame of stream 1 after receiving the closing frame
f1 := &Frame{
1,
0,
C_NOOP,
testPayload,
}
n, _ = sesh.Obfs(f1, obfsBuf)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err)
}
s1, err := sesh.Accept()
if err != nil {
t.Fatal("failed to accept stream 1 after receiving it")
}
payloadBuf := make([]byte, testPayloadLen)
if _, err := s1.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Error("failed to read from steam 1")
}
if !s1.(*Stream).isClosed() {
t.Error("s1 isn't closed")
}
if sesh.streamCount() != 0 {
t.Error("stream count isn't 0 after stream 1 closed")
}
}
func TestParallel(t *testing.T) {
rand.Seed(0)
sessionKey := make([]byte, 32)
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
rand.Read(sessionKey)
sesh := MakeSession(0, seshConfigOrdered)
numStreams := 10
seqs := make([]*uint64, numStreams)
for i, _ := range seqs {
seqs[i] = new(uint64)
}
randFrame := func() *Frame {
id := rand.Intn(numStreams)
return &Frame{
uint32(id),
atomic.AddUint64(seqs[id], 1) - 1,
uint8(rand.Intn(2)),
[]byte{1, 2, 3, 4},
}
}
numOfTests := 100
tests := make([]struct {
name string
frame *Frame
}, numOfTests)
for i, _ := range tests {
tests[i].name = strconv.Itoa(i)
tests[i].frame = randFrame()
}
for _, tc := range tests {
go func(frame *Frame) {
data := make([]byte, 1000)
n, _ := sesh.Obfs(frame, data)
data = data[0:n]
err := sesh.recvDataFromRemote(data)
if err != nil {
t.Error(err)
}
}(tc.frame)
}
var count int
sesh.streams.Range(func(_, s interface{}) bool {
if s != nil {
count++
}
return true
})
sc := int(sesh.streamCount())
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
}
}

View File

@ -157,7 +157,7 @@ func TestStream_Close(t *testing.T) {
return
}
if _, ok := sesh.streams.Load(streamID); ok {
if sI, _ := sesh.streams.Load(streamID); sI != nil {
t.Error("stream still exists")
return
}