From 86095ba5e63d8fb801b25c99f3c852cfec4f4b9a Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Thu, 9 Apr 2020 00:34:02 +0100 Subject: [PATCH] Refactor out Dialer --- cmd/ck-client/ck-client.go | 8 +++++--- internal/client/connector.go | 5 ++--- internal/client/state.go | 2 -- internal/server/dispatcher.go | 5 ++--- internal/server/state.go | 24 ++++++++++++++---------- internal/util/dialer.go | 7 +++++++ 6 files changed, 30 insertions(+), 21 deletions(-) create mode 100644 internal/util/dialer.go diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index e4cde4f..45f5800 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "flag" "fmt" + "net" "os" "github.com/cbeuw/Cloak/internal/client" @@ -129,7 +130,6 @@ func main() { if err != nil { log.Fatal(err) } - remoteConfig.Protector = protector var adminUID []byte if b64AdminUID != "" { @@ -141,13 +141,15 @@ func main() { var seshMaker func() *mux.Session + d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive} + if adminUID != nil { log.Infof("API base is %v", localConfig.LocalAddr) authInfo.UID = adminUID remoteConfig.NumConn = 1 seshMaker = func() *mux.Session { - return client.MakeSession(remoteConfig, authInfo, true) + return client.MakeSession(remoteConfig, authInfo, d, true) } } else { var network string @@ -158,7 +160,7 @@ func main() { } log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod) seshMaker = func() *mux.Session { - return client.MakeSession(remoteConfig, authInfo, false) + return client.MakeSession(remoteConfig, authInfo, d, false) } } diff --git a/internal/client/connector.go b/internal/client/connector.go index 70f8f2c..edc7b80 100644 --- a/internal/client/connector.go +++ b/internal/client/connector.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" ) -func MakeSession(connConfig remoteConnConfig, authInfo authInfo, isAdmin bool) *mux.Session { +func MakeSession(connConfig remoteConnConfig, authInfo authInfo, dialer util.Dialer, isAdmin bool) *mux.Session { log.Info("Attempting to start a new session") if !isAdmin { // sessionID is usergenerated. There shouldn't be a security concern because the scope of @@ -24,7 +24,6 @@ func MakeSession(connConfig remoteConnConfig, authInfo authInfo, isAdmin bool) * authInfo.SessionId = 0 } - d := net.Dialer{Control: connConfig.Protector, KeepAlive: connConfig.KeepAlive} connsCh := make(chan net.Conn, connConfig.NumConn) var _sessionKey atomic.Value var wg sync.WaitGroup @@ -32,7 +31,7 @@ func MakeSession(connConfig remoteConnConfig, authInfo authInfo, isAdmin bool) * wg.Add(1) go func() { makeconn: - remoteConn, err := d.Dial("tcp", connConfig.RemoteAddr) + remoteConn, err := dialer.Dial("tcp", connConfig.RemoteAddr) if err != nil { log.Errorf("Failed to establish new connections to remote: %v", err) // TODO increase the interval if failed multiple times diff --git a/internal/client/state.go b/internal/client/state.go index a24ed8b..d8855dd 100644 --- a/internal/client/state.go +++ b/internal/client/state.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "net" "strings" - "syscall" "time" "github.com/cbeuw/Cloak/internal/ecdh" @@ -41,7 +40,6 @@ type rawConfig struct { type remoteConnConfig struct { NumConn int KeepAlive time.Duration - Protector func(string, string, syscall.RawConn) error RemoteAddr string TransportMaker func() Transport } diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 3591b18..33179ee 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -37,7 +37,7 @@ func DispatchConnection(conn net.Conn, sta *State) { if redirPort == "" { _, redirPort, _ = net.SplitHostPort(conn.LocalAddr().String()) } - webConn, err := net.Dial("tcp", net.JoinHostPort(sta.RedirHost.String(), redirPort)) + webConn, err := sta.RedirDialer.Dial("tcp", net.JoinHostPort(sta.RedirHost.String(), redirPort)) if err != nil { log.Errorf("Making connection to redirection server: %v", err) return @@ -165,8 +165,7 @@ func DispatchConnection(conn net.Conn, sta *State) { } } proxyAddr := sta.ProxyBook[ci.ProxyMethod] - d := net.Dialer{KeepAlive: sta.KeepAlive} - localConn, err := d.Dial(proxyAddr.Network(), proxyAddr.String()) + localConn, err := sta.ProxyDialer.Dial(proxyAddr.Network(), proxyAddr.String()) if err != nil { log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err) user.CloseSession(ci.SessionId, "Failed to connect to proxy server") diff --git a/internal/server/state.go b/internal/server/state.go index 2118039..5fa8aa4 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/cbeuw/Cloak/internal/server/usermanager" + "github.com/cbeuw/Cloak/internal/util" "io/ioutil" "net" "strings" @@ -30,19 +31,22 @@ type rawConfig struct { // State type stores the global state of the program type State struct { - BindAddr []net.Addr - ProxyBook map[string]net.Addr + BindAddr []net.Addr + ProxyBook map[string]net.Addr + ProxyDialer util.Dialer - Now func() time.Time - AdminUID []byte - Timeout time.Duration - KeepAlive time.Duration + Now func() time.Time + AdminUID []byte + Timeout time.Duration + //KeepAlive time.Duration BypassUID map[[16]byte]struct{} staticPv crypto.PrivateKey - RedirHost net.Addr - RedirPort string + // TODO: this doesn't have to be a net.Addr; resolution is done in Dial automatically + RedirHost net.Addr + RedirPort string + RedirDialer util.Dialer usedRandomM sync.RWMutex usedRandom map[[32]byte]int64 @@ -176,9 +180,9 @@ func (sta *State) ParseConfig(conf string) (err error) { } if preParse.KeepAlive <= 0 { - sta.KeepAlive = -1 + sta.ProxyDialer = &net.Dialer{KeepAlive: -1} } else { - sta.KeepAlive = time.Duration(preParse.KeepAlive) * time.Second + sta.ProxyDialer = &net.Dialer{KeepAlive: time.Duration(preParse.KeepAlive) * time.Second} } sta.RedirHost, sta.RedirPort, err = parseRedirAddr(preParse.RedirAddr) diff --git a/internal/util/dialer.go b/internal/util/dialer.go new file mode 100644 index 0000000..cba620f --- /dev/null +++ b/internal/util/dialer.go @@ -0,0 +1,7 @@ +package util + +import "net" + +type Dialer interface { + Dial(network, address string) (net.Conn, error) +}