mirror of https://github.com/cbeuw/Cloak
Make and add stream upon reception of all new streamIDs even if they are closing
This commit is contained in:
parent
39e54bae6c
commit
6f34229aa0
|
|
@ -51,7 +51,9 @@ type Session struct {
|
|||
// atomic
|
||||
nextStreamID uint32
|
||||
|
||||
streams sync.Map
|
||||
// atomic
|
||||
activeStreamCount uint32
|
||||
streams sync.Map
|
||||
|
||||
// Switchboard manages all connections to remote
|
||||
sb *switchboard
|
||||
|
|
@ -94,6 +96,16 @@ func MakeSession(id uint32, config *SessionConfig) *Session {
|
|||
return sesh
|
||||
}
|
||||
|
||||
func (sesh *Session) streamCountIncr() uint32 {
|
||||
return atomic.AddUint32(&sesh.activeStreamCount, 1)
|
||||
}
|
||||
func (sesh *Session) streamCountDecr() uint32 {
|
||||
return atomic.AddUint32(&sesh.activeStreamCount, ^uint32(0))
|
||||
}
|
||||
func (sesh *Session) streamCount() uint32 {
|
||||
return atomic.LoadUint32(&sesh.activeStreamCount)
|
||||
}
|
||||
|
||||
func (sesh *Session) AddConnection(conn net.Conn) {
|
||||
sesh.sb.addConn(conn)
|
||||
addrs := []net.Addr{conn.LocalAddr(), conn.RemoteAddr()}
|
||||
|
|
@ -112,6 +124,7 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
|||
}
|
||||
stream := makeStream(sesh, id, connId)
|
||||
sesh.streams.Store(id, stream)
|
||||
sesh.streamCountIncr()
|
||||
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
||||
return stream, nil
|
||||
}
|
||||
|
|
@ -159,14 +172,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
|||
log.Tracef("stream %v passively closed", s.id)
|
||||
}
|
||||
|
||||
sesh.streams.Delete(s.id)
|
||||
var count int
|
||||
sesh.streams.Range(func(_, _ interface{}) bool {
|
||||
count += 1
|
||||
return true
|
||||
})
|
||||
if count == 0 {
|
||||
log.Tracef("session %v has no active stream left", sesh.id)
|
||||
sesh.streams.Store(s.id, nil)
|
||||
if sesh.streamCountDecr() == 0 {
|
||||
log.Debugf("session %v has no active stream left", sesh.id)
|
||||
go sesh.timeoutAfter(30 * time.Second)
|
||||
}
|
||||
return nil
|
||||
|
|
@ -181,32 +189,25 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
|||
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
|
||||
}
|
||||
|
||||
if frame.Closing == C_STREAM {
|
||||
streamI, existing := sesh.streams.Load(frame.StreamID)
|
||||
if existing {
|
||||
// DO NOT close the stream straight away here because the sequence number of this frame
|
||||
// hasn't been checked. There may be later data frames which haven't arrived
|
||||
stream := streamI.(*Stream)
|
||||
return stream.writeFrame(*frame)
|
||||
} else {
|
||||
// If the stream has been closed and the current frame is a closing frame, we do noop
|
||||
return nil
|
||||
}
|
||||
} else if frame.Closing == C_SESSION {
|
||||
// Closing session
|
||||
if frame.Closing == C_SESSION {
|
||||
sesh.SetTerminalMsg("Received a closing notification frame")
|
||||
return sesh.passiveClose()
|
||||
} else {
|
||||
connId, _, _ := sesh.sb.pickRandConn()
|
||||
// we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write
|
||||
newStream := makeStream(sesh, frame.StreamID, connId)
|
||||
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
|
||||
if existing {
|
||||
return existingStreamI.(*Stream).writeFrame(*frame)
|
||||
} else {
|
||||
sesh.acceptCh <- newStream
|
||||
return newStream.writeFrame(*frame)
|
||||
}
|
||||
|
||||
connId, _, _ := sesh.sb.pickRandConn()
|
||||
// we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write
|
||||
newStream := makeStream(sesh, frame.StreamID, connId)
|
||||
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
|
||||
if existing {
|
||||
if existingStreamI == nil {
|
||||
// this is when the stream existed before but has since been closed. We do nothing
|
||||
return nil
|
||||
}
|
||||
return existingStreamI.(*Stream).writeFrame(*frame)
|
||||
} else {
|
||||
sesh.streamCountIncr()
|
||||
sesh.acceptCh <- newStream
|
||||
return newStream.writeFrame(*frame)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -232,10 +233,14 @@ func (sesh *Session) passiveClose() error {
|
|||
sesh.acceptCh <- nil
|
||||
|
||||
sesh.streams.Range(func(key, streamI interface{}) bool {
|
||||
if streamI == nil {
|
||||
return true
|
||||
}
|
||||
stream := streamI.(*Stream)
|
||||
atomic.StoreUint32(&stream.closed, 1)
|
||||
_ = stream.recvBuf.Close() // will not block
|
||||
sesh.streams.Delete(key)
|
||||
sesh.streamCountDecr()
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -261,10 +266,14 @@ func (sesh *Session) Close() error {
|
|||
sesh.acceptCh <- nil
|
||||
|
||||
sesh.streams.Range(func(key, streamI interface{}) bool {
|
||||
if streamI == nil {
|
||||
return true
|
||||
}
|
||||
stream := streamI.(*Stream)
|
||||
atomic.StoreUint32(&stream.closed, 1)
|
||||
_ = stream.recvBuf.Close() // will not block
|
||||
sesh.streams.Delete(key)
|
||||
sesh.streamCountDecr()
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -296,12 +305,8 @@ func (sesh *Session) IsClosed() bool {
|
|||
|
||||
func (sesh *Session) timeoutAfter(to time.Duration) {
|
||||
time.Sleep(to)
|
||||
var count int
|
||||
sesh.streams.Range(func(_, _ interface{}) bool {
|
||||
count += 1
|
||||
return true
|
||||
})
|
||||
if count == 0 && !sesh.IsClosed() {
|
||||
|
||||
if sesh.streamCount() == 0 && !sesh.IsClosed() {
|
||||
sesh.SetTerminalMsg("timeout")
|
||||
sesh.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"bytes"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
@ -174,6 +176,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("failed to fetch stream 1 after receiving it")
|
||||
}
|
||||
if sesh.streamCount() != 1 {
|
||||
t.Error("stream count isn't 1")
|
||||
}
|
||||
|
||||
// create stream 2
|
||||
f2 := &Frame{
|
||||
|
|
@ -188,9 +193,12 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
|||
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||
}
|
||||
s2I, ok := sesh.streams.Load(f2.StreamID)
|
||||
if !ok {
|
||||
if s2I == nil || !ok {
|
||||
t.Fatal("failed to fetch stream 2 after receiving it")
|
||||
}
|
||||
if sesh.streamCount() != 2 {
|
||||
t.Error("stream count isn't 2")
|
||||
}
|
||||
|
||||
// close stream 1
|
||||
f1CloseStream := &Frame{
|
||||
|
|
@ -204,33 +212,189 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
_, ok = sesh.streams.Load(f1.StreamID)
|
||||
if ok {
|
||||
s1I, _ = sesh.streams.Load(f1.StreamID)
|
||||
if s1I != nil {
|
||||
t.Fatal("stream 1 still exist after receiving stream close")
|
||||
}
|
||||
s1 := s1I.(*Stream)
|
||||
if !s1.isClosed() {
|
||||
s1, _ := sesh.Accept()
|
||||
if !s1.(*Stream).isClosed() {
|
||||
t.Fatal("stream 1 not marked as closed")
|
||||
}
|
||||
payloadBuf := make([]byte, testPayloadLen)
|
||||
_, err = s1.recvBuf.Read(payloadBuf)
|
||||
_, err = s1.Read(payloadBuf)
|
||||
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||
t.Fatalf("failed to read from stream 1 after closing: %v", err)
|
||||
}
|
||||
s2 := s2I.(*Stream)
|
||||
if s2.isClosed() {
|
||||
s2, _ := sesh.Accept()
|
||||
if s2.(*Stream).isClosed() {
|
||||
t.Fatal("stream 2 shouldn't be closed")
|
||||
}
|
||||
if sesh.streamCount() != 1 {
|
||||
t.Error("stream count isn't 1 after stream 1 closed")
|
||||
}
|
||||
|
||||
// close stream 1 again
|
||||
n, _ = sesh.Obfs(f1CloseStream, obfsBuf)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
||||
}
|
||||
_, ok = sesh.streams.Load(f1.StreamID)
|
||||
if ok {
|
||||
t.Fatal("stream 1 exists after receiving stream close for the second time")
|
||||
s1I, _ = sesh.streams.Load(f1.StreamID)
|
||||
if s1I != nil {
|
||||
t.Error("stream 1 exists after receiving stream close for the second time")
|
||||
}
|
||||
if sesh.streamCount() != 1 {
|
||||
t.Error("stream count isn't 1 after stream 1 closed twice")
|
||||
}
|
||||
|
||||
// close session
|
||||
fCloseSession := &Frame{
|
||||
StreamID: 0xffffffff,
|
||||
Seq: 0,
|
||||
Closing: C_SESSION,
|
||||
Payload: testPayload,
|
||||
}
|
||||
n, _ = sesh.Obfs(fCloseSession, obfsBuf)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving session closing frame: %v", err)
|
||||
}
|
||||
if !sesh.IsClosed() {
|
||||
t.Error("session not closed after receiving signal")
|
||||
}
|
||||
if !s2.(*Stream).isClosed() {
|
||||
t.Error("stream 2 isn't closed after session closed")
|
||||
}
|
||||
if _, err := s2.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||
t.Error("failed to read from stream 2 after session closed")
|
||||
}
|
||||
if _, err := s2.Write(testPayload); err == nil {
|
||||
t.Error("can still write to stream 2 after session closed")
|
||||
}
|
||||
if sesh.streamCount() != 0 {
|
||||
t.Error("stream count isn't 0 after session closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
||||
// Tests for when the closing frame of a stream is received first before any data frame
|
||||
testPayloadLen := 1024
|
||||
testPayload := make([]byte, testPayloadLen)
|
||||
rand.Read(testPayload)
|
||||
obfsBuf := make([]byte, 17000)
|
||||
|
||||
sessionKey := make([]byte, 32)
|
||||
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
|
||||
seshConfigOrdered.Obfuscator = obfuscator
|
||||
|
||||
rand.Read(sessionKey)
|
||||
sesh := MakeSession(0, seshConfigOrdered)
|
||||
|
||||
// receive stream 1 closing first
|
||||
f1CloseStream := &Frame{
|
||||
1,
|
||||
1,
|
||||
C_STREAM,
|
||||
testPayload,
|
||||
}
|
||||
n, _ := sesh.Obfs(f1CloseStream, obfsBuf)
|
||||
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
_, ok := sesh.streams.Load(f1CloseStream.StreamID)
|
||||
if !ok {
|
||||
t.Fatal("stream 1 doesn't exist")
|
||||
}
|
||||
if sesh.streamCount() != 1 {
|
||||
t.Error("stream count isn't 1 after stream 1 received")
|
||||
}
|
||||
|
||||
// receive data frame of stream 1 after receiving the closing frame
|
||||
f1 := &Frame{
|
||||
1,
|
||||
0,
|
||||
C_NOOP,
|
||||
testPayload,
|
||||
}
|
||||
n, _ = sesh.Obfs(f1, obfsBuf)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||
}
|
||||
s1, err := sesh.Accept()
|
||||
if err != nil {
|
||||
t.Fatal("failed to accept stream 1 after receiving it")
|
||||
}
|
||||
payloadBuf := make([]byte, testPayloadLen)
|
||||
if _, err := s1.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||
t.Error("failed to read from steam 1")
|
||||
}
|
||||
if !s1.(*Stream).isClosed() {
|
||||
t.Error("s1 isn't closed")
|
||||
}
|
||||
if sesh.streamCount() != 0 {
|
||||
t.Error("stream count isn't 0 after stream 1 closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
rand.Seed(0)
|
||||
|
||||
sessionKey := make([]byte, 32)
|
||||
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
|
||||
seshConfigOrdered.Obfuscator = obfuscator
|
||||
rand.Read(sessionKey)
|
||||
sesh := MakeSession(0, seshConfigOrdered)
|
||||
|
||||
numStreams := 10
|
||||
seqs := make([]*uint64, numStreams)
|
||||
for i, _ := range seqs {
|
||||
seqs[i] = new(uint64)
|
||||
}
|
||||
randFrame := func() *Frame {
|
||||
id := rand.Intn(numStreams)
|
||||
return &Frame{
|
||||
uint32(id),
|
||||
atomic.AddUint64(seqs[id], 1) - 1,
|
||||
uint8(rand.Intn(2)),
|
||||
[]byte{1, 2, 3, 4},
|
||||
}
|
||||
}
|
||||
|
||||
numOfTests := 100
|
||||
tests := make([]struct {
|
||||
name string
|
||||
frame *Frame
|
||||
}, numOfTests)
|
||||
for i, _ := range tests {
|
||||
tests[i].name = strconv.Itoa(i)
|
||||
tests[i].frame = randFrame()
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
go func(frame *Frame) {
|
||||
data := make([]byte, 1000)
|
||||
n, _ := sesh.Obfs(frame, data)
|
||||
data = data[0:n]
|
||||
|
||||
err := sesh.recvDataFromRemote(data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}(tc.frame)
|
||||
}
|
||||
|
||||
var count int
|
||||
sesh.streams.Range(func(_, s interface{}) bool {
|
||||
if s != nil {
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
sc := int(sesh.streamCount())
|
||||
if sc != count {
|
||||
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ func TestStream_Close(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
if _, ok := sesh.streams.Load(streamID); ok {
|
||||
if sI, _ := sesh.streams.Load(streamID); sI != nil {
|
||||
t.Error("stream still exists")
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue