From b4d65d8a0e1cc8dd5e4bb84356c4e3546c531f9f Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 29 Dec 2020 19:53:14 +0000 Subject: [PATCH] Use sync.Map for lock free pickRandConn --- internal/multiplex/session.go | 2 - internal/multiplex/switchboard.go | 65 +++++++++++++------------- internal/multiplex/switchboard_test.go | 11 ++--- 3 files changed, 36 insertions(+), 42 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 49a0ed5..fc21c00 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -23,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session") var errRepeatStreamClosing = errors.New("trying to close a closed stream") var errNoMultiplex = errors.New("a singleplexing session can have only one stream") -type switchboardStrategy int - type SessionConfig struct { Obfuscator diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 834fd63..f46e5da 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -10,6 +10,8 @@ import ( "time" ) +type switchboardStrategy int + const ( FIXED_CONN_MAPPING switchboardStrategy = iota UNIFORM_SPREAD @@ -28,9 +30,9 @@ type switchboard struct { valve Valve strategy switchboardStrategy - connsM sync.RWMutex - conns []net.Conn - randPool sync.Pool + conns sync.Map + connsCount uint32 + randPool sync.Pool broken uint32 } @@ -57,27 +59,14 @@ func makeSwitchboard(sesh *Session) *switchboard { var errBrokenSwitchboard = errors.New("the switchboard is broken") func (sb *switchboard) delConn(conn net.Conn) { - sb.connsM.Lock() - defer sb.connsM.Unlock() - - if len(sb.conns) <= 1 { - sb.conns = nil - return + if _, ok := sb.conns.LoadAndDelete(conn); ok { + atomic.AddUint32(&sb.connsCount, ^uint32(0)) } - var i int - var c net.Conn - for i, c = range sb.conns { - if c == conn { - break - } - } - sb.conns = append(sb.conns[:i], sb.conns[i+1:]...) } func (sb *switchboard) addConn(conn net.Conn) { - sb.connsM.Lock() - sb.conns = append(sb.conns, conn) - sb.connsM.Unlock() + atomic.AddUint32(&sb.connsCount, 1) + sb.conns.Store(conn, conn) go sb.deplex(conn) } @@ -133,18 +122,28 @@ func (sb *switchboard) pickRandConn() (net.Conn, error) { return nil, errBrokenSwitchboard } - randReader := sb.randPool.Get().(*rand.Rand) - sb.connsM.RLock() - defer sb.connsM.RUnlock() - - connsCount := len(sb.conns) + connsCount := atomic.LoadUint32(&sb.connsCount) if connsCount == 0 { return nil, errBrokenSwitchboard } - r := randReader.Intn(connsCount) + + randReader := sb.randPool.Get().(*rand.Rand) + + r := randReader.Intn(int(connsCount)) sb.randPool.Put(randReader) - return sb.conns[r], nil + var c int + var ret net.Conn + sb.conns.Range(func(_, conn interface{}) bool { + if r == c { + ret = conn.(net.Conn) + return false + } + c++ + return true + }) + + return ret, nil } // actively triggered by session.Close() @@ -152,12 +151,12 @@ func (sb *switchboard) closeAll() { if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { return } - sb.connsM.Lock() - for _, conn := range sb.conns { - conn.Close() - } - sb.conns = nil - sb.connsM.Unlock() + sb.conns.Range(func(_, conn interface{}) bool { + conn.(net.Conn).Close() + sb.conns.Delete(conn) + atomic.AddUint32(&sb.connsCount, ^uint32(0)) + return true + }) } // deplex function costantly reads from a TCP connection diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index d1f6eb4..be0acaa 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" "math/rand" "sync" + "sync/atomic" "testing" "time" ) @@ -173,17 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) { } wg.Wait() - sesh.sb.connsM.RLock() - if len(sesh.sb.conns) != 1000 { + if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 { t.Error("connsCount incorrect") } - sesh.sb.connsM.RUnlock() sesh.sb.closeAll() assert.Eventuallyf(t, func() bool { - sesh.sb.connsM.RLock() - defer sesh.sb.connsM.RUnlock() - return len(sesh.sb.conns) == 0 - }, time.Second, 10*time.Millisecond, "connsCount incorrect") + return atomic.LoadUint32(&sesh.sb.connsCount) == 0 + }, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount)) }