Refactor Singleplexing

This commit is contained in:
Andy Wang 2020-10-15 21:32:38 +01:00
parent 4914fba337
commit 9887649b88
6 changed files with 235 additions and 187 deletions

View File

@ -169,20 +169,18 @@ func main() {
} }
} }
useSessionPerConnection := remoteConfig.NumConn == 0
if authInfo.Unordered { if authInfo.Unordered {
acceptor := func() (*net.UDPConn, error) { acceptor := func() (*net.UDPConn, error) {
udpAddr, _ := net.ResolveUDPAddr("udp", localConfig.LocalAddr) udpAddr, _ := net.ResolveUDPAddr("udp", localConfig.LocalAddr)
return net.ListenUDP("udp", udpAddr) return net.ListenUDP("udp", udpAddr)
} }
client.RouteUDP(acceptor, localConfig.Timeout, seshMaker, useSessionPerConnection) client.RouteUDP(acceptor, localConfig.Timeout, seshMaker)
} else { } else {
listener, err := net.Listen("tcp", localConfig.LocalAddr) listener, err := net.Listen("tcp", localConfig.LocalAddr)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
client.RouteTCP(listener, localConfig.Timeout, seshMaker, useSessionPerConnection) client.RouteTCP(listener, localConfig.Timeout, seshMaker)
} }
} }

View File

@ -25,16 +25,10 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D
authInfo.SessionId = 0 authInfo.SessionId = 0
} }
numConn := connConfig.NumConn connsCh := make(chan net.Conn, connConfig.NumConn)
if numConn <= 0 {
log.Infof("Using session per connection (no multiplexing)")
numConn = 1
}
connsCh := make(chan net.Conn, numConn)
var _sessionKey atomic.Value var _sessionKey atomic.Value
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < numConn; i++ { for i := 0; i < connConfig.NumConn; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
makeconn: makeconn:
@ -69,6 +63,7 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D
} }
seshConfig := mux.SessionConfig{ seshConfig := mux.SessionConfig{
Singleplex: connConfig.Singleplex,
Obfuscator: obfuscator, Obfuscator: obfuscator,
Valve: nil, Valve: nil,
Unordered: authInfo.Unordered, Unordered: authInfo.Unordered,
@ -76,7 +71,7 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D
} }
sesh := mux.MakeSession(authInfo.SessionId, seshConfig) sesh := mux.MakeSession(authInfo.SessionId, seshConfig)
for i := 0; i < numConn; i++ { for i := 0; i < connConfig.NumConn; i++ {
conn := <-connsCh conn := <-connsCh
sesh.AddConnection(conn) sesh.AddConnection(conn)
} }

View File

