Use sync.Map for lock free pickRandConn

This commit is contained in:
Andy Wang 2020-12-29 19:53:14 +00:00
parent 8ab0c2d96b
commit b4d65d8a0e
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
3 changed files with 36 additions and 42 deletions

View File

@ -23,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session")
var errRepeatStreamClosing = errors.New("trying to close a closed stream") var errRepeatStreamClosing = errors.New("trying to close a closed stream")
var errNoMultiplex = errors.New("a singleplexing session can have only one stream") var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
type switchboardStrategy int
type SessionConfig struct { type SessionConfig struct {
Obfuscator Obfuscator

View File

@ -10,6 +10,8 @@ import (
"time" "time"
) )
type switchboardStrategy int
const ( const (
FIXED_CONN_MAPPING switchboardStrategy = iota FIXED_CONN_MAPPING switchboardStrategy = iota
UNIFORM_SPREAD UNIFORM_SPREAD
@ -28,9 +30,9 @@ type switchboard struct {
valve Valve valve Valve
strategy switchboardStrategy strategy switchboardStrategy
connsM sync.RWMutex conns sync.Map
conns []net.Conn connsCount uint32
randPool sync.Pool randPool sync.Pool
broken uint32 broken uint32
} }
@ -57,27 +59,14 @@ func makeSwitchboard(sesh *Session) *switchboard {
var errBrokenSwitchboard = errors.New("the switchboard is broken") var errBrokenSwitchboard = errors.New("the switchboard is broken")
func (sb *switchboard) delConn(conn net.Conn) { func (sb *switchboard) delConn(conn net.Conn) {
sb.connsM.Lock() if _, ok := sb.conns.LoadAndDelete(conn); ok {
defer sb.connsM.Unlock() atomic.AddUint32(&sb.connsCount, ^uint32(0))
if len(sb.conns) <= 1 {
sb.conns = nil
return
} }
var i int
var c net.Conn
for i, c = range sb.conns {
if c == conn {
break
}
}
sb.conns = append(sb.conns[:i], sb.conns[i+1:]...)
} }
func (sb *switchboard) addConn(conn net.Conn) { func (sb *switchboard) addConn(conn net.Conn) {
sb.connsM.Lock() atomic.AddUint32(&sb.connsCount, 1)
sb.conns = append(sb.conns, conn) sb.conns.Store(conn, conn)
sb.connsM.Unlock()
go sb.deplex(conn) go sb.deplex(conn)
} }
@ -133,18 +122,28 @@ func (sb *switchboard) pickRandConn() (net.Conn, error) {
return nil, errBrokenSwitchboard return nil, errBrokenSwitchboard
} }
randReader := sb.randPool.Get().(*rand.Rand) connsCount := atomic.LoadUint32(&sb.connsCount)
sb.connsM.RLock()
defer sb.connsM.RUnlock()
connsCount := len(sb.conns)
if connsCount == 0 { if connsCount == 0 {
return nil, errBrokenSwitchboard return nil, errBrokenSwitchboard
} }
r := randReader.Intn(connsCount)
randReader := sb.randPool.Get().(*rand.Rand)
r := randReader.Intn(int(connsCount))
sb.randPool.Put(randReader) sb.randPool.Put(randReader)
return sb.conns[r], nil var c int
var ret net.Conn
sb.conns.Range(func(_, conn interface{}) bool {
if r == c {
ret = conn.(net.Conn)
return false
}
c++
return true
})
return ret, nil
} }
// actively triggered by session.Close() // actively triggered by session.Close()
@ -152,12 +151,12 @@ func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
return return
} }
sb.connsM.Lock() sb.conns.Range(func(_, conn interface{}) bool {
for _, conn := range sb.conns { conn.(net.Conn).Close()
conn.Close() sb.conns.Delete(conn)
} atomic.AddUint32(&sb.connsCount, ^uint32(0))
sb.conns = nil return true
sb.connsM.Unlock() })
} }
// deplex function costantly reads from a TCP connection // deplex function costantly reads from a TCP connection

View File

@ -5,6 +5,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"math/rand" "math/rand"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
@ -173,17 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) {
} }
wg.Wait() wg.Wait()
sesh.sb.connsM.RLock() if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
if len(sesh.sb.conns) != 1000 {
t.Error("connsCount incorrect") t.Error("connsCount incorrect")
} }
sesh.sb.connsM.RUnlock()
sesh.sb.closeAll() sesh.sb.closeAll()
assert.Eventuallyf(t, func() bool { assert.Eventuallyf(t, func() bool {
sesh.sb.connsM.RLock() return atomic.LoadUint32(&sesh.sb.connsCount) == 0
defer sesh.sb.connsM.RUnlock() }, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
return len(sesh.sb.conns) == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect")
} }