mirror of https://github.com/cbeuw/Cloak
Multiplex initial commit
This commit is contained in:
commit
44d2c0e073
|
|
@ -0,0 +1,10 @@
|
||||||
|
package multiplex
|
||||||
|
|
||||||
|
import ()
|
||||||
|
|
||||||
|
type Frame struct {
|
||||||
|
StreamID uint32
|
||||||
|
Seq uint32
|
||||||
|
ClosedStreamID uint32
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
package multiplex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The data is multiplexed through several TCP connections, therefore the
|
||||||
|
// order of arrival is not guaranteed. A stream's first packet may be sent through
|
||||||
|
// connection0 and its second packet may be sent through connection1. Although both
|
||||||
|
// packets are transmitted reliably (as TCP is reliable), packet1 may arrive to the
|
||||||
|
// remote side before packet0.
|
||||||
|
//
|
||||||
|
// However, shadowsocks' protocol does not provide sequence control. We must therefore
|
||||||
|
// make sure packets arrive in order.
|
||||||
|
//
|
||||||
|
// Cloak packets will have a 32-bit sequence number on them, so we know in which order
|
||||||
|
// they should be sent to shadowsocks. In the case that the packets arrive out-of-order,
|
||||||
|
// the code in this file provides buffering and sorting.
|
||||||
|
//
|
||||||
|
// Similar to TCP, the next seq number after 2^32-1 is 0. This is called wrap around.
|
||||||
|
//
|
||||||
|
// Note that in golang, integer overflow results in wrap around
|
||||||
|
//
|
||||||
|
// Stream.nextRecvSeq is the expected sequence number of the next packet
|
||||||
|
// Stream.rev counts the amount of time the sequence number gets wrapped
|
||||||
|
|
||||||
|
type frameNode struct {
|
||||||
|
seq uint32
|
||||||
|
trueSeq uint64
|
||||||
|
frame *Frame
|
||||||
|
}
|
||||||
|
type sorterHeap []*frameNode
|
||||||
|
|
||||||
|
func (sh sorterHeap) Less(i, j int) bool {
|
||||||
|
return sh[i].trueSeq < sh[j].trueSeq
|
||||||
|
}
|
||||||
|
func (sh sorterHeap) Len() int {
|
||||||
|
return len(sh)
|
||||||
|
}
|
||||||
|
func (sh sorterHeap) Swap(i, j int) {
|
||||||
|
sh[i], sh[j] = sh[j], sh[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh *sorterHeap) Push(x interface{}) {
|
||||||
|
*sh = append(*sh, x.(*frameNode))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh *sorterHeap) Pop() interface{} {
|
||||||
|
old := *sh
|
||||||
|
n := len(old)
|
||||||
|
x := old[n-1]
|
||||||
|
*sh = old[0 : n-1]
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) recvNewFrame(f *Frame) {
|
||||||
|
// For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255
|
||||||
|
fs := &frameNode{
|
||||||
|
f.Seq,
|
||||||
|
0,
|
||||||
|
f,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: if a malicious client resend a previously sent seq number, what will happen?
|
||||||
|
if fs.seq < s.nextRecvSeq {
|
||||||
|
// e.g. we are on rev=0 (wrap has not happened yet)
|
||||||
|
// and we get the order of recv as 253 254 0 1
|
||||||
|
// after 254, nextN should be 255, but 0 is received and 0 < 255
|
||||||
|
// now 0 should have a trueSeq of 256
|
||||||
|
if !s.wrapMode {
|
||||||
|
// wrapMode is true when the latest seq is wrapped but nextN is not
|
||||||
|
s.wrapMode = true
|
||||||
|
}
|
||||||
|
fs.trueSeq = uint64(2<<16*(s.rev+1)) + uint64(fs.seq) + 1
|
||||||
|
// +1 because wrapped 0 should have trueSeq of 256 instead of 255
|
||||||
|
// when this bit was run on 1, the trueSeq of 1 would become 256
|
||||||
|
} else {
|
||||||
|
fs.trueSeq = uint64(2<<16*s.rev) + uint64(fs.seq)
|
||||||
|
// when this bit was run on 255, the trueSeq of 255 would be 255
|
||||||
|
}
|
||||||
|
heap.Push(&s.sh, fs)
|
||||||
|
|
||||||
|
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
||||||
|
for len(s.sh) > 0 && s.sh[0].seq == s.nextRecvSeq {
|
||||||
|
|
||||||
|
s.sortedBufCh <- heap.Pop(&s.sh).(*frameNode).frame.Payload
|
||||||
|
|
||||||
|
s.nextRecvSeq += 1
|
||||||
|
if s.nextRecvSeq == 0 {
|
||||||
|
// when nextN is wrapped, wrapMode becomes false and rev+1
|
||||||
|
s.rev += 1
|
||||||
|
s.wrapMode = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
package multiplex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Copied from smux
|
||||||
|
errBrokenPipe = "broken pipe"
|
||||||
|
acceptBacklog = 1024
|
||||||
|
|
||||||
|
closeBacklog = 512
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
id int
|
||||||
|
|
||||||
|
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
|
||||||
|
obfs func(*Frame) []byte
|
||||||
|
// Remove TLS header, decrypt and unmarshall multiplexing headers
|
||||||
|
deobfs func([]byte) *Frame
|
||||||
|
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
|
||||||
|
obfsedReader func(net.Conn, []byte) (int, error)
|
||||||
|
|
||||||
|
nextStreamIDM sync.Mutex
|
||||||
|
nextStreamID uint32
|
||||||
|
|
||||||
|
streamsM sync.RWMutex
|
||||||
|
streams map[uint32]*Stream
|
||||||
|
|
||||||
|
// Switchboard manages all connections to remote
|
||||||
|
sb *switchboard
|
||||||
|
|
||||||
|
// For accepting new streams
|
||||||
|
acceptCh chan *Stream
|
||||||
|
// Once a stream.Close is called, it sends its streamID to this channel
|
||||||
|
// to be read by another stream to send the streamID to notify the remote
|
||||||
|
// that this stream is closed
|
||||||
|
closeQCh chan uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: put this in main maybe?
|
||||||
|
func MakeSession(id int, conns []net.Conn) *Session {
|
||||||
|
sesh := &Session{
|
||||||
|
id: id,
|
||||||
|
nextStreamID: 0,
|
||||||
|
streams: make(map[uint32]*Stream),
|
||||||
|
acceptCh: make(chan *Stream, acceptBacklog),
|
||||||
|
closeQCh: make(chan uint32, closeBacklog),
|
||||||
|
}
|
||||||
|
sesh.sb = makeSwitchboard(conns, sesh)
|
||||||
|
sesh.sb.run()
|
||||||
|
return sesh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sesh *Session) OpenStream() (*Stream, error) {
|
||||||
|
sesh.nextStreamIDM.Lock()
|
||||||
|
id := sesh.nextStreamID
|
||||||
|
sesh.nextStreamID += 1
|
||||||
|
sesh.nextStreamIDM.Unlock()
|
||||||
|
|
||||||
|
stream := makeStream(id, sesh)
|
||||||
|
|
||||||
|
sesh.streamsM.Lock()
|
||||||
|
sesh.streams[id] = stream
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sesh *Session) AcceptStream() (*Stream, error) {
|
||||||
|
stream := <-sesh.acceptCh
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sesh *Session) delStream(id uint32) {
|
||||||
|
sesh.streamsM.RLock()
|
||||||
|
delete(sesh.streams, id)
|
||||||
|
sesh.streamsM.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sesh *Session) isStream(id uint32) bool {
|
||||||
|
sesh.streamsM.Lock()
|
||||||
|
_, ok := sesh.streams[id]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sesh *Session) getStream(id uint32) *Stream {
|
||||||
|
sesh.streamsM.Lock()
|
||||||
|
defer sesh.streamsM.Unlock()
|
||||||
|
return sesh.streams[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sesh *Session) addStream(id uint32) {
|
||||||
|
stream := makeStream(id, sesh)
|
||||||
|
sesh.acceptCh <- stream
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
package multiplex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
readBuffer = 10240
|
||||||
|
)
|
||||||
|
|
||||||
|
type Stream struct {
|
||||||
|
id uint32
|
||||||
|
|
||||||
|
session *Session
|
||||||
|
|
||||||
|
// Copied from smux
|
||||||
|
dieM sync.Mutex
|
||||||
|
die chan struct{}
|
||||||
|
|
||||||
|
// Explanations of the following 4 fields can be found in frameSorter.go
|
||||||
|
nextRecvSeq uint32
|
||||||
|
rev int
|
||||||
|
sh sorterHeap
|
||||||
|
wrapMode bool
|
||||||
|
|
||||||
|
sortedBufCh chan []byte
|
||||||
|
|
||||||
|
nextSendSeqM sync.Mutex
|
||||||
|
nextSendSeq uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeStream(id uint32, sesh *Session) *Stream {
|
||||||
|
stream := &Stream{
|
||||||
|
id: id,
|
||||||
|
session: sesh,
|
||||||
|
}
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *Stream) Read(buf []byte) (n int, err error) {
|
||||||
|
if len(buf) == 0 {
|
||||||
|
select {
|
||||||
|
case <-stream.die:
|
||||||
|
return 0, errors.New(errBrokenPipe)
|
||||||
|
case data := <-stream.sortedBufCh:
|
||||||
|
if len(data) > 0 {
|
||||||
|
copy(buf, data)
|
||||||
|
return len(data), nil
|
||||||
|
} else {
|
||||||
|
// TODO: close stream here or not?
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, errors.New(errBrokenPipe)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *Stream) Write(in []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case <-stream.die:
|
||||||
|
return 0, errors.New(errBrokenPipe)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
var closingID uint32
|
||||||
|
|
||||||
|
select {
|
||||||
|
case closingID = <-stream.session.closeQCh:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
f := &Frame{
|
||||||
|
StreamID: stream.id,
|
||||||
|
Seq: stream.nextSendSeq,
|
||||||
|
ClosedStreamID: closingID,
|
||||||
|
}
|
||||||
|
copy(f.Payload, in)
|
||||||
|
|
||||||
|
stream.nextSendSeqM.Lock()
|
||||||
|
stream.nextSendSeq += 1
|
||||||
|
stream.nextSendSeqM.Unlock()
|
||||||
|
|
||||||
|
tlsRecord := stream.session.obfs(f)
|
||||||
|
stream.session.sb.dispatCh <- tlsRecord
|
||||||
|
|
||||||
|
return len(in), nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *Stream) Close() error {
|
||||||
|
stream.session.delStream(stream.id)
|
||||||
|
close(stream.die)
|
||||||
|
close(stream.sortedBufCh)
|
||||||
|
stream.session.closeQCh <- stream.id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
package multiplex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sentNotifyBacklog = 1024
|
||||||
|
dispatchBacklog = 10240
|
||||||
|
)
|
||||||
|
|
||||||
|
type switchboard struct {
|
||||||
|
session *Session
|
||||||
|
|
||||||
|
ces []*connEnclave
|
||||||
|
|
||||||
|
// For telling dispatcher how many bytes have been sent after Connection.send.
|
||||||
|
sentNotifyCh chan *sentNotifier
|
||||||
|
dispatCh chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some data comes from a Stream to be sent through one of the many
|
||||||
|
// remoteConn, but which remoteConn should we use to send the data?
|
||||||
|
//
|
||||||
|
// In this case, we pick the remoteConn that has about the smallest sendQueue.
|
||||||
|
// Though "smallest" is not guaranteed because it doesn't has to be
|
||||||
|
type connEnclave struct {
|
||||||
|
sb *switchboard
|
||||||
|
remoteConn net.Conn
|
||||||
|
sendQueue int
|
||||||
|
}
|
||||||
|
|
||||||
|
type byQ []*connEnclave
|
||||||
|
|
||||||
|
func (a byQ) Len() int {
|
||||||
|
return len(a)
|
||||||
|
}
|
||||||
|
func (a byQ) Swap(i, j int) {
|
||||||
|
a[i], a[j] = a[j], a[i]
|
||||||
|
}
|
||||||
|
func (a byQ) Less(i, j int) bool {
|
||||||
|
return a[i].sendQueue < a[j].sendQueue
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeSwitchboard(conns []net.Conn, sesh *Session) *switchboard {
|
||||||
|
sb := &switchboard{
|
||||||
|
session: sesh,
|
||||||
|
ces: []*connEnclave{},
|
||||||
|
sentNotifyCh: make(chan *sentNotifier, sentNotifyBacklog),
|
||||||
|
dispatCh: make(chan []byte, dispatchBacklog),
|
||||||
|
}
|
||||||
|
for _, c := range conns {
|
||||||
|
ce := &connEnclave{
|
||||||
|
sb: sb,
|
||||||
|
remoteConn: c,
|
||||||
|
sendQueue: 0,
|
||||||
|
}
|
||||||
|
sb.ces = append(sb.ces, ce)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sb *switchboard) run() {
|
||||||
|
go startDispatcher()
|
||||||
|
go startDeplexer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Everytime after a remoteConn sends something, it constructs this struct
|
||||||
|
// Which is sent back to dispatch() through sentNotifyCh to tell dispatch
|
||||||
|
// how many bytes it has sent
|
||||||
|
type sentNotifier struct {
|
||||||
|
ce *connEnclave
|
||||||
|
sent int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce *connEnclave) send(data []byte) {
|
||||||
|
// TODO: error handling
|
||||||
|
n, _ := ce.remoteConn.Write(data)
|
||||||
|
sn := &sentNotifier{
|
||||||
|
ce,
|
||||||
|
n,
|
||||||
|
}
|
||||||
|
ce.sb.sentNotifyCh <- sn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatcher sends data coming from a stream to a remote connection
|
||||||
|
func (sb *switchboard) startDispatcher() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// dispatCh receives data from stream.Write
|
||||||
|
case data := <-sb.dispatCh:
|
||||||
|
go sb.ces[0].send(data)
|
||||||
|
sb.ces[0].sendQueue += len(data)
|
||||||
|
case notified := <-sb.sentNotifyCh:
|
||||||
|
notified.ce.sendQueue -= notified.sent
|
||||||
|
sort.Sort(byQ(sb.ces))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deplexer sends data coming from a remote connection to a stream
|
||||||
|
func (sb *switchboard) startDeplexer() {
|
||||||
|
for _, ce := range sb.ces {
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 20480)
|
||||||
|
for {
|
||||||
|
sb.session.obfsedReader(ce.remoteConn, buf)
|
||||||
|
frame := sb.session.deobfs(buf)
|
||||||
|
if !sb.session.isStream(frame.StreamID) {
|
||||||
|
sb.session.addStream(frame.StreamID)
|
||||||
|
}
|
||||||
|
sb.session.getStream(frame.StreamID).recvNewFrame(frame)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue