diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index dc83e40..9284b50 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -168,6 +168,10 @@ func main() { if authInfo.Unordered { client.RouteUDP(localConfig, seshMaker) } else { - client.RouteTCP(localConfig, seshMaker) + listener, err := net.Listen("tcp", localConfig.LocalAddr) + if err != nil { + log.Fatal(err) + } + client.RouteTCP(listener, localConfig.Timeout, seshMaker) } } diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 38872b1..858c4fe 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -155,14 +155,7 @@ func main() { if err != nil { log.Fatal(err) } - for { - conn, err := listener.Accept() - if err != nil { - log.Errorf("%v", err) - continue - } - go server.DispatchConnection(conn, sta) - } + server.Serve(listener, sta) } for i, addr := range bindAddr { diff --git a/internal/client/connector.go b/internal/client/connector.go index b56b2e3..c399274 100644 --- a/internal/client/connector.go +++ b/internal/client/connector.go @@ -15,6 +15,7 @@ import ( func MakeSession(connConfig remoteConnConfig, authInfo authInfo, dialer common.Dialer, isAdmin bool) *mux.Session { log.Info("Attempting to start a new session") + //TODO: let caller set this if !isAdmin { // sessionID is usergenerated. There shouldn't be a security concern because the scope of // sessionID is limited to its UID. diff --git a/internal/client/piper.go b/internal/client/piper.go index fa8d01b..7a3e3dd 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -100,14 +100,10 @@ start: } -func RouteTCP(localConfig localConnConfig, newSeshFunc func() *mux.Session) { - tcpListener, err := net.Listen("tcp", localConfig.LocalAddr) - if err != nil { - log.Fatal(err) - } +func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session) { var sesh *mux.Session for { - localConn, err := tcpListener.Accept() + localConn, err := listener.Accept() if err != nil { log.Fatal(err) continue @@ -142,7 +138,7 @@ func RouteTCP(localConfig localConnConfig, newSeshFunc func() *mux.Session) { } }() //util.Pipe(stream, localConn, localConfig.Timeout) - if _, err = common.Copy(stream, localConn, localConfig.Timeout); err != nil { + if _, err = common.Copy(stream, localConn, streamTimeout); err != nil { log.Tracef("copying proxy client to stream: %v", err) } }() diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index dbe45e6..87b4701 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -16,7 +16,18 @@ import ( var b64 = base64.StdEncoding.EncodeToString -func DispatchConnection(conn net.Conn, sta *State) { +func Serve(l net.Listener, sta *State) { + for { + conn, err := l.Accept() + if err != nil { + log.Errorf("%v", err) + continue + } + go dispatchConnection(conn, sta) + } +} + +func dispatchConnection(conn net.Conn, sta *State) { remoteAddr := conn.RemoteAddr() var err error buf := make([]byte, 1500)