@ -10,31 +10,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type ConnWithReadFromTimeout interface { func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, newSeshFunc func() *mux.Session) {
net.Conn
SetReadFromTimeout(d time.Duration)
}
type CloseSessionAfterCloseStream struct {
ConnWithReadFromTimeout
Session *mux.Session
}
func (s *CloseSessionAfterCloseStream) Close() error {
if err := s.ConnWithReadFromTimeout.Close(); err != nil {
return err
}
return s.Session.Close()
}
func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, newSeshFunc func() *mux.Session, useSessionPerConnection bool) {
var sesh *mux.Session var sesh *mux.Session
localConn, err := bindFunc() localConn, err := bindFunc()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
streams := make(map[string]ConnWithReadFromTimeout) streams := make(map[string]*mux.Stream)
data := make([]byte, 8192) data := make([]byte, 8192)
for { for {
@ -44,34 +27,21 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
continue continue
} }
if !useSessionPerConnection && (sesh == nil || sesh.IsClosed()) { if sesh == nil || sesh.IsClosed() || sesh.Singleplex {
sesh = newSeshFunc() sesh = newSeshFunc()
} }
var stream ConnWithReadFromTimeout
stream, ok := streams[addr.String()] stream, ok := streams[addr.String()]
if !ok { if !ok {
connectionSession := sesh stream, err = sesh.OpenStream()
if useSessionPerConnection {
connectionSession = newSeshFunc()
}
stream, err = connectionSession.OpenStream()
if err != nil { if err != nil {
log.Errorf("Failed to open stream: %v", err) log.Errorf("Failed to open stream: %v", err)
if useSessionPerConnection { if sesh.Singleplex {
connectionSession.Close() sesh.Close()
} }
continue continue
} }
if useSessionPerConnection {
stream = &CloseSessionAfterCloseStream{
ConnWithReadFromTimeout: stream,
Session: connectionSession,
}
}
streams[addr.String()] = stream streams[addr.String()] = stream
proxyAddr := addr proxyAddr := addr
go func() { go func() {
@ -104,7 +74,7 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
} }
} }
func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session, useSessionPerConnection bool) { func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc func() *mux.Session) {
var sesh *mux.Session var sesh *mux.Session
for { for {
localConn, err := listener.Accept() localConn, err := listener.Accept()
@ -112,7 +82,7 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu
log.Fatal(err) log.Fatal(err)
continue continue
} }
if !useSessionPerConnection && (sesh == nil || sesh.IsClosed()) { if sesh == nil || sesh.IsClosed() || sesh.Singleplex {
sesh = newSeshFunc() sesh = newSeshFunc()
} }
go func() { go func() {
@ -124,29 +94,16 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu
return return
} }
connectionSession := sesh stream, err := sesh.OpenStream()
if useSessionPerConnection {
connectionSession = newSeshFunc()
}
var stream ConnWithReadFromTimeout
stream, err = connectionSession.OpenStream()
if err != nil { if err != nil {
log.Errorf("Failed to open stream: %v", err) log.Errorf("Failed to open stream: %v", err)
localConn.Close() localConn.Close()
if useSessionPerConnection { if sesh.Singleplex {
connectionSession.Close() sesh.Close()
} }
return return
} }
if useSessionPerConnection {
stream = &CloseSessionAfterCloseStream{
ConnWithReadFromTimeout: stream,
Session: connectionSession,
}
}
_, err = stream.Write(data[:i]) _, err = stream.Write(data[:i])
if err != nil { if err != nil {
log.Errorf("Failed to write to stream: %v", err) log.Errorf("Failed to write to stream: %v", err)

View File

@ -40,6 +40,7 @@ type RawConfig struct {
} }
type RemoteConnConfig struct { type RemoteConnConfig struct {
Singleplex bool
NumConn int NumConn int
KeepAlive time.Duration KeepAlive time.Duration
RemoteAddr string RemoteAddr string
@ -178,9 +179,12 @@ func (raw *RawConfig) SplitConfigs(worldState common.WorldState) (local LocalCon
} }
remote.RemoteAddr = net.JoinHostPort(raw.RemoteHost, raw.RemotePort) remote.RemoteAddr = net.JoinHostPort(raw.RemoteHost, raw.RemotePort)
if raw.NumConn <= 0 { if raw.NumConn <= 0 {
raw.NumConn = 0 remote.NumConn = 1
remote.Singleplex = true
} else {
remote.NumConn = raw.NumConn
remote.Singleplex = false
} }
remote.NumConn = raw.NumConn
// Transport and (if TLS mode), browser // Transport and (if TLS mode), browser
switch strings.ToLower(raw.Transport) { switch strings.ToLower(raw.Transport) {

View File

@ -21,6 +21,7 @@ const (
var ErrBrokenSession = errors.New("broken session") var ErrBrokenSession = errors.New("broken session")
var errRepeatSessionClosing = errors.New("trying to close a closed session") var errRepeatSessionClosing = errors.New("trying to close a closed session")
var errRepeatStreamClosing = errors.New("trying to close a closed stream") var errRepeatStreamClosing = errors.New("trying to close a closed stream")
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
type switchboardStrategy int type switchboardStrategy int
@ -31,6 +32,8 @@ type SessionConfig struct {
Unordered bool Unordered bool
Singleplex bool
MaxFrameSize int // maximum size of the frame, including the header MaxFrameSize int // maximum size of the frame, including the header
SendBufferSize int SendBufferSize int
ReceiveBufferSize int ReceiveBufferSize int
@ -125,6 +128,11 @@ func (sesh *Session) OpenStream() (*Stream, error) {
} }
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1 id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
// Because atomic.AddUint32 returns the value after incrementation // Because atomic.AddUint32 returns the value after incrementation
if sesh.Singleplex && id > 1 {
// if there are more than one streams, which shouldn't happen if we are
// singleplexing
return nil, errNoMultiplex
}
stream := makeStream(sesh, id) stream := makeStream(sesh, id)
sesh.streams.Store(id, stream) sesh.streams.Store(id, stream)
sesh.streamCountIncr() sesh.streamCountIncr()
@ -177,8 +185,12 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
sesh.streams.Store(s.id, nil) // id may or may not exist. if we use Delete(s.id) here it will panic sesh.streams.Store(s.id, nil) // id may or may not exist. if we use Delete(s.id) here it will panic
if sesh.streamCountDecr() == 0 { if sesh.streamCountDecr() == 0 {
log.Debugf("session %v has no active stream left", sesh.id) if sesh.Singleplex {
go sesh.timeoutAfter(30 * time.Second) return sesh.Close()
} else {
log.Debugf("session %v has no active stream left", sesh.id)
go sesh.timeoutAfter(30 * time.Second)
}
} }
return nil return nil
} }

View File

@ -22,6 +22,7 @@ import (
) )
const numConns = 200 // -race option limits the number of goroutines to 8192 const numConns = 200 // -race option limits the number of goroutines to 8192
const delayBeforeTestingConnClose = 500 * time.Millisecond
func serveTCPEcho(l net.Listener) { func serveTCPEcho(l net.Listener) {
for { for {
@ -78,10 +79,26 @@ var bypassUID = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=") var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=")
var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=") var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=")
func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client.RemoteConnConfig, client.AuthInfo) { var udpClientConfigs = map[string]client.RawConfig{
var clientConfig = client.RawConfig{ "basic": {
ServerName: "www.example.com", ServerName: "www.example.com",
ProxyMethod: "tcp", ProxyMethod: "openvpn",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 4,
UDP: true,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
},
}
var tcpClientConfigs = map[string]client.RawConfig{
"basic": {
ServerName: "www.example.com",
ProxyMethod: "shadowsocks",
EncryptionMethod: "plain", EncryptionMethod: "plain",
UID: bypassUID[:], UID: bypassUID[:],
PublicKey: publicKey, PublicKey: publicKey,
@ -92,8 +109,25 @@ func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client
RemotePort: "9999", RemotePort: "9999",
LocalHost: "127.0.0.1", LocalHost: "127.0.0.1",
LocalPort: "9999", LocalPort: "9999",
} },
lcl, rmt, auth, err := clientConfig.SplitConfigs(state) "singleplex": {
ServerName: "www.example.com",
ProxyMethod: "shadowsocks",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 0,
UDP: false,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
},
}
func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState) (client.LocalConnConfig, client.RemoteConnConfig, client.AuthInfo) {
lcl, rmt, auth, err := rawConfig.SplitConfigs(state)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -102,7 +136,7 @@ func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client
func basicServerState(ws common.WorldState, db *os.File) *server.State { func basicServerState(ws common.WorldState, db *os.File) *server.State {
var serverConfig = server.RawConfig{ var serverConfig = server.RawConfig{
ProxyBook: map[string][]string{"tcp": {"tcp", "fake.com:9999"}, "udp": {"udp", "fake.com:9999"}}, ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}},
BindAddr: []string{"fake.com:9999"}, BindAddr: []string{"fake.com:9999"},
BypassUID: [][]byte{bypassUID[:]}, BypassUID: [][]byte{bypassUID[:]},
RedirAddr: "fake.com:9999", RedirAddr: "fake.com:9999",
@ -133,13 +167,27 @@ func (m *mockUDPDialer) Dial(network, address string) (net.Conn, error) {
} }
func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) { func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) {
// transport // redirecting web server
ckClientDialer, ckServerListener := connutil.DialerListener(10 * 1024) // ^
// |
// |
// redirFromCkServerL
// |
// |
// proxy client ----proxyToCkClientD----> ck-client ------> ck-server ----proxyFromCkServerL----> proxy server
// ^
// |
// |
// netToCkServerD
// |
// |
// whatever connection initiator (including a proper ck-client)
netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024)
clientSeshMaker := func() *mux.Session { clientSeshMaker := func() *mux.Session {
return client.MakeSession(rcc, ai, ckClientDialer, false) return client.MakeSession(rcc, ai, netToCkServerD, false)
} }
useSessionPerConnection := rcc.NumConn == 0
var proxyToCkClientD common.Dialer var proxyToCkClientD common.Dialer
if ai.Unordered { if ai.Unordered {
addrCh := make(chan *net.UDPAddr, 1) addrCh := make(chan *net.UDPAddr, 1)
@ -152,23 +200,23 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a
addrCh <- conn.LocalAddr().(*net.UDPAddr) addrCh <- conn.LocalAddr().(*net.UDPAddr)
return conn, err return conn, err
} }
go client.RouteUDP(acceptor, lcc.Timeout, clientSeshMaker, useSessionPerConnection) go client.RouteUDP(acceptor, lcc.Timeout, clientSeshMaker)
proxyToCkClientD = mDialer proxyToCkClientD = mDialer
} else { } else {
var proxyToCkClientL *connutil.PipeListener var proxyToCkClientL *connutil.PipeListener
proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024) proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024)
go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker, useSessionPerConnection) go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker)
} }
// set up server // set up server
ckServerToProxyD, ckServerToProxyL := connutil.DialerListener(10 * 1024) ckServerToProxyD, proxyFromCkServerL := connutil.DialerListener(10 * 1024)
ckServerToWebD, ckServerToWebL := connutil.DialerListener(10 * 1024) ckServerToWebD, redirFromCkServerL := connutil.DialerListener(10 * 1024)
serverState.ProxyDialer = ckServerToProxyD serverState.ProxyDialer = ckServerToProxyD
serverState.RedirDialer = ckServerToWebD serverState.RedirDialer = ckServerToWebD
go server.Serve(ckServerListener, serverState) go server.Serve(ckServerListener, serverState)
return proxyToCkClientD, ckServerToProxyL, ckClientDialer, ckServerToWebL, nil return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil
} }
func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
@ -206,18 +254,16 @@ func TestUDP(t *testing.T) {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState) lcc, rcc, ai := generateClientConfigs(udpClientConfigs["basic"], worldState)
ai.ProxyMethod = "udp"
ai.Unordered = true
sta := basicServerState(worldState, tmpDB) sta := basicServerState(worldState, tmpDB)
pxyClientD, pxyServerL, _, _, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Run("simple send", func(t *testing.T) { t.Run("simple send", func(t *testing.T) {
pxyClientConn, err := pxyClientD.Dial("udp", "") pxyClientConn, err := proxyToCkClientD.Dial("udp", "")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -233,7 +279,7 @@ func TestUDP(t *testing.T) {
t.Error(err) t.Error(err)
} }
pxyServerConn, err := pxyServerL.ListenPacket("", "") pxyServerConn, err := proxyFromCkServerL.ListenPacket("", "")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -248,9 +294,9 @@ func TestUDP(t *testing.T) {
}) })
t.Run("user echo", func(t *testing.T) { t.Run("user echo", func(t *testing.T) {
go serveUDPEcho(pxyServerL) go serveUDPEcho(proxyFromCkServerL)
var conn [1]net.Conn var conn [1]net.Conn
conn[0], err = pxyClientD.Dial("udp", "") conn[0], err = proxyToCkClientD.Dial("udp", "")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -260,16 +306,70 @@ func TestUDP(t *testing.T) {
} }
func TestTCP(t *testing.T) { func TestTCPSingleplex(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(tcpClientConfigs["singleplex"], worldState)
var tmpDB, _ = ioutil.TempFile("", "ck_user_info") var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name()) defer os.Remove(tmpDB.Name())
log.SetLevel(log.ErrorLevel) sta := basicServerState(worldState, tmpDB)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
go serveTCPEcho(proxyFromCkServerL)
proxyConn1, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Error(err)
}
_, err = proxyConn1.Write([]byte("hello"))
if err != nil {
t.Error(err)
}
// make sure the server has accepted the connection before fetching the server
proxyConn1.Read(make([]byte, 10))
user, err := sta.Panel.GetUser(ai.UID[:])
if err != nil {
t.Fatalf("failed to fetch user: %v", err)
}
if user.NumSession() != 1 {
t.Error("no session were made on first connection establishment")
}
proxyConn2, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Error(err)
}
proxyConn2.Write([]byte("hello"))
// make sure the server has accepted the connection before fetching the server
proxyConn2.Read(make([]byte, 10))
if user.NumSession() != 2 {
t.Error("no extra session were made on second connection establishment")
}
proxyConn1.Close()
time.Sleep(delayBeforeTestingConnClose)
if user.NumSession() != 1 {
t.Error("first session was not closed on connection close")
}
}
func TestTCPMultiplex(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState)
lcc, rcc, ai := generateClientConfigs(tcpClientConfigs["basic"], worldState)
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
sta := basicServerState(worldState, tmpDB) sta := basicServerState(worldState, tmpDB)
pxyClientD, pxyServerL, dialerToCkServer, rdirServerL, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -280,8 +380,8 @@ func TestTCP(t *testing.T) {
writeData := make([]byte, dataLen) writeData := make([]byte, dataLen)
rand.Read(writeData) rand.Read(writeData)
t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) { t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) {
go serveTCPEcho(pxyServerL) go serveTCPEcho(proxyFromCkServerL)
conn, err := pxyClientD.Dial("", "") conn, err := proxyToCkClientD.Dial("", "")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -307,10 +407,10 @@ func TestTCP(t *testing.T) {
}) })
t.Run("user echo", func(t *testing.T) { t.Run("user echo", func(t *testing.T) {
go serveTCPEcho(pxyServerL) go serveTCPEcho(proxyFromCkServerL)
var conns [numConns]net.Conn var conns [numConns]net.Conn
for i := 0; i < numConns; i++ { for i := 0; i < numConns; i++ {
conns[i], err = pxyClientD.Dial("", "") conns[i], err = proxyToCkClientD.Dial("", "")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -320,10 +420,10 @@ func TestTCP(t *testing.T) {
}) })
t.Run("redir echo", func(t *testing.T) { t.Run("redir echo", func(t *testing.T) {
go serveTCPEcho(rdirServerL) go serveTCPEcho(redirFromCkServerL)
var conns [numConns]net.Conn var conns [numConns]net.Conn
for i := 0; i < numConns; i++ { for i := 0; i < numConns; i++ {
conns[i], err = dialerToCkServer.Dial("", "") conns[i], err = netToCkServerD.Dial("", "")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -333,62 +433,70 @@ func TestTCP(t *testing.T) {
} }
func TestClosingStreamsFromProxy(t *testing.T) { func TestClosingStreamsFromProxy(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState)
sta := basicServerState(worldState, tmpDB) for clientConfigName, clientConfig := range tcpClientConfigs {
pxyClientD, pxyServerL, _, _, err := establishSession(lcc, rcc, ai, sta) clientConfig := clientConfig
if err != nil { clientConfigName := clientConfigName
t.Fatal(err) t.Run(clientConfigName, func(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
lcc, rcc, ai := generateClientConfigs(clientConfig, worldState)
sta := basicServerState(worldState, tmpDB)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("closing from server", func(t *testing.T) {
clientConn, _ := proxyToCkClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := proxyFromCkServerL.Accept()
serverConn.Close()
time.Sleep(delayBeforeTestingConnClose)
if _, err := clientConn.Read(make([]byte, 16)); err == nil {
t.Errorf("closing stream on server side is not reflected to the client: %v", err)
}
})
t.Run("closing from client", func(t *testing.T) {
// closing stream on client side
clientConn, _ := proxyToCkClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := proxyFromCkServerL.Accept()
clientConn.Close()
time.Sleep(delayBeforeTestingConnClose)
if _, err := serverConn.Read(make([]byte, 16)); err == nil {
t.Errorf("closing stream on client side is not reflected to the server: %v", err)
}
})
t.Run("send then close", func(t *testing.T) {
testData := make([]byte, 24*1024)
rand.Read(testData)
clientConn, _ := proxyToCkClientD.Dial("", "")
go func() {
clientConn.Write(testData)
// TODO: this is time dependent. It could be due to the time it took for this
// connutil.StreamPipe's Close to be reflected on the copy function, instead of inherent bad sync
// in multiplexer
time.Sleep(10 * time.Millisecond)
clientConn.Close()
}()
readBuf := make([]byte, len(testData))
serverConn, _ := proxyFromCkServerL.Accept()
_, err := io.ReadFull(serverConn, readBuf)
if err != nil {
t.Errorf("failed to read data sent before closing: %v", err)
}
})
})
} }
t.Run("closing from server", func(t *testing.T) {
clientConn, _ := pxyClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := pxyServerL.Accept()
serverConn.Close()
time.Sleep(500 * time.Millisecond)
if _, err := clientConn.Read(make([]byte, 16)); err == nil {
t.Errorf("closing stream on server side is not reflected to the client: %v", err)
}
})
t.Run("closing from client", func(t *testing.T) {
// closing stream on client side
clientConn, _ := pxyClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := pxyServerL.Accept()
clientConn.Close()
time.Sleep(500 * time.Millisecond)
if _, err := serverConn.Read(make([]byte, 16)); err == nil {
t.Errorf("closing stream on client side is not reflected to the server: %v", err)
}
})
t.Run("send then close", func(t *testing.T) {
testData := make([]byte, 24*1024)
rand.Read(testData)
clientConn, _ := pxyClientD.Dial("", "")
go func() {
clientConn.Write(testData)
// TODO: this is time dependent. It could be due to the time it took for this
// connutil.StreamPipe's Close to be reflected on the copy function, instead of inherent bad sync
// in multiplexer
time.Sleep(10 * time.Millisecond)
clientConn.Close()
}()
readBuf := make([]byte, len(testData))
serverConn, _ := pxyServerL.Accept()
_, err := io.ReadFull(serverConn, readBuf)
if err != nil {
t.Errorf("failed to read data sent before closing: %v", err)
}
})
} }
func BenchmarkThroughput(b *testing.B) { func BenchmarkThroughput(b *testing.B) {
@ -396,7 +504,7 @@ func BenchmarkThroughput(b *testing.B) {
defer os.Remove(tmpDB.Name()) defer os.Remove(tmpDB.Name())
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState) lcc, rcc, ai := generateClientConfigs(tcpClientConfigs["basic"], worldState)
sta := basicServerState(worldState, tmpDB) sta := basicServerState(worldState, tmpDB)
const bufSize = 16 * 1024 const bufSize = 16 * 1024
@ -409,7 +517,7 @@ func BenchmarkThroughput(b *testing.B) {
for name, method := range encryptionMethods { for name, method := range encryptionMethods {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
ai.EncryptionMethod = method ai.EncryptionMethod = method
pxyClientD, pxyServerL, _, _, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -418,13 +526,13 @@ func BenchmarkThroughput(b *testing.B) {
more := make(chan int, 10) more := make(chan int, 10)
go func() { go func() {
writeBuf := make([]byte, bufSize+100) writeBuf := make([]byte, bufSize+100)
serverConn, _ := pxyServerL.Accept() serverConn, _ := proxyFromCkServerL.Accept()
for { for {
serverConn.Write(writeBuf) serverConn.Write(writeBuf)
<-more <-more
} }
}() }()
clientConn, _ := pxyClientD.Dial("", "") clientConn, _ := proxyToCkClientD.Dial("", "")
readBuf := make([]byte, bufSize) readBuf := make([]byte, bufSize)
clientConn.Write([]byte{1}) // to make server accept clientConn.Write([]byte{1}) // to make server accept
b.SetBytes(bufSize) b.SetBytes(bufSize)
@ -435,32 +543,6 @@ func BenchmarkThroughput(b *testing.B) {
} }
}) })
/*
b.Run("multiconn", func(b *testing.B) {
writeBuf := make([]byte, bufSize)
b.SetBytes(bufSize)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ready := make(chan int, 10)
go func() {
serverConn, _ := pxyServerL.Accept()
for {
serverConn.Write(writeBuf)
<-ready
}
}()
readBuf := make([]byte, bufSize)
clientConn, _ := pxyClientD.Dial("", "")
clientConn.Write([]byte{1}) // to make server accept
for pb.Next() {
io.ReadFull(clientConn,readBuf)
ready <- 0
}
})
})
*/
}) })
} }