Experimental support for UDP

This commit is contained in:
Andy Wang 2019-08-14 10:04:27 +01:00
parent c19c43f6e8
commit 44a09219f7
6 changed files with 478 additions and 32 deletions

View File

@ -23,7 +23,7 @@ import (
var version string
func makeSession(sta *client.State, isAdmin bool) *mux.Session {
func makeSession(sta *client.State, isAdmin bool, unordered bool) *mux.Session {
log.Info("Attemtping to start a new session")
if !isAdmin {
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
@ -78,6 +78,7 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session {
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
Unordered: unordered,
}
sesh := mux.MakeSession(sta.SessionID, seshConfig)
@ -99,6 +100,7 @@ func main() {
var remoteHost string
// The proxy port,should be 443
var remotePort string
var udp bool
var config string
var b64AdminUID string
@ -116,6 +118,7 @@ func main() {
flag.StringVar(&localPort, "l", "1984", "localPort: Cloak listens to proxy clients on this port")
flag.StringVar(&remoteHost, "s", "", "remoteHost: IP of your proxy server")
flag.StringVar(&remotePort, "p", "443", "remotePort: proxy port, should be 443")
flag.BoolVar(&udp, "u", false, "udp: set this flag if the underlying proxy is using UDP protocol")
flag.StringVar(&config, "c", "ckclient.json", "config: path to the configuration file or options seperated with semicolons")
flag.StringVar(&b64AdminUID, "a", "", "adminUID: enter the adminUID to serve the admin api")
askVersion := flag.Bool("v", false, "Print the version number")
@ -164,10 +167,6 @@ func main() {
// IPv6 needs square brackets
listeningIP = "[" + listeningIP + "]"
}
listener, err := net.Listen("tcp", listeningIP+":"+sta.LocalPort)
if err != nil {
log.Fatal(err)
}
var adminUID []byte
if b64AdminUID != "" {
@ -177,50 +176,144 @@ func main() {
}
}
var tcpListener net.Listener
var network string
if udp {
network = "udp"
} else {
network = "tcp"
// TODO use the local variable instead fo sta.LocalPort
tcpListener, err = net.Listen("tcp", listeningIP+":"+sta.LocalPort)
if err != nil {
log.Fatal(err)
}
}
if adminUID != nil {
log.Infof("API base is %v:%v", listeningIP, sta.LocalPort)
sta.SessionID = 0
sta.UID = adminUID
sta.NumConn = 1
} else {
log.Infof("Listening on %v:%v for proxy clients", listeningIP, sta.LocalPort)
log.Infof("Listening on %v %v:%v for proxy clients", network, listeningIP, sta.LocalPort)
}
var sesh *mux.Session
for {
localConn, err := listener.Accept()
if udp {
localUDPAddr, err := net.ResolveUDPAddr("udp", listeningIP+":"+localPort)
if err != nil {
log.Error(err)
continue
log.Fatal(err)
}
if sesh == nil || sesh.IsClosed() {
sesh = makeSession(sta, adminUID != nil)
localConn, err := net.ListenUDP("udp", localUDPAddr)
if err != nil {
log.Fatal(err)
}
go func() {
for {
var otherEnd atomic.Value
data := make([]byte, 10240)
i, err := io.ReadAtLeast(localConn, data, 1)
i, oe, err := localConn.ReadFromUDP(data)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
localConn.Close()
return
}
otherEnd.Store(oe)
if sesh == nil || sesh.IsClosed() {
sesh = makeSession(sta, adminUID != nil, true)
}
log.Debugf("proxy local address %v", otherEnd.Load().(*net.UDPAddr).String())
stream, err := sesh.OpenStream()
if err != nil {
log.Errorf("Failed to open stream: %v", err)
localConn.Close()
//localConnWrite.Close()
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Errorf("Failed to write to stream: %v", err)
localConn.Close()
//localConnWrite.Close()
stream.Close()
return
}
go util.Pipe(localConn, stream)
util.Pipe(stream, localConn)
}()
go func() {
buf := make([]byte, 16380)
for {
i, err := io.ReadAtLeast(stream, buf, 1)
if err != nil {
log.Print(err)
go localConn.Close()
go stream.Close()
return
}
i, err = localConn.WriteToUDP(buf[:i], otherEnd.Load().(*net.UDPAddr))
if err != nil {
log.Print(err)
go localConn.Close()
go stream.Close()
return
}
}
}()
buf := make([]byte, 16380)
for {
i, oe, err := localConn.ReadFromUDP(buf)
if err != nil {
log.Print(err)
go localConn.Close()
go stream.Close()
return
}
otherEnd.Store(oe)
i, err = stream.Write(buf[:i])
if err != nil {
log.Print(err)
go localConn.Close()
go stream.Close()
return
}
}
}
} else {
for {
localConn, err := tcpListener.Accept()
if err != nil {
log.Fatal(err)
continue
}
if sesh == nil || sesh.IsClosed() {
sesh = makeSession(sta, adminUID != nil, false)
}
go func() {
data := make([]byte, 10240)
i, err := io.ReadAtLeast(localConn, data, 1)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
localConn.Close()
return
}
stream, err := sesh.OpenStream()
if err != nil {
log.Errorf("Failed to open stream: %v", err)
localConn.Close()
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Errorf("Failed to write to stream: %v", err)
localConn.Close()
stream.Close()
return
}
go util.Pipe(localConn, stream)
util.Pipe(stream, localConn)
}()
}
}
}

View File

@ -120,6 +120,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
Unordered: ci.Unordered,
}
sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)
if err != nil {
@ -174,8 +175,11 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
user.DeleteSession(ci.SessionId, "Failed to connect to proxy server")
continue
}
log.Debugf("%v endpoint has been successfully connected", ci.ProxyMethod)
go util.Pipe(localConn, newStream)
go util.Pipe(newStream, localConn)
}
}

View File

@ -0,0 +1,81 @@
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
package multiplex
import (
"io"
"sync"
)
const DATAGRAM_NUMBER_LIMIT = 1024
type datagramBuffer struct {
buf [][]byte
closed bool
rwCond *sync.Cond
}
func NewDatagramBuffer() *datagramBuffer {
d := &datagramBuffer{
buf: make([][]byte, 0),
rwCond: sync.NewCond(&sync.Mutex{}),
}
return d
}
func (d *datagramBuffer) Read(target []byte) (int, error) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed && len(d.buf) == 0 {
return 0, io.EOF
}
if len(d.buf) > 0 {
break
}
d.rwCond.Wait()
}
var data []byte
data, d.buf = d.buf[0], d.buf[1:]
copy(target, data)
// err will always be nil because we have already verified that buf.Len() != 0
d.rwCond.Broadcast()
return len(data), nil
}
func (d *datagramBuffer) Write(input []byte) (int, error) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed {
return 0, io.ErrClosedPipe
}
if len(d.buf) <= DATAGRAM_NUMBER_LIMIT {
// if d.buf gets too large, write() will panic. We don't want this to happen
break
}
d.rwCond.Wait()
}
data := make([]byte, len(input))
copy(data, input)
d.buf = append(d.buf, data)
// err will always be nil
d.rwCond.Broadcast()
return len(data), nil
}
func (d *datagramBuffer) Close() error {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.closed = true
d.rwCond.Broadcast()
return nil
}
func (d *datagramBuffer) Len() int {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
return len(d.buf)
}

