mirror of https://github.com/cbeuw/Cloak
Use sync.Map for lock free pickRandConn
This commit is contained in:
parent
8ab0c2d96b
commit
b4d65d8a0e
|
|
@ -23,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session")
|
|||
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
||||
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
|
||||
|
||||
type switchboardStrategy int
|
||||
|
||||
type SessionConfig struct {
|
||||
Obfuscator
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
type switchboardStrategy int
|
||||
|
||||
const (
|
||||
FIXED_CONN_MAPPING switchboardStrategy = iota
|
||||
UNIFORM_SPREAD
|
||||
|
|
@ -28,9 +30,9 @@ type switchboard struct {
|
|||
valve Valve
|
||||
strategy switchboardStrategy
|
||||
|
||||
connsM sync.RWMutex
|
||||
conns []net.Conn
|
||||
randPool sync.Pool
|
||||
conns sync.Map
|
||||
connsCount uint32
|
||||
randPool sync.Pool
|
||||
|
||||
broken uint32
|
||||
}
|
||||
|
|
@ -57,27 +59,14 @@ func makeSwitchboard(sesh *Session) *switchboard {
|
|||
var errBrokenSwitchboard = errors.New("the switchboard is broken")
|
||||
|
||||
func (sb *switchboard) delConn(conn net.Conn) {
|
||||
sb.connsM.Lock()
|
||||
defer sb.connsM.Unlock()
|
||||
|
||||
if len(sb.conns) <= 1 {
|
||||
sb.conns = nil
|
||||
return
|
||||
if _, ok := sb.conns.LoadAndDelete(conn); ok {
|
||||
atomic.AddUint32(&sb.connsCount, ^uint32(0))
|
||||
}
|
||||
var i int
|
||||
var c net.Conn
|
||||
for i, c = range sb.conns {
|
||||
if c == conn {
|
||||
break
|
||||
}
|
||||
}
|
||||
sb.conns = append(sb.conns[:i], sb.conns[i+1:]...)
|
||||
}
|
||||
|
||||
func (sb *switchboard) addConn(conn net.Conn) {
|
||||
sb.connsM.Lock()
|
||||
sb.conns = append(sb.conns, conn)
|
||||
sb.connsM.Unlock()
|
||||
atomic.AddUint32(&sb.connsCount, 1)
|
||||
sb.conns.Store(conn, conn)
|
||||
go sb.deplex(conn)
|
||||
}
|
||||
|
||||
|
|
@ -133,18 +122,28 @@ func (sb *switchboard) pickRandConn() (net.Conn, error) {
|
|||
return nil, errBrokenSwitchboard
|
||||
}
|
||||
|
||||
randReader := sb.randPool.Get().(*rand.Rand)
|
||||
sb.connsM.RLock()
|
||||
defer sb.connsM.RUnlock()
|
||||
|
||||
connsCount := len(sb.conns)
|
||||
connsCount := atomic.LoadUint32(&sb.connsCount)
|
||||
if connsCount == 0 {
|
||||
return nil, errBrokenSwitchboard
|
||||
}
|
||||
r := randReader.Intn(connsCount)
|
||||
|
||||
randReader := sb.randPool.Get().(*rand.Rand)
|
||||
|
||||
r := randReader.Intn(int(connsCount))
|
||||
sb.randPool.Put(randReader)
|
||||
|
||||
return sb.conns[r], nil
|
||||
var c int
|
||||
var ret net.Conn
|
||||
sb.conns.Range(func(_, conn interface{}) bool {
|
||||
if r == c {
|
||||
ret = conn.(net.Conn)
|
||||
return false
|
||||
}
|
||||
c++
|
||||
return true
|
||||
})
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// actively triggered by session.Close()
|
||||
|
|
@ -152,12 +151,12 @@ func (sb *switchboard) closeAll() {
|
|||
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
|
||||
return
|
||||
}
|
||||
sb.connsM.Lock()
|
||||
for _, conn := range sb.conns {
|
||||
conn.Close()
|
||||
}
|
||||
sb.conns = nil
|
||||
sb.connsM.Unlock()
|
||||
sb.conns.Range(func(_, conn interface{}) bool {
|
||||
conn.(net.Conn).Close()
|
||||
sb.conns.Delete(conn)
|
||||
atomic.AddUint32(&sb.connsCount, ^uint32(0))
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// deplex function costantly reads from a TCP connection
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -173,17 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) {
|
|||
}
|
||||
wg.Wait()
|
||||
|
||||
sesh.sb.connsM.RLock()
|
||||
if len(sesh.sb.conns) != 1000 {
|
||||
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
|
||||
t.Error("connsCount incorrect")
|
||||
}
|
||||
sesh.sb.connsM.RUnlock()
|
||||
|
||||
sesh.sb.closeAll()
|
||||
|
||||
assert.Eventuallyf(t, func() bool {
|
||||
sesh.sb.connsM.RLock()
|
||||
defer sesh.sb.connsM.RUnlock()
|
||||
return len(sesh.sb.conns) == 0
|
||||
}, time.Second, 10*time.Millisecond, "connsCount incorrect")
|
||||
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
|
||||
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue