diff --git a/internal/client/piper.go b/internal/client/piper.go index 92b92a4..c867db5 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -4,6 +4,7 @@ import ( "github.com/cbeuw/Cloak/internal/common" "io" "net" + "sync" "time" mux "github.com/cbeuw/Cloak/internal/multiplex" @@ -17,7 +18,7 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration log.Fatal(err) } - streams := make(map[string]*mux.Stream) + var streams sync.Map data := make([]byte, 8192) for { @@ -31,13 +32,15 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration sesh = newSeshFunc() } - stream, ok := streams[addr.String()] + var stream *mux.Stream + streamObj, ok := streams.Load(addr.String()) if !ok { if singleplex { sesh = newSeshFunc() } stream, err = sesh.OpenStream() + streamObj = stream if err != nil { if singleplex { sesh.Close() @@ -45,8 +48,9 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration log.Errorf("Failed to open stream: %v", err) continue } + _ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) - streams[addr.String()] = stream + streams.Store(addr.String(), stream) proxyAddr := addr go func(stream *mux.Stream, localConn *net.UDPConn) { buf := make([]byte, 8192) @@ -54,13 +58,16 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration n, err := stream.Read(buf) if err != nil { log.Tracef("copying stream to proxy client: %v", err) + streams.Delete(addr.String()) stream.Close() return } + _ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) _, err = localConn.WriteTo(buf[:n], proxyAddr) if err != nil { log.Tracef("copying stream to proxy client: %v", err) + streams.Delete(addr.String()) stream.Close() return } @@ -68,13 +75,15 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration }(stream, localConn) } + stream = streamObj.(*mux.Stream) _, err = stream.Write(data[:i]) if err != nil { log.Tracef("copying proxy client to stream: %v", err) - delete(streams, addr.String()) + streams.Delete(addr.String()) stream.Close() continue } + _ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) } }