diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 8014790..f28213d 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -8,10 +8,10 @@ import ( "encoding/binary" "flag" "fmt" + "github.com/cbeuw/Cloak/internal/common" "net" "os" - - "github.com/cbeuw/Cloak/internal/common" + "syscall" "github.com/cbeuw/Cloak/internal/client" mux "github.com/cbeuw/Cloak/internal/multiplex" @@ -154,7 +154,37 @@ func main() { var seshMaker func() *mux.Session - d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive} + control := func(network string, address string, rawConn syscall.RawConn) error { + if !authInfo.Unordered { + sendBufferSize := remoteConfig.TcpSendBuffer + receiveBufferSize := remoteConfig.TcpReceiveBuffer + + err := rawConn.Control(func(fd uintptr) { + if sendBufferSize > 0 { + log.Debugf("Setting remote connection tcp send buffer: %d", sendBufferSize) + err := syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF, sendBufferSize) + if err != nil { + log.Errorf("setsocketopt SO_SNDBUF: %s\n", err) + } + } + + if receiveBufferSize > 0 { + log.Debugf("Setting remote connection tcp receive buffer: %d", receiveBufferSize) + err = syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, receiveBufferSize) + if err != nil { + log.Errorf("setsocketopt SO_RCVBUF: %s\n", err) + } + } + }) + if err != nil { + panic(err) + } + } + + return protector(network, address, rawConn) + } + + d := &net.Dialer{Control: control, KeepAlive: remoteConfig.KeepAlive} if adminUID != nil { log.Infof("API base is %v", localConfig.LocalAddr) @@ -199,8 +229,43 @@ func main() { } else { listener, err := net.Listen("tcp", localConfig.LocalAddr) if err != nil { - log.Fatal(err) + panic(err) } - client.RouteTCP(listener, localConfig.Timeout, remoteConfig.Singleplex, localConfig.TcpSendBuffer, localConfig.TcpReceiveBuffer, seshMaker) + + tcpListener, ok := listener.(*net.TCPListener) + if !ok { + panic("Unknown listener type") + } + + syscallConn, err := tcpListener.SyscallConn() + if err != nil { + panic(err) + } + + sendBufferSize := localConfig.TcpSendBuffer + receiveBufferSize := localConfig.TcpReceiveBuffer + + err = syscallConn.Control(func(fd uintptr) { + if sendBufferSize > 0 { + log.Debugf("Setting remote connection tcp send buffer: %d", sendBufferSize) + err := syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF, sendBufferSize) + if err != nil { + log.Errorf("setsocketopt SO_SNDBUF: %s\n", err) + } + } + + if receiveBufferSize > 0 { + log.Debugf("Setting remote connection tcp receive buffer: %d", receiveBufferSize) + err = syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, receiveBufferSize) + if err != nil { + log.Errorf("setsocketopt SO_RCVBUF: %s\n", err) + } + } + }) + if err != nil { + panic(err) + } + + client.RouteTCP(listener, localConfig.Timeout, remoteConfig.Singleplex, seshMaker) } } diff --git a/internal/client/connector.go b/internal/client/connector.go index d6c1c9a..1b283d8 100644 --- a/internal/client/connector.go +++ b/internal/client/connector.go @@ -32,9 +32,6 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D goto makeconn } - sendBufferSize := connConfig.TcpSendBuffer - receiveBufferSize := connConfig.TcpReceiveBuffer - tcpConn, ok := remoteConn.(*net.TCPConn) if ok { syscallConn, err := tcpConn.SyscallConn() @@ -43,22 +40,6 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D } err = syscallConn.Control(func(fd uintptr) { - if sendBufferSize > 0 { - log.Debugf("Setting remote connection tcp send buffer: %d", sendBufferSize) - err := syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF, sendBufferSize) - if err != nil { - log.Errorf("setsocketopt SO_SNDBUF: %s\n", err) - } - } - - if receiveBufferSize > 0 { - log.Debugf("Setting remote connection tcp receive buffer: %d", receiveBufferSize) - err = syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, receiveBufferSize) - if err != nil { - log.Errorf("setsocketopt SO_RCVBUF: %s\n", err) - } - } - err = syscall.SetsockoptInt(common.Platformfd(fd), syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1) if err != nil { log.Errorf("setsocketopt TCP_NODELAY: %s\n", err) diff --git a/internal/client/piper.go b/internal/client/piper.go index 1927012..3d802b1 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -96,7 +96,7 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration } } -func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, sendBufferSize int, receiveBufferSize int, newSeshFunc func() *mux.Session) { +func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) { var sesh *mux.Session for { localConn, err := listener.Accept() @@ -114,22 +114,6 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex boo } err = syscallConn.Control(func(fd uintptr) { - if sendBufferSize > 0 { - log.Debugf("Setting loopback connection tcp send buffer: %d", sendBufferSize) - err := syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF, sendBufferSize) - if err != nil { - log.Errorf("setsocketopt SO_SNDBUF: %s\n", err) - } - } - - if receiveBufferSize > 0 { - log.Debugf("Setting loopback connection tcp receive buffer: %d", receiveBufferSize) - err = syscall.SetsockoptInt(common.Platformfd(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, receiveBufferSize) - if err != nil { - log.Errorf("setsocketopt SO_RCVBUF: %s\n", err) - } - } - err = syscall.SetsockoptInt(common.Platformfd(fd), syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1) if err != nil { log.Errorf("setsocketopt TCP_NODELAY: %s\n", err) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 110ca7c..6df255c 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -207,7 +207,7 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a } else { var proxyToCkClientL *connutil.PipeListener proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024) - go client.RouteTCP(proxyToCkClientL, lcc.Timeout, rcc.Singleplex, 0, 0, clientSeshMaker) + go client.RouteTCP(proxyToCkClientL, lcc.Timeout, rcc.Singleplex, clientSeshMaker) } // set up server