diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 23be6fa..802b50a 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -79,7 +79,11 @@ func (sesh *Session) OpenStream() (*Stream, error) { } id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1 // Because atomic.AddUint32 returns the value after incrementation - stream := makeStream(id, sesh) + connId, err := sesh.sb.assignRandomConn() + if err != nil { + return nil, err + } + stream := makeStream(sesh, id, connId) sesh.streamsM.Lock() sesh.streams[id] = stream sesh.streamsM.Unlock() @@ -123,7 +127,13 @@ func (sesh *Session) recvDataFromRemote(data []byte) { // If the stream has been closed and the current frame is a closing frame, we do noop return } else { - stream = makeStream(frame.StreamID, sesh) + // it may be tempting to use the connId from which the frame was received. However it doesn't make + // any difference because we only care to send the data from the same stream through the same + // TCP connection. The remote may use a different connection to send the same stream than the one the client + // use to send. + connId, _ := sesh.sb.assignRandomConn() + // we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write + stream = makeStream(sesh, frame.StreamID, connId) sesh.streams[frame.StreamID] = stream sesh.acceptCh <- stream stream.writeFrame(frame) diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index c8b4611..8f3a380 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -28,27 +28,34 @@ type Stream struct { writingM sync.RWMutex - // close(die) is used to notify different goroutines that this stream is closing closed uint32 obfsBuf []byte + + // we assign each stream a fixed underlying TCP connection to utilise order guarantee provided by TCP itself + // so that frameSorter should have few to none ooo frames to deal with + // overall the streams in a session should be uniformly distributed across all connections + assignedConnId uint32 } -func makeStream(id uint32, sesh *Session) *Stream { +func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream { buf := NewBufferedPipe() stream := &Stream{ - id: id, - session: sesh, - sortedBuf: buf, - obfsBuf: make([]byte, 17000), - sorter: NewFrameSorter(buf), + id: id, + session: sesh, + sortedBuf: buf, + obfsBuf: make([]byte, 17000), + sorter: NewFrameSorter(buf), + assignedConnId: assignedConnId, } log.Tracef("stream %v opened", id) return stream } +//func (s *Stream) reassignConnId(connId uint32) { atomic.StoreUint32(&s.assignedConnId,connId)} + func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) writeFrame(frame *Frame) { s.sorter.writeNewFrame(frame) } @@ -96,7 +103,7 @@ func (s *Stream) Write(in []byte) (n int, err error) { if err != nil { return i, err } - n, err = s.session.sb.Write(s.obfsBuf[:i]) + n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) return } @@ -139,7 +146,7 @@ func (s *Stream) Close() error { if err != nil { return err } - _, err = s.session.sb.Write(s.obfsBuf[:i]) + _, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) if err != nil { return err } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index c6a7db0..f3e840b 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -99,6 +99,7 @@ func TestStream_Read(t *testing.T) { i, _ := sesh.Obfs(f, obfsBuf) streamID++ ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) stream, err := sesh.Accept() if err != nil { t.Error("failed to accept stream", err) @@ -120,6 +121,7 @@ func TestStream_Read(t *testing.T) { i, _ := sesh.Obfs(f, obfsBuf) streamID++ ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() i, err := stream.Read(nil) if i != 0 || err != nil { @@ -140,6 +142,7 @@ func TestStream_Read(t *testing.T) { i, _ := sesh.Obfs(f, obfsBuf) streamID++ ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() stream.Close() i, err := stream.Read(buf) @@ -164,6 +167,7 @@ func TestStream_Read(t *testing.T) { i, _ := sesh.Obfs(f, obfsBuf) streamID++ ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() sesh.Close() i, err := stream.Read(buf) diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 79b03d2..f650ca2 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -3,6 +3,7 @@ package multiplex import ( "errors" log "github.com/sirupsen/logrus" + "math/rand" "net" "sync" "sync/atomic" @@ -14,38 +15,20 @@ type switchboard struct { *Valve - // optimum is the connEnclave with the smallest sendQueue - optimum atomic.Value // *connEnclave - cesM sync.RWMutex - ces []*connEnclave + connsM sync.RWMutex + conns map[uint32]net.Conn + nextConnId uint32 broken uint32 } -func (sb *switchboard) getOptimum() *connEnclave { - if i := sb.optimum.Load(); i == nil { - return nil - } else { - return i.(*connEnclave) - } -} - -// Some data comes from a Stream to be sent through one of the many -// remoteConn, but which remoteConn should we use to send the data? -// -// In this case, we pick the remoteConn that has about the smallest sendQueue. -type connEnclave struct { - remoteConn net.Conn - sendQueue uint32 -} - func makeSwitchboard(sesh *Session, valve *Valve) *switchboard { // rates are uint64 because in the usermanager we want the bandwidth to be atomically // operated (so that the bandwidth can change on the fly). sb := &switchboard{ session: sesh, Valve: valve, - ces: []*connEnclave{}, + conns: make(map[uint32]net.Conn), } return sb } @@ -53,65 +36,19 @@ func makeSwitchboard(sesh *Session, valve *Valve) *switchboard { var errNilOptimum = errors.New("The optimal connection is nil") var errBrokenSwitchboard = errors.New("the switchboard is broken") -func (sb *switchboard) Write(data []byte) (int, error) { - if atomic.LoadUint32(&sb.broken) == 1 { - return 0, errBrokenSwitchboard - } - ce := sb.getOptimum() - if ce == nil { - return 0, errNilOptimum - } - atomic.AddUint32(&ce.sendQueue, uint32(len(data))) - go sb.updateOptimum() - n, err := ce.remoteConn.Write(data) - if err != nil { - return n, err - } - sb.txWait(n) - sb.Valve.AddTx(int64(n)) - atomic.AddUint32(&ce.sendQueue, ^uint32(n-1)) - go sb.updateOptimum() - return n, nil -} - -func (sb *switchboard) updateOptimum() { - currentOpti := sb.getOptimum() - currentOptiQ := atomic.LoadUint32(¤tOpti.sendQueue) - sb.cesM.RLock() - for _, ce := range sb.ces { - ceQ := atomic.LoadUint32(&ce.sendQueue) - if ceQ < currentOptiQ { - currentOpti = ce - currentOptiQ = ceQ - } - } - sb.cesM.RUnlock() - sb.optimum.Store(currentOpti) -} - func (sb *switchboard) addConn(conn net.Conn) { - var sendQueue uint32 - newCe := &connEnclave{ - remoteConn: conn, - sendQueue: sendQueue, - } - sb.cesM.Lock() - sb.ces = append(sb.ces, newCe) - sb.cesM.Unlock() - sb.optimum.Store(newCe) - go sb.deplex(newCe) + connId := atomic.AddUint32(&sb.nextConnId, 1) - 1 + sb.connsM.Lock() + sb.conns[connId] = conn + sb.connsM.Unlock() + go sb.deplex(connId, conn) } -func (sb *switchboard) removeConn(closing *connEnclave) { - sb.cesM.Lock() - for i, ce := range sb.ces { - if closing == ce { - sb.ces = append(sb.ces[:i], sb.ces[i+1:]...) - break - } - } - remaining := len(sb.ces) - sb.cesM.Unlock() +func (sb *switchboard) removeConn(connId uint32) { + sb.connsM.Lock() + delete(sb.conns, connId) + remaining := len(sb.conns) + sb.connsM.Unlock() if remaining == 0 { atomic.StoreUint32(&sb.broken, 1) sb.session.SetTerminalMsg("no underlying connection left") @@ -119,29 +56,67 @@ func (sb *switchboard) removeConn(closing *connEnclave) { } } +// a pointer to connId is passed here so that the switchboard can reassign it +func (sb *switchboard) send(data []byte, connId *uint32) (int, error) { + var conn net.Conn + sb.connsM.RLock() + conn, ok := sb.conns[*connId] + sb.connsM.RUnlock() + if ok { + return conn.Write(data) + } else { + // do not call assignRandomConn() here. + // we'll have to do connsM.RLock() after we get a new connId from assignRandomConn, in order to + // get the new conn through conns[newConnId] + // however between connsM.RUnlock() in assignRandomConn and our call to connsM.RLock(), things may happen. + // in particular if newConnId is removed between the RUnlock and RLock, conns[newConnId] will return + // a nil pointer. To prevent this we must get newConnId and the reference to conn itself in one single mutex + // protection + if atomic.LoadUint32(&sb.broken) == 1 { + return 0, errBrokenSwitchboard + } + sb.connsM.RLock() + newConnId := rand.Intn(len(sb.conns)) + conn = sb.conns[uint32(newConnId)] + sb.connsM.RUnlock() + return conn.Write(data) + } + +} + +func (sb *switchboard) assignRandomConn() (uint32, error) { + if atomic.LoadUint32(&sb.broken) == 1 { + return 0, errBrokenSwitchboard + } + sb.connsM.RLock() + connId := rand.Intn(len(sb.conns)) + sb.connsM.RUnlock() + return uint32(connId), nil +} + // actively triggered by session.Close() func (sb *switchboard) closeAll() { if atomic.SwapUint32(&sb.broken, 1) == 1 { return } - sb.cesM.RLock() - for _, ce := range sb.ces { - ce.remoteConn.Close() + sb.connsM.RLock() + for _, conn := range sb.conns { + conn.Close() } - sb.cesM.RUnlock() + sb.connsM.RUnlock() } // deplex function costantly reads from a TCP connection -func (sb *switchboard) deplex(ce *connEnclave) { +func (sb *switchboard) deplex(connId uint32, conn net.Conn) { buf := make([]byte, 20480) for { - n, err := sb.session.unitRead(ce.remoteConn, buf) + n, err := sb.session.unitRead(conn, buf) sb.rxWait(n) sb.Valve.AddRx(int64(n)) if err != nil { log.Tracef("a connection for session %v has closed: %v", sb.session.id, err) - go ce.remoteConn.Close() - sb.removeConn(ce) + go conn.Close() + sb.removeConn(connId) return }