Refactor for easier testing

This commit is contained in:
Andy Wang 2020-04-10 11:07:38 +01:00
parent e5bda61587
commit d53b80208f
5 changed files with 22 additions and 17 deletions

View File

@ -168,6 +168,10 @@ func main() {
if authInfo.Unordered { if authInfo.Unordered {
client.RouteUDP(localConfig, seshMaker) client.RouteUDP(localConfig, seshMaker)
} else { } else {
client.RouteTCP(localConfig, seshMaker) listener, err := net.Listen("tcp", localConfig.LocalAddr)
if err != nil {
log.Fatal(err)
}
client.RouteTCP(listener, localConfig.Timeout, seshMaker)
} }
} }

View File

@ -155,14 +155,7 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
for { server.Serve(listener, sta)
conn, err := listener.Accept()
if err != nil {
log.Errorf("%v", err)
continue
}
go server.DispatchConnection(conn, sta)
}
} }
for i, addr := range bindAddr { for i, addr := range bindAddr {

View File

@ -15,6 +15,7 @@ import (
func MakeSession(connConfig remoteConnConfig, authInfo authInfo, dialer common.Dialer, isAdmin bool) *mux.Session { func MakeSession(connConfig remoteConnConfig, authInfo authInfo, dialer common.Dialer, isAdmin bool) *mux.Session {
log.Info("Attempting to start a new session") log.Info("Attempting to start a new session")
//TODO: let caller set this
if !isAdmin { if !isAdmin {
// sessionID is usergenerated. There shouldn't be a security concern because the scope of // sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID. // sessionID is limited to its UID.

View File

@ -100,14 +100,10 @@ start:
} }
func RouteTCP(localConfig localConnConfig, newSeshFunc func() *mux.Session) { func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session) {
tcpListener, err := net.Listen("tcp", localConfig.LocalAddr)
if err != nil {
log.Fatal(err)
}
var sesh *mux.Session var sesh *mux.Session
for { for {
localConn, err := tcpListener.Accept() localConn, err := listener.Accept()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
continue continue
@ -142,7 +138,7 @@ func RouteTCP(localConfig localConnConfig, newSeshFunc func() *mux.Session) {
} }
}() }()
//util.Pipe(stream, localConn, localConfig.Timeout) //util.Pipe(stream, localConn, localConfig.Timeout)
if _, err = common.Copy(stream, localConn, localConfig.Timeout); err != nil { if _, err = common.Copy(stream, localConn, streamTimeout); err != nil {
log.Tracef("copying proxy client to stream: %v", err) log.Tracef("copying proxy client to stream: %v", err)
} }
}() }()

View File

@ -16,7 +16,18 @@ import (
var b64 = base64.StdEncoding.EncodeToString var b64 = base64.StdEncoding.EncodeToString
func DispatchConnection(conn net.Conn, sta *State) { func Serve(l net.Listener, sta *State) {
for {
conn, err := l.Accept()
if err != nil {
log.Errorf("%v", err)
continue
}
go dispatchConnection(conn, sta)
}
}
func dispatchConnection(conn net.Conn, sta *State) {
remoteAddr := conn.RemoteAddr() remoteAddr := conn.RemoteAddr()
var err error var err error
buf := make([]byte, 1500) buf := make([]byte, 1500)