From 8ab0c2d96b591a92281527dc3295ff0d27cfd6fe Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Mon, 28 Dec 2020 01:10:24 +0000 Subject: [PATCH] Redo the implementation of switchboard and remove the need for connId --- internal/multiplex/session.go | 12 +-- internal/multiplex/stream.go | 4 +- internal/multiplex/switchboard.go | 141 +++++++++++++------------ internal/multiplex/switchboard_test.go | 32 +++--- 4 files changed, 97 insertions(+), 92 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index b9d8540..49a0ed5 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -286,13 +286,11 @@ func (sesh *Session) closeSession() error { sesh.streamsM.Lock() close(sesh.acceptCh) for id, stream := range sesh.streams { - if stream == nil { - continue + if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) { + _ = stream.getRecvBuf().Close() // will not block + delete(sesh.streams, id) + sesh.streamCountDecr() } - atomic.StoreUint32(&stream.closed, 1) - _ = stream.getRecvBuf().Close() // will not block - delete(sesh.streams, id) - sesh.streamCountDecr() } sesh.streamsM.Unlock() return nil @@ -333,7 +331,7 @@ func (sesh *Session) Close() error { if err != nil { return err } - _, err = sesh.sb.send((*buf)[:i], new(uint32)) + _, err = sesh.sb.send((*buf)[:i], new(net.Conn)) if err != nil { return err } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index ffd7e23..9141e59 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -40,7 +40,7 @@ type Stream struct { // recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets // so it won't have to use its priority queue to sort it. // This is not used in unordered connection mode - assignedConnId uint32 + assignedConn net.Conn readFromTimeout time.Duration } @@ -119,7 +119,7 @@ func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { return err } - _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId) + _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 829d944..834fd63 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -28,11 +28,9 @@ type switchboard struct { valve Valve strategy switchboardStrategy - // map of connId to net.Conn - conns sync.Map - numConns uint32 - nextConnId uint32 - randPool sync.Pool + connsM sync.RWMutex + conns []net.Conn + randPool sync.Pool broken uint32 } @@ -46,10 +44,9 @@ func makeSwitchboard(sesh *Session) *switchboard { strategy = FIXED_CONN_MAPPING } sb := &switchboard{ - session: sesh, - strategy: strategy, - valve: sesh.Valve, - nextConnId: 1, + session: sesh, + strategy: strategy, + valve: sesh.Valve, randPool: sync.Pool{New: func() interface{} { return rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) }}, @@ -59,88 +56,95 @@ func makeSwitchboard(sesh *Session) *switchboard { var errBrokenSwitchboard = errors.New("the switchboard is broken") -func (sb *switchboard) connsCount() int { - return int(atomic.LoadUint32(&sb.numConns)) +func (sb *switchboard) delConn(conn net.Conn) { + sb.connsM.Lock() + defer sb.connsM.Unlock() + + if len(sb.conns) <= 1 { + sb.conns = nil + return + } + var i int + var c net.Conn + for i, c = range sb.conns { + if c == conn { + break + } + } + sb.conns = append(sb.conns[:i], sb.conns[i+1:]...) } func (sb *switchboard) addConn(conn net.Conn) { - connId := atomic.AddUint32(&sb.nextConnId, 1) - 1 - atomic.AddUint32(&sb.numConns, 1) - sb.conns.Store(connId, conn) - go sb.deplex(connId, conn) + sb.connsM.Lock() + sb.conns = append(sb.conns, conn) + sb.connsM.Unlock() + go sb.deplex(conn) } -// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable -func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { +// a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable +func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) { sb.valve.txWait(len(data)) - if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 { + if atomic.LoadUint32(&sb.broken) == 1 { return 0, errBrokenSwitchboard } var conn net.Conn switch sb.strategy { case UNIFORM_SPREAD: - _, conn, err = sb.pickRandConn() + conn, err = sb.pickRandConn() if err != nil { return 0, errBrokenSwitchboard } case FIXED_CONN_MAPPING: - connI, ok := sb.conns.Load(*connId) - if ok { - conn = connI.(net.Conn) - } else { - var newConnId uint32 - newConnId, conn, err = sb.pickRandConn() - if err != nil { - return 0, errBrokenSwitchboard - } - *connId = newConnId - } + conn = *assignedConn default: return 0, errors.New("unsupported traffic distribution strategy") } - n, err = conn.Write(data) - if err != nil { - sb.conns.Delete(*connId) - sb.session.SetTerminalMsg("failed to write to remote " + err.Error()) - sb.session.passiveClose() - return n, err + if conn != nil { + n, err = conn.Write(data) + if err != nil { + sb.delConn(conn) + } + } else { + conn, err = sb.pickRandConn() + if err != nil { + sb.delConn(conn) + sb.session.SetTerminalMsg("failed to pick a connection " + err.Error()) + sb.session.passiveClose() + return 0, err + } + n, err = conn.Write(data) + if err != nil { + sb.delConn(conn) + sb.session.SetTerminalMsg("failed to send to remote " + err.Error()) + sb.session.passiveClose() + return n, err + } + *assignedConn = conn } sb.valve.AddTx(int64(n)) return n, nil } // returns a random connId -func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) { - connCount := sb.connsCount() - if atomic.LoadUint32(&sb.broken) == 1 || connCount == 0 { - return 0, nil, errBrokenSwitchboard +func (sb *switchboard) pickRandConn() (net.Conn, error) { + if atomic.LoadUint32(&sb.broken) == 1 { + return nil, errBrokenSwitchboard } - // there is no guarantee that sb.conns still has the same amount of entries - // between the count loop and the pick loop - // so if the r > len(sb.conns) at the point of range call, the last visited element is picked - var id uint32 - var conn net.Conn randReader := sb.randPool.Get().(*rand.Rand) - r := randReader.Intn(connCount) - sb.randPool.Put(randReader) - var c int - sb.conns.Range(func(connIdI, connI interface{}) bool { - if r == c { - id = connIdI.(uint32) - conn = connI.(net.Conn) - return false - } - c++ - return true - }) - // if len(sb.conns) is 0 - if conn == nil { - return 0, nil, errBrokenSwitchboard + sb.connsM.RLock() + defer sb.connsM.RUnlock() + + connsCount := len(sb.conns) + if connsCount == 0 { + return nil, errBrokenSwitchboard } - return id, conn, nil + r := randReader.Intn(connsCount) + sb.randPool.Put(randReader) + + return sb.conns[r], nil } // actively triggered by session.Close() @@ -148,16 +152,16 @@ func (sb *switchboard) closeAll() { if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { return } - sb.conns.Range(func(key, connI interface{}) bool { - conn := connI.(net.Conn) + sb.connsM.Lock() + for _, conn := range sb.conns { conn.Close() - sb.conns.Delete(key) - return true - }) + } + sb.conns = nil + sb.connsM.Unlock() } // deplex function costantly reads from a TCP connection -func (sb *switchboard) deplex(connId uint32, conn net.Conn) { +func (sb *switchboard) deplex(conn net.Conn) { defer conn.Close() buf := make([]byte, sb.session.connReceiveBufferSize) for { @@ -166,8 +170,7 @@ func (sb *switchboard) deplex(connId uint32, conn net.Conn) { sb.valve.AddRx(int64(n)) if err != nil { log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) - sb.conns.Delete(connId) - atomic.AddUint32(&sb.numConns, ^uint32(0)) + sb.delConn(conn) sb.session.SetTerminalMsg("a connection has dropped unexpectedly") sb.session.passiveClose() return diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index 2c3f36f..d1f6eb4 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -14,14 +14,14 @@ func TestSwitchboard_Send(t *testing.T) { sesh := MakeSession(0, seshConfig) hole0 := connutil.Discard() sesh.sb.addConn(hole0) - connId, _, err := sesh.sb.pickRandConn() + conn, err := sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return } data := make([]byte, 1000) rand.Read(data) - _, err = sesh.sb.send(data, &connId) + _, err = sesh.sb.send(data, &conn) if err != nil { t.Error(err) return @@ -29,23 +29,23 @@ func TestSwitchboard_Send(t *testing.T) { hole1 := connutil.Discard() sesh.sb.addConn(hole1) - connId, _, err = sesh.sb.pickRandConn() + conn, err = sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return } - _, err = sesh.sb.send(data, &connId) + _, err = sesh.sb.send(data, &conn) if err != nil { t.Error(err) return } - connId, _, err = sesh.sb.pickRandConn() + conn, err = sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return } - _, err = sesh.sb.send(data, &connId) + _, err = sesh.sb.send(data, &conn) if err != nil { t.Error(err) return @@ -71,7 +71,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) { seshConfig := SessionConfig{} sesh := MakeSession(0, seshConfig) sesh.sb.addConn(hole) - connId, _, err := sesh.sb.pickRandConn() + conn, err := sesh.sb.pickRandConn() if err != nil { b.Error("failed to get a random conn", err) return @@ -81,7 +81,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) { b.SetBytes(int64(len(data))) b.ResetTimer() for i := 0; i < b.N; i++ { - sesh.sb.send(data, &connId) + sesh.sb.send(data, &conn) } } @@ -92,7 +92,7 @@ func TestSwitchboard_TxCredit(t *testing.T) { sesh := MakeSession(0, seshConfig) hole := connutil.Discard() sesh.sb.addConn(hole) - connId, _, err := sesh.sb.pickRandConn() + conn, err := sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return @@ -103,7 +103,7 @@ func TestSwitchboard_TxCredit(t *testing.T) { t.Run("FIXED CONN MAPPING", func(t *testing.T) { *sesh.sb.valve.(*LimitedValve).tx = 0 sesh.sb.strategy = FIXED_CONN_MAPPING - n, err := sesh.sb.send(data[:10], &connId) + n, err := sesh.sb.send(data[:10], &conn) if err != nil { t.Error(err) return @@ -119,7 +119,7 @@ func TestSwitchboard_TxCredit(t *testing.T) { t.Run("UNIFORM", func(t *testing.T) { *sesh.sb.valve.(*LimitedValve).tx = 0 sesh.sb.strategy = UNIFORM_SPREAD - n, err := sesh.sb.send(data[:10], &connId) + n, err := sesh.sb.send(data[:10], &conn) if err != nil { t.Error(err) return @@ -173,13 +173,17 @@ func TestSwitchboard_ConnsCount(t *testing.T) { } wg.Wait() - if sesh.sb.connsCount() != 1000 { + sesh.sb.connsM.RLock() + if len(sesh.sb.conns) != 1000 { t.Error("connsCount incorrect") } + sesh.sb.connsM.RUnlock() sesh.sb.closeAll() assert.Eventuallyf(t, func() bool { - return sesh.sb.connsCount() == 0 - }, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", sesh.sb.connsCount()) + sesh.sb.connsM.RLock() + defer sesh.sb.connsM.RUnlock() + return len(sesh.sb.conns) == 0 + }, time.Second, 10*time.Millisecond, "connsCount incorrect") }