From 061b10e8023ee480917d2150663b8c6eb441afbf Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 6 Dec 2020 11:14:33 +0000 Subject: [PATCH] Improve tests code quality --- internal/multiplex/mux_test.go | 19 +-- internal/multiplex/session_test.go | 192 ++++++++++++----------------- 2 files changed, 87 insertions(+), 124 deletions(-) diff --git a/internal/multiplex/mux_test.go b/internal/multiplex/mux_test.go index e23d8a1..436c407 100644 --- a/internal/multiplex/mux_test.go +++ b/internal/multiplex/mux_test.go @@ -63,22 +63,22 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) { return clientSession, serverSession, paris } -func runEchoTest(t *testing.T, streams []*Stream) { - const testDataLen = 16384 +func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { var wg sync.WaitGroup - for _, stream := range streams { + for _, conn := range conns { wg.Add(1) - go func(stream *Stream) { + go func(conn net.Conn) { + testDataLen := rand.Intn(maxMsgLen) testData := make([]byte, testDataLen) rand.Read(testData) - n, err := stream.Write(testData) + n, err := conn.Write(testData) if n != testDataLen { t.Fatalf("written only %v, err %v", n, err) } recvBuf := make([]byte, testDataLen) - _, err = io.ReadFull(stream, recvBuf) + _, err = io.ReadFull(conn, recvBuf) if err != nil { t.Fatalf("failed to read back: %v", err) } @@ -87,7 +87,7 @@ func runEchoTest(t *testing.T, streams []*Stream) { t.Fatalf("echoed data not correct") } wg.Done() - }(stream) + }(conn) } wg.Wait() } @@ -95,11 +95,12 @@ func runEchoTest(t *testing.T, streams []*Stream) { func TestMultiplex(t *testing.T) { const numStreams = 2000 // -race option limits the number of goroutines to 8192 const numConns = 4 + const maxMsgLen = 16384 clientSession, serverSession, _ := makeSessionPair(numConns) go serveEcho(serverSession) - streams := make([]*Stream, numStreams) + streams := make([]net.Conn, numStreams) for i := 0; i < numStreams; i++ { stream, err := clientSession.OpenStream() if err != nil { @@ -109,7 +110,7 @@ func TestMultiplex(t *testing.T) { } //test echo - runEchoTest(t, streams) + runEchoTest(t, streams, maxMsgLen) if clientSession.streamCount() != numStreams { t.Errorf("client stream count is wrong: %v", clientSession.streamCount()) } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index b40d32c..52fe6a5 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -17,8 +17,10 @@ var seshConfigUnordered = SessionConfig{ Unordered: true, } +const testPayloadLen = 1024 +const obfsBufLen = testPayloadLen * 2 + func TestRecvDataFromRemote(t *testing.T) { - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) f := &Frame{ @@ -27,126 +29,88 @@ func TestRecvDataFromRemote(t *testing.T) { 0, testPayload, } - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) - t.Run("plain ordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - err := sesh.recvDataFromRemote(obfsBuf[:n]) + MakeObfuscatorUnwrap := func(method byte, sessionKey [32]byte) Obfuscator { + ret, err := MakeObfuscator(method, sessionKey) if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return + t.Fatalf("failed to make an obfuscator: %v", err) } + return ret + } - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) - t.Run("aes-gcm ordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) + sessionTypes := []struct { + name string + config SessionConfig + }{ + {"ordered", + SessionConfig{}}, + {"unordered", + SessionConfig{Unordered: true}}, + } - err := sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return - } + encryptionMethods := []struct { + name string + obfuscator Obfuscator + }{ + { + "plain", + MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), + }, + { + "aes-gcm", + MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), + }, + { + "chacha20-poly1305", + MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), + }, + } - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) - t.Run("chacha20-poly1305 ordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) + for _, st := range sessionTypes { + t.Run(st.name, func(t *testing.T) { + for _, em := range encryptionMethods { + t.Run(em.name, func(t *testing.T) { + st.config.Obfuscator = em.obfuscator + sesh := MakeSession(0, st.config) + n, err := sesh.Obfs(f, obfsBuf, 0) + if err != nil { + t.Error(err) + return + } + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Error(err) + return + } + stream, err := sesh.Accept() + if err != nil { + t.Error(err) + return + } - err := sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return - } - - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) - - t.Run("plain unordered", func(t *testing.T) { - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigUnordered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - - err := sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return - } - - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } - }) + resultPayload := make([]byte, testPayloadLen) + _, err = stream.Read(resultPayload) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(testPayload, resultPayload) { + t.Errorf("Expecting %x, got %x", testPayload, resultPayload) + } + }) + } + }) + } } func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) @@ -273,10 +237,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { // Tests for when the closing frame of a stream is received first before any data frame - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) @@ -354,7 +317,7 @@ func TestParallelStreams(t *testing.T) { } } - numOfTests := 5000 + const numOfTests = 5000 tests := make([]struct { name string frame *Frame @@ -368,11 +331,11 @@ func TestParallelStreams(t *testing.T) { for _, tc := range tests { wg.Add(1) go func(frame *Frame) { - data := make([]byte, 1000) - n, _ := sesh.Obfs(frame, data, 0) - data = data[0:n] + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(frame, obfsBuf, 0) + obfsBuf = obfsBuf[0:n] - err := sesh.recvDataFromRemote(data) + err := sesh.recvDataFromRemote(obfsBuf) if err != nil { t.Error(err) } @@ -452,7 +415,6 @@ func TestSession_timeoutAfter(t *testing.T) { } func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { - testPayloadLen := 1024 testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) f := &Frame{ @@ -461,7 +423,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { 0, testPayload, } - obfsBuf := make([]byte, 17000) + obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:])