From c26be98e79736a911935a1eae5ea4c06e2753e9b Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 3 Nov 2019 12:22:12 +0000 Subject: [PATCH] Use sync.Map in multiplex instead of manual locks --- cmd/ck-server/ck-server.go | 1 + internal/multiplex/session.go | 70 +++++++++-------- internal/multiplex/session_test.go | 24 +++++- internal/multiplex/stream_test.go | 2 +- internal/multiplex/switchboard.go | 101 ++++++++++++------------- internal/multiplex/switchboard_test.go | 10 +-- 6 files changed, 109 insertions(+), 99 deletions(-) diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 652c7fd..7a64e9d 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -167,6 +167,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { user.CloseSession(ci.SessionId, "") return } else { + // TODO: other errors continue } } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index d8f179c..55a24b1 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -51,8 +51,7 @@ type Session struct { // atomic nextStreamID uint32 - streamsM sync.Mutex - streams map[uint32]*Stream + streams sync.Map // Switchboard manages all connections to remote sb *switchboard @@ -73,7 +72,6 @@ func MakeSession(id uint32, config *SessionConfig) *Session { id: id, SessionConfig: config, nextStreamID: 1, - streams: make(map[uint32]*Stream), acceptCh: make(chan *Stream, acceptBacklog), } sesh.addrs.Store([]net.Addr{nil, nil}) @@ -108,14 +106,12 @@ func (sesh *Session) OpenStream() (*Stream, error) { } id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1 // Because atomic.AddUint32 returns the value after incrementation - connId, err := sesh.sb.assignRandomConn() + connId, _, err := sesh.sb.pickRandConn() if err != nil { return nil, err } stream := makeStream(sesh, id, connId) - sesh.streamsM.Lock() - sesh.streams[id] = stream - sesh.streamsM.Unlock() + sesh.streams.Store(id, stream) log.Tracef("stream %v of session %v opened", id, sesh.id) return stream, nil } @@ -128,9 +124,7 @@ func (sesh *Session) Accept() (net.Conn, error) { if stream == nil { return nil, ErrBrokenSession } - sesh.streamsM.Lock() - sesh.streams[stream.id] = stream - sesh.streamsM.Unlock() + sesh.streams.Store(stream.id, stream) log.Tracef("stream %v of session %v accepted", stream.id, sesh.id) return stream, nil } @@ -166,26 +160,29 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { log.Tracef("stream %v passively closed", s.id) } - sesh.streamsM.Lock() - delete(sesh.streams, s.id) - if len(sesh.streams) == 0 { + sesh.streams.Delete(s.id) + var count int + sesh.streams.Range(func(_, _ interface{}) bool { + count += 1 + return true + }) + if count == 0 { log.Tracef("session %v has no active stream left", sesh.id) go sesh.timeoutAfter(30 * time.Second) } - sesh.streamsM.Unlock() return nil } +// recvDataFromRemote deobfuscate the frame and send it to the appropriate stream buffer func (sesh *Session) recvDataFromRemote(data []byte) error { frame, err := sesh.Deobfs(data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } - sesh.streamsM.Lock() - stream, existing := sesh.streams[frame.StreamID] - sesh.streamsM.Unlock() + streamI, existing := sesh.streams.Load(frame.StreamID) if existing { + stream := streamI.(*Stream) return stream.writeFrame(*frame) } else { if frame.Closing == 1 { @@ -200,9 +197,9 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { // any difference because we only care to send the data from the same stream through the same // TCP connection. The remote may use a different connection to send the same stream than the one the client // use to send. - connId, _ := sesh.sb.assignRandomConn() + connId, _, _ := sesh.sb.pickRandConn() // we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write - stream = makeStream(sesh, frame.StreamID, connId) + stream := makeStream(sesh, frame.StreamID, connId) sesh.acceptCh <- stream return stream.writeFrame(*frame) } @@ -230,13 +227,13 @@ func (sesh *Session) passiveClose() error { } sesh.acceptCh <- nil - sesh.streamsM.Lock() - for id, stream := range sesh.streams { + sesh.streams.Range(func(key, streamI interface{}) bool { + stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) - _ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() - delete(sesh.streams, id) - } - sesh.streamsM.Unlock() + _ = stream.recvBuf.Close() // will not block + sesh.streams.Delete(key) + return true + }) sesh.sb.closeAll() log.Debugf("session %v closed gracefully", sesh.id) @@ -259,13 +256,13 @@ func (sesh *Session) Close() error { } sesh.acceptCh <- nil - sesh.streamsM.Lock() - for id, stream := range sesh.streams { + sesh.streams.Range(func(key, streamI interface{}) bool { + stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) - _ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() - delete(sesh.streams, id) - } - sesh.streamsM.Unlock() + _ = stream.recvBuf.Close() // will not block + sesh.streams.Delete(key) + return true + }) pad := genRandomPadding() f := &Frame{ @@ -295,13 +292,14 @@ func (sesh *Session) IsClosed() bool { func (sesh *Session) timeoutAfter(to time.Duration) { time.Sleep(to) - sesh.streamsM.Lock() - if len(sesh.streams) == 0 && !sesh.IsClosed() { - sesh.streamsM.Unlock() + var count int + sesh.streams.Range(func(_, _ interface{}) bool { + count += 1 + return true + }) + if count == 0 && !sesh.IsClosed() { sesh.SetTerminalMsg("timeout") sesh.Close() - } else { - sesh.streamsM.Unlock() } } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index ac6e027..bcb9909 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -40,7 +40,11 @@ func TestRecvDataFromRemote(t *testing.T) { sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) - sesh.recvDataFromRemote(obfsBuf[:n]) + err := sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Error(err) + return + } stream, err := sesh.Accept() if err != nil { t.Error(err) @@ -63,7 +67,11 @@ func TestRecvDataFromRemote(t *testing.T) { sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) - sesh.recvDataFromRemote(obfsBuf[:n]) + err := sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Error(err) + return + } stream, err := sesh.Accept() if err != nil { t.Error(err) @@ -86,7 +94,11 @@ func TestRecvDataFromRemote(t *testing.T) { sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) - sesh.recvDataFromRemote(obfsBuf[:n]) + err := sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Error(err) + return + } stream, err := sesh.Accept() if err != nil { t.Error(err) @@ -110,7 +122,11 @@ func TestRecvDataFromRemote(t *testing.T) { sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) - sesh.recvDataFromRemote(obfsBuf[:n]) + err := sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Error(err) + return + } stream, err := sesh.Accept() if err != nil { t.Error(err) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 9d1103c..ee4d939 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -157,7 +157,7 @@ func TestStream_Close(t *testing.T) { return } - if _, ok := sesh.streams[streamID]; ok { + if _, ok := sesh.streams.Load(streamID); ok { t.Error("stream still exists") return } diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 432c0dc..b5d0f35 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -25,8 +25,7 @@ type switchboard struct { *switchboardConfig - connsM sync.RWMutex - conns map[uint32]net.Conn + conns sync.Map nextConnId uint32 broken uint32 @@ -38,7 +37,6 @@ func makeSwitchboard(sesh *Session, config *switchboardConfig) *switchboard { sb := &switchboard{ session: sesh, switchboardConfig: config, - conns: make(map[uint32]net.Conn), } return sb } @@ -46,11 +44,19 @@ func makeSwitchboard(sesh *Session, config *switchboardConfig) *switchboard { var errNilOptimum = errors.New("The optimal connection is nil") var errBrokenSwitchboard = errors.New("the switchboard is broken") +func (sb *switchboard) connsCount() int { + // count the number of entries in conns + var count int + sb.conns.Range(func(_, _ interface{}) bool { + count += 1 + return true + }) + return count +} + func (sb *switchboard) addConn(conn net.Conn) { connId := atomic.AddUint32(&sb.nextConnId, 1) - 1 - sb.connsM.Lock() - sb.conns[connId] = conn - sb.connsM.Unlock() + sb.conns.Store(connId, conn) go sb.deplex(connId, conn) } @@ -67,69 +73,58 @@ func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { } sb.Valve.txWait(len(data)) - sb.connsM.RLock() - defer sb.connsM.RUnlock() - if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 { + connCount := sb.connsCount() + if atomic.LoadUint32(&sb.broken) == 1 || connCount == 0 { return 0, errBrokenSwitchboard } if sb.strategy == UNIFORM_SPREAD { - r := rand.Intn(len(sb.conns)) - var c int - for newConnId := range sb.conns { - if r == c { - conn, _ := sb.conns[newConnId] - return writeAndRegUsage(conn, data) - } - c++ + _, conn, err := sb.pickRandConn() + if err != nil { + return 0, errBrokenSwitchboard } - return 0, errBrokenSwitchboard + return writeAndRegUsage(conn, data) } else { - var conn net.Conn - conn, ok := sb.conns[*connId] + connI, ok := sb.conns.Load(*connId) + conn := connI.(net.Conn) if ok { return writeAndRegUsage(conn, data) } else { - // do not call assignRandomConn() here. - // we'll have to do connsM.RLock() after we get a new connId from assignRandomConn, in order to - // get the new conn through conns[newConnId] - // however between connsM.RUnlock() in assignRandomConn and our call to connsM.RLock(), things may happen. - // in particular if newConnId is removed between the RUnlock and RLock, conns[newConnId] will return - // a nil pointer. To prevent this we must get newConnId and the reference to conn itself in one single mutex - // protection - r := rand.Intn(len(sb.conns)) - var c int - for newConnId := range sb.conns { - if r == c { - connId = &newConnId - conn, _ = sb.conns[newConnId] - return writeAndRegUsage(conn, data) - } - c++ + newConnId, conn, err := sb.pickRandConn() + if err != nil { + return 0, errBrokenSwitchboard } - return 0, errBrokenSwitchboard + connId = &newConnId + return writeAndRegUsage(conn, data) } } } // returns a random connId -func (sb *switchboard) assignRandomConn() (uint32, error) { - sb.connsM.RLock() - defer sb.connsM.RUnlock() - if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 { - return 0, errBrokenSwitchboard +func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) { + connCount := sb.connsCount() + if atomic.LoadUint32(&sb.broken) == 1 || connCount == 0 { + return 0, nil, errBrokenSwitchboard } - r := rand.Intn(len(sb.conns)) - var c int - for connId := range sb.conns { + // there is no guarantee that sb.conns still has the same amount of entries + // between the count loop and the pick loop + // so if the r > len(sb.conns) at the point of range call, the last visited element is picked + var id uint32 + var conn net.Conn + r := rand.Intn(connCount) + sb.conns.Range(func(connIdI, connI interface{}) bool { + var c int if r == c { - return connId, nil + id = connIdI.(uint32) + conn = connI.(net.Conn) + return false } c++ - } - return 0, errBrokenSwitchboard + return true + }) + return id, conn, nil } func (sb *switchboard) close(terminalMsg string) { @@ -142,12 +137,12 @@ func (sb *switchboard) close(terminalMsg string) { // actively triggered by session.Close() func (sb *switchboard) closeAll() { - sb.connsM.Lock() - for key, conn := range sb.conns { + sb.conns.Range(func(key, connI interface{}) bool { + conn := connI.(net.Conn) conn.Close() - delete(sb.conns, key) - } - sb.connsM.Unlock() + sb.conns.Delete(key) + return true + }) } // deplex function costantly reads from a TCP connection diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index 494ea2e..bfc2fa4 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -21,7 +21,7 @@ func TestSwitchboard_Send(t *testing.T) { sesh := MakeSession(0, seshConfig) hole0 := getHole() sesh.sb.addConn(hole0) - connId, err := sesh.sb.assignRandomConn() + connId, _, err := sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return @@ -36,7 +36,7 @@ func TestSwitchboard_Send(t *testing.T) { hole1 := getHole() sesh.sb.addConn(hole1) - connId, err = sesh.sb.assignRandomConn() + connId, _, err = sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return @@ -47,7 +47,7 @@ func TestSwitchboard_Send(t *testing.T) { return } - connId, err = sesh.sb.assignRandomConn() + connId, _, err = sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return @@ -88,7 +88,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) { } sesh := MakeSession(0, seshConfig) sesh.sb.addConn(hole) - connId, err := sesh.sb.assignRandomConn() + connId, _, err := sesh.sb.pickRandConn() if err != nil { b.Error("failed to get a random conn", err) return @@ -115,7 +115,7 @@ func TestSwitchboard_TxCredit(t *testing.T) { sesh := MakeSession(0, seshConfig) hole := newBlackHole() sesh.sb.addConn(hole) - connId, err := sesh.sb.assignRandomConn() + connId, _, err := sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return