Refactor for easier testing

This commit is contained in:
Andy Wang 2020-04-10 11:07:38 +01:00
parent e5bda61587
commit d53b80208f
5 changed files with 22 additions and 17 deletions

View File

@ -168,6 +168,10 @@ func main() {
if authInfo.Unordered {
client.RouteUDP(localConfig, seshMaker)
} else {
client.RouteTCP(localConfig, seshMaker)
listener, err := net.Listen("tcp", localConfig.LocalAddr)
if err != nil {
log.Fatal(err)
}
client.RouteTCP(listener, localConfig.Timeout, seshMaker)
}
}

View File

@ -155,14 +155,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
for {
conn, err := listener.Accept()
if err != nil {
log.Errorf("%v", err)
continue
}
go server.DispatchConnection(conn, sta)
}
server.Serve(listener, sta)
}
for i, addr := range bindAddr {

View File

@ -15,6 +15,7 @@ import (
func MakeSession(connConfig remoteConnConfig, authInfo authInfo, dialer common.Dialer, isAdmin bool) *mux.Session {
log.Info("Attempting to start a new session")
//TODO: let caller set this
if !isAdmin {
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID.

View File

@ -100,14 +100,10 @@ start:
}
func RouteTCP(localConfig localConnConfig, newSeshFunc func() *mux.Session) {
tcpListener, err := net.Listen("tcp", localConfig.LocalAddr)
if err != nil {
log.Fatal(err)
}
func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session) {
var sesh *mux.Session
for {
localConn, err := tcpListener.Accept()
localConn, err := listener.Accept()
if err != nil {
log.Fatal(err)
continue
@ -142,7 +138,7 @@ func RouteTCP(localConfig localConnConfig, newSeshFunc func() *mux.Session) {
}
}()
//util.Pipe(stream, localConn, localConfig.Timeout)
if _, err = common.Copy(stream, localConn, localConfig.Timeout); err != nil {
if _, err = common.Copy(stream, localConn, streamTimeout); err != nil {
log.Tracef("copying proxy client to stream: %v", err)
}
}()

View File

@ -16,7 +16,18 @@ import (
var b64 = base64.StdEncoding.EncodeToString
func DispatchConnection(conn net.Conn, sta *State) {
func Serve(l net.Listener, sta *State) {
for {
conn, err := l.Accept()
if err != nil {
log.Errorf("%v", err)
continue
}
go dispatchConnection(conn, sta)
}
}
func dispatchConnection(conn net.Conn, sta *State) {
remoteAddr := conn.RemoteAddr()
var err error
buf := make([]byte, 1500)