From f476650953a636a51ea94b5f41b05d08b3cd1ef0 Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Sun, 28 Oct 2018 21:22:38 +0000 Subject: [PATCH] Rework switchboard dispatch --- cmd/ck-client/ck-client.go | 2 - cmd/ck-server/ck-server.go | 16 ++- internal/multiplex/session.go | 6 +- internal/multiplex/stream.go | 13 +-- internal/multiplex/switchboard.go | 160 ++++++++++++------------------ internal/util/util.go | 3 - 6 files changed, 77 insertions(+), 123 deletions(-) diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 7954c1d..bc8a72e 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -26,14 +26,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) { for { i, err := io.ReadAtLeast(src, buf, 1) if err != nil || i == 0 { - log.Println(err) go dst.Close() go src.Close() return } i, err = dst.Write(buf[:i]) if err != nil || i == 0 { - log.Println(err) go dst.Close() go src.Close() return diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 2485fe2..11378da 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -6,10 +6,10 @@ import ( "io" "log" "net" - "net/http" - _ "net/http/pprof" + //"net/http" + //_ "net/http/pprof" "os" - "runtime" + //"runtime" "strings" "time" @@ -27,14 +27,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) { for { i, err := io.ReadAtLeast(src, buf, 1) if err != nil || i == 0 { - log.Println(err) go dst.Close() go src.Close() return } i, err = dst.Write(buf[:i]) if err != nil || i == 0 { - log.Println(err) go dst.Close() go src.Close() return @@ -136,10 +134,10 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } func main() { - runtime.SetBlockProfileRate(5) - go func() { - log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) - }() + //runtime.SetBlockProfileRate(5) + //go func() { + // log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) + //}() // Should be 127.0.0.1 to listen to ss-server on this machine var localHost string // server_port in ss config, same as remotePort in plugin mode diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index e8c3216..23206f3 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -2,7 +2,6 @@ package multiplex import ( "errors" - "log" "net" "sync" "sync/atomic" @@ -62,7 +61,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([] } func (sesh *Session) AddConnection(conn net.Conn) { - sesh.sb.newConnCh <- conn + sesh.sb.addConn(conn) } func (sesh *Session) OpenStream() (*Stream, error) { @@ -106,7 +105,6 @@ func (sesh *Session) getStream(id uint32) *Stream { // addStream is used when the remote opened a new stream and we got notified func (sesh *Session) addStream(id uint32) *Stream { - log.Printf("Adding stream %v", id) stream := makeStream(id, sesh) sesh.streamsM.Lock() sesh.streams[id] = stream @@ -136,7 +134,7 @@ func (sesh *Session) Close() error { } sesh.streamsM.Unlock() - close(sesh.sb.die) + sesh.sb.shutdown() return nil } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 04b55db..777e4ea 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -55,7 +55,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { if len(buf) == 0 { select { case <-stream.die: - log.Printf("Stream %v dying\n", stream.id) return 0, errBrokenStream default: return 0, nil @@ -63,7 +62,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { } select { case <-stream.die: - log.Printf("Stream %v dying\n", stream.id) return 0, errBrokenStream case data := <-stream.sortedBufCh: if len(buf) < len(data) { @@ -79,7 +77,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { func (stream *Stream) Write(in []byte) (n int, err error) { select { case <-stream.die: - log.Printf("Stream %v dying\n", stream.id) return 0, errBrokenStream default: } @@ -94,9 +91,9 @@ func (stream *Stream) Write(in []byte) (n int, err error) { atomic.AddUint32(&stream.nextSendSeq, 1) tlsRecord := stream.session.obfs(f) - stream.session.sb.dispatCh <- tlsRecord + n, err = stream.session.sb.send(tlsRecord) - return len(in), nil + return } @@ -109,7 +106,6 @@ func (stream *Stream) passiveClose() error { if stream.closing { return errRepeatStreamClosing } - log.Printf("ID: %v passiveclosing\n", stream.id) stream.closing = true close(stream.die) stream.session.delStream(stream.id) @@ -125,13 +121,11 @@ func (stream *Stream) Close() error { if stream.closing { return errRepeatStreamClosing } - log.Printf("ID: %v closing\n", stream.id) stream.closing = true close(stream.die) prand.Seed(int64(stream.id)) padLen := int(math.Floor(prand.Float64()*200 + 300)) - log.Println(padLen) pad := make([]byte, padLen) prand.Read(pad) f := &Frame{ @@ -141,7 +135,7 @@ func (stream *Stream) Close() error { Payload: pad, } tlsRecord := stream.session.obfs(f) - stream.session.sb.dispatCh <- tlsRecord + stream.session.sb.send(tlsRecord) stream.session.delStream(stream.id) return nil @@ -150,7 +144,6 @@ func (stream *Stream) Close() error { // Same as Close() but no call to session.delStream. // This is called in session.Close() to avoid mutex deadlock func (stream *Stream) closeNoDelMap() error { - log.Printf("ID: %v closing\n", stream.id) // Lock here because closing a closed channel causes panic stream.closingM.Lock() diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index caa2bb6..70a2cab 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -3,7 +3,8 @@ package multiplex import ( "log" "net" - "sort" + "sync" + "sync/atomic" ) const ( @@ -16,51 +17,26 @@ const ( type switchboard struct { session *Session - ces []*connEnclave - - // For telling dispatcher how many bytes have been sent after Connection.send. - sentNotifyCh chan *sentNotifier - // dispatCh is used by streams to send new data to remote - dispatCh chan []byte - newConnCh chan net.Conn - closingCECh chan *connEnclave - die chan struct{} - closing bool + optimum atomic.Value + cesM sync.RWMutex + ces []*connEnclave } // Some data comes from a Stream to be sent through one of the many // remoteConn, but which remoteConn should we use to send the data? // // In this case, we pick the remoteConn that has about the smallest sendQueue. -// Though "smallest" is not guaranteed because it doesn't has to be type connEnclave struct { sb *switchboard remoteConn net.Conn - sendQueue int -} - -type byQ []*connEnclave - -func (a byQ) Len() int { - return len(a) -} -func (a byQ) Swap(i, j int) { - a[i], a[j] = a[j], a[i] -} -func (a byQ) Less(i, j int) bool { - return a[i].sendQueue < a[j].sendQueue + sendQueue uint32 } // It takes at least 1 conn to start a switchboard func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard { sb := &switchboard{ - session: sesh, - ces: []*connEnclave{}, - sentNotifyCh: make(chan *sentNotifier, sentNotifyBacklog), - dispatCh: make(chan []byte, dispatchBacklog), - newConnCh: make(chan net.Conn, newConnBacklog), - closingCECh: make(chan *connEnclave, 5), - die: make(chan struct{}), + session: sesh, + ces: []*connEnclave{}, } ce := &connEnclave{ sb: sb, @@ -70,80 +46,74 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard { sb.ces = append(sb.ces, ce) go sb.deplex(ce) - go sb.dispatch() return sb } -// Everytime after a remoteConn sends something, it constructs this struct -// Which is sent back to dispatch() through sentNotifyCh to tell dispatch -// how many bytes it has sent -type sentNotifier struct { - ce *connEnclave - sent int -} - -func (ce *connEnclave) send(data []byte) { - // TODO: error handling +func (sb *switchboard) send(data []byte) (int, error) { + ce := sb.optimum.Load().(*connEnclave) + atomic.AddUint32(&ce.sendQueue, uint32(len(data))) + go sb.updateOptimum() n, err := ce.remoteConn.Write(data) if err != nil { - ce.sb.closingCECh <- ce - log.Println(err) + return 0, err + // TODO } - - sn := &sentNotifier{ - ce, - n, - } - ce.sb.sentNotifyCh <- sn - + atomic.AddUint32(&ce.sendQueue, ^uint32(n-1)) + go sb.updateOptimum() + return n, nil } -// Dispatcher sends data coming from a stream to a remote connection -// I used channels here because I didn't want to use mutex -func (sb *switchboard) dispatch() { - var dying bool - for { - select { - // dispatCh receives data from stream.Write - case data := <-sb.dispatCh: - go sb.ces[0].send(data) - sb.ces[0].sendQueue += len(data) - case notified := <-sb.sentNotifyCh: - notified.ce.sendQueue -= notified.sent - sort.Sort(byQ(sb.ces)) - case conn := <-sb.newConnCh: - log.Println("newConn") - newCe := &connEnclave{ - sb: sb, - remoteConn: conn, - sendQueue: 0, - } - sb.ces = append(sb.ces, newCe) - go sb.deplex(newCe) - case closing := <-sb.closingCECh: - log.Println("Closing conn") - for i, ce := range sb.ces { - if closing == ce { - sb.ces = append(sb.ces[:i], sb.ces[i+1:]...) - break - } - } - if len(sb.ces) == 0 && !dying { - sb.session.Close() - } - case <-sb.die: - dying = true - for _, ce := range sb.ces { - ce.remoteConn.Close() - } - return +func (sb *switchboard) updateOptimum() { + currentOpti := sb.optimum.Load().(*connEnclave) + currentOptiQ := atomic.LoadUint32(¤tOpti.sendQueue) + sb.cesM.RLock() + for _, ce := range sb.ces { + ceQ := atomic.LoadUint32(&ce.sendQueue) + if ceQ < currentOptiQ { + currentOpti = ce + currentOptiQ = ceQ } } + sb.cesM.RUnlock() + sb.optimum.Store(currentOpti) } -// deplex function costantly reads from a TCP connection -// it is responsible to act in response to the deobfsed header -// i.e. should a new stream be added? which existing stream should be closed? +func (sb *switchboard) addConn(conn net.Conn) { + + newCe := &connEnclave{ + sb: sb, + remoteConn: conn, + sendQueue: 0, + } + sb.cesM.Lock() + sb.ces = append(sb.ces, newCe) + sb.cesM.Unlock() + sb.optimum.Store(newCe) + go sb.deplex(newCe) +} + +func (sb *switchboard) removeConn(closing *connEnclave) { + sb.cesM.Lock() + for i, ce := range sb.ces { + if closing == ce { + sb.ces = append(sb.ces[:i], sb.ces[i+1:]...) + break + } + } + sb.cesM.Unlock() + if len(sb.ces) == 0 { + sb.session.Close() + } +} + +func (sb *switchboard) shutdown() { + for _, ce := range sb.ces { + ce.remoteConn.Close() + } +} + +// deplex function costantly reads from a TCP connection, call deobfs and distribute it +// to the corresponding frame func (sb *switchboard) deplex(ce *connEnclave) { buf := make([]byte, 20480) for { @@ -151,7 +121,7 @@ func (sb *switchboard) deplex(ce *connEnclave) { if err != nil { log.Println(err) go ce.remoteConn.Close() - sb.closingCECh <- ce + sb.removeConn(ce) return } frame := sb.session.deobfs(buf[:i]) diff --git a/internal/util/util.go b/internal/util/util.go index 82f32d7..ef38ffa 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -9,7 +9,6 @@ import ( prand "math/rand" "net" "strconv" - "time" ) func AESEncrypt(iv []byte, key []byte, plaintext []byte) []byte { @@ -69,7 +68,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) { left := dataLength readPtr := 5 - conn.SetReadDeadline(time.Now().Add(3 * time.Second)) for left != 0 { // If left > buffer size (i.e. our message got segmented), the entire MTU is read // if left = buffer size, the entire buffer is all there left to read @@ -82,7 +80,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) { left -= i readPtr += i } - conn.SetReadDeadline(time.Time{}) n = 5 + dataLength return