mirror of https://github.com/cbeuw/Cloak
Rework switchboard dispatch
This commit is contained in:
parent
9e4aedbdc1
commit
f476650953
|
|
@ -26,14 +26,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
|
|||
for {
|
||||
i, err := io.ReadAtLeast(src, buf, 1)
|
||||
if err != nil || i == 0 {
|
||||
log.Println(err)
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
}
|
||||
i, err = dst.Write(buf[:i])
|
||||
if err != nil || i == 0 {
|
||||
log.Println(err)
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ import (
|
|||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
//"net/http"
|
||||
//_ "net/http/pprof"
|
||||
"os"
|
||||
"runtime"
|
||||
//"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -27,14 +27,12 @@ func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
|
|||
for {
|
||||
i, err := io.ReadAtLeast(src, buf, 1)
|
||||
if err != nil || i == 0 {
|
||||
log.Println(err)
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
}
|
||||
i, err = dst.Write(buf[:i])
|
||||
if err != nil || i == 0 {
|
||||
log.Println(err)
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
|
|
@ -136,10 +134,10 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
|||
}
|
||||
|
||||
func main() {
|
||||
runtime.SetBlockProfileRate(5)
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
|
||||
}()
|
||||
//runtime.SetBlockProfileRate(5)
|
||||
//go func() {
|
||||
// log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
|
||||
//}()
|
||||
// Should be 127.0.0.1 to listen to ss-server on this machine
|
||||
var localHost string
|
||||
// server_port in ss config, same as remotePort in plugin mode
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package multiplex
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
|
@ -62,7 +61,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([]
|
|||
}
|
||||
|
||||
func (sesh *Session) AddConnection(conn net.Conn) {
|
||||
sesh.sb.newConnCh <- conn
|
||||
sesh.sb.addConn(conn)
|
||||
}
|
||||
|
||||
func (sesh *Session) OpenStream() (*Stream, error) {
|
||||
|
|
@ -106,7 +105,6 @@ func (sesh *Session) getStream(id uint32) *Stream {
|
|||
|
||||
// addStream is used when the remote opened a new stream and we got notified
|
||||
func (sesh *Session) addStream(id uint32) *Stream {
|
||||
log.Printf("Adding stream %v", id)
|
||||
stream := makeStream(id, sesh)
|
||||
sesh.streamsM.Lock()
|
||||
sesh.streams[id] = stream
|
||||
|
|
@ -136,7 +134,7 @@ func (sesh *Session) Close() error {
|
|||
}
|
||||
sesh.streamsM.Unlock()
|
||||
|
||||
close(sesh.sb.die)
|
||||
sesh.sb.shutdown()
|
||||
return nil
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
|||
if len(buf) == 0 {
|
||||
select {
|
||||
case <-stream.die:
|
||||
log.Printf("Stream %v dying\n", stream.id)
|
||||
return 0, errBrokenStream
|
||||
default:
|
||||
return 0, nil
|
||||
|
|
@ -63,7 +62,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
|||
}
|
||||
select {
|
||||
case <-stream.die:
|
||||
log.Printf("Stream %v dying\n", stream.id)
|
||||
return 0, errBrokenStream
|
||||
case data := <-stream.sortedBufCh:
|
||||
if len(buf) < len(data) {
|
||||
|
|
@ -79,7 +77,6 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
|||
func (stream *Stream) Write(in []byte) (n int, err error) {
|
||||
select {
|
||||
case <-stream.die:
|
||||
log.Printf("Stream %v dying\n", stream.id)
|
||||
return 0, errBrokenStream
|
||||
default:
|
||||
}
|
||||
|
|
@ -94,9 +91,9 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
|
|||
atomic.AddUint32(&stream.nextSendSeq, 1)
|
||||
|
||||
tlsRecord := stream.session.obfs(f)
|
||||
stream.session.sb.dispatCh <- tlsRecord
|
||||
n, err = stream.session.sb.send(tlsRecord)
|
||||
|
||||
return len(in), nil
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +106,6 @@ func (stream *Stream) passiveClose() error {
|
|||
if stream.closing {
|
||||
return errRepeatStreamClosing
|
||||
}
|
||||
log.Printf("ID: %v passiveclosing\n", stream.id)
|
||||
stream.closing = true
|
||||
close(stream.die)
|
||||
stream.session.delStream(stream.id)
|
||||
|
|
@ -125,13 +121,11 @@ func (stream *Stream) Close() error {
|
|||
if stream.closing {
|
||||
return errRepeatStreamClosing
|
||||
}
|
||||
log.Printf("ID: %v closing\n", stream.id)
|
||||
stream.closing = true
|
||||
close(stream.die)
|
||||
|
||||
prand.Seed(int64(stream.id))
|
||||
padLen := int(math.Floor(prand.Float64()*200 + 300))
|
||||
log.Println(padLen)
|
||||
pad := make([]byte, padLen)
|
||||
prand.Read(pad)
|
||||
f := &Frame{
|
||||
|
|
@ -141,7 +135,7 @@ func (stream *Stream) Close() error {
|
|||
Payload: pad,
|
||||
}
|
||||
tlsRecord := stream.session.obfs(f)
|
||||
stream.session.sb.dispatCh <- tlsRecord
|
||||
stream.session.sb.send(tlsRecord)
|
||||
|
||||
stream.session.delStream(stream.id)
|
||||
return nil
|
||||
|
|
@ -150,7 +144,6 @@ func (stream *Stream) Close() error {
|
|||
// Same as Close() but no call to session.delStream.
|
||||
// This is called in session.Close() to avoid mutex deadlock
|
||||
func (stream *Stream) closeNoDelMap() error {
|
||||
log.Printf("ID: %v closing\n", stream.id)
|
||||
|
||||
// Lock here because closing a closed channel causes panic
|
||||
stream.closingM.Lock()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ package multiplex
|
|||
import (
|
||||
"log"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -16,39 +17,19 @@ const (
|
|||
type switchboard struct {
|
||||
session *Session
|
||||
|
||||
optimum atomic.Value
|
||||
cesM sync.RWMutex
|
||||
ces []*connEnclave
|
||||
|
||||
// For telling dispatcher how many bytes have been sent after Connection.send.
|
||||
sentNotifyCh chan *sentNotifier
|
||||
// dispatCh is used by streams to send new data to remote
|
||||
dispatCh chan []byte
|
||||
newConnCh chan net.Conn
|
||||
closingCECh chan *connEnclave
|
||||
die chan struct{}
|
||||
closing bool
|
||||
}
|
||||
|
||||
// Some data comes from a Stream to be sent through one of the many
|
||||
// remoteConn, but which remoteConn should we use to send the data?
|
||||
//
|
||||
// In this case, we pick the remoteConn that has about the smallest sendQueue.
|
||||
// Though "smallest" is not guaranteed because it doesn't has to be
|
||||
type connEnclave struct {
|
||||
sb *switchboard
|
||||
remoteConn net.Conn
|
||||
sendQueue int
|
||||
}
|
||||
|
||||
type byQ []*connEnclave
|
||||
|
||||
func (a byQ) Len() int {
|
||||
return len(a)
|
||||
}
|
||||
func (a byQ) Swap(i, j int) {
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
func (a byQ) Less(i, j int) bool {
|
||||
return a[i].sendQueue < a[j].sendQueue
|
||||
sendQueue uint32
|
||||
}
|
||||
|
||||
// It takes at least 1 conn to start a switchboard
|
||||
|
|
@ -56,11 +37,6 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
|
|||
sb := &switchboard{
|
||||
session: sesh,
|
||||
ces: []*connEnclave{},
|
||||
sentNotifyCh: make(chan *sentNotifier, sentNotifyBacklog),
|
||||
dispatCh: make(chan []byte, dispatchBacklog),
|
||||
newConnCh: make(chan net.Conn, newConnBacklog),
|
||||
closingCECh: make(chan *connEnclave, 5),
|
||||
die: make(chan struct{}),
|
||||
}
|
||||
ce := &connEnclave{
|
||||
sb: sb,
|
||||
|
|
@ -70,80 +46,74 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
|
|||
sb.ces = append(sb.ces, ce)
|
||||
go sb.deplex(ce)
|
||||
|
||||
go sb.dispatch()
|
||||
return sb
|
||||
}
|
||||
|
||||
// Everytime after a remoteConn sends something, it constructs this struct
|
||||
// Which is sent back to dispatch() through sentNotifyCh to tell dispatch
|
||||
// how many bytes it has sent
|
||||
type sentNotifier struct {
|
||||
ce *connEnclave
|
||||
sent int
|
||||
}
|
||||
|
||||
func (ce *connEnclave) send(data []byte) {
|
||||
// TODO: error handling
|
||||
func (sb *switchboard) send(data []byte) (int, error) {
|
||||
ce := sb.optimum.Load().(*connEnclave)
|
||||
atomic.AddUint32(&ce.sendQueue, uint32(len(data)))
|
||||
go sb.updateOptimum()
|
||||
n, err := ce.remoteConn.Write(data)
|
||||
if err != nil {
|
||||
ce.sb.closingCECh <- ce
|
||||
log.Println(err)
|
||||
return 0, err
|
||||
// TODO
|
||||
}
|
||||
|
||||
sn := &sentNotifier{
|
||||
ce,
|
||||
n,
|
||||
}
|
||||
ce.sb.sentNotifyCh <- sn
|
||||
|
||||
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
|
||||
go sb.updateOptimum()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Dispatcher sends data coming from a stream to a remote connection
|
||||
// I used channels here because I didn't want to use mutex
|
||||
func (sb *switchboard) dispatch() {
|
||||
var dying bool
|
||||
for {
|
||||
select {
|
||||
// dispatCh receives data from stream.Write
|
||||
case data := <-sb.dispatCh:
|
||||
go sb.ces[0].send(data)
|
||||
sb.ces[0].sendQueue += len(data)
|
||||
case notified := <-sb.sentNotifyCh:
|
||||
notified.ce.sendQueue -= notified.sent
|
||||
sort.Sort(byQ(sb.ces))
|
||||
case conn := <-sb.newConnCh:
|
||||
log.Println("newConn")
|
||||
func (sb *switchboard) updateOptimum() {
|
||||
currentOpti := sb.optimum.Load().(*connEnclave)
|
||||
currentOptiQ := atomic.LoadUint32(¤tOpti.sendQueue)
|
||||
sb.cesM.RLock()
|
||||
for _, ce := range sb.ces {
|
||||
ceQ := atomic.LoadUint32(&ce.sendQueue)
|
||||
if ceQ < currentOptiQ {
|
||||
currentOpti = ce
|
||||
currentOptiQ = ceQ
|
||||
}
|
||||
}
|
||||
sb.cesM.RUnlock()
|
||||
sb.optimum.Store(currentOpti)
|
||||
}
|
||||
|
||||
func (sb *switchboard) addConn(conn net.Conn) {
|
||||
|
||||
newCe := &connEnclave{
|
||||
sb: sb,
|
||||
remoteConn: conn,
|
||||
sendQueue: 0,
|
||||
}
|
||||
sb.cesM.Lock()
|
||||
sb.ces = append(sb.ces, newCe)
|
||||
sb.cesM.Unlock()
|
||||
sb.optimum.Store(newCe)
|
||||
go sb.deplex(newCe)
|
||||
case closing := <-sb.closingCECh:
|
||||
log.Println("Closing conn")
|
||||
}
|
||||
|
||||
func (sb *switchboard) removeConn(closing *connEnclave) {
|
||||
sb.cesM.Lock()
|
||||
for i, ce := range sb.ces {
|
||||
if closing == ce {
|
||||
sb.ces = append(sb.ces[:i], sb.ces[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(sb.ces) == 0 && !dying {
|
||||
sb.cesM.Unlock()
|
||||
if len(sb.ces) == 0 {
|
||||
sb.session.Close()
|
||||
}
|
||||
case <-sb.die:
|
||||
dying = true
|
||||
for _, ce := range sb.ces {
|
||||
ce.remoteConn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deplex function costantly reads from a TCP connection
|
||||
// it is responsible to act in response to the deobfsed header
|
||||
// i.e. should a new stream be added? which existing stream should be closed?
|
||||
func (sb *switchboard) shutdown() {
|
||||
for _, ce := range sb.ces {
|
||||
ce.remoteConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// deplex function costantly reads from a TCP connection, call deobfs and distribute it
|
||||
// to the corresponding frame
|
||||
func (sb *switchboard) deplex(ce *connEnclave) {
|
||||
buf := make([]byte, 20480)
|
||||
for {
|
||||
|
|
@ -151,7 +121,7 @@ func (sb *switchboard) deplex(ce *connEnclave) {
|
|||
if err != nil {
|
||||
log.Println(err)
|
||||
go ce.remoteConn.Close()
|
||||
sb.closingCECh <- ce
|
||||
sb.removeConn(ce)
|
||||
return
|
||||
}
|
||||
frame := sb.session.deobfs(buf[:i])
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import (
|
|||
prand "math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AESEncrypt(iv []byte, key []byte, plaintext []byte) []byte {
|
||||
|
|
@ -69,7 +68,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
|
|||
left := dataLength
|
||||
readPtr := 5
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
for left != 0 {
|
||||
// If left > buffer size (i.e. our message got segmented), the entire MTU is read
|
||||
// if left = buffer size, the entire buffer is all there left to read
|
||||
|
|
@ -82,7 +80,6 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
|
|||
left -= i
|
||||
readPtr += i
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
n = 5 + dataLength
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in New Issue