From 298e6249e6419a493300b9ecb89e666a2c67a65d Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 20 Dec 2020 00:21:16 +0000 Subject: [PATCH 01/13] Fixup build scripts --- .github/workflows/release.yml | 5 ++--- release.sh | 6 +++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 39d8cf1..1b59569 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,11 +13,10 @@ jobs: - name: Build run: | export PATH=${PATH}:`go env GOPATH`/bin - v=${{ github.ref }} ./release.sh + v=${GITHUB_REF#refs/*/} ./release.sh - name: Release uses: softprops/action-gh-release@v1 with: - fail_on_unmatched_files: true - files: release/*.* + files: release/* env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/release.sh b/release.sh index 8b271cf..bee82f8 100755 --- a/release.sh +++ b/release.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + go get github.com/mitchellh/gox mkdir -p release @@ -21,9 +23,11 @@ CGO_ENABLED=0 gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarc CGO_ENABLED=0 GOOS="linux" GOARCH="mips" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mips_softfloat-"${v}" CGO_ENABLED=0 GOOS="linux" GOARCH="mipsle" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mipsle_softfloat-"${v}" mv ck-client-* ../../release +popd os="linux" arch="amd64 386 arm arm64" -pushd ../ck-server || exit 1 +pushd cmd/ck-server || exit 1 CGO_ENABLED=0 gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarch="$osarch" -output="$output" mv ck-server-* ../../release +popd \ No newline at end of file From 0d3f8dd27f05f71bdc42abfbf59db2d44affe89b Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Mon, 21 Dec 2020 15:06:46 +0000 Subject: [PATCH 02/13] Allow DatabasePath to be empty if user info database is never used --- README.md | 14 +++--- internal/server/state.go | 11 +++-- internal/server/usermanager/usermanager.go | 1 + internal/server/usermanager/voidmanager.go | 31 +++++++++++++ .../server/usermanager/voidmanager_test.go | 43 +++++++++++++++++++ internal/server/userpanel.go | 3 ++ internal/test/integration_test.go | 43 ++++++------------- 7 files changed, 106 insertions(+), 40 deletions(-) create mode 100644 internal/server/usermanager/voidmanager.go create mode 100644 internal/server/usermanager/voidmanager_test.go diff --git a/README.md b/README.md index 9b50dfc..285575f 100644 --- a/README.md +++ b/README.md @@ -103,15 +103,13 @@ Example: `PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64. -`AdminUID` is the UID of the admin user in base64. - `BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions -`DatabasePath` is the path to `userinfo.db`. If `userinfo.db` doesn't exist in this directory, Cloak will create one -automatically. **If Cloak is started as a Shadowsocks plugin and Shadowsocks is started with its working directory as -/ (e.g. starting ss-server with systemctl), you need to set this field as an absolute path to a desired folder. If you -leave it as default then Cloak will attempt to create userinfo.db under /, which it doesn't have the permission to do so -and will raise an error. See Issue #13.** +`AdminUID` is the UID of the admin user in base64. You can leave this empty if you only ever add users to `BypassUID`. + +`DatabasePath` is the path to `userinfo.db`, which is used to store user usage information and restrictions. Cloak will +create the file automatically if it doesn't exist. You can leave this empty if you only ever add users to `BypassUID`. +This field also has no effect if `AdminUID` isn't a valid UID or is empty. `KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the upstream proxy server. Zero or negative value disables it. Default is 0 (disabled). @@ -184,6 +182,8 @@ Run `ck-server -uid` and add the UID into the `BypassUID` field in `ckserver.jso ##### Users subject to bandwidth and credit controls +0. First make sure you have `AdminUID` generated and set in `ckserver.json`, along with a path to `userinfo.db` + in `DatabasePath` (Cloak will create this file for you if it didn't already exist). 1. On your client, run `ck-client -s -l -a -c ` to enter admin mode 2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data diff --git a/internal/server/state.go b/internal/server/state.go index 576e326..03d9298 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -143,9 +143,14 @@ func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, er err = errors.New("command & control mode not implemented") return } else { - manager, err := usermanager.MakeLocalManager(preParse.DatabasePath, worldState) - if err != nil { - return sta, err + var manager usermanager.UserManager + if len(preParse.AdminUID) == 0 || preParse.DatabasePath == "" { + manager = &usermanager.Voidmanager{} + } else { + manager, err = usermanager.MakeLocalManager(preParse.DatabasePath, worldState) + if err != nil { + return sta, err + } } sta.Panel = MakeUserPanel(manager) } diff --git a/internal/server/usermanager/usermanager.go b/internal/server/usermanager/usermanager.go index bfd5cc4..7bf84d5 100644 --- a/internal/server/usermanager/usermanager.go +++ b/internal/server/usermanager/usermanager.go @@ -40,6 +40,7 @@ const ( var ErrUserNotFound = errors.New("UID does not correspond to a user") var ErrSessionsCapReached = errors.New("Sessions cap has reached") +var ErrMangerIsVoid = errors.New("cannot perform operation with user manager as database path is not specified") var ErrNoUpCredit = errors.New("No upload credit left") var ErrNoDownCredit = errors.New("No download credit left") diff --git a/internal/server/usermanager/voidmanager.go b/internal/server/usermanager/voidmanager.go new file mode 100644 index 0000000..a20ab3c --- /dev/null +++ b/internal/server/usermanager/voidmanager.go @@ -0,0 +1,31 @@ +package usermanager + +type Voidmanager struct{} + +func (v *Voidmanager) AuthenticateUser(bytes []byte) (int64, int64, error) { + return 0, 0, ErrMangerIsVoid +} + +func (v *Voidmanager) AuthoriseNewSession(bytes []byte, info AuthorisationInfo) error { + return ErrMangerIsVoid +} + +func (v *Voidmanager) UploadStatus(updates []StatusUpdate) ([]StatusResponse, error) { + return nil, ErrMangerIsVoid +} + +func (v *Voidmanager) ListAllUsers() ([]UserInfo, error) { + return nil, ErrMangerIsVoid +} + +func (v *Voidmanager) GetUserInfo(UID []byte) (UserInfo, error) { + return UserInfo{}, ErrMangerIsVoid +} + +func (v *Voidmanager) WriteUserInfo(info UserInfo) error { + return ErrMangerIsVoid +} + +func (v *Voidmanager) DeleteUser(UID []byte) error { + return ErrMangerIsVoid +} diff --git a/internal/server/usermanager/voidmanager_test.go b/internal/server/usermanager/voidmanager_test.go new file mode 100644 index 0000000..55ab2b4 --- /dev/null +++ b/internal/server/usermanager/voidmanager_test.go @@ -0,0 +1,43 @@ +package usermanager + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +var v = &Voidmanager{} + +func Test_Voidmanager_AuthenticateUser(t *testing.T) { + _, _, err := v.AuthenticateUser([]byte{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_AuthoriseNewSession(t *testing.T) { + err := v.AuthoriseNewSession([]byte{}, AuthorisationInfo{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_DeleteUser(t *testing.T) { + err := v.DeleteUser([]byte{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_GetUserInfo(t *testing.T) { + _, err := v.GetUserInfo([]byte{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_ListAllUsers(t *testing.T) { + _, err := v.ListAllUsers() + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_UploadStatus(t *testing.T) { + _, err := v.UploadStatus([]StatusUpdate{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_WriteUserInfo(t *testing.T) { + err := v.WriteUserInfo(UserInfo{}) + assert.Equal(t, ErrMangerIsVoid, err) +} diff --git a/internal/server/userpanel.go b/internal/server/userpanel.go index 453ff39..953179e 100644 --- a/internal/server/userpanel.go +++ b/internal/server/userpanel.go @@ -185,6 +185,9 @@ func (panel *userPanel) commitUpdate() error { panel.usageUpdateQueue = make(map[[16]byte]*usagePair) panel.usageUpdateQueueM.Unlock() + if len(statuses) == 0 { + return nil + } responses, err := panel.Manager.UploadStatus(statuses) if err != nil { return err diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index db58b39..5812ba2 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -12,10 +12,8 @@ import ( "github.com/cbeuw/connutil" "github.com/stretchr/testify/assert" "io" - "io/ioutil" "math/rand" "net" - "os" "sync" "testing" "time" @@ -24,8 +22,6 @@ import ( ) const numConns = 200 // -race option limits the number of goroutines to 8192 -const delayBeforeTestingConnClose = 500 * time.Millisecond -const connCloseRetries = 3 func serveTCPEcho(l net.Listener) { for { @@ -137,17 +133,15 @@ func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState) return lcl, rmt, auth } -func basicServerState(ws common.WorldState, db *os.File) *server.State { +func basicServerState(ws common.WorldState) *server.State { var serverConfig = server.RawConfig{ - ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}}, - BindAddr: []string{"fake.com:9999"}, - BypassUID: [][]byte{bypassUID[:]}, - RedirAddr: "fake.com:9999", - PrivateKey: privateKey, - AdminUID: nil, - DatabasePath: db.Name(), - KeepAlive: 15, - CncMode: false, + ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}}, + BindAddr: []string{"fake.com:9999"}, + BypassUID: [][]byte{bypassUID[:]}, + RedirAddr: "fake.com:9999", + PrivateKey: privateKey, + KeepAlive: 15, + CncMode: false, } state, err := server.InitState(serverConfig, ws) if err != nil { @@ -258,13 +252,11 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { } func TestUDP(t *testing.T) { - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) if err != nil { @@ -319,9 +311,7 @@ func TestTCPSingleplex(t *testing.T) { log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState) - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) if err != nil { t.Fatal(err) @@ -381,9 +371,7 @@ func TestTCPMultiplex(t *testing.T) { worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta) if err != nil { @@ -456,11 +444,8 @@ func TestClosingStreamsFromProxy(t *testing.T) { clientConfig := clientConfig clientConfigName := clientConfigName t.Run(clientConfigName, func(t *testing.T) { - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) - lcc, rcc, ai := generateClientConfigs(clientConfig, worldState) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) if err != nil { t.Fatal(err) @@ -519,12 +504,10 @@ func TestClosingStreamsFromProxy(t *testing.T) { } func BenchmarkThroughput(b *testing.B) { - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) const bufSize = 16 * 1024 encryptionMethods := map[string]byte{ From de0daac1233e96be172c704c2b73e5f36cdaab6e Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Mon, 21 Dec 2020 16:37:33 +0000 Subject: [PATCH 03/13] Update deprecated curve25519 functions and defend against low-order point attacks --- internal/client/auth.go | 12 ++++++++++-- internal/ecdh/curve25519.go | 8 +++----- internal/ecdh/curve25519_test.go | 4 ++-- internal/server/TLS.go | 8 +++++++- internal/server/websocket.go | 8 +++++++- 5 files changed, 29 insertions(+), 11 deletions(-) diff --git a/internal/client/auth.go b/internal/client/auth.go index 939a34d..4925541 100644 --- a/internal/client/auth.go +++ b/internal/client/auth.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/ecdh" + log "github.com/sirupsen/logrus" ) const ( @@ -26,7 +27,10 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh | 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes | +----------+----------------+---------------------+-------------+--------------+--------+------------+ */ - ephPv, ephPub, _ := ecdh.GenerateKey(authInfo.WorldState.Rand) + ephPv, ephPub, err := ecdh.GenerateKey(authInfo.WorldState.Rand) + if err != nil { + log.Panicf("failed to generate ephemeral key pair: %v", err) + } copy(ret.randPubKey[:], ecdh.Marshal(ephPub)) plaintext := make([]byte, 48) @@ -40,7 +44,11 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh plaintext[41] |= UNORDERED_FLAG } - copy(sharedSecret[:], ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey)) + secret, err := ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey) + if err != nil { + log.Panicf("error in generating shared secret: %v", err) + } + copy(sharedSecret[:], secret) ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext) copy(ret.ciphertextWithTag[:], ciphertextWithTag[:]) return diff --git a/internal/ecdh/curve25519.go b/internal/ecdh/curve25519.go index 94d066b..5744c5e 100644 --- a/internal/ecdh/curve25519.go +++ b/internal/ecdh/curve25519.go @@ -68,13 +68,11 @@ func Unmarshal(data []byte) (crypto.PublicKey, bool) { return &pub, true } -func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) []byte { - var priv, pub, secret *[32]byte +func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) ([]byte, error) { + var priv, pub *[32]byte priv = privKey.(*[32]byte) pub = pubKey.(*[32]byte) - secret = new([32]byte) - curve25519.ScalarMult(secret, priv, pub) - return secret[:] + return curve25519.X25519(priv[:], pub[:]) } diff --git a/internal/ecdh/curve25519_test.go b/internal/ecdh/curve25519_test.go index 8e9a1c1..39d56ba 100644 --- a/internal/ecdh/curve25519_test.go +++ b/internal/ecdh/curve25519_test.go @@ -90,11 +90,11 @@ func testECDH(t testing.TB) { t.Fatalf("Unmarshal does not work") } - secret1 = GenerateSharedSecret(privKey1, pubKey2) + secret1, err = GenerateSharedSecret(privKey1, pubKey2) if err != nil { t.Error(err) } - secret2 = GenerateSharedSecret(privKey2, pubKey1) + secret2, err = GenerateSharedSecret(privKey2, pubKey1) if err != nil { t.Error(err) } diff --git a/internal/server/TLS.go b/internal/server/TLS.go index 8a0ea6a..0e66387 100644 --- a/internal/server/TLS.go +++ b/internal/server/TLS.go @@ -79,7 +79,13 @@ func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fr return } - copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub)) + var sharedSecret []byte + sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub) + if err != nil { + return + } + + copy(fragments.sharedSecret[:], sharedSecret) var keyShare []byte keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}]) if err != nil { diff --git a/internal/server/websocket.go b/internal/server/websocket.go index 2b192b9..1c9e940 100644 --- a/internal/server/websocket.go +++ b/internal/server/websocket.go @@ -84,7 +84,13 @@ func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fra return } - copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub)) + var sharedSecret []byte + sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub) + if err != nil { + return + } + + copy(fragments.sharedSecret[:], sharedSecret) if len(hidden[32:]) != 64 { err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:])) From c9ac93b0b98e3fcb441eab859232d8fd70392104 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Mon, 21 Dec 2020 20:38:28 +0000 Subject: [PATCH 04/13] Refactor session_test.go --- internal/multiplex/session_test.go | 341 +++++++++++++++-------------- 1 file changed, 174 insertions(+), 167 deletions(-) diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index b280895..89bd410 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -12,10 +12,9 @@ import ( "time" ) -var seshConfigOrdered = SessionConfig{} - -var seshConfigUnordered = SessionConfig{ - Unordered: true, +var seshConfigs = map[string]SessionConfig{ + "ordered": {}, + "unordered": {Unordered: true}, } const testPayloadLen = 1024 @@ -43,40 +42,20 @@ func TestRecvDataFromRemote(t *testing.T) { return ret } - sessionTypes := []struct { - name string - config SessionConfig - }{ - {"ordered", - SessionConfig{}}, - {"unordered", - SessionConfig{Unordered: true}}, + encryptionMethods := map[string]Obfuscator{ + "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), + "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), + "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), } - encryptionMethods := []struct { - name string - obfuscator Obfuscator - }{ - { - "plain", - MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), - }, - { - "aes-gcm", - MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), - }, - { - "chacha20-poly1305", - MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), - }, - } - - for _, st := range sessionTypes { - t.Run(st.name, func(t *testing.T) { - for _, em := range encryptionMethods { - t.Run(em.name, func(t *testing.T) { - st.config.Obfuscator = em.obfuscator - sesh := MakeSession(0, st.config) + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + for method, obfuscator := range encryptionMethods { + obfuscator := obfuscator + t.Run(method, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, err := sesh.Obfs(f, obfsBuf, 0) if err != nil { t.Error(err) @@ -116,8 +95,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + + seshConfig := seshConfigs["ordered"] + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) f1 := &Frame{ 1, @@ -245,8 +226,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + + seshConfig := seshConfigs["ordered"] + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) // receive stream 1 closing first f1CloseStream := &Frame{ @@ -300,119 +283,125 @@ func TestParallelStreams(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - numStreams := acceptBacklog - seqs := make([]*uint64, numStreams) - for i := range seqs { - seqs[i] = new(uint64) - } - randFrame := func() *Frame { - id := rand.Intn(numStreams) - return &Frame{ - uint32(id), - atomic.AddUint64(seqs[id], 1) - 1, - uint8(rand.Intn(2)), - []byte{1, 2, 3, 4}, - } - } + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) - const numOfTests = 5000 - tests := make([]struct { - name string - frame *Frame - }, numOfTests) - for i := range tests { - tests[i].name = strconv.Itoa(i) - tests[i].frame = randFrame() - } - - var wg sync.WaitGroup - for _, tc := range tests { - wg.Add(1) - go func(frame *Frame) { - obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(frame, obfsBuf, 0) - obfsBuf = obfsBuf[0:n] - - err := sesh.recvDataFromRemote(obfsBuf) - if err != nil { - t.Error(err) + numStreams := acceptBacklog + seqs := make([]*uint64, numStreams) + for i := range seqs { + seqs[i] = new(uint64) + } + randFrame := func() *Frame { + id := rand.Intn(numStreams) + return &Frame{ + uint32(id), + atomic.AddUint64(seqs[id], 1) - 1, + uint8(rand.Intn(2)), + []byte{1, 2, 3, 4}, + } } - wg.Done() - }(tc.frame) - } - wg.Wait() - sc := int(sesh.streamCount()) - var count int - sesh.streams.Range(func(_, s interface{}) bool { - if s != nil { - count++ - } - return true - }) - if sc != count { - t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) + const numOfTests = 5000 + tests := make([]struct { + name string + frame *Frame + }, numOfTests) + for i := range tests { + tests[i].name = strconv.Itoa(i) + tests[i].frame = randFrame() + } + + var wg sync.WaitGroup + for _, tc := range tests { + wg.Add(1) + go func(frame *Frame) { + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(frame, obfsBuf, 0) + obfsBuf = obfsBuf[0:n] + + err := sesh.recvDataFromRemote(obfsBuf) + if err != nil { + t.Error(err) + } + wg.Done() + }(tc.frame) + } + + wg.Wait() + sc := int(sesh.streamCount()) + var count int + sesh.streams.Range(func(_, s interface{}) bool { + if s != nil { + count++ + } + return true + }) + if sc != count { + t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) + } + }) } } func TestStream_SetReadDeadline(t *testing.T) { - var sessionKey [32]byte - rand.Read(sessionKey[:]) - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + sesh.AddConnection(connutil.Discard()) - testReadDeadline := func(sesh *Session) { - t.Run("read after deadline set", func(t *testing.T) { - stream, _ := sesh.OpenStream() - _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) - _, err := stream.Read(make([]byte, 1)) - if err != ErrTimeout { - t.Errorf("expecting error %v, got %v", ErrTimeout, err) - } - }) + t.Run("read after deadline set", func(t *testing.T) { + stream, _ := sesh.OpenStream() + _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) + _, err := stream.Read(make([]byte, 1)) + if err != ErrTimeout { + t.Errorf("expecting error %v, got %v", ErrTimeout, err) + } + }) - t.Run("unblock when deadline passed", func(t *testing.T) { - stream, _ := sesh.OpenStream() + t.Run("unblock when deadline passed", func(t *testing.T) { + stream, _ := sesh.OpenStream() - done := make(chan struct{}) - go func() { - _, _ = stream.Read(make([]byte, 1)) - done <- struct{}{} - }() + done := make(chan struct{}) + go func() { + _, _ = stream.Read(make([]byte, 1)) + done <- struct{}{} + }() - _ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - select { - case <-done: - return - case <-time.After(500 * time.Millisecond): - t.Error("Read did not unblock after deadline has passed") - } + select { + case <-done: + return + case <-time.After(500 * time.Millisecond): + t.Error("Read did not unblock after deadline has passed") + } + }) }) } - - sesh := MakeSession(0, seshConfigOrdered) - sesh.AddConnection(connutil.Discard()) - testReadDeadline(sesh) - sesh = MakeSession(0, seshConfigUnordered) - sesh.AddConnection(connutil.Discard()) - testReadDeadline(sesh) } func TestSession_timeoutAfter(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - seshConfigOrdered.InactivityTimeout = 100 * time.Millisecond - sesh := MakeSession(0, seshConfigOrdered) - assert.Eventually(t, func() bool { - return sesh.IsClosed() - }, 5*seshConfigOrdered.InactivityTimeout, seshConfigOrdered.InactivityTimeout, "session should have timed out") + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + seshConfig.InactivityTimeout = 100 * time.Millisecond + sesh := MakeSession(0, seshConfig) + + assert.Eventually(t, func() bool { + return sesh.IsClosed() + }, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out") + }) + } } func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { @@ -429,42 +418,60 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { var sessionKey [32]byte rand.Read(sessionKey[:]) - b.Run("plain", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAESGCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } + + for name, ep := range table { + seshConfig := seshConfigs["ordered"] + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, _ := sesh.Obfs(f, obfsBuf, 0) - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) - - b.Run("aes-gcm", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) - - b.Run("chacha20-poly1305", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) + b.Run(name, func(b *testing.B) { + b.SetBytes(int64(len(f.Payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + sesh.recvDataFromRemote(obfsBuf[:n]) + } + }) + } +} + +func BenchmarkMultiStreamWrite(b *testing.B) { + var sessionKey [32]byte + rand.Read(sessionKey[:]) + + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAESGCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } + + testPayload := make([]byte, testPayloadLen) + + for name, ep := range table { + b.Run(name, func(b *testing.B) { + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + b.Run(seshType, func(b *testing.B) { + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) + sesh.AddConnection(connutil.Discard()) + b.ResetTimer() + b.SetBytes(testPayloadLen) + b.RunParallel(func(pb *testing.PB) { + stream, _ := sesh.OpenStream() + for pb.Next() { + stream.Write(testPayload) + } + }) + }) + } + }) + } } From 3633c9a03c7795ce2f484618b182d7181e4cd5a5 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 12:39:39 +0000 Subject: [PATCH 05/13] Fix multiplex test as test payload length may be randomised to 0 --- internal/multiplex/mux_test.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/internal/multiplex/mux_test.go b/internal/multiplex/mux_test.go index 76344ca..c8c60f4 100644 --- a/internal/multiplex/mux_test.go +++ b/internal/multiplex/mux_test.go @@ -10,7 +10,6 @@ import ( "net" "sync" "testing" - "time" ) func serveEcho(l net.Listener) { @@ -64,21 +63,20 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) { return clientSession, serverSession, paris } -func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { +func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) { var wg sync.WaitGroup for _, conn := range conns { wg.Add(1) go func(conn net.Conn) { - testDataLen := rand.Intn(maxMsgLen) - testData := make([]byte, testDataLen) + testData := make([]byte, msgLen) rand.Read(testData) n, err := conn.Write(testData) - if n != testDataLen { + if n != msgLen { t.Fatalf("written only %v, err %v", n, err) } - recvBuf := make([]byte, testDataLen) + recvBuf := make([]byte, msgLen) _, err = io.ReadFull(conn, recvBuf) if err != nil { t.Fatalf("failed to read back: %v", err) @@ -96,7 +94,7 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { func TestMultiplex(t *testing.T) { const numStreams = 2000 // -race option limits the number of goroutines to 8192 const numConns = 4 - const maxMsgLen = 16384 + const msgLen = 16384 clientSession, serverSession, _ := makeSessionPair(numConns) go serveEcho(serverSession) @@ -111,15 +109,10 @@ func TestMultiplex(t *testing.T) { } //test echo - runEchoTest(t, streams, maxMsgLen) + runEchoTest(t, streams, msgLen) - assert.Eventuallyf(t, func() bool { - return clientSession.streamCount() == numStreams - }, time.Second, 10*time.Millisecond, "client stream count is wrong: %v", clientSession.streamCount()) - - assert.Eventuallyf(t, func() bool { - return serverSession.streamCount() == numStreams - }, time.Second, 10*time.Millisecond, "server stream count is wrong: %v", serverSession.streamCount()) + assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong") + assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong") // close one stream closing, streams := streams[0], streams[1:] From 42f36b94d3c5704b8f2a3a9b13d65366ac13e7e6 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 13:16:48 +0000 Subject: [PATCH 06/13] Achieve zero allocation when writing data through stream --- internal/multiplex/session.go | 16 +++++++--------- internal/multiplex/stream.go | 36 +++++++++++++++-------------------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index c808e2b..a32b6bf 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -165,6 +165,7 @@ func (sesh *Session) Accept() (net.Conn, error) { } func (sesh *Session) closeStream(s *Stream, active bool) error { + // must be holding s.wirtingM on entry if atomic.SwapUint32(&s.closed, 1) == 1 { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } @@ -173,16 +174,13 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if active { // Notify remote that this stream is closed padding := genRandomPadding() - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingStream, - Payload: padding, - } - s.nextSendSeq++ + s.writingFrame.Closing = closingStream + s.writingFrame.Payload = padding obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead) - i, err := sesh.Obfs(f, obfsBuf, 0) + + i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0) + s.writingFrame.Seq++ if err != nil { return err } @@ -190,7 +188,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if err != nil { return err } - log.Tracef("stream %v actively closed. seq %v", s.id, f.Seq) + log.Tracef("stream %v actively closed.", s.id) } else { log.Tracef("stream %v passively closed", s.id) } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index beee2b8..d64628f 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -27,8 +27,8 @@ type Stream struct { // been read by the consumer through Read or WriteTo recvBuf recvBuffer - writingM sync.Mutex - nextSendSeq uint64 + writingM sync.Mutex + writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom // atomic closed uint32 @@ -63,6 +63,11 @@ func makeStream(sesh *Session, id uint32) *Stream { id: id, session: sesh, recvBuf: recvBuf, + writingFrame: Frame{ + StreamID: id, + Seq: 0, + Closing: closingNothing, + }, } return stream @@ -110,15 +115,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { return n, nil } -func (s *Stream) obfuscateAndSend(f *Frame, payloadOffsetInObfsBuf int) error { +func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { var cipherTextLen int - cipherTextLen, err := s.session.Obfs(f, s.obfsBuf, payloadOffsetInObfsBuf) + cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) if err != nil { return err } _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) - log.Tracef("%v sent to remote through stream %v with err %v. seq: %v", len(f.Payload), s.id, err, f.Seq) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) @@ -154,14 +158,9 @@ func (s *Stream) Write(in []byte) (n int, err error) { } framePayload = in[n : s.session.maxStreamUnitWrite+n] } - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingNothing, - Payload: framePayload, - } - s.nextSendSeq++ - err = s.obfuscateAndSend(f, 0) + s.writingFrame.Payload = framePayload + err = s.obfuscateAndSend(0) + s.writingFrame.Seq++ if err != nil { return } @@ -193,14 +192,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { } s.writingM.Lock() - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingNothing, - Payload: s.obfsBuf[frameHeaderLength : frameHeaderLength+read], - } - s.nextSendSeq++ - err = s.obfuscateAndSend(f, frameHeaderLength) + s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read] + err = s.obfuscateAndSend(frameHeaderLength) + s.writingFrame.Seq++ s.writingM.Unlock() if err != nil { From badda764549a358e150d2bd766bbc520f7bb01bb Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 13:40:37 +0000 Subject: [PATCH 07/13] Improve data receive benchmark --- internal/multiplex/session_test.go | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index 89bd410..31cee76 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -413,7 +413,6 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { 0, testPayload, } - obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) @@ -424,18 +423,27 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { "chacha20poly1305": EncryptionMethodChaha20Poly1305, } + const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic for name, ep := range table { - seshConfig := seshConfigs["ordered"] - obfuscator, _ := MakeObfuscator(ep, sessionKey) - seshConfig.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfig) - n, _ := sesh.Obfs(f, obfsBuf, 0) - + ep := ep b.Run(name, func(b *testing.B) { + seshConfig := seshConfigs["ordered"] + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) + + binaryFrames := [maxIter][]byte{} + for i := 0; i < maxIter; i++ { + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(f, obfsBuf, 0) + binaryFrames[i] = obfsBuf[:n] + f.Seq++ + } + b.SetBytes(int64(len(f.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) + sesh.recvDataFromRemote(binaryFrames[i]) } }) } From 4bc80af9a198ddd807486479e9348aa2e01ae30d Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 14:12:00 +0000 Subject: [PATCH 08/13] Lazily allocate stream receiving buffer --- internal/multiplex/session.go | 2 +- internal/multiplex/stream.go | 36 ++++++++++++++++++----------------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index a32b6bf..1a4b8b4 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -169,7 +169,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if atomic.SwapUint32(&s.closed, 1) == 1 { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } - _ = s.recvBuf.Close() // recvBuf.Close should not return error + _ = s.getRecvBuf().Close() // recvBuf.Close should not return error if active { // Notify remote that this stream is closed diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index d64628f..2245f5b 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -23,8 +23,9 @@ type Stream struct { session *Session + allocIdempot sync.Once // a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't - // been read by the consumer through Read or WriteTo + // been read by the consumer through Read or WriteTo. Lazily allocated recvBuf recvBuffer writingM sync.Mutex @@ -33,11 +34,9 @@ type Stream struct { // atomic closed uint32 - // lazy allocation for obfsBuf. This is desirable because obfsBuf is only used when data is sent from + // obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from // the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste // memory - allocIdempot sync.Once - // obfuscation happens in this buffer obfsBuf []byte // When we want order guarantee (i.e. session.Unordered is false), @@ -52,17 +51,9 @@ type Stream struct { } func makeStream(sesh *Session, id uint32) *Stream { - var recvBuf recvBuffer - if sesh.Unordered { - recvBuf = NewDatagramBufferedPipe() - } else { - recvBuf = NewStreamBuffer() - } - stream := &Stream{ id: id, session: sesh, - recvBuf: recvBuf, writingFrame: Frame{ StreamID: id, Seq: 0, @@ -75,9 +66,20 @@ func makeStream(sesh *Session, id uint32) *Stream { func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } +func (s *Stream) getRecvBuf() recvBuffer { + s.allocIdempot.Do(func() { + if s.session.Unordered { + s.recvBuf = NewDatagramBufferedPipe() + } else { + s.recvBuf = NewStreamBuffer() + } + }) + return s.recvBuf +} + // receive a readily deobfuscated Frame so its payload can later be Read func (s *Stream) recvFrame(frame Frame) error { - toBeClosed, err := s.recvBuf.Write(frame) + toBeClosed, err := s.getRecvBuf().Write(frame) if toBeClosed { err = s.passiveClose() if errors.Is(err, errRepeatStreamClosing) { @@ -96,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { return 0, nil } - n, err = s.recvBuf.Read(buf) + n, err = s.getRecvBuf().Read(buf) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -107,7 +109,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { // WriteTo continuously write data Stream has received into the writer w. func (s *Stream) WriteTo(w io.Writer) (int64, error) { // will keep writing until the underlying buffer is closed - n, err := s.recvBuf.WriteTo(w) + n, err := s.getRecvBuf().WriteTo(w) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -219,8 +221,8 @@ func (s *Stream) Close() error { func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } -func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } -func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } +func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) } +func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d } var errNotImplemented = errors.New("Not implemented") From 104117cafbbd01bd01fe5bd11a4fa1fcd06f9567 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 14:32:41 +0000 Subject: [PATCH 09/13] Fix one instance of not accessing recvBuf via the getter --- internal/multiplex/session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 1a4b8b4..9964f74 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -264,7 +264,7 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error { } stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) - _ = stream.recvBuf.Close() // will not block + _ = stream.getRecvBuf().Close() // will not block sesh.streams.Delete(key) sesh.streamCountDecr() return true From 5a3f63f101a05728383d22a90f12670984f69922 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 14:45:29 +0000 Subject: [PATCH 10/13] Reduce allocation of frame objects on receiving data --- internal/multiplex/obfs.go | 22 ++++++++++------------ internal/multiplex/obfs_test.go | 18 +++++++++++------- internal/multiplex/session.go | 9 ++++++++- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 0c1f8c6..1379072 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -12,7 +12,7 @@ import ( ) type Obfser func(*Frame, []byte, int) (int, error) -type Deobfser func([]byte) (*Frame, error) +type Deobfser func(*Frame, []byte) error var u32 = binary.BigEndian.Uint32 var u64 = binary.BigEndian.Uint64 @@ -135,9 +135,9 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { // frame header length + minimum data size (i.e. nonce size of salsa20) const minInputLen = frameHeaderLength + salsa20NonceSize - deobfs := func(in []byte) (*Frame, error) { + deobfs := func(f *Frame, in []byte) error { if len(in) < minInputLen { - return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) + return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) } header := in[:frameHeaderLength] @@ -153,7 +153,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { usefulPayloadLen := len(pldWithOverHead) - int(extraLen) if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { - return nil, errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") + return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") } var outputPayload []byte @@ -167,18 +167,16 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { } else { _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) if err != nil { - return nil, err + return err } outputPayload = pldWithOverHead[:usefulPayloadLen] } - ret := &Frame{ - StreamID: streamID, - Seq: seq, - Closing: closing, - Payload: outputPayload, - } - return ret, nil + f.StreamID = streamID + f.Seq = seq + f.Closing = closing + f.Payload = outputPayload + return nil } return deobfs } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 6cbbb5b..99f4f5f 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -17,8 +17,7 @@ func TestGenerateObfs(t *testing.T) { run := func(obfuscator Obfuscator, ct *testing.T) { obfsBuf := make([]byte, 512) - f := &Frame{} - _testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42))) + _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) if err != nil { @@ -26,7 +25,8 @@ func TestGenerateObfs(t *testing.T) { return } - resultFrame, err := obfuscator.Deobfs(obfsBuf[:i]) + var resultFrame Frame + err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) if err != nil { ct.Error("failed to deobfs ", err) return @@ -148,10 +148,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.SetBytes(int64(n)) b.ResetTimer() for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("AES128GCM", func(b *testing.B) { @@ -162,10 +163,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("plain", func(b *testing.B) { @@ -173,10 +175,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, nil) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("chacha20Poly1305", func(b *testing.B) { @@ -186,10 +189,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 9964f74..6a88aa3 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -65,6 +65,9 @@ type Session struct { activeStreamCount uint32 streams sync.Map + // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame + recvFramePool sync.Pool + // Switchboard manages all connections to remote sb *switchboard @@ -89,6 +92,7 @@ func MakeSession(id uint32, config SessionConfig) *Session { SessionConfig: config, nextStreamID: 1, acceptCh: make(chan *Stream, acceptBacklog), + recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }}, } sesh.addrs.Store([]net.Addr{nil, nil}) @@ -212,7 +216,10 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { // to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new // stream and then writes to the stream buffer func (sesh *Session) recvDataFromRemote(data []byte) error { - frame, err := sesh.Deobfs(data) + frame := sesh.recvFramePool.Get().(*Frame) + defer sesh.recvFramePool.Put(frame) + + err := sesh.Deobfs(frame, data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } From ff503b06a8adf6839ffb91dce8a4d527c8c1d156 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 19:39:13 +0000 Subject: [PATCH 11/13] Only allocate and copy frame object into sorter heap when necessary (out of order frame) --- internal/multiplex/datagramBufferedPipe.go | 2 +- internal/multiplex/datagramBufferedPipe_test.go | 10 +++++----- internal/multiplex/recvBuffer.go | 2 +- internal/multiplex/session.go | 4 ++-- internal/multiplex/stream.go | 2 +- internal/multiplex/streamBuffer.go | 7 ++++--- internal/multiplex/streamBuffer_test.go | 7 +++---- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/multiplex/datagramBufferedPipe.go b/internal/multiplex/datagramBufferedPipe.go index e1a0462..a7b99e4 100644 --- a/internal/multiplex/datagramBufferedPipe.go +++ b/internal/multiplex/datagramBufferedPipe.go @@ -112,7 +112,7 @@ func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { } } -func (d *datagramBufferedPipe) Write(f Frame) (toBeClosed bool, err error) { +func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) { d.rwCond.L.Lock() defer d.rwCond.L.Unlock() if d.buf == nil { diff --git a/internal/multiplex/datagramBufferedPipe_test.go b/internal/multiplex/datagramBufferedPipe_test.go index 4a5d4e2..6b20f76 100644 --- a/internal/multiplex/datagramBufferedPipe_test.go +++ b/internal/multiplex/datagramBufferedPipe_test.go @@ -10,7 +10,7 @@ func TestDatagramBuffer_RW(t *testing.T) { b := []byte{0x01, 0x02, 0x03} t.Run("simple write", func(t *testing.T) { pipe := NewDatagramBufferedPipe() - _, err := pipe.Write(Frame{Payload: b}) + _, err := pipe.Write(&Frame{Payload: b}) if err != nil { t.Error( "expecting", "nil error", @@ -22,7 +22,7 @@ func TestDatagramBuffer_RW(t *testing.T) { t.Run("simple read", func(t *testing.T) { pipe := NewDatagramBufferedPipe() - _, _ = pipe.Write(Frame{Payload: b}) + _, _ = pipe.Write(&Frame{Payload: b}) b2 := make([]byte, len(b)) n, err := pipe.Read(b2) if n != len(b) { @@ -55,7 +55,7 @@ func TestDatagramBuffer_RW(t *testing.T) { t.Run("writing closing frame", func(t *testing.T) { pipe := NewDatagramBufferedPipe() - toBeClosed, err := pipe.Write(Frame{Closing: closingStream}) + toBeClosed, err := pipe.Write(&Frame{Closing: closingStream}) if !toBeClosed { t.Error("should be to be closed") } @@ -77,7 +77,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) { b := []byte{0x01, 0x02, 0x03} go func() { time.Sleep(readBlockTime) - pipe.Write(Frame{Payload: b}) + pipe.Write(&Frame{Payload: b}) }() b2 := make([]byte, len(b)) n, err := pipe.Read(b2) @@ -110,7 +110,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) { func TestDatagramBuffer_CloseThenRead(t *testing.T) { pipe := NewDatagramBufferedPipe() b := []byte{0x01, 0x02, 0x03} - pipe.Write(Frame{Payload: b}) + pipe.Write(&Frame{Payload: b}) b2 := make([]byte, len(b)) pipe.Close() n, err := pipe.Read(b2) diff --git a/internal/multiplex/recvBuffer.go b/internal/multiplex/recvBuffer.go index 0797daf..63f1f6f 100644 --- a/internal/multiplex/recvBuffer.go +++ b/internal/multiplex/recvBuffer.go @@ -15,7 +15,7 @@ type recvBuffer interface { // when the buffer is empty. io.ReadCloser io.WriterTo - Write(Frame) (toBeClosed bool, err error) + Write(*Frame) (toBeClosed bool, err error) SetReadDeadline(time time.Time) // SetWriteToTimeout sets the duration a recvBuffer waits in a WriteTo call when nothing // has been written for a while. After that duration it should return ErrTimeout diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 6a88aa3..92b5cf8 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -236,12 +236,12 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { // this is when the stream existed before but has since been closed. We do nothing return nil } - return existingStreamI.(*Stream).recvFrame(*frame) + return existingStreamI.(*Stream).recvFrame(frame) } else { // new stream sesh.streamCountIncr() sesh.acceptCh <- newStream - return newStream.recvFrame(*frame) + return newStream.recvFrame(frame) } } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 2245f5b..d827117 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -78,7 +78,7 @@ func (s *Stream) getRecvBuf() recvBuffer { } // receive a readily deobfuscated Frame so its payload can later be Read -func (s *Stream) recvFrame(frame Frame) error { +func (s *Stream) recvFrame(frame *Frame) error { toBeClosed, err := s.getRecvBuf().Write(frame) if toBeClosed { err = s.passiveClose() diff --git a/internal/multiplex/streamBuffer.go b/internal/multiplex/streamBuffer.go index 4adfae2..13cc523 100644 --- a/internal/multiplex/streamBuffer.go +++ b/internal/multiplex/streamBuffer.go @@ -63,7 +63,7 @@ func NewStreamBuffer() *streamBuffer { return sb } -func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) { +func (sb *streamBuffer) Write(f *Frame) (toBeClosed bool, err error) { sb.recvM.Lock() defer sb.recvM.Unlock() // when there'fs no ooo packages in heap and we receive the next package in order @@ -81,10 +81,11 @@ func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) { return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq) } - heap.Push(&sb.sh, &f) + saved := *f + heap.Push(&sb.sh, &saved) // Keep popping from the heap until empty or to the point that the wanted seq was not received for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq { - f = *heap.Pop(&sb.sh).(*Frame) + f = heap.Pop(&sb.sh).(*Frame) if f.Closing != closingNothing { return true, nil } else { diff --git a/internal/multiplex/streamBuffer_test.go b/internal/multiplex/streamBuffer_test.go index 67fb3a5..b36bb6a 100644 --- a/internal/multiplex/streamBuffer_test.go +++ b/internal/multiplex/streamBuffer_test.go @@ -20,11 +20,10 @@ func TestRecvNewFrame(t *testing.T) { for _, n := range set { bu64 := make([]byte, 8) binary.BigEndian.PutUint64(bu64, n) - frame := Frame{ + sb.Write(&Frame{ Seq: n, Payload: bu64, - } - sb.Write(frame) + }) } var sortedResult []uint64 @@ -80,7 +79,7 @@ func TestStreamBuffer_RecvThenClose(t *testing.T) { Closing: 0, Payload: testData, } - sb.Write(testFrame) + sb.Write(&testFrame) sb.Close() readBuf := make([]byte, testDataLen) From fd5005db0af215dfe28e8784856a005de5a5fcc9 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 20:16:47 +0000 Subject: [PATCH 12/13] Fix a timing sensitive test on reading data after actively closing a stream --- internal/multiplex/stream_test.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 6ce16d9..84e8982 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -151,6 +151,16 @@ func TestStream_Close(t *testing.T) { t.Error("failed to accept stream", err) return } + + // we read something to wait for the test frame to reach our recvBuffer. + // if it's empty by the point we call stream.Close(), the incoming + // frame will be dropped + readBuf := make([]byte, len(testPayload)) + _, err = io.ReadFull(stream, readBuf[:1]) + if err != nil { + t.Errorf("can't read any data before active closing") + } + err = stream.Close() if err != nil { t.Error("failed to actively close stream", err) @@ -162,8 +172,7 @@ func TestStream_Close(t *testing.T) { return } - readBuf := make([]byte, len(testPayload)) - _, err = io.ReadFull(stream, readBuf) + _, err = io.ReadFull(stream, readBuf[1:]) if err != nil { t.Errorf("can't read residual data %v", err) } From 35f41424c934d1e51314e9a4819e702817926631 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 20:07:17 +0000 Subject: [PATCH 13/13] Use default hashmap to store streams. Avoid allocating a stream object on receiving every single frame --- internal/multiplex/session.go | 39 +++++++++++++++++++----------- internal/multiplex/session_test.go | 33 ++++++++++++++++--------- internal/multiplex/stream_test.go | 11 ++++++--- 3 files changed, 55 insertions(+), 28 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 92b5cf8..6f165b0 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -63,7 +63,9 @@ type Session struct { // atomic activeStreamCount uint32 - streams sync.Map + + streamsM sync.Mutex + streams map[uint32]*Stream // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame recvFramePool sync.Pool @@ -93,6 +95,7 @@ func MakeSession(id uint32, config SessionConfig) *Session { nextStreamID: 1, acceptCh: make(chan *Stream, acceptBacklog), recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }}, + streams: map[uint32]*Stream{}, } sesh.addrs.Store([]net.Addr{nil, nil}) @@ -149,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) { return nil, errNoMultiplex } stream := makeStream(sesh, id) - sesh.streams.Store(id, stream) + sesh.streamsM.Lock() + sesh.streams[id] = stream + sesh.streamsM.Unlock() sesh.streamCountIncr() log.Tracef("stream %v of session %v opened", id, sesh.id) return stream, nil @@ -200,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { // We set it as nil to signify that the stream id had existed before. // If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell // if the frame it received was from a new stream or a dying stream whose frame arrived late - sesh.streams.Store(s.id, nil) + sesh.streamsM.Lock() + sesh.streams[s.id] = nil + sesh.streamsM.Unlock() if sesh.streamCountDecr() == 0 { if sesh.Singleplex { return sesh.Close() @@ -229,15 +236,19 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { return sesh.passiveClose() } - newStream := makeStream(sesh, frame.StreamID) - existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream) + sesh.streamsM.Lock() + existingStream, existing := sesh.streams[frame.StreamID] if existing { - if existingStreamI == nil { + sesh.streamsM.Unlock() + if existingStream == nil { // this is when the stream existed before but has since been closed. We do nothing return nil } - return existingStreamI.(*Stream).recvFrame(frame) + return existingStream.recvFrame(frame) } else { + newStream := makeStream(sesh, frame.StreamID) + sesh.streams[frame.StreamID] = newStream + sesh.streamsM.Unlock() // new stream sesh.streamCountIncr() sesh.acceptCh <- newStream @@ -265,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error { } sesh.acceptCh <- nil - sesh.streams.Range(func(key, streamI interface{}) bool { - if streamI == nil { - return true + sesh.streamsM.Lock() + for id, stream := range sesh.streams { + if stream == nil { + continue } - stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) _ = stream.getRecvBuf().Close() // will not block - sesh.streams.Delete(key) + delete(sesh.streams, id) sesh.streamCountDecr() - return true - }) + } + sesh.streamsM.Unlock() if closeSwitchboard { sesh.sb.closeAll() diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index 31cee76..f4b32bb 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -112,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) } - _, ok := sesh.streams.Load(f1.StreamID) + sesh.streamsM.Lock() + _, ok := sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() if !ok { t.Fatal("failed to fetch stream 1 after receiving it") } @@ -132,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving normal frame for stream 2: %v", err) } - s2I, ok := sesh.streams.Load(f2.StreamID) - if s2I == nil || !ok { + sesh.streamsM.Lock() + s2M, ok := sesh.streams[f2.StreamID] + sesh.streamsM.Unlock() + if s2M == nil || !ok { t.Fatal("failed to fetch stream 2 after receiving it") } if sesh.streamCount() != 2 { @@ -152,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) } - s1I, _ := sesh.streams.Load(f1.StreamID) - if s1I != nil { + sesh.streamsM.Lock() + s1M, _ := sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() + if s1M != nil { t.Fatal("stream 1 still exist after receiving stream close") } s1, _ := sesh.Accept() @@ -179,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving stream closing frame for stream 1 %v", err) } - s1I, _ = sesh.streams.Load(f1.StreamID) - if s1I != nil { + sesh.streamsM.Lock() + s1M, _ = sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() + if s1M != nil { t.Error("stream 1 exists after receiving stream close for the second time") } streamCount := sesh.streamCount() @@ -243,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { if err != nil { t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) } - _, ok := sesh.streams.Load(f1CloseStream.StreamID) + sesh.streamsM.Lock() + _, ok := sesh.streams[f1CloseStream.StreamID] + sesh.streamsM.Unlock() if !ok { t.Fatal("stream 1 doesn't exist") } @@ -334,12 +344,13 @@ func TestParallelStreams(t *testing.T) { wg.Wait() sc := int(sesh.streamCount()) var count int - sesh.streams.Range(func(_, s interface{}) bool { + sesh.streamsM.Lock() + for _, s := range sesh.streams { if s != nil { count++ } - return true - }) + } + sesh.streamsM.Unlock() if sc != count { t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 84e8982..c0b86fb 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -167,10 +167,13 @@ func TestStream_Close(t *testing.T) { return } - if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil { + sesh.streamsM.Lock() + if s, _ := sesh.streams[stream.(*Stream).id]; s != nil { + sesh.streamsM.Unlock() t.Error("stream still exists") return } + sesh.streamsM.Unlock() _, err = io.ReadFull(stream, readBuf[1:]) if err != nil { @@ -242,8 +245,10 @@ func TestStream_Close(t *testing.T) { } assert.Eventually(t, func() bool { - sI, _ := sesh.streams.Load(stream.(*Stream).id) - return sI == nil + sesh.streamsM.Lock() + s, _ := sesh.streams[stream.(*Stream).id] + sesh.streamsM.Unlock() + return s == nil }, time.Second, 10*time.Millisecond, "streams still exists") })