From 44d2c0e0730c0c8ce8186dd7b1d77f3bde29e2c3 Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Fri, 5 Oct 2018 23:44:20 +0100 Subject: [PATCH] Multiplex initial commit --- internal/multiplex/frame.go | 10 +++ internal/multiplex/frameSorter.go | 96 ++++++++++++++++++++++++ internal/multiplex/session.go | 98 +++++++++++++++++++++++++ internal/multiplex/stream.go | 99 +++++++++++++++++++++++++ internal/multiplex/switchboard.go | 118 ++++++++++++++++++++++++++++++ 5 files changed, 421 insertions(+) create mode 100644 internal/multiplex/frame.go create mode 100644 internal/multiplex/frameSorter.go create mode 100644 internal/multiplex/session.go create mode 100644 internal/multiplex/stream.go create mode 100644 internal/multiplex/switchboard.go diff --git a/internal/multiplex/frame.go b/internal/multiplex/frame.go new file mode 100644 index 0000000..49328e9 --- /dev/null +++ b/internal/multiplex/frame.go @@ -0,0 +1,10 @@ +package multiplex + +import () + +type Frame struct { + StreamID uint32 + Seq uint32 + ClosedStreamID uint32 + Payload []byte +} diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go new file mode 100644 index 0000000..432fdfe --- /dev/null +++ b/internal/multiplex/frameSorter.go @@ -0,0 +1,96 @@ +package multiplex + +import ( + "container/heap" +) + +// The data is multiplexed through several TCP connections, therefore the +// order of arrival is not guaranteed. A stream's first packet may be sent through +// connection0 and its second packet may be sent through connection1. Although both +// packets are transmitted reliably (as TCP is reliable), packet1 may arrive to the +// remote side before packet0. +// +// However, shadowsocks' protocol does not provide sequence control. We must therefore +// make sure packets arrive in order. +// +// Cloak packets will have a 32-bit sequence number on them, so we know in which order +// they should be sent to shadowsocks. In the case that the packets arrive out-of-order, +// the code in this file provides buffering and sorting. +// +// Similar to TCP, the next seq number after 2^32-1 is 0. This is called wrap around. +// +// Note that in golang, integer overflow results in wrap around +// +// Stream.nextRecvSeq is the expected sequence number of the next packet +// Stream.rev counts the amount of time the sequence number gets wrapped + +type frameNode struct { + seq uint32 + trueSeq uint64 + frame *Frame +} +type sorterHeap []*frameNode + +func (sh sorterHeap) Less(i, j int) bool { + return sh[i].trueSeq < sh[j].trueSeq +} +func (sh sorterHeap) Len() int { + return len(sh) +} +func (sh sorterHeap) Swap(i, j int) { + sh[i], sh[j] = sh[j], sh[i] +} + +func (sh *sorterHeap) Push(x interface{}) { + *sh = append(*sh, x.(*frameNode)) +} + +func (sh *sorterHeap) Pop() interface{} { + old := *sh + n := len(old) + x := old[n-1] + *sh = old[0 : n-1] + return x +} + +func (s *Stream) recvNewFrame(f *Frame) { + // For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255 + fs := &frameNode{ + f.Seq, + 0, + f, + } + + // TODO: if a malicious client resend a previously sent seq number, what will happen? + if fs.seq < s.nextRecvSeq { + // e.g. we are on rev=0 (wrap has not happened yet) + // and we get the order of recv as 253 254 0 1 + // after 254, nextN should be 255, but 0 is received and 0 < 255 + // now 0 should have a trueSeq of 256 + if !s.wrapMode { + // wrapMode is true when the latest seq is wrapped but nextN is not + s.wrapMode = true + } + fs.trueSeq = uint64(2<<16*(s.rev+1)) + uint64(fs.seq) + 1 + // +1 because wrapped 0 should have trueSeq of 256 instead of 255 + // when this bit was run on 1, the trueSeq of 1 would become 256 + } else { + fs.trueSeq = uint64(2<<16*s.rev) + uint64(fs.seq) + // when this bit was run on 255, the trueSeq of 255 would be 255 + } + heap.Push(&s.sh, fs) + + // Keep popping from the heap until empty or to the point that the wanted seq was not received + for len(s.sh) > 0 && s.sh[0].seq == s.nextRecvSeq { + + s.sortedBufCh <- heap.Pop(&s.sh).(*frameNode).frame.Payload + + s.nextRecvSeq += 1 + if s.nextRecvSeq == 0 { + // when nextN is wrapped, wrapMode becomes false and rev+1 + s.rev += 1 + s.wrapMode = false + } + } + +} diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go new file mode 100644 index 0000000..ef3d32f --- /dev/null +++ b/internal/multiplex/session.go @@ -0,0 +1,98 @@ +package multiplex + +import ( + "net" + "sync" +) + +const ( + // Copied from smux + errBrokenPipe = "broken pipe" + acceptBacklog = 1024 + + closeBacklog = 512 +) + +type Session struct { + id int + + // Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header + obfs func(*Frame) []byte + // Remove TLS header, decrypt and unmarshall multiplexing headers + deobfs func([]byte) *Frame + // This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain + obfsedReader func(net.Conn, []byte) (int, error) + + nextStreamIDM sync.Mutex + nextStreamID uint32 + + streamsM sync.RWMutex + streams map[uint32]*Stream + + // Switchboard manages all connections to remote + sb *switchboard + + // For accepting new streams + acceptCh chan *Stream + // Once a stream.Close is called, it sends its streamID to this channel + // to be read by another stream to send the streamID to notify the remote + // that this stream is closed + closeQCh chan uint32 +} + +// TODO: put this in main maybe? +func MakeSession(id int, conns []net.Conn) *Session { + sesh := &Session{ + id: id, + nextStreamID: 0, + streams: make(map[uint32]*Stream), + acceptCh: make(chan *Stream, acceptBacklog), + closeQCh: make(chan uint32, closeBacklog), + } + sesh.sb = makeSwitchboard(conns, sesh) + sesh.sb.run() + return sesh +} + +func (sesh *Session) OpenStream() (*Stream, error) { + sesh.nextStreamIDM.Lock() + id := sesh.nextStreamID + sesh.nextStreamID += 1 + sesh.nextStreamIDM.Unlock() + + stream := makeStream(id, sesh) + + sesh.streamsM.Lock() + sesh.streams[id] = stream + sesh.streamsM.Unlock() + return stream, nil +} + +func (sesh *Session) AcceptStream() (*Stream, error) { + stream := <-sesh.acceptCh + return stream, nil +} + +func (sesh *Session) delStream(id uint32) { + sesh.streamsM.RLock() + delete(sesh.streams, id) + sesh.streamsM.RUnlock() +} + +func (sesh *Session) isStream(id uint32) bool { + sesh.streamsM.Lock() + _, ok := sesh.streams[id] + sesh.streamsM.Unlock() + return ok +} + +func (sesh *Session) getStream(id uint32) *Stream { + sesh.streamsM.Lock() + defer sesh.streamsM.Unlock() + return sesh.streams[id] +} + +func (sesh *Session) addStream(id uint32) { + stream := makeStream(id, sesh) + sesh.acceptCh <- stream +} diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go new file mode 100644 index 0000000..dac338d --- /dev/null +++ b/internal/multiplex/stream.go @@ -0,0 +1,99 @@ +package multiplex + +import ( + "errors" + "io" + "sync" +) + +const ( + readBuffer = 10240 +) + +type Stream struct { + id uint32 + + session *Session + + // Copied from smux + dieM sync.Mutex + die chan struct{} + + // Explanations of the following 4 fields can be found in frameSorter.go + nextRecvSeq uint32 + rev int + sh sorterHeap + wrapMode bool + + sortedBufCh chan []byte + + nextSendSeqM sync.Mutex + nextSendSeq uint32 +} + +func makeStream(id uint32, sesh *Session) *Stream { + stream := &Stream{ + id: id, + session: sesh, + } + return stream +} + +func (stream *Stream) Read(buf []byte) (n int, err error) { + if len(buf) == 0 { + select { + case <-stream.die: + return 0, errors.New(errBrokenPipe) + case data := <-stream.sortedBufCh: + if len(data) > 0 { + copy(buf, data) + return len(data), nil + } else { + // TODO: close stream here or not? + return 0, io.EOF + } + } + } + return 0, errors.New(errBrokenPipe) + +} + +func (stream *Stream) Write(in []byte) (n int, err error) { + select { + case <-stream.die: + return 0, errors.New(errBrokenPipe) + default: + } + + var closingID uint32 + + select { + case closingID = <-stream.session.closeQCh: + default: + } + + f := &Frame{ + StreamID: stream.id, + Seq: stream.nextSendSeq, + ClosedStreamID: closingID, + } + copy(f.Payload, in) + + stream.nextSendSeqM.Lock() + stream.nextSendSeq += 1 + stream.nextSendSeqM.Unlock() + + tlsRecord := stream.session.obfs(f) + stream.session.sb.dispatCh <- tlsRecord + + return len(in), nil + +} + +func (stream *Stream) Close() error { + stream.session.delStream(stream.id) + close(stream.die) + close(stream.sortedBufCh) + stream.session.closeQCh <- stream.id + return nil +} diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go new file mode 100644 index 0000000..af83099 --- /dev/null +++ b/internal/multiplex/switchboard.go @@ -0,0 +1,118 @@ +package multiplex + +import ( + "net" + "sort" +) + +const ( + sentNotifyBacklog = 1024 + dispatchBacklog = 10240 +) + +type switchboard struct { + session *Session + + ces []*connEnclave + + // For telling dispatcher how many bytes have been sent after Connection.send. + sentNotifyCh chan *sentNotifier + dispatCh chan []byte +} + +// Some data comes from a Stream to be sent through one of the many +// remoteConn, but which remoteConn should we use to send the data? +// +// In this case, we pick the remoteConn that has about the smallest sendQueue. +// Though "smallest" is not guaranteed because it doesn't has to be +type connEnclave struct { + sb *switchboard + remoteConn net.Conn + sendQueue int +} + +type byQ []*connEnclave + +func (a byQ) Len() int { + return len(a) +} +func (a byQ) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} +func (a byQ) Less(i, j int) bool { + return a[i].sendQueue < a[j].sendQueue +} + +func makeSwitchboard(conns []net.Conn, sesh *Session) *switchboard { + sb := &switchboard{ + session: sesh, + ces: []*connEnclave{}, + sentNotifyCh: make(chan *sentNotifier, sentNotifyBacklog), + dispatCh: make(chan []byte, dispatchBacklog), + } + for _, c := range conns { + ce := &connEnclave{ + sb: sb, + remoteConn: c, + sendQueue: 0, + } + sb.ces = append(sb.ces, ce) + } + + return sb +} + +func (sb *switchboard) run() { + go startDispatcher() + go startDeplexer() +} + +// Everytime after a remoteConn sends something, it constructs this struct +// Which is sent back to dispatch() through sentNotifyCh to tell dispatch +// how many bytes it has sent +type sentNotifier struct { + ce *connEnclave + sent int +} + +func (ce *connEnclave) send(data []byte) { + // TODO: error handling + n, _ := ce.remoteConn.Write(data) + sn := &sentNotifier{ + ce, + n, + } + ce.sb.sentNotifyCh <- sn +} + +// Dispatcher sends data coming from a stream to a remote connection +func (sb *switchboard) startDispatcher() { + for { + select { + // dispatCh receives data from stream.Write + case data := <-sb.dispatCh: + go sb.ces[0].send(data) + sb.ces[0].sendQueue += len(data) + case notified := <-sb.sentNotifyCh: + notified.ce.sendQueue -= notified.sent + sort.Sort(byQ(sb.ces)) + } + } +} + +// Deplexer sends data coming from a remote connection to a stream +func (sb *switchboard) startDeplexer() { + for _, ce := range sb.ces { + go func() { + buf := make([]byte, 20480) + for { + sb.session.obfsedReader(ce.remoteConn, buf) + frame := sb.session.deobfs(buf) + if !sb.session.isStream(frame.StreamID) { + sb.session.addStream(frame.StreamID) + } + sb.session.getStream(frame.StreamID).recvNewFrame(frame) + } + }() + } +}