From 5d4e8b8d8d73ad45a0ce302fedcf4d35d5538831 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sat, 11 Apr 2020 22:37:15 +0100 Subject: [PATCH] Refactor udp piping and add tests --- cmd/ck-client/ck-client.go | 2 +- internal/client/piper.go | 16 ++-- internal/multiplex/stream.go | 4 + internal/test/integration_test.go | 125 +++++++++++++++++++++++++----- 4 files changed, 118 insertions(+), 29 deletions(-) diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 9284b50..738e582 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -166,7 +166,7 @@ func main() { } if authInfo.Unordered { - client.RouteUDP(localConfig, seshMaker) + client.RouteUDP(net.ListenPacket, localConfig, seshMaker) } else { listener, err := net.Listen("tcp", localConfig.LocalAddr) if err != nil { diff --git a/internal/client/piper.go b/internal/client/piper.go index 1c1a03a..28e84aa 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -11,20 +11,16 @@ import ( log "github.com/sirupsen/logrus" ) -func RouteUDP(localConfig LocalConnConfig, newSeshFunc func() *mux.Session) { +func RouteUDP(listen func(string, string) (net.PacketConn, error), localConfig LocalConnConfig, newSeshFunc func() *mux.Session) { var sesh *mux.Session - localUDPAddr, err := net.ResolveUDPAddr("udp", localConfig.LocalAddr) - if err != nil { - log.Fatal(err) - } start: - localConn, err := net.ListenUDP("udp", localUDPAddr) + localConn, err := listen("udp", localConfig.LocalAddr) if err != nil { log.Fatal(err) } var otherEnd atomic.Value data := make([]byte, 10240) - i, oe, err := localConn.ReadFromUDP(data) + i, oe, err := localConn.ReadFrom(data) if err != nil { log.Errorf("Failed to read first packet from proxy client: %v", err) localConn.Close() @@ -35,7 +31,7 @@ start: if sesh == nil || sesh.IsClosed() { sesh = newSeshFunc() } - log.Debugf("proxy local address %v", otherEnd.Load().(*net.UDPAddr).String()) + log.Debugf("proxy local address %v", otherEnd.Load().(net.Addr).String()) stream, err := sesh.OpenStream() if err != nil { log.Errorf("Failed to open stream: %v", err) @@ -63,7 +59,7 @@ start: stream.Close() break } - _, err = localConn.WriteToUDP(buf[:i], otherEnd.Load().(*net.UDPAddr)) + _, err = localConn.WriteTo(buf[:i], otherEnd.Load().(net.Addr)) if err != nil { log.Print(err) localConn.Close() @@ -82,7 +78,7 @@ start: if localConfig.Timeout != 0 { localConn.SetReadDeadline(time.Now().Add(localConfig.Timeout)) } - i, oe, err := localConn.ReadFromUDP(buf) + i, oe, err := localConn.ReadFrom(buf) if err != nil { localConn.Close() stream.Close() diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 87f0f14..f3eadef 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -101,6 +101,10 @@ func (s *Stream) Write(in []byte) (n int, err error) { if len(in)-n <= s.session.maxStreamUnitWrite { framePayload = in[n:] } else { + if s.session.Unordered { // no splitting + err = io.ErrShortBuffer + return + } framePayload = in[n : s.session.maxStreamUnitWrite+n] } diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 12a107d..f6bcca2 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -21,23 +21,60 @@ import ( log "github.com/sirupsen/logrus" ) -func serveEcho(l net.Listener) { +func serveTCPEcho(l net.Listener) { for { conn, err := l.Accept() if err != nil { - // TODO: pass the error back + log.Error(err) return } go func() { + conn := conn _, err := io.Copy(conn, conn) if err != nil { - // TODO: pass the error back + conn.Close() + log.Error(err) return } }() } } +/* +func serveUDPEcho(listener *connutil.PipeListener) { + for { + conn, err := listener.ListenPacket("udp", "") + if err != nil { + log.Error(err) + return + } + const bufSize = 32 * 1024 + go func() { + conn := conn + defer conn.Close() + buf := make([]byte, bufSize) + for { + r,_, err := conn.ReadFrom(buf) + if err != nil { + log.Error(err) + return + } + w, err := conn.WriteTo(buf[:r], nil) + if err != nil { + log.Error(err) + return + } + if r != w { + log.Error("written not eqal to read") + return + } + } + }() + } +} + +*/ + var bypassUID = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=") var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=") @@ -45,7 +82,7 @@ var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7h func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client.RemoteConnConfig, client.AuthInfo) { var clientConfig = client.RawConfig{ ServerName: "www.example.com", - ProxyMethod: "test", + ProxyMethod: "tcp", EncryptionMethod: "plain", UID: bypassUID[:], PublicKey: publicKey, @@ -66,7 +103,7 @@ func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client func basicServerState(ws common.WorldState, db *os.File) *server.State { var serverConfig = server.RawConfig{ - ProxyBook: map[string][]string{"test": {"tcp", "fake.com:9999"}}, + ProxyBook: map[string][]string{"tcp": {"tcp", "fake.com:9999"}, "udp": {"udp", "fake.com:9999"}}, BindAddr: []string{"fake.com:9999"}, BypassUID: [][]byte{bypassUID[:]}, RedirAddr: "fake.com:9999", @@ -84,16 +121,19 @@ func basicServerState(ws common.WorldState, db *os.File) *server.State { return state } -func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, net.Listener, common.Dialer, net.Listener, error) { +func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) { // transport ckClientDialer, ckServerListener := connutil.DialerListener(10 * 1024) - clientSeshMaker := func() *mux.Session { return client.MakeSession(rcc, ai, ckClientDialer, false) } proxyToCkClientD, proxyToCkClientL := connutil.DialerListener(10 * 1024) - go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker) + if ai.Unordered { + go client.RouteUDP(proxyToCkClientL.ListenPacket, lcc, clientSeshMaker) + } else { + go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker) + } // set up server ckServerToProxyD, ckServerToProxyL := connutil.DialerListener(10 * 1024) @@ -106,12 +146,12 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a return proxyToCkClientD, ckServerToProxyL, ckClientDialer, ckServerToWebL, nil } -func runEchoTest(t *testing.T, conns []net.Conn) { +func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { var wg sync.WaitGroup for _, conn := range conns { wg.Add(1) go func(conn net.Conn) { - testDataLen := rand.Intn(65536) + testDataLen := rand.Intn(maxMsgLen) testData := make([]byte, testDataLen) rand.Read(testData) @@ -135,10 +175,59 @@ func runEchoTest(t *testing.T, conns []net.Conn) { wg.Wait() } +func TestUDP(t *testing.T) { + var tmpDB, _ = ioutil.TempFile("", "ck_user_info") + defer os.Remove(tmpDB.Name()) + log.SetLevel(log.TraceLevel) + + worldState := common.WorldOfTime(time.Unix(10, 0)) + lcc, rcc, ai := basicClientConfigs(worldState) + ai.ProxyMethod = "udp" + ai.Unordered = true + sta := basicServerState(worldState, tmpDB) + + pxyClientD, pxyServerL, _, _, err := establishSession(lcc, rcc, ai, sta) + if err != nil { + t.Fatal(err) + } + + t.Run("simple send", func(t *testing.T) { + pxyClientConn, err := pxyClientD.Dial("udp", "") + if err != nil { + t.Error(err) + } + + const testDataLen = 1500 + testData := make([]byte, testDataLen) + rand.Read(testData) + n, err := pxyClientConn.Write(testData) + if n != testDataLen { + t.Errorf("wrong length sent: %v", n) + } + if err != nil { + t.Error(err) + } + + pxyServerConn, err := pxyServerL.ListenPacket("", "") + if err != nil { + t.Error(err) + } + recvBuf := make([]byte, testDataLen+100) + r, _, err := pxyServerConn.ReadFrom(recvBuf) + if err != nil { + t.Error(err) + } + if !bytes.Equal(testData, recvBuf[:r]) { + t.Error("read wrong data") + } + }) + +} + func TestTCP(t *testing.T) { var tmpDB, _ = ioutil.TempFile("", "ck_user_info") defer os.Remove(tmpDB.Name()) - log.SetLevel(log.FatalLevel) + log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := basicClientConfigs(worldState) @@ -155,7 +244,7 @@ func TestTCP(t *testing.T) { writeData := make([]byte, dataLen) rand.Read(writeData) t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) { - go serveEcho(pxyServerL) + go serveTCPEcho(pxyServerL) conn, err := pxyClientD.Dial("", "") if err != nil { t.Error(err) @@ -182,7 +271,7 @@ func TestTCP(t *testing.T) { }) t.Run("user echo", func(t *testing.T) { - go serveEcho(pxyServerL) + go serveTCPEcho(pxyServerL) const numConns = 2000 // -race option limits the number of goroutines to 8192 var conns [numConns]net.Conn for i := 0; i < numConns; i++ { @@ -192,11 +281,11 @@ func TestTCP(t *testing.T) { } } - runEchoTest(t, conns[:]) + runEchoTest(t, conns[:], 65536) }) t.Run("redir echo", func(t *testing.T) { - go serveEcho(rdirServerL) + go serveTCPEcho(rdirServerL) const numConns = 2000 // -race option limits the number of goroutines to 8192 var conns [numConns]net.Conn for i := 0; i < numConns; i++ { @@ -205,14 +294,14 @@ func TestTCP(t *testing.T) { t.Error(err) } } - runEchoTest(t, conns[:]) + runEchoTest(t, conns[:], 65536) }) } func TestClosingStreamsFromProxy(t *testing.T) { var tmpDB, _ = ioutil.TempFile("", "ck_user_info") defer os.Remove(tmpDB.Name()) - log.SetLevel(log.FatalLevel) + log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := basicClientConfigs(worldState) sta := basicServerState(worldState, tmpDB) @@ -247,7 +336,7 @@ func TestClosingStreamsFromProxy(t *testing.T) { func BenchmarkThroughput(b *testing.B) { var tmpDB, _ = ioutil.TempFile("", "ck_user_info") defer os.Remove(tmpDB.Name()) - log.SetLevel(log.FatalLevel) + log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := basicClientConfigs(worldState) sta := basicServerState(worldState, tmpDB)