From 61b1031da64f25db39fa95fb38df0245d939e9df Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 6 Dec 2020 10:50:45 +0000 Subject: [PATCH 1/4] Reduce code duplication in session closing --- internal/multiplex/session.go | 37 +++++++++++++++-------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index d86fd4e..e1ab275 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -253,8 +253,7 @@ func (sesh *Session) TerminalMsg() string { } } -func (sesh *Session) passiveClose() error { - log.Debugf("attempting to passively close session %v", sesh.id) +func (sesh *Session) closeSession(closeSwitchboard bool) error { if atomic.SwapUint32(&sesh.closed, 1) == 1 { log.Debugf("session %v has already been closed", sesh.id) return errRepeatSessionClosing @@ -273,7 +272,18 @@ func (sesh *Session) passiveClose() error { return true }) - sesh.sb.closeAll() + if closeSwitchboard { + sesh.sb.closeAll() + } + return nil +} + +func (sesh *Session) passiveClose() error { + log.Debugf("attempting to passively close session %v", sesh.id) + err := sesh.closeSession(true) + if err != nil { + return err + } log.Debugf("session %v closed gracefully", sesh.id) return nil } @@ -288,25 +298,10 @@ func genRandomPadding() []byte { func (sesh *Session) Close() error { log.Debugf("attempting to actively close session %v", sesh.id) - if atomic.SwapUint32(&sesh.closed, 1) == 1 { - log.Debugf("session %v has already been closed", sesh.id) - return errRepeatSessionClosing + err := sesh.closeSession(false) + if err != nil { + return err } - sesh.acceptCh <- nil - - // close all streams - sesh.streams.Range(func(key, streamI interface{}) bool { - if streamI == nil { - return true - } - stream := streamI.(*Stream) - atomic.StoreUint32(&stream.closed, 1) - _ = stream.recvBuf.Close() // will not block - sesh.streams.Delete(key) - sesh.streamCountDecr() - return true - }) - // we send a notice frame telling remote to close the session pad := genRandomPadding() f := &Frame{ From 061b10e8023ee480917d2150663b8c6eb441afbf Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 6 Dec 2020 11:14:33 +0000 Subject: [PATCH 2/4] Improve tests code quality --- internal/multiplex/mux_test.go | 19 +-- internal/multiplex/session_test.go | 192 ++++++++++++----------------- 2 files changed, 87 insertions(+), 124 deletions(-) diff --git a/internal/multiplex/mux_test.go b/internal/multiplex/mux_test.go index e23d8a1..436c407 100644 --- a/internal/multiplex/mux_test.go +++ b/internal/multiplex/mux_test.go @@ -63,22 +63,22 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) { return clientSession, serverSession, paris } -func runEchoTest(t *testing.T, streams []*Stream) { - const testDataLen = 16384 +func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { var wg sync.WaitGroup - for _, stream := range streams { + for _, conn := range conns { wg.Add(1) - go func(stream *Stream) { + go func(conn net.Conn) { + testDataLen := rand.Intn(maxMsgLen) testData := make([]byte, testDataLen) rand.Read(testData) - n, err := stream.Write(testData) + n, err := conn.Write(testData) if n != testDataLen { t.Fatalf("written only %v, err %v", n, err) } recvBuf := make([]byte, testDataLen) - _, err = io.ReadFull(stream, recvBuf) + _, err = io.ReadFull(conn, recvBuf) if err != nil { t.Fatalf("failed to read back: %v", err) } @@ -87,7 +87,7 @@ func runEchoTest(t *testing.T, streams []*Stream) { t.Fatalf("echoed data not correct") } wg.Done() - }(stream) + }(conn) } wg.Wait() } @@ -95,11 +95,12 @@ func runEchoTest(t *testing.T, streams []*Stream) { func TestMultiplex(t *testing.T) { const numStreams = 2000 // -race option limits the number of goroutines to 8192 const numConns = 4 + const maxMsgLen = 16384 clientSession, serverSession, _ := makeSessionPair(numConns) go serveEcho(serverSession) - streams := make([]*Stream, numStreams) + streams := make([]net.Conn, numStreams) for i := 0; i < numStreams; i++ { stream, err := clientSession.OpenStream() if err != nil { @@ -109,7 +110,7 @@ func TestMultiplex(t *testing.T) { } //test echo - runEchoTest(t, streams) + runEchoTest(t, streams, maxMsgLen) if clientSession.streamCount() != numStreams { t.Errorf("client stream count is wrong: %v", clientSession.streamCount()) } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index b40d32c..52fe6a5 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -17,8 +17,10 @@ var seshConfigUnordered = SessionConfig{ Unordered: true, } +const testPayloadLen = 1024 +const obfsBufLen = testPayloadLen * 2 + func TestRecvDataFromRemote(t *testing.T) { - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) f := &Frame{ @@ -27,126 +29,88 @@ func TestRecvDataFromRemote(t *testing.T) { 0, testPayload, } - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) - t.Run("plain ordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - err := sesh.recvDataFromRemote(obfsBuf[:n]) + MakeObfuscatorUnwrap := func(method byte, sessionKey [32]byte) Obfuscator { + ret, err := MakeObfuscator(method, sessionKey) if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return + t.Fatalf("failed to make an obfuscator: %v", err) } + return ret + } - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) - t.Run("aes-gcm ordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) + sessionTypes := []struct { + name string + config SessionConfig + }{ + {"ordered", + SessionConfig{}}, + {"unordered", + SessionConfig{Unordered: true}}, + } - err := sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return - } + encryptionMethods := []struct { + name string + obfuscator Obfuscator + }{ + { + "plain", + MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), + }, + { + "aes-gcm", + MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), + }, + { + "chacha20-poly1305", + MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), + }, + } - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) - t.Run("chacha20-poly1305 ordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) + for _, st := range sessionTypes { + t.Run(st.name, func(t *testing.T) { + for _, em := range encryptionMethods { + t.Run(em.name, func(t *testing.T) { + st.config.Obfuscator = em.obfuscator + sesh := MakeSession(0, st.config) + n, err := sesh.Obfs(f, obfsBuf, 0) + if err != nil { + t.Error(err) + return + } + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Error(err) + return + } + stream, err := sesh.Accept() + if err != nil { + t.Error(err) + return + } - err := sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return - } - - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) - - t.Run("plain unordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigUnordered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - - err := sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return - } - - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) + resultPayload := make([]byte, testPayloadLen) + _, err = stream.Read(resultPayload) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(testPayload, resultPayload) { + t.Errorf("Expecting %x, got %x", testPayload, resultPayload) + } + }) + } + }) + } } func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) @@ -273,10 +237,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { // Tests for when the closing frame of a stream is received first before any data frame - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) @@ -354,7 +317,7 @@ func TestParallelStreams(t *testing.T) { } } - numOfTests := 5000 + const numOfTests = 5000 tests := make([]struct { name string frame *Frame @@ -368,11 +331,11 @@ func TestParallelStreams(t *testing.T) { for _, tc := range tests { wg.Add(1) go func(frame *Frame) { - data := make([]byte, 1000) - n, _ := sesh.Obfs(frame, data, 0) - data = data[0:n] + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(frame, obfsBuf, 0) + obfsBuf = obfsBuf[0:n] - err := sesh.recvDataFromRemote(data) + err := sesh.recvDataFromRemote(obfsBuf) if err != nil { t.Error(err) } @@ -452,7 +415,6 @@ func TestSession_timeoutAfter(t *testing.T) { } func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) f := &Frame{ @@ -461,7 +423,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { 0, testPayload, } - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) From caca33a1a53958e2fdbc2c4ebbff83ee8cf7c39b Mon Sep 17 00:00:00 2001 From: notsure2 Date: Mon, 7 Dec 2020 22:28:03 +0200 Subject: [PATCH 3/4] Respect user choice of ProxyMethod in shadowsocks plugin mode. --- cmd/ck-client/ck-client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 9c0cafa..d48f788 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -85,7 +85,9 @@ func main() { } if ssPluginMode { - rawConfig.ProxyMethod = "shadowsocks" + if rawConfig.ProxyMethod == "" { + rawConfig.ProxyMethod = "shadowsocks" + } // json takes precedence over environment variables // i.e. if json field isn't empty, use that if rawConfig.RemoteHost == "" { From e77fd4c446f9e0933008c28032e594eaf55631d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=8D=E7=A1=AE=E5=AE=9A?= <35424927+notsure2@users.noreply.github.com> Date: Sat, 12 Dec 2020 19:00:46 +0200 Subject: [PATCH 4/4] Fix regression: termination of long downloads after StreamTimeout seconds (#141) * Fix termination of long downloads after StreamTimeout seconds. - Even if not broadcasting in a loop, we still need to update the read deadline. - Don't enforce the timeout after the first data is written. * When timeout no longer needs to be enforced, no need to schedule a broadcast. * Fix Cloak client. Don't enforce read deadline after first read. * Enforce StreamTimeout on the initial bytes sent by localConn only. * Revert changes to multiplex module. Remove timeout from caller. --- internal/client/piper.go | 8 +++++--- internal/server/dispatcher.go | 2 -- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/client/piper.go b/internal/client/piper.go index fd1746a..295e3e8 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -85,14 +85,17 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu if sesh == nil || sesh.IsClosed() || sesh.Singleplex { sesh = newSeshFunc() } - go func(sesh *mux.Session, localConn net.Conn) { + go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) { data := make([]byte, 10240) + _ = localConn.SetReadDeadline(time.Now().Add(streamTimeout)) i, err := io.ReadAtLeast(localConn, data, 1) if err != nil { log.Errorf("Failed to read first packet from proxy client: %v", err) localConn.Close() return } + var zeroTime time.Time + _ = localConn.SetReadDeadline(zeroTime) stream, err := sesh.OpenStream() if err != nil { @@ -112,7 +115,6 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu return } - stream.SetReadFromTimeout(streamTimeout) // if localConn hasn't sent anything to stream to a period of time, stream closes go func() { if _, err := common.Copy(localConn, stream); err != nil { log.Tracef("copying stream to proxy client: %v", err) @@ -121,7 +123,7 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu if _, err = common.Copy(stream, localConn); err != nil { log.Tracef("copying proxy client to stream: %v", err) } - }(sesh, localConn) + }(sesh, localConn, streamTimeout) } } diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 56a556f..4fa0698 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -275,8 +275,6 @@ func serveSession(sesh *mux.Session, ci ClientInfo, user *ActiveUser, sta *State } log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod) - // if stream has nothing to send to proxy server for sta.Timeout period of time, stream will return error - newStream.(*mux.Stream).SetWriteToTimeout(sta.Timeout) go func() { if _, err := common.Copy(localConn, newStream); err != nil { log.Tracef("copying stream to proxy server: %v", err)