View File

@ -0,0 +1,127 @@
package multiplex
import (
"bytes"
"testing"
"time"
)
func TestDatagramBuffer_RW(t *testing.T) {
pipe := NewDatagramBuffer()
b := []byte{0x01, 0x02, 0x03}
n, err := pipe.Write(b)
if n != len(b) {
t.Error(
"For", "number of bytes written",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple write",
"expecting", "nil error",
"got", err,
)
return
}
b2 := make([]byte, len(b))
n, err = pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
}
if pipe.Len() != 0 {
t.Error("buf len is not 0 after finished reading")
return
}
}
func TestDatagramBuffer_BlockingRead(t *testing.T) {
pipe := NewDatagramBuffer()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(10 * time.Millisecond)
pipe.Write(b)
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read after block",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
return
}
}
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
pipe := NewDatagramBuffer()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
return
}
}

View File

@ -2,6 +2,7 @@ package multiplex
import (
"errors"
"io"
"net"
"time"
@ -14,12 +15,17 @@ import (
var ErrBrokenStream = errors.New("broken stream")
type ReadWriteCloseLener interface {
io.ReadWriteCloser
Len() int
}
type Stream struct {
id uint32
session *Session
sortedBuf *bufferedPipe
buf ReadWriteCloseLener
sorter *frameSorter
@ -39,12 +45,17 @@ type Stream struct {
}
func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream {
buf := NewBufferedPipe()
var buf ReadWriteCloseLener
if sesh.Unordered {
buf = NewDatagramBuffer()
} else {
buf = NewBufferedPipe()
}
stream := &Stream{
id: id,
session: sesh,
sortedBuf: buf,
buf: buf,
obfsBuf: make([]byte, 17000),
sorter: NewFrameSorter(buf),
assignedConnId: assignedConnId,
@ -59,7 +70,7 @@ func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
func (s *Stream) writeFrame(frame *Frame) {
if s.session.Unordered {
s.sortedBuf.Write(frame.Payload)
s.buf.Write(frame.Payload)
} else {
s.sorter.writeNewFrame(frame)
}
@ -74,17 +85,19 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
return 0, nil
}
}
if s.isClosed() {
if s.sortedBuf.Len() == 0 {
// TODO: Len check may not be necessary as this can be offloaded to buffer implementation
if s.buf.Len() == 0 {
return 0, ErrBrokenStream
} else {
n, err = s.sortedBuf.Read(buf)
//log.Tracef("%v read from stream %v with err %v",n, s.id,err)
n, err = s.buf.Read(buf)
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
return
}
} else {
n, err = s.sortedBuf.Read(buf)
//log.Tracef("%v read from stream %v with err %v",n, s.id,err)
n, err = s.buf.Read(buf)
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
return
}
}
@ -114,7 +127,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
return i, err
}
n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
//log.Tracef("%v sent to remote through stream %v with err %v",n, s.id,err)
log.Tracef("%v sent to remote through stream %v with err %v", len(in), s.id, err)
if err != nil {
return
}
@ -126,7 +139,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
func (s *Stream) _close() {
atomic.StoreUint32(&s.closed, 1)
s.sorter.Close() // this will trigger frameSorter to return
s.sortedBuf.Close()
s.buf.Close()
}
// only close locally. Used when the stream close is notified by the remote

View File

@ -11,7 +11,7 @@ import (
"time"
)
func setupSesh() *Session {
func setupSesh(unordered bool) *Session {
sessionKey := make([]byte, 32)
rand.Read(sessionKey)
obfuscator, _ := GenerateObfs(0x00, sessionKey)
@ -20,6 +20,7 @@ func setupSesh() *Session {
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
Unordered: unordered,
}
return MakeSession(0, seshConfig)
}
@ -50,7 +51,7 @@ func (b *blackhole) SetWriteDeadline(t time.Time) error { return nil }
func BenchmarkStream_Write(b *testing.B) {
const PAYLOAD_LEN = 1000
hole := newBlackHole()
sesh := setupSesh()
sesh := setupSesh(false)
sesh.AddConnection(hole)
testData := make([]byte, PAYLOAD_LEN)
rand.Read(testData)
@ -70,7 +71,7 @@ func BenchmarkStream_Write(b *testing.B) {
}
func BenchmarkStream_Read(b *testing.B) {
sesh := setupSesh()
sesh := setupSesh(false)
const PAYLOAD_LEN = 1000
testPayload := make([]byte, PAYLOAD_LEN)
rand.Read(testPayload)
@ -123,7 +124,134 @@ func BenchmarkStream_Read(b *testing.B) {
}
func TestStream_Read(t *testing.T) {
sesh := setupSesh()
sesh := setupSesh(false)
testPayload := []byte{42, 42, 42}
const PAYLOAD_LEN = 3
f := &Frame{
1,
0,
0,
testPayload,
}
ch := make(chan []byte)
l, _ := net.Listen("tcp", "127.0.0.1:0")
go func() {
conn, _ := net.Dial("tcp", l.Addr().String())
for {
data := <-ch
_, err := conn.Write(data)
if err != nil {
t.Error("cannot write to connection", err)
}
}
}()
conn, _ := l.Accept()
sesh.AddConnection(conn)
var streamID uint32
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
t.Run("Plain read", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
}
i, err = stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
}
if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
}
})
t.Run("Nil buf", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept()
i, err := stream.Read(nil)
if i != 0 || err != nil {
t.Error("expecting", 0, nil,
"got", i, err)
}
stream.Close()
i, err = stream.Read(nil)
if i != 0 || err != ErrBrokenStream {
t.Error("expecting", 0, ErrBrokenStream,
"got", i, err)
}
})
t.Run("Read after stream close", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept()
stream.Close()
i, err := stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
}
if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
}
_, err = stream.Read(buf)
if err == nil {
t.Error("expecting error", ErrBrokenStream,
"got nil error")
}
})
t.Run("Read after session close", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept()
sesh.Close()
i, err := stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
}
if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
}
_, err = stream.Read(buf)
if err == nil {
t.Error("expecting error", ErrBrokenStream,
"got nil error")
}
})
}
func TestStream_UnorderedRead(t *testing.T) {
sesh := setupSesh(true)
testPayload := []byte{42, 42, 42}
const PAYLOAD_LEN = 3