Rework switchboard dispatch

This commit is contained in:
Qian Wang 2018-10-28 21:22:38 +00:00
parent 9e4aedbdc1
commit f476650953
6 changed files with 77 additions and 123 deletions

View File

@ -26,14 +26,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
for {
i, err := io.ReadAtLeast(src, buf, 1)
if err != nil || i == 0 {
log.Println(err)
go dst.Close()
go src.Close()
return
}
i, err = dst.Write(buf[:i])
if err != nil || i == 0 {
log.Println(err)
go dst.Close()
go src.Close()
return

View File

@ -6,10 +6,10 @@ import (
"io"
"log"
"net"
"net/http"
_ "net/http/pprof"
//"net/http"
//_ "net/http/pprof"
"os"
"runtime"
//"runtime"
"strings"
"time"
@ -27,14 +27,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
for {
i, err := io.ReadAtLeast(src, buf, 1)
if err != nil || i == 0 {
log.Println(err)
go dst.Close()
go src.Close()
return
}
i, err = dst.Write(buf[:i])
if err != nil || i == 0 {
log.Println(err)
go dst.Close()
go src.Close()
return
@ -136,10 +134,10 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
}
func main() {
runtime.SetBlockProfileRate(5)
go func() {
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
}()
//runtime.SetBlockProfileRate(5)
//go func() {
// log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
//}()
// Should be 127.0.0.1 to listen to ss-server on this machine
var localHost string
// server_port in ss config, same as remotePort in plugin mode

View File

@ -2,7 +2,6 @@ package multiplex
import (
"errors"
"log"
"net"
"sync"
"sync/atomic"
@ -62,7 +61,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([]
}
func (sesh *Session) AddConnection(conn net.Conn) {
sesh.sb.newConnCh <- conn
sesh.sb.addConn(conn)
}
func (sesh *Session) OpenStream() (*Stream, error) {
@ -106,7 +105,6 @@ func (sesh *Session) getStream(id uint32) *Stream {
// addStream is used when the remote opened a new stream and we got notified
func (sesh *Session) addStream(id uint32) *Stream {
log.Printf("Adding stream %v", id)
stream := makeStream(id, sesh)
sesh.streamsM.Lock()
sesh.streams[id] = stream
@ -136,7 +134,7 @@ func (sesh *Session) Close() error {
}
sesh.streamsM.Unlock()
close(sesh.sb.die)
sesh.sb.shutdown()
return nil
}

View File

@ -55,7 +55,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
if len(buf) == 0 {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errBrokenStream
default:
return 0, nil
@ -63,7 +62,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
}
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errBrokenStream
case data := <-stream.sortedBufCh:
if len(buf) < len(data) {
@ -79,7 +77,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
func (stream *Stream) Write(in []byte) (n int, err error) {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errBrokenStream
default:
}
@ -94,9 +91,9 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
atomic.AddUint32(&stream.nextSendSeq, 1)
tlsRecord := stream.session.obfs(f)
stream.session.sb.dispatCh <- tlsRecord
n, err = stream.session.sb.send(tlsRecord)
return len(in), nil
return
}
@ -109,7 +106,6 @@ func (stream *Stream) passiveClose() error {
if stream.closing {
return errRepeatStreamClosing
}
log.Printf("ID: %v passiveclosing\n", stream.id)
stream.closing = true
close(stream.die)
stream.session.delStream(stream.id)
@ -125,13 +121,11 @@ func (stream *Stream) Close() error {
if stream.closing {
return errRepeatStreamClosing
}
log.Printf("ID: %v closing\n", stream.id)
stream.closing = true
close(stream.die)
prand.Seed(int64(stream.id))
padLen := int(math.Floor(prand.Float64()*200 + 300))
log.Println(padLen)
pad := make([]byte, padLen)
prand.Read(pad)
f := &Frame{
@ -141,7 +135,7 @@ func (stream *Stream) Close() error {
Payload: pad,
}
tlsRecord := stream.session.obfs(f)
stream.session.sb.dispatCh <- tlsRecord
stream.session.sb.send(tlsRecord)
stream.session.delStream(stream.id)
return nil
@ -150,7 +144,6 @@ func (stream *Stream) Close() error {
// Same as Close() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock
func (stream *Stream) closeNoDelMap() error {
log.Printf("ID: %v closing\n", stream.id)
// Lock here because closing a closed channel causes panic
stream.closingM.Lock()

View File

@ -3,7 +3,8 @@ package multiplex
import (
"log"
"net"
"sort"
"sync"
"sync/atomic"
)
const (
@ -16,39 +17,19 @@ const (
type switchboard struct {
session *Session
optimum atomic.Value
cesM sync.RWMutex
ces []*connEnclave
// For telling dispatcher how many bytes have been sent after Connection.send.
sentNotifyCh chan *sentNotifier
// dispatCh is used by streams to send new data to remote
dispatCh chan []byte
newConnCh chan net.Conn
closingCECh chan *connEnclave
die chan struct{}
closing bool
}
// Some data comes from a Stream to be sent through one of the many
// remoteConn, but which remoteConn should we use to send the data?
//
// In this case, we pick the remoteConn that has about the smallest sendQueue.
// Though "smallest" is not guaranteed because it doesn't has to be
type connEnclave struct {
sb *switchboard
remoteConn net.Conn
sendQueue int
}
type byQ []*connEnclave
func (a byQ) Len() int {
return len(a)
}
func (a byQ) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a byQ) Less(i, j int) bool {
return a[i].sendQueue < a[j].sendQueue
sendQueue uint32
}
// It takes at least 1 conn to start a switchboard
@ -56,11 +37,6 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
sb := &switchboard{
session: sesh,
ces: []*connEnclave{},
sentNotifyCh: make(chan *sentNotifier, sentNotifyBacklog),
dispatCh: make(chan []byte, dispatchBacklog),
newConnCh: make(chan net.Conn, newConnBacklog),
closingCECh: make(chan *connEnclave, 5),
die: make(chan struct{}),
}
ce := &connEnclave{
sb: sb,
@ -70,80 +46,74 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
sb.ces = append(sb.ces, ce)
go sb.deplex(ce)
go sb.dispatch()
return sb
}
// Everytime after a remoteConn sends something, it constructs this struct
// Which is sent back to dispatch() through sentNotifyCh to tell dispatch
// how many bytes it has sent
type sentNotifier struct {
ce *connEnclave
sent int
}
func (ce *connEnclave) send(data []byte) {
// TODO: error handling
func (sb *switchboard) send(data []byte) (int, error) {
ce := sb.optimum.Load().(*connEnclave)
atomic.AddUint32(&ce.sendQueue, uint32(len(data)))
go sb.updateOptimum()
n, err := ce.remoteConn.Write(data)
if err != nil {
ce.sb.closingCECh <- ce
log.Println(err)
return 0, err
// TODO
}
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
go sb.updateOptimum()
return n, nil
}
sn := &sentNotifier{
ce,
n,
func (sb *switchboard) updateOptimum() {
currentOpti := sb.optimum.Load().(*connEnclave)
currentOptiQ := atomic.LoadUint32(&currentOpti.sendQueue)
sb.cesM.RLock()
for _, ce := range sb.ces {
ceQ := atomic.LoadUint32(&ce.sendQueue)
if ceQ < currentOptiQ {
currentOpti = ce
currentOptiQ = ceQ
}
ce.sb.sentNotifyCh <- sn
}
sb.cesM.RUnlock()
sb.optimum.Store(currentOpti)
}
// Dispatcher sends data coming from a stream to a remote connection
// I used channels here because I didn't want to use mutex
func (sb *switchboard) dispatch() {
var dying bool
for {
select {
// dispatCh receives data from stream.Write
case data := <-sb.dispatCh:
go sb.ces[0].send(data)
sb.ces[0].sendQueue += len(data)
case notified := <-sb.sentNotifyCh:
notified.ce.sendQueue -= notified.sent
sort.Sort(byQ(sb.ces))
case conn := <-sb.newConnCh:
log.Println("newConn")
func (sb *switchboard) addConn(conn net.Conn) {
newCe := &connEnclave{
sb: sb,
remoteConn: conn,
sendQueue: 0,
}
sb.cesM.Lock()
sb.ces = append(sb.ces, newCe)
sb.cesM.Unlock()
sb.optimum.Store(newCe)
go sb.deplex(newCe)
case closing := <-sb.closingCECh:
log.Println("Closing conn")
}
func (sb *switchboard) removeConn(closing *connEnclave) {
sb.cesM.Lock()
for i, ce := range sb.ces {
if closing == ce {
sb.ces = append(sb.ces[:i], sb.ces[i+1:]...)
break
}
}
if len(sb.ces) == 0 && !dying {
sb.cesM.Unlock()
if len(sb.ces) == 0 {
sb.session.Close()
}
case <-sb.die:
dying = true
for _, ce := range sb.ces {
ce.remoteConn.Close()
}
return
}
}
}
// deplex function costantly reads from a TCP connection
// it is responsible to act in response to the deobfsed header
// i.e. should a new stream be added? which existing stream should be closed?
func (sb *switchboard) shutdown() {
for _, ce := range sb.ces {
ce.remoteConn.Close()
}
}
// deplex function costantly reads from a TCP connection, call deobfs and distribute it
// to the corresponding frame
func (sb *switchboard) deplex(ce *connEnclave) {
buf := make([]byte, 20480)
for {
@ -151,7 +121,7 @@ func (sb *switchboard) deplex(ce *connEnclave) {
if err != nil {
log.Println(err)
go ce.remoteConn.Close()
sb.closingCECh <- ce
sb.removeConn(ce)
return
}
frame := sb.session.deobfs(buf[:i])

View File

@ -9,7 +9,6 @@ import (
prand "math/rand"
"net"
"strconv"
"time"
)
func AESEncrypt(iv []byte, key []byte, plaintext []byte) []byte {
@ -69,7 +68,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
left := dataLength
readPtr := 5
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
for left != 0 {
// If left > buffer size (i.e. our message got segmented), the entire MTU is read
// if left = buffer size, the entire buffer is all there left to read
@ -82,7 +80,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
left -= i
readPtr += i
}
conn.SetReadDeadline(time.Time{})
n = 5 + dataLength
return