diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 701f217..432c0dc 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -56,22 +56,30 @@ func (sb *switchboard) addConn(conn net.Conn) { // a pointer to connId is passed here so that the switchboard can reassign it func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { + writeAndRegUsage := func(conn net.Conn, d []byte) (int, error) { + n, err = conn.Write(d) + if err != nil { + sb.close("failed to write to remote " + err.Error()) + return n, err + } + sb.AddTx(int64(n)) + return n, nil + } + sb.Valve.txWait(len(data)) sb.connsM.RLock() defer sb.connsM.RUnlock() - if sb.strategy == UNIFORM_SPREAD { - if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 { - return 0, errBrokenSwitchboard - } + if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 { + return 0, errBrokenSwitchboard + } + if sb.strategy == UNIFORM_SPREAD { r := rand.Intn(len(sb.conns)) var c int for newConnId := range sb.conns { if r == c { conn, _ := sb.conns[newConnId] - n, err = conn.Write(data) - sb.AddTx(int64(n)) - return + return writeAndRegUsage(conn, data) } c++ } @@ -80,9 +88,7 @@ func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { var conn net.Conn conn, ok := sb.conns[*connId] if ok { - n, err = conn.Write(data) - sb.AddTx(int64(n)) - return + return writeAndRegUsage(conn, data) } else { // do not call assignRandomConn() here. // we'll have to do connsM.RLock() after we get a new connId from assignRandomConn, in order to @@ -91,19 +97,13 @@ func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { // in particular if newConnId is removed between the RUnlock and RLock, conns[newConnId] will return // a nil pointer. To prevent this we must get newConnId and the reference to conn itself in one single mutex // protection - if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 { - return 0, errBrokenSwitchboard - } - r := rand.Intn(len(sb.conns)) var c int for newConnId := range sb.conns { if r == c { connId = &newConnId conn, _ = sb.conns[newConnId] - n, err = conn.Write(data) - sb.AddTx(int64(n)) - return + return writeAndRegUsage(conn, data) } c++ } @@ -132,6 +132,14 @@ func (sb *switchboard) assignRandomConn() (uint32, error) { return 0, errBrokenSwitchboard } +func (sb *switchboard) close(terminalMsg string) { + atomic.StoreUint32(&sb.broken, 1) + if !sb.session.IsClosed() { + sb.session.SetTerminalMsg(terminalMsg) + sb.session.passiveClose() + } +} + // actively triggered by session.Close() func (sb *switchboard) closeAll() { sb.connsM.Lock() @@ -152,11 +160,7 @@ func (sb *switchboard) deplex(connId uint32, conn net.Conn) { if err != nil { log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) go conn.Close() - atomic.StoreUint32(&sb.broken, 1) - if !sb.session.IsClosed() { - sb.session.SetTerminalMsg("a connection has dropped unexpectedly") - sb.session.passiveClose() - } + sb.close("a connection has dropped unexpectedly") return }