diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index d48f788..efe51c2 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -169,6 +169,7 @@ func main() { } log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod) seshMaker = func() *mux.Session { + authInfo := authInfo // copy the struct because we are overwriting SessionId // sessionID is usergenerated. There shouldn't be a security concern because the scope of // sessionID is limited to its UID. quad := make([]byte, 4) @@ -184,12 +185,12 @@ func main() { return net.ListenUDP("udp", udpAddr) } - client.RouteUDP(acceptor, localConfig.Timeout, seshMaker) + client.RouteUDP(acceptor, localConfig.Timeout, remoteConfig.Singleplex, seshMaker) } else { listener, err := net.Listen("tcp", localConfig.LocalAddr) if err != nil { log.Fatal(err) } - client.RouteTCP(listener, localConfig.Timeout, seshMaker) + client.RouteTCP(listener, localConfig.Timeout, remoteConfig.Singleplex, seshMaker) } } diff --git a/internal/client/piper.go b/internal/client/piper.go index 295e3e8..92b92a4 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -10,7 +10,7 @@ import ( log "github.com/sirupsen/logrus" ) -func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, newSeshFunc func() *mux.Session) { +func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { var sesh *mux.Session localConn, err := bindFunc() if err != nil { @@ -27,18 +27,22 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration continue } - if sesh == nil || sesh.IsClosed() || sesh.Singleplex { + if !singleplex && (sesh == nil || sesh.IsClosed()) { sesh = newSeshFunc() } stream, ok := streams[addr.String()] if !ok { + if singleplex { + sesh = newSeshFunc() + } + stream, err = sesh.OpenStream() if err != nil { - log.Errorf("Failed to open stream: %v", err) - if sesh.Singleplex { + if singleplex { sesh.Close() } + log.Errorf("Failed to open stream: %v", err) continue } @@ -74,7 +78,7 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration } } -func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session) { +func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { var sesh *mux.Session for { localConn, err := listener.Accept() @@ -82,10 +86,14 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu log.Fatal(err) continue } - if sesh == nil || sesh.IsClosed() || sesh.Singleplex { + if !singleplex && (sesh == nil || sesh.IsClosed()) { sesh = newSeshFunc() } go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) { + if singleplex { + sesh = newSeshFunc() + } + data := make([]byte, 10240) _ = localConn.SetReadDeadline(time.Now().Add(streamTimeout)) i, err := io.ReadAtLeast(localConn, data, 1) @@ -101,7 +109,7 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu if err != nil { log.Errorf("Failed to open stream: %v", err) localConn.Close() - if sesh.Singleplex { + if singleplex { sesh.Close() } return @@ -125,5 +133,4 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu } }(sesh, localConn, streamTimeout) } - } diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 9f086c1..1a3d0fb 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -185,7 +185,9 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a // whatever connection initiator (including a proper ck-client) netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024) + clientSeshMaker := func() *mux.Session { + ai := ai quad := make([]byte, 4) common.RandRead(ai.WorldState.Rand, quad) ai.SessionId = binary.BigEndian.Uint32(quad) @@ -206,12 +208,12 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a addrCh <- conn.LocalAddr().(*net.UDPAddr) return conn, err } - go client.RouteUDP(acceptor, lcc.Timeout, clientSeshMaker) + go client.RouteUDP(acceptor, lcc.Timeout, rcc.Singleplex, clientSeshMaker) proxyToCkClientD = mDialer } else { var proxyToCkClientL *connutil.PipeListener proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024) - go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker) + go client.RouteTCP(proxyToCkClientL, lcc.Timeout, rcc.Singleplex, clientSeshMaker) } // set up server