diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 1a43651..e535fc4 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -14,7 +14,6 @@ import ( "github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/client" - mux "github.com/cbeuw/Cloak/internal/multiplex" log "github.com/sirupsen/logrus" ) @@ -152,7 +151,7 @@ func main() { } } - var seshMaker func() *mux.Session + var seshMaker func() *client.CloakClient d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive} @@ -162,8 +161,8 @@ func main() { authInfo.SessionId = 0 remoteConfig.NumConn = 1 - seshMaker = func() *mux.Session { - return client.MakeSession(remoteConfig, authInfo, d) + seshMaker = func() *client.CloakClient { + return client.NewCloakClient(remoteConfig, authInfo, d) } } else { var network string @@ -173,7 +172,7 @@ func main() { network = "TCP" } log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod) - seshMaker = func() *mux.Session { + seshMaker = func() *client.CloakClient { authInfo := authInfo // copy the struct because we are overwriting SessionId randByte := make([]byte, 1) @@ -185,7 +184,7 @@ func main() { quad := make([]byte, 4) common.RandRead(authInfo.WorldState.Rand, quad) authInfo.SessionId = binary.BigEndian.Uint32(quad) - return client.MakeSession(remoteConfig, authInfo, d) + return client.NewCloakClient(remoteConfig, authInfo, d) } } diff --git a/internal/client/connector.go b/internal/client/connector.go index 9bd1f7b..69a9fc3 100644 --- a/internal/client/connector.go +++ b/internal/client/connector.go @@ -12,8 +12,16 @@ import ( log "github.com/sirupsen/logrus" ) -// On different invocations to MakeSession, authInfo.SessionId MUST be different -func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.Dialer) *mux.Session { +type CloakClient struct { + connConfig RemoteConnConfig + authInfo AuthInfo + dialer common.Dialer + + session *mux.Session +} + +// On different invocations to NewCloakClient, authInfo.SessionId MUST be different +func NewCloakClient(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.Dialer) *CloakClient { log.Info("Attempting to start a new session") connsCh := make(chan net.Conn, connConfig.NumConn) @@ -61,13 +69,31 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D Unordered: authInfo.Unordered, MsgOnWireSizeLimit: appDataMaxLength, } - sesh := mux.MakeSession(authInfo.SessionId, seshConfig) + session := mux.MakeSession(authInfo.SessionId, seshConfig) for i := 0; i < connConfig.NumConn; i++ { conn := <-connsCh - sesh.AddConnection(conn) + session.AddConnection(conn) } log.Infof("Session %v established", authInfo.SessionId) - return sesh + + return &CloakClient{ + connConfig: connConfig, + authInfo: authInfo, + dialer: dialer, + session: session, + } +} + +func (client *CloakClient) Dial() (net.Conn, error) { + return client.session.OpenStream() +} + +func (client *CloakClient) Close() error { + return client.session.Close() +} + +func (client *CloakClient) IsClosed() bool { + return client.session.IsClosed() } diff --git a/internal/client/piper.go b/internal/client/piper.go index 43186ac..11cc366 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -8,18 +8,17 @@ import ( "github.com/cbeuw/Cloak/internal/common" - mux "github.com/cbeuw/Cloak/internal/multiplex" log "github.com/sirupsen/logrus" ) -func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { - var sesh *mux.Session +func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *CloakClient) { + var cloakClient *CloakClient localConn, err := bindFunc() if err != nil { log.Fatal(err) } - streams := make(map[string]*mux.Stream) + streams := make(map[string]net.Conn) var streamsMutex sync.Mutex data := make([]byte, 8192) @@ -30,21 +29,21 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration continue } - if !singleplex && (sesh == nil || sesh.IsClosed()) { - sesh = newSeshFunc() + if !singleplex && (cloakClient == nil || cloakClient.IsClosed()) { + cloakClient = newSeshFunc() } streamsMutex.Lock() stream, ok := streams[addr.String()] if !ok { if singleplex { - sesh = newSeshFunc() + cloakClient = newSeshFunc() } - stream, err = sesh.OpenStream() + stream, err = cloakClient.Dial() if err != nil { if singleplex { - sesh.Close() + cloakClient.Close() } log.Errorf("Failed to open stream: %v", err) streamsMutex.Unlock() @@ -56,7 +55,7 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration _ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) proxyAddr := addr - go func(stream *mux.Stream, localConn *net.UDPConn) { + go func(stream net.Conn, localConn *net.UDPConn) { buf := make([]byte, 8192) for { n, err := stream.Read(buf) @@ -95,18 +94,18 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration } } -func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { - var sesh *mux.Session +func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *CloakClient) { + var cloakClient *CloakClient for { localConn, err := listener.Accept() if err != nil { log.Fatal(err) continue } - if !singleplex && (sesh == nil || sesh.IsClosed()) { - sesh = newSeshFunc() + if !singleplex && (cloakClient == nil || cloakClient.IsClosed()) { + cloakClient = newSeshFunc() } - go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) { + go func(sesh *CloakClient, localConn net.Conn, timeout time.Duration) { if singleplex { sesh = newSeshFunc() } @@ -122,7 +121,7 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex boo var zeroTime time.Time _ = localConn.SetReadDeadline(zeroTime) - stream, err := sesh.OpenStream() + stream, err := sesh.Dial() if err != nil { log.Errorf("Failed to open stream: %v", err) localConn.Close() @@ -148,6 +147,6 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex boo if _, err = common.Copy(stream, localConn); err != nil { log.Tracef("copying proxy client to stream: %v", err) } - }(sesh, localConn, streamTimeout) + }(cloakClient, localConn, streamTimeout) } } diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 6df255c..e1a1ad0 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -180,12 +180,12 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024) - clientSeshMaker := func() *mux.Session { + clientSeshMaker := func() *client.CloakClient { ai := ai quad := make([]byte, 4) common.RandRead(ai.WorldState.Rand, quad) ai.SessionId = binary.BigEndian.Uint32(quad) - return client.MakeSession(rcc, ai, netToCkServerD) + return client.NewCloakClient(rcc, ai, netToCkServerD) } var proxyToCkClientD common.Dialer @@ -262,12 +262,12 @@ func TestUDP(t *testing.T) { lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState) sta := basicServerState(worldState) - proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) - if err != nil { - t.Fatal(err) - } - t.Run("simple send", func(t *testing.T) { + proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) + if err != nil { + t.Fatal(err) + } + pxyClientConn, err := proxyToCkClientD.Dial("udp", "") if err != nil { t.Error(err) @@ -300,6 +300,11 @@ func TestUDP(t *testing.T) { const echoMsgLen = 1024 t.Run("user echo", func(t *testing.T) { + proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) + if err != nil { + t.Fatal(err) + } + go serveUDPEcho(proxyFromCkServerL) var conn [1]net.Conn conn[0], err = proxyToCkClientD.Dial("udp", "") @@ -379,17 +384,17 @@ func TestTCPMultiplex(t *testing.T) { lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) sta := basicServerState(worldState) - proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta) - if err != nil { - t.Fatal(err) - } - t.Run("user echo single", func(t *testing.T) { for i := 0; i < 18; i += 2 { dataLen := 1 << i writeData := make([]byte, dataLen) rand.Read(writeData) t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) { + proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) + if err != nil { + t.Fatal(err) + } + go serveTCPEcho(proxyFromCkServerL) conn, err := proxyToCkClientD.Dial("", "") if err != nil { @@ -418,6 +423,11 @@ func TestTCPMultiplex(t *testing.T) { const echoMsgLen = 16384 t.Run("user echo", func(t *testing.T) { + proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) + if err != nil { + t.Fatal(err) + } + go serveTCPEcho(proxyFromCkServerL) var conns [numConns]net.Conn for i := 0; i < numConns; i++ { @@ -431,6 +441,11 @@ func TestTCPMultiplex(t *testing.T) { }) t.Run("redir echo", func(t *testing.T) { + _, _, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta) + if err != nil { + t.Fatal(err) + } + go serveTCPEcho(redirFromCkServerL) var conns [numConns]net.Conn for i := 0; i < numConns; i++ {