Refactor client for PTSpec Go API

This commit is contained in:
Andy Wang 2022-07-11 22:06:23 +01:00
parent de1c7600c1
commit 4fbf387bbf
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
4 changed files with 79 additions and 40 deletions

View File

@ -14,7 +14,6 @@ import (
"github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/client" "github.com/cbeuw/Cloak/internal/client"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -152,7 +151,7 @@ func main() {
} }
} }
var seshMaker func() *mux.Session var seshMaker func() *client.CloakClient
d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive} d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive}
@ -162,8 +161,8 @@ func main() {
authInfo.SessionId = 0 authInfo.SessionId = 0
remoteConfig.NumConn = 1 remoteConfig.NumConn = 1
seshMaker = func() *mux.Session { seshMaker = func() *client.CloakClient {
return client.MakeSession(remoteConfig, authInfo, d) return client.NewCloakClient(remoteConfig, authInfo, d)
} }
} else { } else {
var network string var network string
@ -173,7 +172,7 @@ func main() {
network = "TCP" network = "TCP"
} }
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod) log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod)
seshMaker = func() *mux.Session { seshMaker = func() *client.CloakClient {
authInfo := authInfo // copy the struct because we are overwriting SessionId authInfo := authInfo // copy the struct because we are overwriting SessionId
randByte := make([]byte, 1) randByte := make([]byte, 1)
@ -185,7 +184,7 @@ func main() {
quad := make([]byte, 4) quad := make([]byte, 4)
common.RandRead(authInfo.WorldState.Rand, quad) common.RandRead(authInfo.WorldState.Rand, quad)
authInfo.SessionId = binary.BigEndian.Uint32(quad) authInfo.SessionId = binary.BigEndian.Uint32(quad)
return client.MakeSession(remoteConfig, authInfo, d) return client.NewCloakClient(remoteConfig, authInfo, d)
} }
} }

View File

@ -12,8 +12,16 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// On different invocations to MakeSession, authInfo.SessionId MUST be different type CloakClient struct {
func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.Dialer) *mux.Session { connConfig RemoteConnConfig
authInfo AuthInfo
dialer common.Dialer
session *mux.Session
}
// On different invocations to NewCloakClient, authInfo.SessionId MUST be different
func NewCloakClient(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.Dialer) *CloakClient {
log.Info("Attempting to start a new session") log.Info("Attempting to start a new session")
connsCh := make(chan net.Conn, connConfig.NumConn) connsCh := make(chan net.Conn, connConfig.NumConn)
@ -61,13 +69,31 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D
Unordered: authInfo.Unordered, Unordered: authInfo.Unordered,
MsgOnWireSizeLimit: appDataMaxLength, MsgOnWireSizeLimit: appDataMaxLength,
} }
sesh := mux.MakeSession(authInfo.SessionId, seshConfig) session := mux.MakeSession(authInfo.SessionId, seshConfig)
for i := 0; i < connConfig.NumConn; i++ { for i := 0; i < connConfig.NumConn; i++ {
conn := <-connsCh conn := <-connsCh
sesh.AddConnection(conn) session.AddConnection(conn)
} }
log.Infof("Session %v established", authInfo.SessionId) log.Infof("Session %v established", authInfo.SessionId)
return sesh
return &CloakClient{
connConfig: connConfig,
authInfo: authInfo,
dialer: dialer,
session: session,
}
}
func (client *CloakClient) Dial() (net.Conn, error) {
return client.session.OpenStream()
}
func (client *CloakClient) Close() error {
return client.session.Close()
}
func (client *CloakClient) IsClosed() bool {
return client.session.IsClosed()
} }

View File

