diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index d48f788..36b29bd 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -184,12 +184,12 @@ func main() { return net.ListenUDP("udp", udpAddr) } - client.RouteUDP(acceptor, localConfig.Timeout, seshMaker) + client.RouteUDP(acceptor, localConfig.Timeout, remoteConfig.Singleplex, seshMaker) } else { listener, err := net.Listen("tcp", localConfig.LocalAddr) if err != nil { log.Fatal(err) } - client.RouteTCP(listener, localConfig.Timeout, seshMaker) + client.RouteTCP(listener, localConfig.Timeout, remoteConfig.Singleplex, seshMaker) } } diff --git a/internal/client/piper.go b/internal/client/piper.go index 295e3e8..3b077ed 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -10,8 +10,8 @@ import ( log "github.com/sirupsen/logrus" ) -func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, newSeshFunc func() *mux.Session) { - var sesh *mux.Session +func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { + var multiplexSession *mux.Session localConn, err := bindFunc() if err != nil { log.Fatal(err) @@ -27,17 +27,24 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration continue } - if sesh == nil || sesh.IsClosed() || sesh.Singleplex { - sesh = newSeshFunc() + if !singleplex && (multiplexSession == nil || multiplexSession.IsClosed()) { + multiplexSession = newSeshFunc() } stream, ok := streams[addr.String()] if !ok { - stream, err = sesh.OpenStream() + var session *mux.Session + if multiplexSession != nil { + session = multiplexSession + } else { + session = newSeshFunc() + } + + stream, err = session.OpenStream() if err != nil { log.Errorf("Failed to open stream: %v", err) - if sesh.Singleplex { - sesh.Close() + if session.Singleplex { + session.Close() } continue } @@ -74,18 +81,25 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration } } -func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session) { - var sesh *mux.Session +func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { + var multiplexSession *mux.Session for { localConn, err := listener.Accept() if err != nil { log.Fatal(err) continue } - if sesh == nil || sesh.IsClosed() || sesh.Singleplex { - sesh = newSeshFunc() + if !singleplex && (multiplexSession == nil || multiplexSession.IsClosed()) { + multiplexSession = newSeshFunc() } - go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) { + go func(multiplexSession *mux.Session, newSingleplexSeshFunc func() *mux.Session, localConn net.Conn, timeout time.Duration) { + var session *mux.Session + if multiplexSession != nil { + session = multiplexSession + } else { + session = newSingleplexSeshFunc() + } + data := make([]byte, 10240) _ = localConn.SetReadDeadline(time.Now().Add(streamTimeout)) i, err := io.ReadAtLeast(localConn, data, 1) @@ -97,12 +111,12 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu var zeroTime time.Time _ = localConn.SetReadDeadline(zeroTime) - stream, err := sesh.OpenStream() + stream, err := session.OpenStream() if err != nil { log.Errorf("Failed to open stream: %v", err) localConn.Close() - if sesh.Singleplex { - sesh.Close() + if session.Singleplex { + session.Close() } return } @@ -123,7 +137,6 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu if _, err = common.Copy(stream, localConn); err != nil { log.Tracef("copying proxy client to stream: %v", err) } - }(sesh, localConn, streamTimeout) + }(multiplexSession, newSeshFunc, localConn, streamTimeout) } - } diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 9f086c1..1e1fdd0 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -206,12 +206,12 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a addrCh <- conn.LocalAddr().(*net.UDPAddr) return conn, err } - go client.RouteUDP(acceptor, lcc.Timeout, clientSeshMaker) + go client.RouteUDP(acceptor, lcc.Timeout, rcc.Singleplex, clientSeshMaker) proxyToCkClientD = mDialer } else { var proxyToCkClientL *connutil.PipeListener proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024) - go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker) + go client.RouteTCP(proxyToCkClientL, lcc.Timeout, rcc.Singleplex, clientSeshMaker) } // set up server