diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 76f3aa4..add2de2 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -23,7 +23,7 @@ import ( var version string -func makeSession(sta *client.State, isAdmin bool) *mux.Session { +func makeSession(sta *client.State, isAdmin bool, unordered bool) *mux.Session { log.Info("Attemtping to start a new session") if !isAdmin { // sessionID is usergenerated. There shouldn't be a security concern because the scope of @@ -78,6 +78,7 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session { Obfuscator: obfuscator, Valve: nil, UnitRead: util.ReadTLS, + Unordered: unordered, } sesh := mux.MakeSession(sta.SessionID, seshConfig) @@ -99,6 +100,7 @@ func main() { var remoteHost string // The proxy port,should be 443 var remotePort string + var udp bool var config string var b64AdminUID string @@ -116,6 +118,7 @@ func main() { flag.StringVar(&localPort, "l", "1984", "localPort: Cloak listens to proxy clients on this port") flag.StringVar(&remoteHost, "s", "", "remoteHost: IP of your proxy server") flag.StringVar(&remotePort, "p", "443", "remotePort: proxy port, should be 443") + flag.BoolVar(&udp, "u", false, "udp: set this flag if the underlying proxy is using UDP protocol") flag.StringVar(&config, "c", "ckclient.json", "config: path to the configuration file or options seperated with semicolons") flag.StringVar(&b64AdminUID, "a", "", "adminUID: enter the adminUID to serve the admin api") askVersion := flag.Bool("v", false, "Print the version number") @@ -164,10 +167,6 @@ func main() { // IPv6 needs square brackets listeningIP = "[" + listeningIP + "]" } - listener, err := net.Listen("tcp", listeningIP+":"+sta.LocalPort) - if err != nil { - log.Fatal(err) - } var adminUID []byte if b64AdminUID != "" { @@ -177,50 +176,144 @@ func main() { } } + var tcpListener net.Listener + var network string + if udp { + network = "udp" + } else { + network = "tcp" + // TODO use the local variable instead fo sta.LocalPort + tcpListener, err = net.Listen("tcp", listeningIP+":"+sta.LocalPort) + if err != nil { + log.Fatal(err) + } + } + if adminUID != nil { log.Infof("API base is %v:%v", listeningIP, sta.LocalPort) sta.SessionID = 0 sta.UID = adminUID sta.NumConn = 1 } else { - log.Infof("Listening on %v:%v for proxy clients", listeningIP, sta.LocalPort) + log.Infof("Listening on %v %v:%v for proxy clients", network, listeningIP, sta.LocalPort) } var sesh *mux.Session - for { - localConn, err := listener.Accept() + if udp { + localUDPAddr, err := net.ResolveUDPAddr("udp", listeningIP+":"+localPort) if err != nil { - log.Error(err) - continue + log.Fatal(err) } - if sesh == nil || sesh.IsClosed() { - sesh = makeSession(sta, adminUID != nil) + localConn, err := net.ListenUDP("udp", localUDPAddr) + if err != nil { + log.Fatal(err) } - go func() { + for { + var otherEnd atomic.Value data := make([]byte, 10240) - i, err := io.ReadAtLeast(localConn, data, 1) + i, oe, err := localConn.ReadFromUDP(data) if err != nil { log.Errorf("Failed to read first packet from proxy client: %v", err) localConn.Close() return } + otherEnd.Store(oe) + + if sesh == nil || sesh.IsClosed() { + sesh = makeSession(sta, adminUID != nil, true) + } + log.Debugf("proxy local address %v", otherEnd.Load().(*net.UDPAddr).String()) stream, err := sesh.OpenStream() if err != nil { log.Errorf("Failed to open stream: %v", err) localConn.Close() + //localConnWrite.Close() return } _, err = stream.Write(data[:i]) if err != nil { log.Errorf("Failed to write to stream: %v", err) localConn.Close() + //localConnWrite.Close() stream.Close() return } - go util.Pipe(localConn, stream) - util.Pipe(stream, localConn) - }() + + go func() { + buf := make([]byte, 16380) + for { + i, err := io.ReadAtLeast(stream, buf, 1) + if err != nil { + log.Print(err) + go localConn.Close() + go stream.Close() + return + } + i, err = localConn.WriteToUDP(buf[:i], otherEnd.Load().(*net.UDPAddr)) + if err != nil { + log.Print(err) + go localConn.Close() + go stream.Close() + return + } + } + }() + + buf := make([]byte, 16380) + for { + i, oe, err := localConn.ReadFromUDP(buf) + if err != nil { + log.Print(err) + go localConn.Close() + go stream.Close() + return + } + otherEnd.Store(oe) + i, err = stream.Write(buf[:i]) + if err != nil { + log.Print(err) + go localConn.Close() + go stream.Close() + return + } + } + } + } else { + for { + localConn, err := tcpListener.Accept() + if err != nil { + log.Fatal(err) + continue + } + if sesh == nil || sesh.IsClosed() { + sesh = makeSession(sta, adminUID != nil, false) + } + go func() { + data := make([]byte, 10240) + i, err := io.ReadAtLeast(localConn, data, 1) + if err != nil { + log.Errorf("Failed to read first packet from proxy client: %v", err) + localConn.Close() + return + } + stream, err := sesh.OpenStream() + if err != nil { + log.Errorf("Failed to open stream: %v", err) + localConn.Close() + return + } + _, err = stream.Write(data[:i]) + if err != nil { + log.Errorf("Failed to write to stream: %v", err) + localConn.Close() + stream.Close() + return + } + go util.Pipe(localConn, stream) + util.Pipe(stream, localConn) + }() + } } } diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 1acfc0b..2d5e55d 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -120,6 +120,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { Obfuscator: obfuscator, Valve: nil, UnitRead: util.ReadTLS, + Unordered: ci.Unordered, } sesh, existing, err := user.GetSession(ci.SessionId, seshConfig) if err != nil { @@ -174,8 +175,11 @@ func dispatchConnection(conn net.Conn, sta *server.State) { user.DeleteSession(ci.SessionId, "Failed to connect to proxy server") continue } + log.Debugf("%v endpoint has been successfully connected", ci.ProxyMethod) + go util.Pipe(localConn, newStream) go util.Pipe(newStream, localConn) + } } diff --git a/internal/multiplex/datagramBuffer.go b/internal/multiplex/datagramBuffer.go new file mode 100644 index 0000000..6d27af9 --- /dev/null +++ b/internal/multiplex/datagramBuffer.go @@ -0,0 +1,81 @@ +// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173 + +package multiplex + +import ( + "io" + "sync" +) + +const DATAGRAM_NUMBER_LIMIT = 1024 + +type datagramBuffer struct { + buf [][]byte + closed bool + rwCond *sync.Cond +} + +func NewDatagramBuffer() *datagramBuffer { + d := &datagramBuffer{ + buf: make([][]byte, 0), + rwCond: sync.NewCond(&sync.Mutex{}), + } + return d +} + +func (d *datagramBuffer) Read(target []byte) (int, error) { + d.rwCond.L.Lock() + defer d.rwCond.L.Unlock() + for { + if d.closed && len(d.buf) == 0 { + return 0, io.EOF + } + + if len(d.buf) > 0 { + break + } + d.rwCond.Wait() + } + var data []byte + data, d.buf = d.buf[0], d.buf[1:] + copy(target, data) + // err will always be nil because we have already verified that buf.Len() != 0 + d.rwCond.Broadcast() + return len(data), nil +} + +func (d *datagramBuffer) Write(input []byte) (int, error) { + d.rwCond.L.Lock() + defer d.rwCond.L.Unlock() + for { + if d.closed { + return 0, io.ErrClosedPipe + } + if len(d.buf) <= DATAGRAM_NUMBER_LIMIT { + // if d.buf gets too large, write() will panic. We don't want this to happen + break + } + d.rwCond.Wait() + } + data := make([]byte, len(input)) + copy(data, input) + d.buf = append(d.buf, data) + // err will always be nil + d.rwCond.Broadcast() + return len(data), nil +} + +func (d *datagramBuffer) Close() error { + d.rwCond.L.Lock() + defer d.rwCond.L.Unlock() + + d.closed = true + d.rwCond.Broadcast() + return nil +} + +func (d *datagramBuffer) Len() int { + d.rwCond.L.Lock() + defer d.rwCond.L.Unlock() + return len(d.buf) +} diff --git a/internal/multiplex/datagramBuffer_test.go b/internal/multiplex/datagramBuffer_test.go new file mode 100644 index 0000000..22ec4bd --- /dev/null +++ b/internal/multiplex/datagramBuffer_test.go @@ -0,0 +1,127 @@ +package multiplex + +import ( + "bytes" + "testing" + "time" +) + +func TestDatagramBuffer_RW(t *testing.T) { + pipe := NewDatagramBuffer() + b := []byte{0x01, 0x02, 0x03} + n, err := pipe.Write(b) + if n != len(b) { + t.Error( + "For", "number of bytes written", + "expecting", len(b), + "got", n, + ) + return + } + if err != nil { + t.Error( + "For", "simple write", + "expecting", "nil error", + "got", err, + ) + return + } + + b2 := make([]byte, len(b)) + n, err = pipe.Read(b2) + if n != len(b) { + t.Error( + "For", "number of bytes read", + "expecting", len(b), + "got", n, + ) + return + } + if err != nil { + t.Error( + "For", "simple read", + "expecting", "nil error", + "got", err, + ) + return + } + if !bytes.Equal(b, b2) { + t.Error( + "For", "simple read", + "expecting", b, + "got", b2, + ) + } + if pipe.Len() != 0 { + t.Error("buf len is not 0 after finished reading") + return + } + +} + +func TestDatagramBuffer_BlockingRead(t *testing.T) { + pipe := NewDatagramBuffer() + b := []byte{0x01, 0x02, 0x03} + go func() { + time.Sleep(10 * time.Millisecond) + pipe.Write(b) + }() + b2 := make([]byte, len(b)) + n, err := pipe.Read(b2) + if n != len(b) { + t.Error( + "For", "number of bytes read after block", + "expecting", len(b), + "got", n, + ) + return + } + if err != nil { + t.Error( + "For", "blocked read", + "expecting", "nil error", + "got", err, + ) + return + } + if !bytes.Equal(b, b2) { + t.Error( + "For", "blocked read", + "expecting", b, + "got", b2, + ) + return + } +} + +func TestDatagramBuffer_CloseThenRead(t *testing.T) { + pipe := NewDatagramBuffer() + b := []byte{0x01, 0x02, 0x03} + pipe.Write(b) + b2 := make([]byte, len(b)) + pipe.Close() + n, err := pipe.Read(b2) + if n != len(b) { + t.Error( + "For", "number of bytes read", + "expecting", len(b), + "got", n, + ) + } + if err != nil { + t.Error( + "For", "simple read", + "expecting", "nil error", + "got", err, + ) + return + } + if !bytes.Equal(b, b2) { + t.Error( + "For", "simple read", + "expecting", b, + "got", b2, + ) + return + } +} diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index c393310..f9f7c3c 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -2,6 +2,7 @@ package multiplex import ( "errors" + "io" "net" "time" @@ -14,12 +15,17 @@ import ( var ErrBrokenStream = errors.New("broken stream") +type ReadWriteCloseLener interface { + io.ReadWriteCloser + Len() int +} + type Stream struct { id uint32 session *Session - sortedBuf *bufferedPipe + buf ReadWriteCloseLener sorter *frameSorter @@ -39,12 +45,17 @@ type Stream struct { } func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream { - buf := NewBufferedPipe() + var buf ReadWriteCloseLener + if sesh.Unordered { + buf = NewDatagramBuffer() + } else { + buf = NewBufferedPipe() + } stream := &Stream{ id: id, session: sesh, - sortedBuf: buf, + buf: buf, obfsBuf: make([]byte, 17000), sorter: NewFrameSorter(buf), assignedConnId: assignedConnId, @@ -59,7 +70,7 @@ func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) writeFrame(frame *Frame) { if s.session.Unordered { - s.sortedBuf.Write(frame.Payload) + s.buf.Write(frame.Payload) } else { s.sorter.writeNewFrame(frame) } @@ -74,17 +85,19 @@ func (s *Stream) Read(buf []byte) (n int, err error) { return 0, nil } } + if s.isClosed() { - if s.sortedBuf.Len() == 0 { + // TODO: Len check may not be necessary as this can be offloaded to buffer implementation + if s.buf.Len() == 0 { return 0, ErrBrokenStream } else { - n, err = s.sortedBuf.Read(buf) - //log.Tracef("%v read from stream %v with err %v",n, s.id,err) + n, err = s.buf.Read(buf) + log.Tracef("%v read from stream %v with err %v", n, s.id, err) return } } else { - n, err = s.sortedBuf.Read(buf) - //log.Tracef("%v read from stream %v with err %v",n, s.id,err) + n, err = s.buf.Read(buf) + log.Tracef("%v read from stream %v with err %v", n, s.id, err) return } } @@ -114,7 +127,7 @@ func (s *Stream) Write(in []byte) (n int, err error) { return i, err } n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) - //log.Tracef("%v sent to remote through stream %v with err %v",n, s.id,err) + log.Tracef("%v sent to remote through stream %v with err %v", len(in), s.id, err) if err != nil { return } @@ -126,7 +139,7 @@ func (s *Stream) Write(in []byte) (n int, err error) { func (s *Stream) _close() { atomic.StoreUint32(&s.closed, 1) s.sorter.Close() // this will trigger frameSorter to return - s.sortedBuf.Close() + s.buf.Close() } // only close locally. Used when the stream close is notified by the remote diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index fea1360..fd9368e 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -11,7 +11,7 @@ import ( "time" ) -func setupSesh() *Session { +func setupSesh(unordered bool) *Session { sessionKey := make([]byte, 32) rand.Read(sessionKey) obfuscator, _ := GenerateObfs(0x00, sessionKey) @@ -20,6 +20,7 @@ func setupSesh() *Session { Obfuscator: obfuscator, Valve: nil, UnitRead: util.ReadTLS, + Unordered: unordered, } return MakeSession(0, seshConfig) } @@ -50,7 +51,7 @@ func (b *blackhole) SetWriteDeadline(t time.Time) error { return nil } func BenchmarkStream_Write(b *testing.B) { const PAYLOAD_LEN = 1000 hole := newBlackHole() - sesh := setupSesh() + sesh := setupSesh(false) sesh.AddConnection(hole) testData := make([]byte, PAYLOAD_LEN) rand.Read(testData) @@ -70,7 +71,7 @@ func BenchmarkStream_Write(b *testing.B) { } func BenchmarkStream_Read(b *testing.B) { - sesh := setupSesh() + sesh := setupSesh(false) const PAYLOAD_LEN = 1000 testPayload := make([]byte, PAYLOAD_LEN) rand.Read(testPayload) @@ -123,7 +124,134 @@ func BenchmarkStream_Read(b *testing.B) { } func TestStream_Read(t *testing.T) { - sesh := setupSesh() + sesh := setupSesh(false) + testPayload := []byte{42, 42, 42} + const PAYLOAD_LEN = 3 + + f := &Frame{ + 1, + 0, + 0, + testPayload, + } + + ch := make(chan []byte) + l, _ := net.Listen("tcp", "127.0.0.1:0") + go func() { + conn, _ := net.Dial("tcp", l.Addr().String()) + for { + data := <-ch + _, err := conn.Write(data) + if err != nil { + t.Error("cannot write to connection", err) + } + } + }() + conn, _ := l.Accept() + sesh.AddConnection(conn) + + var streamID uint32 + buf := make([]byte, 10) + + obfsBuf := make([]byte, 512) + t.Run("Plain read", func(t *testing.T) { + f.StreamID = streamID + i, _ := sesh.Obfs(f, obfsBuf) + streamID++ + ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) + stream, err := sesh.Accept() + if err != nil { + t.Error("failed to accept stream", err) + } + i, err = stream.Read(buf) + if err != nil { + t.Error("failed to read", err) + } + if i != PAYLOAD_LEN { + t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + } + if !bytes.Equal(buf[:i], testPayload) { + t.Error("expected", testPayload, + "got", buf[:i]) + } + }) + t.Run("Nil buf", func(t *testing.T) { + f.StreamID = streamID + i, _ := sesh.Obfs(f, obfsBuf) + streamID++ + ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) + stream, _ := sesh.Accept() + i, err := stream.Read(nil) + if i != 0 || err != nil { + t.Error("expecting", 0, nil, + "got", i, err) + } + + stream.Close() + i, err = stream.Read(nil) + if i != 0 || err != ErrBrokenStream { + t.Error("expecting", 0, ErrBrokenStream, + "got", i, err) + } + + }) + t.Run("Read after stream close", func(t *testing.T) { + f.StreamID = streamID + i, _ := sesh.Obfs(f, obfsBuf) + streamID++ + ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) + stream, _ := sesh.Accept() + stream.Close() + i, err := stream.Read(buf) + if err != nil { + t.Error("failed to read", err) + } + if i != PAYLOAD_LEN { + t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + } + if !bytes.Equal(buf[:i], testPayload) { + t.Error("expected", testPayload, + "got", buf[:i]) + } + _, err = stream.Read(buf) + if err == nil { + t.Error("expecting error", ErrBrokenStream, + "got nil error") + } + }) + t.Run("Read after session close", func(t *testing.T) { + f.StreamID = streamID + i, _ := sesh.Obfs(f, obfsBuf) + streamID++ + ch <- obfsBuf[:i] + time.Sleep(100 * time.Microsecond) + stream, _ := sesh.Accept() + sesh.Close() + i, err := stream.Read(buf) + if err != nil { + t.Error("failed to read", err) + } + if i != PAYLOAD_LEN { + t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + } + if !bytes.Equal(buf[:i], testPayload) { + t.Error("expected", testPayload, + "got", buf[:i]) + } + _, err = stream.Read(buf) + if err == nil { + t.Error("expecting error", ErrBrokenStream, + "got nil error") + } + }) + +} + +func TestStream_UnorderedRead(t *testing.T) { + sesh := setupSesh(true) testPayload := []byte{42, 42, 42} const PAYLOAD_LEN = 3