From c9ac93b0b98e3fcb441eab859232d8fd70392104 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Mon, 21 Dec 2020 20:38:28 +0000 Subject: [PATCH] Refactor session_test.go --- internal/multiplex/session_test.go | 341 +++++++++++++++-------------- 1 file changed, 174 insertions(+), 167 deletions(-) diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index b280895..89bd410 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -12,10 +12,9 @@ import ( "time" ) -var seshConfigOrdered = SessionConfig{} - -var seshConfigUnordered = SessionConfig{ - Unordered: true, +var seshConfigs = map[string]SessionConfig{ + "ordered": {}, + "unordered": {Unordered: true}, } const testPayloadLen = 1024 @@ -43,40 +42,20 @@ func TestRecvDataFromRemote(t *testing.T) { return ret } - sessionTypes := []struct { - name string - config SessionConfig - }{ - {"ordered", - SessionConfig{}}, - {"unordered", - SessionConfig{Unordered: true}}, + encryptionMethods := map[string]Obfuscator{ + "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), + "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), + "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), } - encryptionMethods := []struct { - name string - obfuscator Obfuscator - }{ - { - "plain", - MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), - }, - { - "aes-gcm", - MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), - }, - { - "chacha20-poly1305", - MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), - }, - } - - for _, st := range sessionTypes { - t.Run(st.name, func(t *testing.T) { - for _, em := range encryptionMethods { - t.Run(em.name, func(t *testing.T) { - st.config.Obfuscator = em.obfuscator - sesh := MakeSession(0, st.config) + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + for method, obfuscator := range encryptionMethods { + obfuscator := obfuscator + t.Run(method, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, err := sesh.Obfs(f, obfsBuf, 0) if err != nil { t.Error(err) @@ -116,8 +95,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + + seshConfig := seshConfigs["ordered"] + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) f1 := &Frame{ 1, @@ -245,8 +226,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + + seshConfig := seshConfigs["ordered"] + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) // receive stream 1 closing first f1CloseStream := &Frame{ @@ -300,119 +283,125 @@ func TestParallelStreams(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - numStreams := acceptBacklog - seqs := make([]*uint64, numStreams) - for i := range seqs { - seqs[i] = new(uint64) - } - randFrame := func() *Frame { - id := rand.Intn(numStreams) - return &Frame{ - uint32(id), - atomic.AddUint64(seqs[id], 1) - 1, - uint8(rand.Intn(2)), - []byte{1, 2, 3, 4}, - } - } + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) - const numOfTests = 5000 - tests := make([]struct { - name string - frame *Frame - }, numOfTests) - for i := range tests { - tests[i].name = strconv.Itoa(i) - tests[i].frame = randFrame() - } - - var wg sync.WaitGroup - for _, tc := range tests { - wg.Add(1) - go func(frame *Frame) { - obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(frame, obfsBuf, 0) - obfsBuf = obfsBuf[0:n] - - err := sesh.recvDataFromRemote(obfsBuf) - if err != nil { - t.Error(err) + numStreams := acceptBacklog + seqs := make([]*uint64, numStreams) + for i := range seqs { + seqs[i] = new(uint64) + } + randFrame := func() *Frame { + id := rand.Intn(numStreams) + return &Frame{ + uint32(id), + atomic.AddUint64(seqs[id], 1) - 1, + uint8(rand.Intn(2)), + []byte{1, 2, 3, 4}, + } } - wg.Done() - }(tc.frame) - } - wg.Wait() - sc := int(sesh.streamCount()) - var count int - sesh.streams.Range(func(_, s interface{}) bool { - if s != nil { - count++ - } - return true - }) - if sc != count { - t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) + const numOfTests = 5000 + tests := make([]struct { + name string + frame *Frame + }, numOfTests) + for i := range tests { + tests[i].name = strconv.Itoa(i) + tests[i].frame = randFrame() + } + + var wg sync.WaitGroup + for _, tc := range tests { + wg.Add(1) + go func(frame *Frame) { + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(frame, obfsBuf, 0) + obfsBuf = obfsBuf[0:n] + + err := sesh.recvDataFromRemote(obfsBuf) + if err != nil { + t.Error(err) + } + wg.Done() + }(tc.frame) + } + + wg.Wait() + sc := int(sesh.streamCount()) + var count int + sesh.streams.Range(func(_, s interface{}) bool { + if s != nil { + count++ + } + return true + }) + if sc != count { + t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) + } + }) } } func TestStream_SetReadDeadline(t *testing.T) { - var sessionKey [32]byte - rand.Read(sessionKey[:]) - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + sesh.AddConnection(connutil.Discard()) - testReadDeadline := func(sesh *Session) { - t.Run("read after deadline set", func(t *testing.T) { - stream, _ := sesh.OpenStream() - _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) - _, err := stream.Read(make([]byte, 1)) - if err != ErrTimeout { - t.Errorf("expecting error %v, got %v", ErrTimeout, err) - } - }) + t.Run("read after deadline set", func(t *testing.T) { + stream, _ := sesh.OpenStream() + _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) + _, err := stream.Read(make([]byte, 1)) + if err != ErrTimeout { + t.Errorf("expecting error %v, got %v", ErrTimeout, err) + } + }) - t.Run("unblock when deadline passed", func(t *testing.T) { - stream, _ := sesh.OpenStream() + t.Run("unblock when deadline passed", func(t *testing.T) { + stream, _ := sesh.OpenStream() - done := make(chan struct{}) - go func() { - _, _ = stream.Read(make([]byte, 1)) - done <- struct{}{} - }() + done := make(chan struct{}) + go func() { + _, _ = stream.Read(make([]byte, 1)) + done <- struct{}{} + }() - _ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - select { - case <-done: - return - case <-time.After(500 * time.Millisecond): - t.Error("Read did not unblock after deadline has passed") - } + select { + case <-done: + return + case <-time.After(500 * time.Millisecond): + t.Error("Read did not unblock after deadline has passed") + } + }) }) } - - sesh := MakeSession(0, seshConfigOrdered) - sesh.AddConnection(connutil.Discard()) - testReadDeadline(sesh) - sesh = MakeSession(0, seshConfigUnordered) - sesh.AddConnection(connutil.Discard()) - testReadDeadline(sesh) } func TestSession_timeoutAfter(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - seshConfigOrdered.InactivityTimeout = 100 * time.Millisecond - sesh := MakeSession(0, seshConfigOrdered) - assert.Eventually(t, func() bool { - return sesh.IsClosed() - }, 5*seshConfigOrdered.InactivityTimeout, seshConfigOrdered.InactivityTimeout, "session should have timed out") + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + seshConfig.InactivityTimeout = 100 * time.Millisecond + sesh := MakeSession(0, seshConfig) + + assert.Eventually(t, func() bool { + return sesh.IsClosed() + }, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out") + }) + } } func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { @@ -429,42 +418,60 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { var sessionKey [32]byte rand.Read(sessionKey[:]) - b.Run("plain", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAESGCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } + + for name, ep := range table { + seshConfig := seshConfigs["ordered"] + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, _ := sesh.Obfs(f, obfsBuf, 0) - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) - - b.Run("aes-gcm", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) - - b.Run("chacha20-poly1305", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) + b.Run(name, func(b *testing.B) { + b.SetBytes(int64(len(f.Payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + sesh.recvDataFromRemote(obfsBuf[:n]) + } + }) + } +} + +func BenchmarkMultiStreamWrite(b *testing.B) { + var sessionKey [32]byte + rand.Read(sessionKey[:]) + + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAESGCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } + + testPayload := make([]byte, testPayloadLen) + + for name, ep := range table { + b.Run(name, func(b *testing.B) { + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + b.Run(seshType, func(b *testing.B) { + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) + sesh.AddConnection(connutil.Discard()) + b.ResetTimer() + b.SetBytes(testPayloadLen) + b.RunParallel(func(pb *testing.PB) { + stream, _ := sesh.OpenStream() + for pb.Next() { + stream.Write(testPayload) + } + }) + }) + } + }) + } }