diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index f46e5da..fa35567 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -13,8 +13,8 @@ import ( type switchboardStrategy int const ( - FIXED_CONN_MAPPING switchboardStrategy = iota - UNIFORM_SPREAD + fixedConnMapping switchboardStrategy = iota + uniformSpread ) // switchboard represents the connection pool. It is responsible for managing @@ -41,9 +41,9 @@ func makeSwitchboard(sesh *Session) *switchboard { var strategy switchboardStrategy if sesh.Unordered { log.Debug("Connection is unordered") - strategy = UNIFORM_SPREAD + strategy = uniformSpread } else { - strategy = FIXED_CONN_MAPPING + strategy = fixedConnMapping } sb := &switchboard{ session: sesh, @@ -58,12 +58,6 @@ func makeSwitchboard(sesh *Session) *switchboard { var errBrokenSwitchboard = errors.New("the switchboard is broken") -func (sb *switchboard) delConn(conn net.Conn) { - if _, ok := sb.conns.LoadAndDelete(conn); ok { - atomic.AddUint32(&sb.connsCount, ^uint32(0)) - } -} - func (sb *switchboard) addConn(conn net.Conn) { atomic.AddUint32(&sb.connsCount, 1) sb.conns.Store(conn, conn) @@ -79,39 +73,38 @@ func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err err var conn net.Conn switch sb.strategy { - case UNIFORM_SPREAD: + case uniformSpread: conn, err = sb.pickRandConn() if err != nil { return 0, errBrokenSwitchboard } - case FIXED_CONN_MAPPING: - conn = *assignedConn - default: - return 0, errors.New("unsupported traffic distribution strategy") - } - - if conn != nil { n, err = conn.Write(data) if err != nil { - sb.delConn(conn) - } - } else { - conn, err = sb.pickRandConn() - if err != nil { - sb.delConn(conn) - sb.session.SetTerminalMsg("failed to pick a connection " + err.Error()) - sb.session.passiveClose() - return 0, err - } - n, err = conn.Write(data) - if err != nil { - sb.delConn(conn) sb.session.SetTerminalMsg("failed to send to remote " + err.Error()) sb.session.passiveClose() return n, err } - *assignedConn = conn + case fixedConnMapping: + conn = *assignedConn + if conn == nil { + conn, err = sb.pickRandConn() + if err != nil { + sb.session.SetTerminalMsg("failed to pick a connection " + err.Error()) + sb.session.passiveClose() + return 0, err + } + *assignedConn = conn + } + n, err = conn.Write(data) + if err != nil { + sb.session.SetTerminalMsg("failed to send to remote " + err.Error()) + sb.session.passiveClose() + return n, err + } + default: + return 0, errors.New("unsupported traffic distribution strategy") } + sb.valve.AddTx(int64(n)) return n, nil } @@ -169,7 +162,6 @@ func (sb *switchboard) deplex(conn net.Conn) { sb.valve.AddRx(int64(n)) if err != nil { log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) - sb.delConn(conn) sb.session.SetTerminalMsg("a connection has dropped unexpectedly") sb.session.passiveClose() return diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index be0acaa..5e95afe 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -101,9 +101,9 @@ func TestSwitchboard_TxCredit(t *testing.T) { data := make([]byte, 1000) rand.Read(data) - t.Run("FIXED CONN MAPPING", func(t *testing.T) { + t.Run("fixed conn mapping", func(t *testing.T) { *sesh.sb.valve.(*LimitedValve).tx = 0 - sesh.sb.strategy = FIXED_CONN_MAPPING + sesh.sb.strategy = fixedConnMapping n, err := sesh.sb.send(data[:10], &conn) if err != nil { t.Error(err) @@ -117,9 +117,9 @@ func TestSwitchboard_TxCredit(t *testing.T) { t.Error("tx credit didn't increase by 10") } }) - t.Run("UNIFORM", func(t *testing.T) { + t.Run("uniform spread", func(t *testing.T) { *sesh.sb.valve.(*LimitedValve).tx = 0 - sesh.sb.strategy = UNIFORM_SPREAD + sesh.sb.strategy = uniformSpread n, err := sesh.sb.send(data[:10], &conn) if err != nil { t.Error(err)