mirror of https://github.com/cbeuw/Cloak
Use sync.Map for lock free pickRandConn
This commit is contained in:
parent
8ab0c2d96b
commit
b4d65d8a0e
|
|
@ -23,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session")
|
||||||
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
||||||
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
|
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
|
||||||
|
|
||||||
type switchboardStrategy int
|
|
||||||
|
|
||||||
type SessionConfig struct {
|
type SessionConfig struct {
|
||||||
Obfuscator
|
Obfuscator
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type switchboardStrategy int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
FIXED_CONN_MAPPING switchboardStrategy = iota
|
FIXED_CONN_MAPPING switchboardStrategy = iota
|
||||||
UNIFORM_SPREAD
|
UNIFORM_SPREAD
|
||||||
|
|
@ -28,9 +30,9 @@ type switchboard struct {
|
||||||
valve Valve
|
valve Valve
|
||||||
strategy switchboardStrategy
|
strategy switchboardStrategy
|
||||||
|
|
||||||
connsM sync.RWMutex
|
conns sync.Map
|
||||||
conns []net.Conn
|
connsCount uint32
|
||||||
randPool sync.Pool
|
randPool sync.Pool
|
||||||
|
|
||||||
broken uint32
|
broken uint32
|
||||||
}
|
}
|
||||||
|
|
@ -57,27 +59,14 @@ func makeSwitchboard(sesh *Session) *switchboard {
|
||||||
var errBrokenSwitchboard = errors.New("the switchboard is broken")
|
var errBrokenSwitchboard = errors.New("the switchboard is broken")
|
||||||
|
|
||||||
func (sb *switchboard) delConn(conn net.Conn) {
|
func (sb *switchboard) delConn(conn net.Conn) {
|
||||||
sb.connsM.Lock()
|
if _, ok := sb.conns.LoadAndDelete(conn); ok {
|
||||||
defer sb.connsM.Unlock()
|
atomic.AddUint32(&sb.connsCount, ^uint32(0))
|
||||||
|
|
||||||
if len(sb.conns) <= 1 {
|
|
||||||
sb.conns = nil
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
var i int
|
|
||||||
var c net.Conn
|
|
||||||
for i, c = range sb.conns {
|
|
||||||
if c == conn {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sb.conns = append(sb.conns[:i], sb.conns[i+1:]...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sb *switchboard) addConn(conn net.Conn) {
|
func (sb *switchboard) addConn(conn net.Conn) {
|
||||||
sb.connsM.Lock()
|
atomic.AddUint32(&sb.connsCount, 1)
|
||||||
sb.conns = append(sb.conns, conn)
|
sb.conns.Store(conn, conn)
|
||||||
sb.connsM.Unlock()
|
|
||||||
go sb.deplex(conn)
|
go sb.deplex(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -133,18 +122,28 @@ func (sb *switchboard) pickRandConn() (net.Conn, error) {
|
||||||
return nil, errBrokenSwitchboard
|
return nil, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
|
|
||||||
randReader := sb.randPool.Get().(*rand.Rand)
|
connsCount := atomic.LoadUint32(&sb.connsCount)
|
||||||
sb.connsM.RLock()
|
|
||||||
defer sb.connsM.RUnlock()
|
|
||||||
|
|
||||||
connsCount := len(sb.conns)
|
|
||||||
if connsCount == 0 {
|
if connsCount == 0 {
|
||||||
return nil, errBrokenSwitchboard
|
return nil, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
r := randReader.Intn(connsCount)
|
|
||||||
|
randReader := sb.randPool.Get().(*rand.Rand)
|
||||||
|
|
||||||
|
r := randReader.Intn(int(connsCount))
|
||||||
sb.randPool.Put(randReader)
|
sb.randPool.Put(randReader)
|
||||||
|
|
||||||
return sb.conns[r], nil
|
var c int
|
||||||
|
var ret net.Conn
|
||||||
|
sb.conns.Range(func(_, conn interface{}) bool {
|
||||||
|
if r == c {
|
||||||
|
ret = conn.(net.Conn)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// actively triggered by session.Close()
|
// actively triggered by session.Close()
|
||||||
|
|
@ -152,12 +151,12 @@ func (sb *switchboard) closeAll() {
|
||||||
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
|
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sb.connsM.Lock()
|
sb.conns.Range(func(_, conn interface{}) bool {
|
||||||
for _, conn := range sb.conns {
|
conn.(net.Conn).Close()
|
||||||
conn.Close()
|
sb.conns.Delete(conn)
|
||||||
}
|
atomic.AddUint32(&sb.connsCount, ^uint32(0))
|
||||||
sb.conns = nil
|
return true
|
||||||
sb.connsM.Unlock()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// deplex function costantly reads from a TCP connection
|
// deplex function costantly reads from a TCP connection
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -173,17 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) {
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
sesh.sb.connsM.RLock()
|
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
|
||||||
if len(sesh.sb.conns) != 1000 {
|
|
||||||
t.Error("connsCount incorrect")
|
t.Error("connsCount incorrect")
|
||||||
}
|
}
|
||||||
sesh.sb.connsM.RUnlock()
|
|
||||||
|
|
||||||
sesh.sb.closeAll()
|
sesh.sb.closeAll()
|
||||||
|
|
||||||
assert.Eventuallyf(t, func() bool {
|
assert.Eventuallyf(t, func() bool {
|
||||||
sesh.sb.connsM.RLock()
|
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
|
||||||
defer sesh.sb.connsM.RUnlock()
|
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
|
||||||
return len(sesh.sb.conns) == 0
|
|
||||||
}, time.Second, 10*time.Millisecond, "connsCount incorrect")
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue