From 3e737717bd9beea0ee65ba487766e6ff8ff19542 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 12:25:58 +0000 Subject: [PATCH 01/18] Use assert.Eventually to correctly handle more timing sensitive tests --- internal/multiplex/stream_test.go | 37 +++++++++++++------------------ 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index c0b86fb..35fae9d 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -152,15 +152,6 @@ func TestStream_Close(t *testing.T) { return } - // we read something to wait for the test frame to reach our recvBuffer. - // if it's empty by the point we call stream.Close(), the incoming - // frame will be dropped - readBuf := make([]byte, len(testPayload)) - _, err = io.ReadFull(stream, readBuf[:1]) - if err != nil { - t.Errorf("can't read any data before active closing") - } - err = stream.Close() if err != nil { t.Error("failed to actively close stream", err) @@ -175,10 +166,11 @@ func TestStream_Close(t *testing.T) { } sesh.streamsM.Unlock() - _, err = io.ReadFull(stream, readBuf[1:]) - if err != nil { - t.Errorf("can't read residual data %v", err) - } + readBuf := make([]byte, len(testPayload)) + assert.Eventually(t, func() bool { + _, err = io.ReadFull(stream, readBuf) + return err == nil + }, time.Second, 10*time.Millisecond, "can't read residual data", err) if !bytes.Equal(readBuf, testPayload) { t.Errorf("read wrong data") } @@ -324,10 +316,12 @@ func TestStream_Read(t *testing.T) { writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() stream.Close() - i, err := stream.Read(buf) - if err != nil { - t.Error("failed to read", err) - } + + var err error + assert.Eventually(t, func() bool { + i, err = stream.Read(buf) + return err == nil + }, time.Second, 10*time.Millisecond, "failed to read", err) if i != smallPayloadLen { t.Errorf("expected read %v, got %v", smallPayloadLen, i) } @@ -348,10 +342,11 @@ func TestStream_Read(t *testing.T) { writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() sesh.Close() - i, err := stream.Read(buf) - if err != nil { - t.Error("failed to read", err) - } + var err error + assert.Eventually(t, func() bool { + i, err = stream.Read(buf) + return err == nil + }, time.Second, 10*time.Millisecond, "failed to read", err) if i != smallPayloadLen { t.Errorf("expected read %v, got %v", smallPayloadLen, i) } From 5c5e9f8c14c83ee9c3bf9fa513e2ce97e8fb3916 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 17:28:28 +0000 Subject: [PATCH 02/18] Prevent unnecessary allocation in stream closing --- internal/multiplex/session.go | 50 +++++++++++++++++------------------ 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 6f165b0..fa9fc14 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -174,26 +174,27 @@ func (sesh *Session) Accept() (net.Conn, error) { } func (sesh *Session) closeStream(s *Stream, active bool) error { - // must be holding s.wirtingM on entry if atomic.SwapUint32(&s.closed, 1) == 1 { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } _ = s.getRecvBuf().Close() // recvBuf.Close should not return error if active { - // Notify remote that this stream is closed - padding := genRandomPadding() - s.writingFrame.Closing = closingStream - s.writingFrame.Payload = padding - - obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead) - - i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0) - s.writingFrame.Seq++ - if err != nil { - return err + // must be holding s.wirtingM on entry + if len(s.obfsBuf) < 256+frameHeaderLength+sesh.Obfuscator.maxOverhead { + s.obfsBuf = make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) } - _, err = sesh.sb.send(obfsBuf[:i], &s.assignedConnId) + + // Notify remote that this stream is closed + common.CryptoRandRead(s.obfsBuf[:1]) + padLen := int(s.obfsBuf[0]) + 1 + payload := s.obfsBuf[frameHeaderLength : padLen+frameHeaderLength] + common.CryptoRandRead(payload) + + s.writingFrame.Closing = closingStream + s.writingFrame.Payload = payload + + err := s.obfuscateAndSend(frameHeaderLength) if err != nil { return err } @@ -304,14 +305,6 @@ func (sesh *Session) passiveClose() error { return nil } -func genRandomPadding() []byte { - lenB := make([]byte, 1) - common.CryptoRandRead(lenB) - pad := make([]byte, int(lenB[0])+1) - common.CryptoRandRead(pad) - return pad -} - func (sesh *Session) Close() error { log.Debugf("attempting to actively close session %v", sesh.id) err := sesh.closeSession(false) @@ -319,19 +312,24 @@ func (sesh *Session) Close() error { return err } // we send a notice frame telling remote to close the session - pad := genRandomPadding() + + padBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) + common.CryptoRandRead(padBuf[:1]) + padLen := int(padBuf[0]) + 1 + payload := padBuf[frameHeaderLength : padLen+frameHeaderLength] + common.CryptoRandRead(payload) + f := &Frame{ StreamID: 0xffffffff, Seq: 0, Closing: closingSession, - Payload: pad, + Payload: payload, } - obfsBuf := make([]byte, len(pad)+frameHeaderLength+sesh.Obfuscator.maxOverhead) - i, err := sesh.Obfs(f, obfsBuf, 0) + i, err := sesh.Obfs(f, padBuf, frameHeaderLength) if err != nil { return err } - _, err = sesh.sb.send(obfsBuf[:i], new(uint32)) + _, err = sesh.sb.send(padBuf[:i], new(uint32)) if err != nil { return err } From 9108794362c31aafce945906d2049dbe132fa163 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 17:42:54 +0000 Subject: [PATCH 03/18] Fix race condition in allocating obfsBuf --- internal/multiplex/stream.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index d827117..90d5c16 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -174,9 +174,11 @@ func (s *Stream) Write(in []byte) (n int, err error) { // ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // for readFromTimeout amount of time func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { + s.writingM.Lock() if s.obfsBuf == nil { s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) } + s.writingM.Unlock() for { if s.readFromTimeout != 0 { if rder, ok := r.(net.Conn); !ok { From 53f0116c1d6c351b7b4927b6273d91986bb69be7 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 17:43:18 +0000 Subject: [PATCH 04/18] Fix race condition in the use of assert.Eventually --- internal/multiplex/stream_test.go | 36 +++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 35fae9d..a28cfc6 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "math/rand" + "sync" "testing" "time" @@ -167,10 +168,18 @@ func TestStream_Close(t *testing.T) { sesh.streamsM.Unlock() readBuf := make([]byte, len(testPayload)) + var wg sync.WaitGroup + wg.Add(1) assert.Eventually(t, func() bool { _, err = io.ReadFull(stream, readBuf) - return err == nil + if err == nil { + wg.Done() + return true + } else { + return false + } }, time.Second, 10*time.Millisecond, "can't read residual data", err) + wg.Wait() if !bytes.Equal(readBuf, testPayload) { t.Errorf("read wrong data") } @@ -262,9 +271,6 @@ func TestStream_Read(t *testing.T) { } var streamID uint32 - buf := make([]byte, 10) - - obfsBuf := make([]byte, 512) for name, unordered := range seshes { sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain) @@ -272,6 +278,8 @@ func TestStream_Read(t *testing.T) { sesh.AddConnection(common.NewTLSConn(rawConn)) writingEnd := common.NewTLSConn(rawWritingEnd) t.Run(name, func(t *testing.T) { + buf := make([]byte, 10) + obfsBuf := make([]byte, 512) t.Run("Plain read", func(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf, 0) @@ -318,10 +326,18 @@ func TestStream_Read(t *testing.T) { stream.Close() var err error + var wg sync.WaitGroup + wg.Add(1) assert.Eventually(t, func() bool { i, err = stream.Read(buf) - return err == nil + if err == nil { + wg.Done() + return true + } else { + return false + } }, time.Second, 10*time.Millisecond, "failed to read", err) + wg.Wait() if i != smallPayloadLen { t.Errorf("expected read %v, got %v", smallPayloadLen, i) } @@ -343,10 +359,18 @@ func TestStream_Read(t *testing.T) { stream, _ := sesh.Accept() sesh.Close() var err error + var wg sync.WaitGroup + wg.Add(1) assert.Eventually(t, func() bool { i, err = stream.Read(buf) - return err == nil + if err == nil { + wg.Done() + return true + } else { + return false + } }, time.Second, 10*time.Millisecond, "failed to read", err) + wg.Wait() if i != smallPayloadLen { t.Errorf("expected read %v, got %v", smallPayloadLen, i) } From 0209bcd977ce28273cfed0363adee4c88530edab Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 22:33:01 +0000 Subject: [PATCH 05/18] Fix race condition in steam closing. Fall back to temp buffer allocation --- internal/multiplex/session.go | 20 ++++++++++++-------- internal/multiplex/stream.go | 6 +++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index fa9fc14..0a961f1 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -180,21 +180,25 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { _ = s.getRecvBuf().Close() // recvBuf.Close should not return error if active { - // must be holding s.wirtingM on entry - if len(s.obfsBuf) < 256+frameHeaderLength+sesh.Obfuscator.maxOverhead { - s.obfsBuf = make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) - } + tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) // Notify remote that this stream is closed - common.CryptoRandRead(s.obfsBuf[:1]) - padLen := int(s.obfsBuf[0]) + 1 - payload := s.obfsBuf[frameHeaderLength : padLen+frameHeaderLength] + common.CryptoRandRead(tmpBuf[:1]) + padLen := int(tmpBuf[0]) + 1 + payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength] common.CryptoRandRead(payload) + // must be holding s.wirtingM on entry s.writingFrame.Closing = closingStream s.writingFrame.Payload = payload - err := s.obfuscateAndSend(frameHeaderLength) + cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength) + s.writingFrame.Seq++ + if err != nil { + return err + } + + _, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId) if err != nil { return err } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 90d5c16..84f106a 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -118,7 +118,6 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { } func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { - var cipherTextLen int cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) if err != nil { return err @@ -174,11 +173,9 @@ func (s *Stream) Write(in []byte) (n int, err error) { // ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // for readFromTimeout amount of time func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { - s.writingM.Lock() if s.obfsBuf == nil { s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) } - s.writingM.Unlock() for { if s.readFromTimeout != 0 { if rder, ok := r.(net.Conn); !ok { @@ -191,6 +188,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { if er != nil { return n, er } + + // the above read may have been unblocked by another goroutine calling stream.Close(), so we need + // to check that here if s.isClosed() { return n, ErrBrokenStream } From a97f5759c0c07d7d1244f400c200e0f6273e6633 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 22:34:12 +0000 Subject: [PATCH 06/18] Fix race condition in tests --- internal/multiplex/mux_test.go | 12 +++++++----- internal/test/integration_test.go | 10 ++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/internal/multiplex/mux_test.go b/internal/multiplex/mux_test.go index c8c60f4..02b7721 100644 --- a/internal/multiplex/mux_test.go +++ b/internal/multiplex/mux_test.go @@ -10,6 +10,7 @@ import ( "net" "sync" "testing" + "time" ) func serveEcho(l net.Listener) { @@ -19,13 +20,13 @@ func serveEcho(l net.Listener) { // TODO: pass the error back return } - go func() { + go func(conn net.Conn) { _, err := io.Copy(conn, conn) if err != nil { // TODO: pass the error back return } - }() + }(conn) } } @@ -65,18 +66,19 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) { func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) { var wg sync.WaitGroup + testData := make([]byte, msgLen) + rand.Read(testData) + for _, conn := range conns { wg.Add(1) go func(conn net.Conn) { - testData := make([]byte, msgLen) - rand.Read(testData) - n, err := conn.Write(testData) if n != msgLen { t.Fatalf("written only %v, err %v", n, err) } recvBuf := make([]byte, msgLen) + conn.SetReadDeadline(time.Now().Add(time.Second)) _, err = io.ReadFull(conn, recvBuf) if err != nil { t.Fatalf("failed to read back: %v", err) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 5812ba2..9f6c252 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -30,15 +30,14 @@ func serveTCPEcho(l net.Listener) { log.Error(err) return } - go func() { - conn := conn + go func(conn net.Conn) { _, err := io.Copy(conn, conn) if err != nil { conn.Close() log.Error(err) return } - }() + }(conn) } } @@ -50,8 +49,7 @@ func serveUDPEcho(listener *connutil.PipeListener) { return } const bufSize = 32 * 1024 - go func() { - conn := conn + go func(conn net.PacketConn) { defer conn.Close() buf := make([]byte, bufSize) for { @@ -70,7 +68,7 @@ func serveUDPEcho(listener *connutil.PipeListener) { return } } - }() + }(conn) } } From 70a97233778696f4dc1932fde2bb522499901ff1 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 22:34:26 +0000 Subject: [PATCH 07/18] Temp fix to testing reading after closing a stream --- internal/multiplex/stream_test.go | 70 ++++++++++--------------------- 1 file changed, 22 insertions(+), 48 deletions(-) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index a28cfc6..893aa46 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -7,7 +7,6 @@ import ( "io" "io/ioutil" "math/rand" - "sync" "testing" "time" @@ -152,7 +151,7 @@ func TestStream_Close(t *testing.T) { t.Error("failed to accept stream", err) return } - + time.Sleep(500 * time.Millisecond) err = stream.Close() if err != nil { t.Error("failed to actively close stream", err) @@ -168,18 +167,11 @@ func TestStream_Close(t *testing.T) { sesh.streamsM.Unlock() readBuf := make([]byte, len(testPayload)) - var wg sync.WaitGroup - wg.Add(1) - assert.Eventually(t, func() bool { - _, err = io.ReadFull(stream, readBuf) - if err == nil { - wg.Done() - return true - } else { - return false - } - }, time.Second, 10*time.Millisecond, "can't read residual data", err) - wg.Wait() + _, err = io.ReadFull(stream, readBuf) + if err != nil { + t.Errorf("cannot read resiual data: %v", err) + } + if !bytes.Equal(readBuf, testPayload) { t.Errorf("read wrong data") } @@ -323,27 +315,18 @@ func TestStream_Read(t *testing.T) { streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() + + time.Sleep(500 * time.Millisecond) + stream.Close() - var err error - var wg sync.WaitGroup - wg.Add(1) - assert.Eventually(t, func() bool { - i, err = stream.Read(buf) - if err == nil { - wg.Done() - return true - } else { - return false - } - }, time.Second, 10*time.Millisecond, "failed to read", err) - wg.Wait() - if i != smallPayloadLen { - t.Errorf("expected read %v, got %v", smallPayloadLen, i) + _, err := io.ReadFull(stream, buf[:smallPayloadLen]) + if err != nil { + t.Errorf("cannot read residual data: %v", err) } - if !bytes.Equal(buf[:i], testPayload) { + if !bytes.Equal(buf[:smallPayloadLen], testPayload) { t.Error("expected", testPayload, - "got", buf[:i]) + "got", buf[:smallPayloadLen]) } _, err = stream.Read(buf) if err == nil { @@ -357,26 +340,17 @@ func TestStream_Read(t *testing.T) { streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() + + time.Sleep(500 * time.Millisecond) + sesh.Close() - var err error - var wg sync.WaitGroup - wg.Add(1) - assert.Eventually(t, func() bool { - i, err = stream.Read(buf) - if err == nil { - wg.Done() - return true - } else { - return false - } - }, time.Second, 10*time.Millisecond, "failed to read", err) - wg.Wait() - if i != smallPayloadLen { - t.Errorf("expected read %v, got %v", smallPayloadLen, i) + _, err := io.ReadFull(stream, buf[:smallPayloadLen]) + if err != nil { + t.Errorf("cannot read resiual data: %v", err) } - if !bytes.Equal(buf[:i], testPayload) { + if !bytes.Equal(buf[:smallPayloadLen], testPayload) { t.Error("expected", testPayload, - "got", buf[:i]) + "got", buf[:smallPayloadLen]) } _, err = stream.Read(buf) if err == nil { From 3b24c33e78640e86191803eea8964e46ad859b79 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 23:12:51 +0000 Subject: [PATCH 08/18] Remove incorrect concurrent uses of testing.T.Fatal --- internal/multiplex/mux_test.go | 15 +++++++++------ internal/test/integration_test.go | 24 ++++++++++++++---------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/internal/multiplex/mux_test.go b/internal/multiplex/mux_test.go index 02b7721..86dec41 100644 --- a/internal/multiplex/mux_test.go +++ b/internal/multiplex/mux_test.go @@ -10,7 +10,6 @@ import ( "net" "sync" "testing" - "time" ) func serveEcho(l net.Listener) { @@ -72,22 +71,26 @@ func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) { for _, conn := range conns { wg.Add(1) go func(conn net.Conn) { + defer wg.Done() + + // we cannot call t.Fatalf in concurrent contexts n, err := conn.Write(testData) if n != msgLen { - t.Fatalf("written only %v, err %v", n, err) + t.Errorf("written only %v, err %v", n, err) + return } recvBuf := make([]byte, msgLen) - conn.SetReadDeadline(time.Now().Add(time.Second)) _, err = io.ReadFull(conn, recvBuf) if err != nil { - t.Fatalf("failed to read back: %v", err) + t.Errorf("failed to read back: %v", err) + return } if !bytes.Equal(testData, recvBuf) { - t.Fatalf("echoed data not correct") + t.Errorf("echoed data not correct") + return } - wg.Done() }(conn) } wg.Wait() diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 9f6c252..1baf700 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -220,30 +220,34 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil } -func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { +func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) { var wg sync.WaitGroup + testData := make([]byte, msgLen) + rand.Read(testData) + for _, conn := range conns { wg.Add(1) go func(conn net.Conn) { - testDataLen := rand.Intn(maxMsgLen) - testData := make([]byte, testDataLen) - rand.Read(testData) + defer wg.Done() + // we cannot call t.Fatalf in concurrent contexts n, err := conn.Write(testData) - if n != testDataLen { - t.Fatalf("written only %v, err %v", n, err) + if n != msgLen { + t.Errorf("written only %v, err %v", n, err) + return } - recvBuf := make([]byte, testDataLen) + recvBuf := make([]byte, msgLen) _, err = io.ReadFull(conn, recvBuf) if err != nil { - t.Fatalf("failed to read back: %v", err) + t.Errorf("failed to read back: %v", err) + return } if !bytes.Equal(testData, recvBuf) { - t.Fatalf("echoed data not correct") + t.Errorf("echoed data not correct") + return } - wg.Done() }(conn) } wg.Wait() From 4209483a48c6297430e6a8009a48a9b0ecbb3d55 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Wed, 23 Dec 2020 23:58:47 +0000 Subject: [PATCH 09/18] Reduce memory pressure in tests --- internal/test/integration_test.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 1baf700..73ef17e 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -296,6 +296,7 @@ func TestUDP(t *testing.T) { } }) + const echoMsgLen = 1024 t.Run("user echo", func(t *testing.T) { go serveUDPEcho(proxyFromCkServerL) var conn [1]net.Conn @@ -304,7 +305,7 @@ func TestUDP(t *testing.T) { t.Error(err) } - runEchoTest(t, conn[:], 1024) + runEchoTest(t, conn[:], echoMsgLen) }) } @@ -319,13 +320,14 @@ func TestTCPSingleplex(t *testing.T) { t.Fatal(err) } + const echoMsgLen = 16384 go serveTCPEcho(proxyFromCkServerL) proxyConn1, err := proxyToCkClientD.Dial("", "") if err != nil { - t.Error(err) + t.Fatal(err) } - runEchoTest(t, []net.Conn{proxyConn1}, 65536) + runEchoTest(t, []net.Conn{proxyConn1}, echoMsgLen) user, err := sta.Panel.GetUser(ai.UID[:]) if err != nil { t.Fatalf("failed to fetch user: %v", err) @@ -337,15 +339,15 @@ func TestTCPSingleplex(t *testing.T) { proxyConn2, err := proxyToCkClientD.Dial("", "") if err != nil { - t.Error(err) + t.Fatal(err) } - runEchoTest(t, []net.Conn{proxyConn2}, 65536) + runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen) if user.NumSession() != 2 { t.Error("no extra session were made on second connection establishment") } // Both conns should work - runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, 65536) + runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, echoMsgLen) proxyConn1.Close() @@ -354,17 +356,17 @@ func TestTCPSingleplex(t *testing.T) { }, time.Second, 10*time.Millisecond, "first session was not closed on connection close") // conn2 should still work - runEchoTest(t, []net.Conn{proxyConn2}, 65536) + runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen) var conns [numConns]net.Conn for i := 0; i < numConns; i++ { conns[i], err = proxyToCkClientD.Dial("", "") if err != nil { - t.Error(err) + t.Fatal(err) } } - runEchoTest(t, conns[:], 65536) + runEchoTest(t, conns[:], echoMsgLen) } @@ -412,6 +414,7 @@ func TestTCPMultiplex(t *testing.T) { } }) + const echoMsgLen = 16384 t.Run("user echo", func(t *testing.T) { go serveTCPEcho(proxyFromCkServerL) var conns [numConns]net.Conn @@ -422,7 +425,7 @@ func TestTCPMultiplex(t *testing.T) { } } - runEchoTest(t, conns[:], 65536) + runEchoTest(t, conns[:], echoMsgLen) }) t.Run("redir echo", func(t *testing.T) { @@ -434,7 +437,7 @@ func TestTCPMultiplex(t *testing.T) { t.Error(err) } } - runEchoTest(t, conns[:], 65536) + runEchoTest(t, conns[:], echoMsgLen) }) } From 5933ad878173dd8641b752994af6313fe1c9d052 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Thu, 24 Dec 2020 10:26:53 +0000 Subject: [PATCH 10/18] Replace bytes.Buffer with vanilla []byte in tls wrapper --- internal/common/tls.go | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/internal/common/tls.go b/internal/common/tls.go index fd2fce4..fb54e97 100644 --- a/internal/common/tls.go +++ b/internal/common/tls.go @@ -1,7 +1,6 @@ package common import ( - "bytes" "encoding/binary" "io" "net" @@ -44,7 +43,9 @@ func NewTLSConn(conn net.Conn) *TLSConn { return &TLSConn{ Conn: conn, writeBufPool: sync.Pool{New: func() interface{} { - return new(bytes.Buffer) + b := make([]byte, 0, initialWriteBufSize) + b = append(b, ApplicationData, byte(VersionTLS13>>8), byte(VersionTLS13&0xFF)) + return &b }}, } } @@ -93,16 +94,13 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) { msgLen := len(in) - writeBuf := tls.writeBufPool.Get().(*bytes.Buffer) - writeBuf.WriteByte(ApplicationData) - writeBuf.WriteByte(byte(VersionTLS13 >> 8)) - writeBuf.WriteByte(byte(VersionTLS13 & 0xFF)) - writeBuf.WriteByte(byte(msgLen >> 8)) - writeBuf.WriteByte(byte(msgLen & 0xFF)) - writeBuf.Write(in) - i, err := writeBuf.WriteTo(tls.Conn) + writeBuf := tls.writeBufPool.Get().(*[]byte) + *writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF)) + *writeBuf = append(*writeBuf, in...) + n, err = tls.Conn.Write(*writeBuf) + *writeBuf = (*writeBuf)[:3] tls.writeBufPool.Put(writeBuf) - return int(i - recordLayerLength), err + return n - recordLayerLength, err } func (tls *TLSConn) Close() error { From 881f6e6f9d24bd307e41be5ad4926ece0755b38d Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Thu, 24 Dec 2020 11:35:29 +0000 Subject: [PATCH 11/18] Use sync.Pool for obfuscation buffer --- internal/multiplex/session.go | 24 +++++++++++++----------- internal/multiplex/stream.go | 32 ++++++++++++-------------------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 0a961f1..b918107 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -70,6 +70,8 @@ type Session struct { // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame recvFramePool sync.Pool + streamObfsBufPool sync.Pool + // Switchboard manages all connections to remote sb *switchboard @@ -117,6 +119,11 @@ func MakeSession(id uint32, config SessionConfig) *Session { // todo: validation. this must be smaller than StreamSendBufferSize sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead + sesh.streamObfsBufPool = sync.Pool{New: func() interface{} { + b := make([]byte, sesh.StreamSendBufferSize) + return &b + }} + sesh.sb = makeSwitchboard(sesh) time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout) return sesh @@ -180,25 +187,20 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { _ = s.getRecvBuf().Close() // recvBuf.Close should not return error if active { - tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) + tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte) // Notify remote that this stream is closed - common.CryptoRandRead(tmpBuf[:1]) - padLen := int(tmpBuf[0]) + 1 - payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength] + common.CryptoRandRead((*tmpBuf)[:1]) + padLen := int((*tmpBuf)[0]) + 1 + payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength] common.CryptoRandRead(payload) // must be holding s.wirtingM on entry s.writingFrame.Closing = closingStream s.writingFrame.Payload = payload - cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength) - s.writingFrame.Seq++ - if err != nil { - return err - } - - _, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId) + err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength) + sesh.streamObfsBufPool.Put(tmpBuf) if err != nil { return err } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 84f106a..b29359f 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -34,11 +34,6 @@ type Stream struct { // atomic closed uint32 - // obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from - // the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste - // memory - obfsBuf []byte - // When we want order guarantee (i.e. session.Unordered is false), // we assign each stream a fixed underlying connection. // If the underlying connections the session uses provide ordering guarantee (most likely TCP), @@ -117,13 +112,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { return n, nil } -func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { - cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) +func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { + cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf) + s.writingFrame.Seq++ if err != nil { return err } - _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) + _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) @@ -142,9 +138,6 @@ func (s *Stream) Write(in []byte) (n int, err error) { return 0, ErrBrokenStream } - if s.obfsBuf == nil { - s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) - } for n < len(in) { var framePayload []byte if len(in)-n <= s.session.maxStreamUnitWrite { @@ -160,8 +153,9 @@ func (s *Stream) Write(in []byte) (n int, err error) { framePayload = in[n : s.session.maxStreamUnitWrite+n] } s.writingFrame.Payload = framePayload - err = s.obfuscateAndSend(0) - s.writingFrame.Seq++ + buf := s.session.streamObfsBufPool.Get().(*[]byte) + err = s.obfuscateAndSend(*buf, 0) + s.session.streamObfsBufPool.Put(buf) if err != nil { return } @@ -173,9 +167,6 @@ func (s *Stream) Write(in []byte) (n int, err error) { // ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // for readFromTimeout amount of time func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { - if s.obfsBuf == nil { - s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) - } for { if s.readFromTimeout != 0 { if rder, ok := r.(net.Conn); !ok { @@ -184,7 +175,8 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { rder.SetReadDeadline(time.Now().Add(s.readFromTimeout)) } } - read, er := r.Read(s.obfsBuf[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite]) + buf := s.session.streamObfsBufPool.Get().(*[]byte) + read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite]) if er != nil { return n, er } @@ -196,10 +188,10 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { } s.writingM.Lock() - s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read] - err = s.obfuscateAndSend(frameHeaderLength) - s.writingFrame.Seq++ + s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read] + err = s.obfuscateAndSend(*buf, frameHeaderLength) s.writingM.Unlock() + s.session.streamObfsBufPool.Put(buf) if err != nil { return From 4f34e690060d6797849e37413b9cdfeb839f80a3 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Thu, 24 Dec 2020 13:42:22 +0000 Subject: [PATCH 12/18] Use pooled buffer for session closing frame --- internal/multiplex/session.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index b918107..159172e 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -319,10 +319,10 @@ func (sesh *Session) Close() error { } // we send a notice frame telling remote to close the session - padBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) - common.CryptoRandRead(padBuf[:1]) - padLen := int(padBuf[0]) + 1 - payload := padBuf[frameHeaderLength : padLen+frameHeaderLength] + buf := sesh.streamObfsBufPool.Get().(*[]byte) + common.CryptoRandRead((*buf)[:1]) + padLen := int((*buf)[0]) + 1 + payload := (*buf)[frameHeaderLength : padLen+frameHeaderLength] common.CryptoRandRead(payload) f := &Frame{ @@ -331,11 +331,11 @@ func (sesh *Session) Close() error { Closing: closingSession, Payload: payload, } - i, err := sesh.Obfs(f, padBuf, frameHeaderLength) + i, err := sesh.Obfs(f, *buf, frameHeaderLength) if err != nil { return err } - _, err = sesh.sb.send(padBuf[:i], new(uint32)) + _, err = sesh.sb.send((*buf)[:i], new(uint32)) if err != nil { return err } From 2f17841f8579bd5151127aa1839c3e6262c61f18 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Fri, 25 Dec 2020 23:16:57 +0000 Subject: [PATCH 13/18] Use Compare-And-Swap for atomic booleans indicating session and switchboard closed --- internal/multiplex/session.go | 29 ++++++++++++++--------------- internal/multiplex/switchboard.go | 17 +++++++---------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 159172e..0113afa 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -66,6 +66,8 @@ type Session struct { streamsM sync.Mutex streams map[uint32]*Stream + // For accepting new streams + acceptCh chan *Stream // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame recvFramePool sync.Pool @@ -78,9 +80,6 @@ type Session struct { // Used for LocalAddr() and RemoteAddr() etc. addrs atomic.Value - // For accepting new streams - acceptCh chan *Stream - closed uint32 terminalMsg atomic.Value @@ -181,7 +180,7 @@ func (sesh *Session) Accept() (net.Conn, error) { } func (sesh *Session) closeStream(s *Stream, active bool) error { - if atomic.SwapUint32(&s.closed, 1) == 1 { + if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } _ = s.getRecvBuf().Close() // recvBuf.Close should not return error @@ -244,6 +243,10 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { } sesh.streamsM.Lock() + if sesh.IsClosed() { + sesh.streamsM.Unlock() + return ErrBrokenSession + } existingStream, existing := sesh.streams[frame.StreamID] if existing { sesh.streamsM.Unlock() @@ -255,10 +258,10 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { } else { newStream := makeStream(sesh, frame.StreamID) sesh.streams[frame.StreamID] = newStream + sesh.acceptCh <- newStream sesh.streamsM.Unlock() // new stream sesh.streamCountIncr() - sesh.acceptCh <- newStream return newStream.recvFrame(frame) } } @@ -276,14 +279,14 @@ func (sesh *Session) TerminalMsg() string { } } -func (sesh *Session) closeSession(closeSwitchboard bool) error { - if atomic.SwapUint32(&sesh.closed, 1) == 1 { +func (sesh *Session) closeSession() error { + if !atomic.CompareAndSwapUint32(&sesh.closed, 0, 1) { log.Debugf("session %v has already been closed", sesh.id) return errRepeatSessionClosing } - sesh.acceptCh <- nil sesh.streamsM.Lock() + close(sesh.acceptCh) for id, stream := range sesh.streams { if stream == nil { continue @@ -294,26 +297,23 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error { sesh.streamCountDecr() } sesh.streamsM.Unlock() - - if closeSwitchboard { - sesh.sb.closeAll() - } return nil } func (sesh *Session) passiveClose() error { log.Debugf("attempting to passively close session %v", sesh.id) - err := sesh.closeSession(true) + err := sesh.closeSession() if err != nil { return err } + sesh.sb.closeAll() log.Debugf("session %v closed gracefully", sesh.id) return nil } func (sesh *Session) Close() error { log.Debugf("attempting to actively close session %v", sesh.id) - err := sesh.closeSession(false) + err := sesh.closeSession() if err != nil { return err } @@ -339,7 +339,6 @@ func (sesh *Session) Close() error { if err != nil { return err } - sesh.sb.closeAll() log.Debugf("session %v closed gracefully", sesh.id) return nil diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 7d2c93c..ea21308 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -71,7 +71,8 @@ func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { n, err = conn.Write(d) if err != nil { sb.conns.Delete(*connId) - sb.close("failed to write to remote " + err.Error()) + sb.session.SetTerminalMsg("failed to write to remote " + err.Error()) + sb.session.passiveClose() return n, err } sb.valve.AddTx(int64(n)) @@ -138,16 +139,11 @@ func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) { return id, conn, nil } -func (sb *switchboard) close(terminalMsg string) { - atomic.StoreUint32(&sb.broken, 1) - if !sb.session.IsClosed() { - sb.session.SetTerminalMsg(terminalMsg) - sb.session.passiveClose() - } -} - // actively triggered by session.Close() func (sb *switchboard) closeAll() { + if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { + return + } sb.conns.Range(func(key, connI interface{}) bool { conn := connI.(net.Conn) conn.Close() @@ -168,7 +164,8 @@ func (sb *switchboard) deplex(connId uint32, conn net.Conn) { log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) sb.conns.Delete(connId) atomic.AddUint32(&sb.numConns, ^uint32(0)) - sb.close("a connection has dropped unexpectedly") + sb.session.SetTerminalMsg("a connection has dropped unexpectedly") + sb.session.passiveClose() return } From 415523f10ab5f4e4dd3e31b53394b2811deb3c5f Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sat, 26 Dec 2020 00:49:36 +0000 Subject: [PATCH 14/18] Refactor obfuscate and deobfuscate functions to reduce a layer of indirection --- internal/multiplex/obfs.go | 189 +++++++++++++---------------- internal/multiplex/obfs_test.go | 86 ++++++++----- internal/multiplex/session.go | 4 +- internal/multiplex/session_test.go | 20 +-- internal/multiplex/stream.go | 2 +- internal/multiplex/stream_test.go | 16 +-- 6 files changed, 163 insertions(+), 154 deletions(-) diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 1379072..97b2cd9 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -11,9 +11,6 @@ import ( "golang.org/x/crypto/salsa20" ) -type Obfser func(*Frame, []byte, int) (int, error) -type Deobfser func(*Frame, []byte) error - var u32 = binary.BigEndian.Uint32 var u64 = binary.BigEndian.Uint64 var putU32 = binary.BigEndian.PutUint32 @@ -30,21 +27,15 @@ const ( // Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames. type Obfuscator struct { - // Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header - Obfs Obfser - // Remove TLS header, decrypt and unmarshall frames - Deobfs Deobfser + payloadCipher cipher.AEAD + SessionKey [32]byte maxOverhead int } -// MakeObfs returns a function of type Obfser. An Obfser takes three arguments: -// a *Frame with all the field set correctly, a []byte as buffer to put encrypted -// message in, and an int called payloadOffsetInBuf to be used when *Frame.payload -// is in the byte slice used as buffer (2nd argument). payloadOffsetInBuf specifies -// the index at which data belonging to *Frame.Payload starts in the buffer. -func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { +// obfuscate adds multiplexing headers, encrypt and add TLS header +func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { // The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header // as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use // the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead()) @@ -76,109 +67,99 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { // We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input // is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends // encrypted data. - obfs := func(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { - payloadLen := len(f.Payload) - if payloadLen == 0 { - return 0, errors.New("payload cannot be empty") - } - var extraLen int - if payloadCipher == nil { - extraLen = salsa20NonceSize - payloadLen - if extraLen < 0 { - // if our payload is already greater than 8 bytes - extraLen = 0 - } - } else { - extraLen = payloadCipher.Overhead() - if extraLen < salsa20NonceSize { - return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") - } - } - - usefulLen := frameHeaderLength + payloadLen + extraLen - if len(buf) < usefulLen { - return 0, errors.New("obfs buffer too small") - } - // we do as much in-place as possible to save allocation - payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] - if payloadOffsetInBuf != frameHeaderLength { - // if payload is not at the correct location in buffer - copy(payload, f.Payload) - } - - header := buf[:frameHeaderLength] - putU32(header[0:4], f.StreamID) - putU64(header[4:12], f.Seq) - header[12] = f.Closing - header[13] = byte(extraLen) - - if payloadCipher == nil { - if extraLen != 0 { // read nonce - extra := buf[usefulLen-extraLen : usefulLen] - common.CryptoRandRead(extra) - } - } else { - payloadCipher.Seal(payload[:0], header[:payloadCipher.NonceSize()], payload, nil) - } - - nonce := buf[usefulLen-salsa20NonceSize : usefulLen] - salsa20.XORKeyStream(header, header, nonce, &salsaKey) - - return usefulLen, nil + payloadLen := len(f.Payload) + if payloadLen == 0 { + return 0, errors.New("payload cannot be empty") } - return obfs + var extraLen int + if o.payloadCipher == nil { + extraLen = salsa20NonceSize - payloadLen + if extraLen < 0 { + // if our payload is already greater than 8 bytes + extraLen = 0 + } + } else { + extraLen = o.payloadCipher.Overhead() + if extraLen < salsa20NonceSize { + return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") + } + } + + usefulLen := frameHeaderLength + payloadLen + extraLen + if len(buf) < usefulLen { + return 0, errors.New("obfs buffer too small") + } + // we do as much in-place as possible to save allocation + payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] + if payloadOffsetInBuf != frameHeaderLength { + // if payload is not at the correct location in buffer + copy(payload, f.Payload) + } + + header := buf[:frameHeaderLength] + putU32(header[0:4], f.StreamID) + putU64(header[4:12], f.Seq) + header[12] = f.Closing + header[13] = byte(extraLen) + + if o.payloadCipher == nil { + if extraLen != 0 { // read nonce + extra := buf[usefulLen-extraLen : usefulLen] + common.CryptoRandRead(extra) + } + } else { + o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil) + } + + nonce := buf[usefulLen-salsa20NonceSize : usefulLen] + salsa20.XORKeyStream(header, header, nonce, &o.SessionKey) + + return usefulLen, nil } -// MakeDeobfs returns a function Deobfser. A Deobfser takes in a single byte slice, -// containing the message to be decrypted, and returns a *Frame containing the frame -// information and plaintext -func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { - // frame header length + minimum data size (i.e. nonce size of salsa20) - const minInputLen = frameHeaderLength + salsa20NonceSize - deobfs := func(f *Frame, in []byte) error { - if len(in) < minInputLen { - return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) - } +// deobfuscate removes TLS header, decrypt and unmarshall frames +func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error { + if len(in) < frameHeaderLength+salsa20NonceSize { + return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize) + } - header := in[:frameHeaderLength] - pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead + header := in[:frameHeaderLength] + pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead - nonce := in[len(in)-salsa20NonceSize:] - salsa20.XORKeyStream(header, header, nonce, &salsaKey) + nonce := in[len(in)-salsa20NonceSize:] + salsa20.XORKeyStream(header, header, nonce, &o.SessionKey) - streamID := u32(header[0:4]) - seq := u64(header[4:12]) - closing := header[12] - extraLen := header[13] + streamID := u32(header[0:4]) + seq := u64(header[4:12]) + closing := header[12] + extraLen := header[13] - usefulPayloadLen := len(pldWithOverHead) - int(extraLen) - if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { - return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") - } + usefulPayloadLen := len(pldWithOverHead) - int(extraLen) + if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { + return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") + } - var outputPayload []byte + var outputPayload []byte - if payloadCipher == nil { - if extraLen == 0 { - outputPayload = pldWithOverHead - } else { - outputPayload = pldWithOverHead[:usefulPayloadLen] - } + if o.payloadCipher == nil { + if extraLen == 0 { + outputPayload = pldWithOverHead } else { - _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) - if err != nil { - return err - } outputPayload = pldWithOverHead[:usefulPayloadLen] } - - f.StreamID = streamID - f.Seq = seq - f.Closing = closing - f.Payload = outputPayload - return nil + } else { + _, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil) + if err != nil { + return err + } + outputPayload = pldWithOverHead[:usefulPayloadLen] } - return deobfs + + f.StreamID = streamID + f.Seq = seq + f.Closing = closing + f.Payload = outputPayload + return nil } func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) { @@ -217,7 +198,5 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu } } - obfuscator.Obfs = MakeObfs(sessionKey, payloadCipher) - obfuscator.Deobfs = MakeDeobfs(sessionKey, payloadCipher) return } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 99f4f5f..2bee633 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -19,14 +19,14 @@ func TestGenerateObfs(t *testing.T) { obfsBuf := make([]byte, 512) _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) - i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) + i, err := obfuscator.obfuscate(testFrame, obfsBuf, 0) if err != nil { ct.Error("failed to obfs ", err) return } var resultFrame Frame - err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) + err = obfuscator.deobfuscate(&resultFrame, obfsBuf[:i]) if err != nil { ct.Error("failed to deobfs ", err) return @@ -88,40 +88,57 @@ func BenchmarkObfs(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) + obfuscator := Obfuscator{ + payloadCipher: nil, + SessionKey: key, + maxOverhead: salsa20NonceSize, + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:16]) + payloadCipher, _ := chacha20poly1305.New(key[:]) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) } @@ -143,57 +160,70 @@ func BenchmarkDeobfs(b *testing.B) { b.Run("AES256GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.SetBytes(int64(n)) b.ResetTimer() for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, nil) + obfuscator := Obfuscator{ + payloadCipher: nil, + SessionKey: key, + maxOverhead: salsa20NonceSize, + } + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:16]) + payloadCipher, _ := chacha20poly1305.New(key[:]) - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: nil, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 0113afa..e05e399 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -232,7 +232,7 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { frame := sesh.recvFramePool.Get().(*Frame) defer sesh.recvFramePool.Put(frame) - err := sesh.Deobfs(frame, data) + err := sesh.deobfuscate(frame, data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } @@ -331,7 +331,7 @@ func (sesh *Session) Close() error { Closing: closingSession, Payload: payload, } - i, err := sesh.Obfs(f, *buf, frameHeaderLength) + i, err := sesh.obfuscate(f, *buf, frameHeaderLength) if err != nil { return err } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index f4b32bb..990437a 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -56,7 +56,7 @@ func TestRecvDataFromRemote(t *testing.T) { t.Run(method, func(t *testing.T) { seshConfig.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig) - n, err := sesh.Obfs(f, obfsBuf, 0) + n, err := sesh.obfuscate(f, obfsBuf, 0) if err != nil { t.Error(err) return @@ -107,7 +107,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { testPayload, } // create stream 1 - n, _ := sesh.Obfs(f1, obfsBuf, 0) + n, _ := sesh.obfuscate(f1, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -129,7 +129,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.Obfs(f2, obfsBuf, 0) + n, _ = sesh.obfuscate(f2, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 2: %v", err) @@ -151,7 +151,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingStream, testPayload, } - n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) @@ -180,7 +180,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { } // close stream 1 again - n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1 %v", err) @@ -203,7 +203,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { Closing: closingSession, Payload: testPayload, } - n, _ = sesh.Obfs(fCloseSession, obfsBuf, 0) + n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving session closing frame: %v", err) @@ -246,7 +246,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingStream, testPayload, } - n, _ := sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) @@ -268,7 +268,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.Obfs(f1, obfsBuf, 0) + n, _ = sesh.obfuscate(f1, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -330,7 +330,7 @@ func TestParallelStreams(t *testing.T) { wg.Add(1) go func(frame *Frame) { obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(frame, obfsBuf, 0) + n, _ := sesh.obfuscate(frame, obfsBuf, 0) obfsBuf = obfsBuf[0:n] err := sesh.recvDataFromRemote(obfsBuf) @@ -446,7 +446,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { binaryFrames := [maxIter][]byte{} for i := 0; i < maxIter; i++ { obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(f, obfsBuf, 0) + n, _ := sesh.obfuscate(f, obfsBuf, 0) binaryFrames[i] = obfsBuf[:n] f.Seq++ } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index b29359f..ffd7e23 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -113,7 +113,7 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { } func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { - cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf) + cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf) s.writingFrame.Seq++ if err != nil { return err diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 893aa46..0435557 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -141,7 +141,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, _ := sesh.Obfs(dataFrame, obfsBuf, 0) + i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0) _, err := writingEnd.Write(obfsBuf[:i]) if err != nil { t.Error("failed to write from remote end") @@ -184,7 +184,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, err := sesh.Obfs(dataFrame, obfsBuf, 0) + i, err := sesh.obfuscate(dataFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -206,7 +206,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.Obfs(closingFrame, obfsBuf, 0) + i, err = sesh.obfuscate(closingFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -222,7 +222,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.Obfs(closingFrameDup, obfsBuf, 0) + i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -274,7 +274,7 @@ func TestStream_Read(t *testing.T) { obfsBuf := make([]byte, 512) t.Run("Plain read", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, err := sesh.Accept() @@ -299,7 +299,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Nil buf", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -311,7 +311,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after stream close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -336,7 +336,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after session close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() From 2d08e88efb7205a6136acd86ae329df3d7a92022 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sat, 26 Dec 2020 13:48:42 +0000 Subject: [PATCH 15/18] Use a sync.Pool to remove the global random bottleneck in picking a random conn --- internal/multiplex/switchboard.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index ea21308..6b254f6 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -7,6 +7,7 @@ import ( "net" "sync" "sync/atomic" + "time" ) const ( @@ -31,6 +32,7 @@ type switchboard struct { conns sync.Map numConns uint32 nextConnId uint32 + randPool sync.Pool broken uint32 } @@ -48,6 +50,9 @@ func makeSwitchboard(sesh *Session) *switchboard { strategy: strategy, valve: sesh.Valve, nextConnId: 1, + randPool: sync.Pool{New: func() interface{} { + return rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + }}, } return sb } @@ -121,7 +126,9 @@ func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) { // so if the r > len(sb.conns) at the point of range call, the last visited element is picked var id uint32 var conn net.Conn - r := rand.Intn(connCount) + randReader := sb.randPool.Get().(*rand.Rand) + r := randReader.Intn(connCount) + sb.randPool.Put(randReader) var c int sb.conns.Range(func(connIdI, connI interface{}) bool { if r == c { From 3ad04aa7e9e43bc669e0480222eefb8ce7bf27cf Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sat, 26 Dec 2020 15:36:38 +0000 Subject: [PATCH 16/18] Add latency benchmark --- internal/test/integration_test.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 73ef17e..dbf7c14 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -508,7 +508,7 @@ func TestClosingStreamsFromProxy(t *testing.T) { } } -func BenchmarkThroughput(b *testing.B) { +func BenchmarkIntegration(b *testing.B) { log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) @@ -529,7 +529,7 @@ func BenchmarkThroughput(b *testing.B) { b.Fatal(err) } - b.Run("single stream", func(b *testing.B) { + b.Run("single stream bandwidth", func(b *testing.B) { more := make(chan int, 10) go func() { // sender @@ -553,6 +553,19 @@ func BenchmarkThroughput(b *testing.B) { } }) + b.Run("single stream latency", func(b *testing.B) { + clientConn, _ := proxyToCkClientD.Dial("", "") + buf := []byte{1} + clientConn.Write(buf) + serverConn, _ := proxyFromCkServerL.Accept() + serverConn.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + clientConn.Write(buf) + serverConn.Read(buf) + } + }) + }) } From cbd71fae6d549179222a6f20939d58322b6e0293 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sat, 26 Dec 2020 16:34:25 +0000 Subject: [PATCH 17/18] Control flow optimisations --- internal/multiplex/datagramBufferedPipe.go | 13 ++------- internal/multiplex/streamBufferedPipe.go | 11 +------- internal/multiplex/switchboard.go | 33 ++++++++++------------ 3 files changed, 18 insertions(+), 39 deletions(-) diff --git a/internal/multiplex/datagramBufferedPipe.go b/internal/multiplex/datagramBufferedPipe.go index a7b99e4..7082264 100644 --- a/internal/multiplex/datagramBufferedPipe.go +++ b/internal/multiplex/datagramBufferedPipe.go @@ -13,8 +13,7 @@ import ( // instead of byte-oriented. The integrity of datagrams written into this buffer is preserved. // it won't get chopped up into individual bytes type datagramBufferedPipe struct { - pLens []int - // lazily allocated + pLens []int buf *bytes.Buffer closed bool rwCond *sync.Cond @@ -27,6 +26,7 @@ type datagramBufferedPipe struct { func NewDatagramBufferedPipe() *datagramBufferedPipe { d := &datagramBufferedPipe{ rwCond: sync.NewCond(&sync.Mutex{}), + buf: new(bytes.Buffer), } return d } @@ -34,9 +34,6 @@ func NewDatagramBufferedPipe() *datagramBufferedPipe { func (d *datagramBufferedPipe) Read(target []byte) (int, error) { d.rwCond.L.Lock() defer d.rwCond.L.Unlock() - if d.buf == nil { - d.buf = new(bytes.Buffer) - } for { if d.closed && len(d.pLens) == 0 { return 0, io.EOF @@ -72,9 +69,6 @@ func (d *datagramBufferedPipe) Read(target []byte) (int, error) { func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { d.rwCond.L.Lock() defer d.rwCond.L.Unlock() - if d.buf == nil { - d.buf = new(bytes.Buffer) - } for { if d.closed && len(d.pLens) == 0 { return 0, io.EOF @@ -115,9 +109,6 @@ func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) { d.rwCond.L.Lock() defer d.rwCond.L.Unlock() - if d.buf == nil { - d.buf = new(bytes.Buffer) - } for { if d.closed { return true, io.ErrClosedPipe diff --git a/internal/multiplex/streamBufferedPipe.go b/internal/multiplex/streamBufferedPipe.go index 66dacec..0dd3e46 100644 --- a/internal/multiplex/streamBufferedPipe.go +++ b/internal/multiplex/streamBufferedPipe.go @@ -11,7 +11,6 @@ import ( // The point of a streamBufferedPipe is that Read() will block until data is available type streamBufferedPipe struct { - // only alloc when on first Read or Write buf *bytes.Buffer closed bool @@ -25,6 +24,7 @@ type streamBufferedPipe struct { func NewStreamBufferedPipe() *streamBufferedPipe { p := &streamBufferedPipe{ rwCond: sync.NewCond(&sync.Mutex{}), + buf: new(bytes.Buffer), } return p } @@ -32,9 +32,6 @@ func NewStreamBufferedPipe() *streamBufferedPipe { func (p *streamBufferedPipe) Read(target []byte) (int, error) { p.rwCond.L.Lock() defer p.rwCond.L.Unlock() - if p.buf == nil { - p.buf = new(bytes.Buffer) - } for { if p.closed && p.buf.Len() == 0 { return 0, io.EOF @@ -64,9 +61,6 @@ func (p *streamBufferedPipe) Read(target []byte) (int, error) { func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { p.rwCond.L.Lock() defer p.rwCond.L.Unlock() - if p.buf == nil { - p.buf = new(bytes.Buffer) - } for { if p.closed && p.buf.Len() == 0 { return 0, io.EOF @@ -104,9 +98,6 @@ func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { func (p *streamBufferedPipe) Write(input []byte) (int, error) { p.rwCond.L.Lock() defer p.rwCond.L.Unlock() - if p.buf == nil { - p.buf = new(bytes.Buffer) - } for { if p.closed { return 0, io.ErrClosedPipe diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 6b254f6..84e43c9 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -72,46 +72,43 @@ func (sb *switchboard) addConn(conn net.Conn) { // a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { - writeAndRegUsage := func(conn net.Conn, d []byte) (int, error) { - n, err = conn.Write(d) - if err != nil { - sb.conns.Delete(*connId) - sb.session.SetTerminalMsg("failed to write to remote " + err.Error()) - sb.session.passiveClose() - return n, err - } - sb.valve.AddTx(int64(n)) - return n, nil - } - sb.valve.txWait(len(data)) if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 { return 0, errBrokenSwitchboard } + var conn net.Conn switch sb.strategy { case UNIFORM_SPREAD: - _, conn, err := sb.pickRandConn() + _, conn, err = sb.pickRandConn() if err != nil { return 0, errBrokenSwitchboard } - return writeAndRegUsage(conn, data) case FIXED_CONN_MAPPING: connI, ok := sb.conns.Load(*connId) if ok { - conn := connI.(net.Conn) - return writeAndRegUsage(conn, data) + conn = connI.(net.Conn) } else { - newConnId, conn, err := sb.pickRandConn() + var newConnId uint32 + newConnId, conn, err = sb.pickRandConn() if err != nil { return 0, errBrokenSwitchboard } *connId = newConnId - return writeAndRegUsage(conn, data) } default: return 0, errors.New("unsupported traffic distribution strategy") } + + n, err = conn.Write(data) + if err != nil { + sb.conns.Delete(*connId) + sb.session.SetTerminalMsg("failed to write to remote " + err.Error()) + sb.session.passiveClose() + return n, err + } + sb.valve.AddTx(int64(n)) + return n, nil } // returns a random connId From d1b05ee9e5fa051c4fc7a55f3c8f99d24a650b14 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 27 Dec 2020 13:26:45 +0000 Subject: [PATCH 18/18] Add new encryption method option aes-128-gcm --- README.md | 10 +++++----- internal/client/state.go | 5 ++++- internal/multiplex/obfs.go | 18 +++++++++++++++--- internal/multiplex/obfs_test.go | 12 ++++++++++-- internal/multiplex/session_test.go | 8 +++++--- internal/multiplex/stream_test.go | 3 ++- internal/server/dispatcher.go | 8 +++++++- internal/test/integration_test.go | 3 ++- 8 files changed, 50 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 285575f..ebf70db 100644 --- a/README.md +++ b/README.md @@ -126,11 +126,11 @@ instead a CDN is used, use `CDN`. `ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the server's `ProxyBook` exactly. -`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-gcm` -and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport security. The point of encryption is to hide -fingerprints of proxy protocols and render the payload statistically random-like. **You may only leave it as `plain` if -you are certain that your underlying proxy tool already provides BOTH encryption and authentication (via AEAD or similar -techniques).** +`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-256-gcm` ( +synonymous to `aes-gcm`), `aes-128-gcm`, and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport +security. The point of encryption is to hide fingerprints of proxy protocols and render the payload statistically +random-like. **You may only leave it as `plain` if you are certain that your underlying proxy tool already provides BOTH +encryption and authentication (via AEAD or similar techniques).** `ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to. diff --git a/internal/client/state.go b/internal/client/state.go index f942ad3..c26f839 100644 --- a/internal/client/state.go +++ b/internal/client/state.go @@ -164,7 +164,10 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca case "plain": auth.EncryptionMethod = mux.EncryptionMethodPlain case "aes-gcm": - auth.EncryptionMethod = mux.EncryptionMethodAESGCM + case "aes-256-gcm": + auth.EncryptionMethod = mux.EncryptionMethodAES256GCM + case "aes-128-gcm": + auth.EncryptionMethod = mux.EncryptionMethodAES128GCM case "chacha20-poly1305": auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305 default: diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 97b2cd9..8ec0d5b 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -21,8 +21,9 @@ const salsa20NonceSize = 8 const ( EncryptionMethodPlain = iota - EncryptionMethodAESGCM + EncryptionMethodAES256GCM EncryptionMethodChaha20Poly1305 + EncryptionMethodAES128GCM ) // Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames. @@ -171,7 +172,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu case EncryptionMethodPlain: payloadCipher = nil obfuscator.maxOverhead = salsa20NonceSize - case EncryptionMethodAESGCM: + case EncryptionMethodAES256GCM: var c cipher.Block c, err = aes.NewCipher(sessionKey[:]) if err != nil { @@ -182,6 +183,17 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu return } obfuscator.maxOverhead = payloadCipher.Overhead() + case EncryptionMethodAES128GCM: + var c cipher.Block + c, err = aes.NewCipher(sessionKey[:16]) + if err != nil { + return + } + payloadCipher, err = cipher.NewGCM(c) + if err != nil { + return + } + obfuscator.maxOverhead = payloadCipher.Overhead() case EncryptionMethodChaha20Poly1305: payloadCipher, err = chacha20poly1305.New(sessionKey[:]) if err != nil { @@ -189,7 +201,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu } obfuscator.maxOverhead = payloadCipher.Overhead() default: - return obfuscator, errors.New("Unknown encryption method") + return obfuscator, fmt.Errorf("unknown encryption method valued %v", encryptionMethod) } if payloadCipher != nil { diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 2bee633..78a760d 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -46,8 +46,16 @@ func TestGenerateObfs(t *testing.T) { run(obfuscator, t) } }) - t.Run("aes-gcm", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) + t.Run("aes-256-gcm", func(t *testing.T) { + obfuscator, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey) + if err != nil { + t.Errorf("failed to generate obfuscator %v", err) + } else { + run(obfuscator, t) + } + }) + t.Run("aes-128-gcm", func(t *testing.T) { + obfuscator, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey) if err != nil { t.Errorf("failed to generate obfuscator %v", err) } else { diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index 990437a..ebd305e 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -44,7 +44,7 @@ func TestRecvDataFromRemote(t *testing.T) { encryptionMethods := map[string]Obfuscator{ "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), - "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), + "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAES256GCM, sessionKey), "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), } @@ -430,7 +430,8 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { table := map[string]byte{ "plain": EncryptionMethodPlain, - "aes-gcm": EncryptionMethodAESGCM, + "aes-256-gcm": EncryptionMethodAES256GCM, + "aes-128-gcm": EncryptionMethodAES128GCM, "chacha20poly1305": EncryptionMethodChaha20Poly1305, } @@ -466,7 +467,8 @@ func BenchmarkMultiStreamWrite(b *testing.B) { table := map[string]byte{ "plain": EncryptionMethodPlain, - "aes-gcm": EncryptionMethodAESGCM, + "aes-256-gcm": EncryptionMethodAES256GCM, + "aes-128-gcm": EncryptionMethodAES128GCM, "chacha20poly1305": EncryptionMethodChaha20Poly1305, } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 0435557..eb13bc8 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -39,7 +39,8 @@ func BenchmarkStream_Write_Ordered(b *testing.B) { eMethods := map[string]byte{ "plain": EncryptionMethodPlain, "chacha20-poly1305": EncryptionMethodChaha20Poly1305, - "aes-gcm": EncryptionMethodAESGCM, + "aes-256-gcm": EncryptionMethodAES256GCM, + "aes-128-gcm": EncryptionMethodAES128GCM, } for name, method := range eMethods { diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 80ee28c..9daa772 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -175,7 +175,13 @@ func dispatchConnection(conn net.Conn, sta *State) { common.RandRead(sta.WorldState.Rand, sessionKey[:]) obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey) if err != nil { - log.Error(err) + log.WithFields(log.Fields{ + "remoteAddr": conn.RemoteAddr(), + "UID": b64(ci.UID), + "sessionId": ci.SessionId, + "proxyMethod": ci.ProxyMethod, + "encryptionMethod": ci.EncryptionMethod, + }).Error(err) goWeb() return } diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index dbf7c14..22935b2 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -518,7 +518,8 @@ func BenchmarkIntegration(b *testing.B) { encryptionMethods := map[string]byte{ "plain": mux.EncryptionMethodPlain, "chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305, - "aes-gcm": mux.EncryptionMethodAESGCM, + "aes-256-gcm": mux.EncryptionMethodAES256GCM, + "aes-128-gcm": mux.EncryptionMethodAES128GCM, } for name, method := range encryptionMethods {