@ -8,18 +8,17 @@ import (
"github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *CloakClient) {
var sesh *mux.Session var cloakClient *CloakClient
localConn, err := bindFunc() localConn, err := bindFunc()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
streams := make(map[string]*mux.Stream) streams := make(map[string]net.Conn)
var streamsMutex sync.Mutex var streamsMutex sync.Mutex
data := make([]byte, 8192) data := make([]byte, 8192)
@ -30,21 +29,21 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
continue continue
} }
if !singleplex && (sesh == nil || sesh.IsClosed()) { if !singleplex && (cloakClient == nil || cloakClient.IsClosed()) {
sesh = newSeshFunc() cloakClient = newSeshFunc()
} }
streamsMutex.Lock() streamsMutex.Lock()
stream, ok := streams[addr.String()] stream, ok := streams[addr.String()]
if !ok { if !ok {
if singleplex { if singleplex {
sesh = newSeshFunc() cloakClient = newSeshFunc()
} }
stream, err = sesh.OpenStream() stream, err = cloakClient.Dial()
if err != nil { if err != nil {
if singleplex { if singleplex {
sesh.Close() cloakClient.Close()
} }
log.Errorf("Failed to open stream: %v", err) log.Errorf("Failed to open stream: %v", err)
streamsMutex.Unlock() streamsMutex.Unlock()
@ -56,7 +55,7 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) _ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
proxyAddr := addr proxyAddr := addr
go func(stream *mux.Stream, localConn *net.UDPConn) { go func(stream net.Conn, localConn *net.UDPConn) {
buf := make([]byte, 8192) buf := make([]byte, 8192)
for { for {
n, err := stream.Read(buf) n, err := stream.Read(buf)
@ -95,18 +94,18 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
} }
} }
func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *CloakClient) {
var sesh *mux.Session var cloakClient *CloakClient
for { for {
localConn, err := listener.Accept() localConn, err := listener.Accept()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
continue continue
} }
if !singleplex && (sesh == nil || sesh.IsClosed()) { if !singleplex && (cloakClient == nil || cloakClient.IsClosed()) {
sesh = newSeshFunc() cloakClient = newSeshFunc()
} }
go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) { go func(sesh *CloakClient, localConn net.Conn, timeout time.Duration) {
if singleplex { if singleplex {
sesh = newSeshFunc() sesh = newSeshFunc()
} }
@ -122,7 +121,7 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex boo
var zeroTime time.Time var zeroTime time.Time
_ = localConn.SetReadDeadline(zeroTime) _ = localConn.SetReadDeadline(zeroTime)
stream, err := sesh.OpenStream() stream, err := sesh.Dial()
if err != nil { if err != nil {
log.Errorf("Failed to open stream: %v", err) log.Errorf("Failed to open stream: %v", err)
localConn.Close() localConn.Close()
@ -148,6 +147,6 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex boo
if _, err = common.Copy(stream, localConn); err != nil { if _, err = common.Copy(stream, localConn); err != nil {
log.Tracef("copying proxy client to stream: %v", err) log.Tracef("copying proxy client to stream: %v", err)
} }
}(sesh, localConn, streamTimeout) }(cloakClient, localConn, streamTimeout)
} }
} }

View File

@ -180,12 +180,12 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a
netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024) netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024)
clientSeshMaker := func() *mux.Session { clientSeshMaker := func() *client.CloakClient {
ai := ai ai := ai
quad := make([]byte, 4) quad := make([]byte, 4)
common.RandRead(ai.WorldState.Rand, quad) common.RandRead(ai.WorldState.Rand, quad)
ai.SessionId = binary.BigEndian.Uint32(quad) ai.SessionId = binary.BigEndian.Uint32(quad)
return client.MakeSession(rcc, ai, netToCkServerD) return client.NewCloakClient(rcc, ai, netToCkServerD)
} }
var proxyToCkClientD common.Dialer var proxyToCkClientD common.Dialer
@ -262,12 +262,12 @@ func TestUDP(t *testing.T) {
lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState) lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState)
sta := basicServerState(worldState) sta := basicServerState(worldState)
t.Run("simple send", func(t *testing.T) {
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Run("simple send", func(t *testing.T) {
pxyClientConn, err := proxyToCkClientD.Dial("udp", "") pxyClientConn, err := proxyToCkClientD.Dial("udp", "")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -300,6 +300,11 @@ func TestUDP(t *testing.T) {
const echoMsgLen = 1024 const echoMsgLen = 1024
t.Run("user echo", func(t *testing.T) { t.Run("user echo", func(t *testing.T) {
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
go serveUDPEcho(proxyFromCkServerL) go serveUDPEcho(proxyFromCkServerL)
var conn [1]net.Conn var conn [1]net.Conn
conn[0], err = proxyToCkClientD.Dial("udp", "") conn[0], err = proxyToCkClientD.Dial("udp", "")
@ -379,17 +384,17 @@ func TestTCPMultiplex(t *testing.T) {
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
sta := basicServerState(worldState) sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("user echo single", func(t *testing.T) { t.Run("user echo single", func(t *testing.T) {
for i := 0; i < 18; i += 2 { for i := 0; i < 18; i += 2 {
dataLen := 1 << i dataLen := 1 << i
writeData := make([]byte, dataLen) writeData := make([]byte, dataLen)
rand.Read(writeData) rand.Read(writeData)
t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) { t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) {
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
go serveTCPEcho(proxyFromCkServerL) go serveTCPEcho(proxyFromCkServerL)
conn, err := proxyToCkClientD.Dial("", "") conn, err := proxyToCkClientD.Dial("", "")
if err != nil { if err != nil {
@ -418,6 +423,11 @@ func TestTCPMultiplex(t *testing.T) {
const echoMsgLen = 16384 const echoMsgLen = 16384
t.Run("user echo", func(t *testing.T) { t.Run("user echo", func(t *testing.T) {
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
go serveTCPEcho(proxyFromCkServerL) go serveTCPEcho(proxyFromCkServerL)
var conns [numConns]net.Conn var conns [numConns]net.Conn
for i := 0; i < numConns; i++ { for i := 0; i < numConns; i++ {
@ -431,6 +441,11 @@ func TestTCPMultiplex(t *testing.T) {
}) })
t.Run("redir echo", func(t *testing.T) { t.Run("redir echo", func(t *testing.T) {
_, _, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
go serveTCPEcho(redirFromCkServerL) go serveTCPEcho(redirFromCkServerL)
var conns [numConns]net.Conn var conns [numConns]net.Conn
for i := 0; i < numConns; i++ { for i := 0; i < numConns; i++ {