diff --git a/internal/client/piper.go b/internal/client/piper.go index e610799..b520102 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -4,96 +4,74 @@ import ( "github.com/cbeuw/Cloak/internal/common" "io" "net" - "sync/atomic" "time" mux "github.com/cbeuw/Cloak/internal/multiplex" log "github.com/sirupsen/logrus" ) -func RouteUDP(listen func(string, string) (net.PacketConn, error), localConfig LocalConnConfig, newSeshFunc func() *mux.Session) { +func RouteUDP(acceptFunc func() (*net.UDPConn, error), streamTimeout time.Duration, newSeshFunc func() *mux.Session) { var sesh *mux.Session -start: - localConn, err := listen("udp", localConfig.LocalAddr) + localConn, err := acceptFunc() if err != nil { log.Fatal(err) } - var otherEnd atomic.Value - data := make([]byte, 10240) - i, oe, err := localConn.ReadFrom(data) - if err != nil { - log.Errorf("Failed to read first packet from proxy client: %v", err) - localConn.Close() - return - } - otherEnd.Store(oe) - if sesh == nil || sesh.IsClosed() { - sesh = newSeshFunc() - } - log.Debugf("proxy local address %v", otherEnd.Load().(net.Addr).String()) - stream, err := sesh.OpenStream() - if err != nil { - log.Errorf("Failed to open stream: %v", err) - localConn.Close() - //localConnWrite.Close() - return - } - _, err = stream.Write(data[:i]) - if err != nil { - log.Errorf("Failed to write to stream: %v", err) - localConn.Close() - //localConnWrite.Close() - stream.Close() - return - } + streams := make(map[string]*mux.Stream) - // stream to proxy - go func() { - buf := make([]byte, 16380) - for { - i, err := io.ReadAtLeast(stream, buf, 1) - if err != nil { - log.Print(err) - localConn.Close() - stream.Close() - break - } - _, err = localConn.WriteTo(buf[:i], otherEnd.Load().(net.Addr)) - if err != nil { - log.Print(err) - localConn.Close() - stream.Close() - break - } - } - }() - - // proxy to stream - buf := make([]byte, 16380) - if localConfig.Timeout != 0 { - localConn.SetReadDeadline(time.Now().Add(localConfig.Timeout)) - } + data := make([]byte, 8192) for { - if localConfig.Timeout != 0 { - localConn.SetReadDeadline(time.Now().Add(localConfig.Timeout)) - } - i, oe, err := localConn.ReadFrom(buf) + i, addr, err := localConn.ReadFrom(data) if err != nil { - localConn.Close() - stream.Close() - break + log.Errorf("Failed to read first packet from proxy client: %v", err) + localConn, err = acceptFunc() + if err != nil { + log.Fatal(err) + } + continue } - otherEnd.Store(oe) - _, err = stream.Write(buf[:i]) + + if sesh == nil || sesh.IsClosed() { + sesh = newSeshFunc() + } + + stream, ok := streams[addr.String()] + if !ok { + stream, err = sesh.OpenStream() + if err != nil { + log.Errorf("Failed to open stream: %v", err) + continue + } + streams[addr.String()] = stream + proxyAddr := addr + go func() { + buf := make([]byte, 8192) + for { + n, err := stream.Read(buf) + if err != nil { + log.Tracef("copying stream to proxy client: %v", err) + stream.Close() + return + } + + _, err = localConn.WriteTo(buf[:n], proxyAddr) + if err != nil { + log.Tracef("copying stream to proxy client: %v", err) + stream.Close() + return + } + } + }() + } + + _, err = stream.Write(data[:i]) if err != nil { - localConn.Close() + log.Tracef("copying proxy client to stream: %v", err) + delete(streams, addr.String()) stream.Close() - break + continue } } - goto start - } func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session) { diff --git a/internal/multiplex/datagramBuffer.go b/internal/multiplex/datagramBuffer.go index c3ea771..630b2c9 100644 --- a/internal/multiplex/datagramBuffer.go +++ b/internal/multiplex/datagramBuffer.go @@ -98,6 +98,7 @@ func (d *datagramBuffer) WriteTo(w io.Writer) (n int64, err error) { return n, er } d.rwCond.Broadcast() + continue } d.rwCond.Wait() } diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index b9c6cfc..f2b5b8f 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -42,7 +42,6 @@ func serveTCPEcho(l net.Listener) { } } -/* func serveUDPEcho(listener *connutil.PipeListener) { for { conn, err := listener.ListenPacket("udp", "") @@ -56,7 +55,7 @@ func serveUDPEcho(listener *connutil.PipeListener) { defer conn.Close() buf := make([]byte, bufSize) for { - r,_, err := conn.ReadFrom(buf) + r, _, err := conn.ReadFrom(buf) if err != nil { log.Error(err) return @@ -75,8 +74,6 @@ func serveUDPEcho(listener *connutil.PipeListener) { } } -*/ - var bypassUID = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=") var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=") @@ -123,6 +120,18 @@ func basicServerState(ws common.WorldState, db *os.File) *server.State { return state } +type mockUDPDialer struct { + addrCh chan *net.UDPAddr + raddr *net.UDPAddr +} + +func (m *mockUDPDialer) Dial(network, address string) (net.Conn, error) { + if m.raddr == nil { + m.raddr = <-m.addrCh + } + return net.DialUDP("udp", nil, m.raddr) +} + func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) { // transport ckClientDialer, ckServerListener := connutil.DialerListener(10 * 1024) @@ -130,10 +139,23 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a return client.MakeSession(rcc, ai, ckClientDialer, false) } - proxyToCkClientD, proxyToCkClientL := connutil.DialerListener(10 * 1024) + var proxyToCkClientD common.Dialer if ai.Unordered { - go client.RouteUDP(proxyToCkClientL.ListenPacket, lcc, clientSeshMaker) + addrCh := make(chan *net.UDPAddr, 1) + mDialer := &mockUDPDialer{ + addrCh: addrCh, + } + acceptor := func() (*net.UDPConn, error) { + laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, err := net.ListenUDP("udp", laddr) + addrCh <- conn.LocalAddr().(*net.UDPAddr) + return conn, err + } + go client.RouteUDP(acceptor, lcc.Timeout, clientSeshMaker) + proxyToCkClientD = mDialer } else { + var proxyToCkClientL *connutil.PipeListener + proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024) go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker) } @@ -180,7 +202,7 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { func TestUDP(t *testing.T) { var tmpDB, _ = ioutil.TempFile("", "ck_user_info") defer os.Remove(tmpDB.Name()) - log.SetLevel(log.FatalLevel) + log.SetLevel(log.TraceLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := basicClientConfigs(worldState) @@ -224,6 +246,17 @@ func TestUDP(t *testing.T) { } }) + t.Run("user echo", func(t *testing.T) { + go serveUDPEcho(pxyServerL) + var conn [1]net.Conn + conn[0], err = pxyClientD.Dial("udp", "") + if err != nil { + t.Error(err) + } + + runEchoTest(t, conn[:], 1024) + }) + } func TestTCP(t *testing.T) {