Rework switchboard dispatch

This commit is contained in:
Qian Wang 2018-10-28 21:22:38 +00:00
parent 9e4aedbdc1
commit f476650953
6 changed files with 77 additions and 123 deletions

View File

@ -26,14 +26,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
for { for {
i, err := io.ReadAtLeast(src, buf, 1) i, err := io.ReadAtLeast(src, buf, 1)
if err != nil || i == 0 { if err != nil || i == 0 {
log.Println(err)
go dst.Close() go dst.Close()
go src.Close() go src.Close()
return return
} }
i, err = dst.Write(buf[:i]) i, err = dst.Write(buf[:i])
if err != nil || i == 0 { if err != nil || i == 0 {
log.Println(err)
go dst.Close() go dst.Close()
go src.Close() go src.Close()
return return

View File

@ -6,10 +6,10 @@ import (
"io" "io"
"log" "log"
"net" "net"
"net/http" //"net/http"
_ "net/http/pprof" //_ "net/http/pprof"
"os" "os"
"runtime" //"runtime"
"strings" "strings"
"time" "time"
@ -27,14 +27,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
for { for {
i, err := io.ReadAtLeast(src, buf, 1) i, err := io.ReadAtLeast(src, buf, 1)
if err != nil || i == 0 { if err != nil || i == 0 {
log.Println(err)
go dst.Close() go dst.Close()
go src.Close() go src.Close()
return return
} }
i, err = dst.Write(buf[:i]) i, err = dst.Write(buf[:i])
if err != nil || i == 0 { if err != nil || i == 0 {
log.Println(err)
go dst.Close() go dst.Close()
go src.Close() go src.Close()
return return
@ -136,10 +134,10 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
} }
func main() { func main() {
runtime.SetBlockProfileRate(5) //runtime.SetBlockProfileRate(5)
go func() { //go func() {
log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) // log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
}() //}()
// Should be 127.0.0.1 to listen to ss-server on this machine // Should be 127.0.0.1 to listen to ss-server on this machine
var localHost string var localHost string
// server_port in ss config, same as remotePort in plugin mode // server_port in ss config, same as remotePort in plugin mode

View File

@ -2,7 +2,6 @@ package multiplex
import ( import (
"errors" "errors"
"log"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -62,7 +61,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([]
} }
func (sesh *Session) AddConnection(conn net.Conn) { func (sesh *Session) AddConnection(conn net.Conn) {
sesh.sb.newConnCh <- conn sesh.sb.addConn(conn)
} }
func (sesh *Session) OpenStream() (*Stream, error) { func (sesh *Session) OpenStream() (*Stream, error) {
@ -106,7 +105,6 @@ func (sesh *Session) getStream(id uint32) *Stream {
// addStream is used when the remote opened a new stream and we got notified // addStream is used when the remote opened a new stream and we got notified
func (sesh *Session) addStream(id uint32) *Stream { func (sesh *Session) addStream(id uint32) *Stream {
log.Printf("Adding stream %v", id)
stream := makeStream(id, sesh) stream := makeStream(id, sesh)
sesh.streamsM.Lock() sesh.streamsM.Lock()
sesh.streams[id] = stream sesh.streams[id] = stream
@ -136,7 +134,7 @@ func (sesh *Session) Close() error {
} }
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
close(sesh.sb.die) sesh.sb.shutdown()
return nil return nil
} }

View File

@ -55,7 +55,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
if len(buf) == 0 { if len(buf) == 0 {
select { select {
case <-stream.die: case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errBrokenStream return 0, errBrokenStream
default: default:
return 0, nil return 0, nil
@ -63,7 +62,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
} }
select { select {
case <-stream.die: case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errBrokenStream return 0, errBrokenStream
case data := <-stream.sortedBufCh: case data := <-stream.sortedBufCh:
if len(buf) < len(data) { if len(buf) < len(data) {
@ -79,7 +77,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
func (stream *Stream) Write(in []byte) (n int, err error) { func (stream *Stream) Write(in []byte) (n int, err error) {
select { select {
case <-stream.die: case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errBrokenStream return 0, errBrokenStream
default: default:
} }
@ -94,9 +91,9 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
atomic.AddUint32(&stream.nextSendSeq, 1) atomic.AddUint32(&stream.nextSendSeq, 1)
tlsRecord := stream.session.obfs(f) tlsRecord := stream.session.obfs(f)
stream.session.sb.dispatCh <- tlsRecord n, err = stream.session.sb.send(tlsRecord)
return len(in), nil return
} }
@ -109,7 +106,6 @@ func (stream *Stream) passiveClose() error {
if stream.closing { if stream.closing {
return errRepeatStreamClosing return errRepeatStreamClosing
} }
log.Printf("ID: %v passiveclosing\n", stream.id)
stream.closing = true stream.closing = true
close(stream.die) close(stream.die)
stream.session.delStream(stream.id) stream.session.delStream(stream.id)
@ -125,13 +121,11 @@ func (stream *Stream) Close() error {
if stream.closing { if stream.closing {
return errRepeatStreamClosing return errRepeatStreamClosing
} }
log.Printf("ID: %v closing\n", stream.id)
stream.closing = true stream.closing = true
close(stream.die) close(stream.die)
prand.Seed(int64(stream.id)) prand.Seed(int64(stream.id))
padLen := int(math.Floor(prand.Float64()*200 + 300)) padLen := int(math.Floor(prand.Float64()*200 + 300))
log.Println(padLen)
pad := make([]byte, padLen) pad := make([]byte, padLen)
prand.Read(pad) prand.Read(pad)
f := &Frame{ f := &Frame{
@ -141,7 +135,7 @@ func (stream *Stream) Close() error {
Payload: pad, Payload: pad,
} }
tlsRecord := stream.session.obfs(f) tlsRecord := stream.session.obfs(f)
stream.session.sb.dispatCh <- tlsRecord stream.session.sb.send(tlsRecord)
stream.session.delStream(stream.id) stream.session.delStream(stream.id)
return nil return nil
@ -150,7 +144,6 @@ func (stream *Stream) Close() error {
// Same as Close() but no call to session.delStream. // Same as Close() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock // This is called in session.Close() to avoid mutex deadlock
func (stream *Stream) closeNoDelMap() error { func (stream *Stream) closeNoDelMap() error {
log.Printf("ID: %v closing\n", stream.id)
// Lock here because closing a closed channel causes panic // Lock here because closing a closed channel causes panic
stream.closingM.Lock() stream.closingM.Lock()

View File

@ -3,7 +3,8 @@ package multiplex
import ( import (
"log" "log"
"net" "net"
"sort" "sync"
"sync/atomic"
) )
const ( const (
@ -16,39 +17,19 @@ const (
type switchboard struct { type switchboard struct {
session *Session session *Session
optimum atomic.Value
cesM sync.RWMutex
ces []*connEnclave ces []*connEnclave
// For telling dispatcher how many bytes have been sent after Connection.send.
sentNotifyCh chan *sentNotifier
// dispatCh is used by streams to send new data to remote
dispatCh chan []byte
newConnCh chan net.Conn
closingCECh chan *connEnclave
die chan struct{}
closing bool
} }
// Some data comes from a Stream to be sent through one of the many // Some data comes from a Stream to be sent through one of the many
// remoteConn, but which remoteConn should we use to send the data? // remoteConn, but which remoteConn should we use to send the data?
// //
// In this case, we pick the remoteConn that has about the smallest sendQueue. // In this case, we pick the remoteConn that has about the smallest sendQueue.
// Though "smallest" is not guaranteed because it doesn't has to be
type connEnclave struct { type connEnclave struct {
sb *switchboard sb *switchboard
remoteConn net.Conn remoteConn net.Conn
sendQueue int sendQueue uint32
}
type byQ []*connEnclave
func (a byQ) Len() int {
return len(a)
}
func (a byQ) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a byQ) Less(i, j int) bool {
return a[i].sendQueue < a[j].sendQueue
} }
// It takes at least 1 conn to start a switchboard // It takes at least 1 conn to start a switchboard
@ -56,11 +37,6 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
sb := &switchboard{ sb := &switchboard{
session: sesh, session: sesh,
ces: []*connEnclave{}, ces: []*connEnclave{},
sentNotifyCh: make(chan *sentNotifier, sentNotifyBacklog),
dispatCh: make(chan []byte, dispatchBacklog),
newConnCh: make(chan net.Conn, newConnBacklog),
closingCECh: make(chan *connEnclave, 5),
die: make(chan struct{}),
} }
ce := &connEnclave{ ce := &connEnclave{
sb: sb, sb: sb,
@ -70,80 +46,74 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
sb.ces = append(sb.ces, ce) sb.ces = append(sb.ces, ce)
go sb.deplex(ce) go sb.deplex(ce)
go sb.dispatch()
return sb return sb
} }
// Everytime after a remoteConn sends something, it constructs this struct func (sb *switchboard) send(data []byte) (int, error) {
// Which is sent back to dispatch() through sentNotifyCh to tell dispatch ce := sb.optimum.Load().(*connEnclave)
// how many bytes it has sent atomic.AddUint32(&ce.sendQueue, uint32(len(data)))
type sentNotifier struct { go sb.updateOptimum()
ce *connEnclave
sent int
}
func (ce *connEnclave) send(data []byte) {
// TODO: error handling
n, err := ce.remoteConn.Write(data) n, err := ce.remoteConn.Write(data)
if err != nil { if err != nil {
ce.sb.closingCECh <- ce return 0, err
log.Println(err) // TODO
}
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
go sb.updateOptimum()
return n, nil
} }
sn := &sentNotifier{ func (sb *switchboard) updateOptimum() {
ce, currentOpti := sb.optimum.Load().(*connEnclave)
n, currentOptiQ := atomic.LoadUint32(&currentOpti.sendQueue)
sb.cesM.RLock()
for _, ce := range sb.ces {
ceQ := atomic.LoadUint32(&ce.sendQueue)
if ceQ < currentOptiQ {
currentOpti = ce
currentOptiQ = ceQ
} }
ce.sb.sentNotifyCh <- sn }
sb.cesM.RUnlock()
sb.optimum.Store(currentOpti)
} }
// Dispatcher sends data coming from a stream to a remote connection func (sb *switchboard) addConn(conn net.Conn) {
// I used channels here because I didn't want to use mutex
func (sb *switchboard) dispatch() {
var dying bool
for {
select {
// dispatCh receives data from stream.Write
case data := <-sb.dispatCh:
go sb.ces[0].send(data)
sb.ces[0].sendQueue += len(data)
case notified := <-sb.sentNotifyCh:
notified.ce.sendQueue -= notified.sent
sort.Sort(byQ(sb.ces))
case conn := <-sb.newConnCh:
log.Println("newConn")
newCe := &connEnclave{ newCe := &connEnclave{
sb: sb, sb: sb,
remoteConn: conn, remoteConn: conn,
sendQueue: 0, sendQueue: 0,
} }
sb.cesM.Lock()
sb.ces = append(sb.ces, newCe) sb.ces = append(sb.ces, newCe)
sb.cesM.Unlock()
sb.optimum.Store(newCe)
go sb.deplex(newCe) go sb.deplex(newCe)
case closing := <-sb.closingCECh: }
log.Println("Closing conn")
func (sb *switchboard) removeConn(closing *connEnclave) {
sb.cesM.Lock()
for i, ce := range sb.ces { for i, ce := range sb.ces {
if closing == ce { if closing == ce {
sb.ces = append(sb.ces[:i], sb.ces[i+1:]...) sb.ces = append(sb.ces[:i], sb.ces[i+1:]...)
break break
} }
} }
if len(sb.ces) == 0 && !dying { sb.cesM.Unlock()
if len(sb.ces) == 0 {
sb.session.Close() sb.session.Close()
} }
case <-sb.die:
dying = true
for _, ce := range sb.ces {
ce.remoteConn.Close()
}
return
}
}
} }
// deplex function costantly reads from a TCP connection func (sb *switchboard) shutdown() {
// it is responsible to act in response to the deobfsed header for _, ce := range sb.ces {
// i.e. should a new stream be added? which existing stream should be closed? ce.remoteConn.Close()
}
}
// deplex function costantly reads from a TCP connection, call deobfs and distribute it
// to the corresponding frame
func (sb *switchboard) deplex(ce *connEnclave) { func (sb *switchboard) deplex(ce *connEnclave) {
buf := make([]byte, 20480) buf := make([]byte, 20480)
for { for {
@ -151,7 +121,7 @@ func (sb *switchboard) deplex(ce *connEnclave) {
if err != nil { if err != nil {
log.Println(err) log.Println(err)
go ce.remoteConn.Close() go ce.remoteConn.Close()
sb.closingCECh <- ce sb.removeConn(ce)
return return
} }
frame := sb.session.deobfs(buf[:i]) frame := sb.session.deobfs(buf[:i])

View File

@ -9,7 +9,6 @@ import (
prand "math/rand" prand "math/rand"
"net" "net"
"strconv" "strconv"
"time"
) )
func AESEncrypt(iv []byte, key []byte, plaintext []byte) []byte { func AESEncrypt(iv []byte, key []byte, plaintext []byte) []byte {
@ -69,7 +68,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
left := dataLength left := dataLength
readPtr := 5 readPtr := 5
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
for left != 0 { for left != 0 {
// If left > buffer size (i.e. our message got segmented), the entire MTU is read // If left > buffer size (i.e. our message got segmented), the entire MTU is read
// if left = buffer size, the entire buffer is all there left to read // if left = buffer size, the entire buffer is all there left to read
@ -82,7 +80,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
left -= i left -= i
readPtr += i readPtr += i
} }
conn.SetReadDeadline(time.Time{})
n = 5 + dataLength n = 5 + dataLength
return return