Refactor session_test.go

This commit is contained in:
Andy Wang 2020-12-21 20:38:28 +00:00
parent de0daac123
commit c9ac93b0b9
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
1 changed files with 174 additions and 167 deletions

View File

@ -12,10 +12,9 @@ import (
"time"
)
var seshConfigOrdered = SessionConfig{}
var seshConfigUnordered = SessionConfig{
Unordered: true,
var seshConfigs = map[string]SessionConfig{
"ordered": {},
"unordered": {Unordered: true},
}
const testPayloadLen = 1024
@ -43,40 +42,20 @@ func TestRecvDataFromRemote(t *testing.T) {
return ret
}
sessionTypes := []struct {
name string
config SessionConfig
}{
{"ordered",
SessionConfig{}},
{"unordered",
SessionConfig{Unordered: true}},
encryptionMethods := map[string]Obfuscator{
"plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
"aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
"chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
}
encryptionMethods := []struct {
name string
obfuscator Obfuscator
}{
{
"plain",
MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
},
{
"aes-gcm",
MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
},
{
"chacha20-poly1305",
MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
},
}
for _, st := range sessionTypes {
t.Run(st.name, func(t *testing.T) {
for _, em := range encryptionMethods {
t.Run(em.name, func(t *testing.T) {
st.config.Obfuscator = em.obfuscator
sesh := MakeSession(0, st.config)
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
for method, obfuscator := range encryptionMethods {
obfuscator := obfuscator
t.Run(method, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
n, err := sesh.Obfs(f, obfsBuf, 0)
if err != nil {
t.Error(err)
@ -116,8 +95,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
f1 := &Frame{
1,
@ -245,8 +226,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
// receive stream 1 closing first
f1CloseStream := &Frame{
@ -300,119 +283,125 @@ func TestParallelStreams(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
numStreams := acceptBacklog
seqs := make([]*uint64, numStreams)
for i := range seqs {
seqs[i] = new(uint64)
}
randFrame := func() *Frame {
id := rand.Intn(numStreams)
return &Frame{
uint32(id),
atomic.AddUint64(seqs[id], 1) - 1,
uint8(rand.Intn(2)),
[]byte{1, 2, 3, 4},
}
}
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
const numOfTests = 5000
tests := make([]struct {
name string
frame *Frame
}, numOfTests)
for i := range tests {
tests[i].name = strconv.Itoa(i)
tests[i].frame = randFrame()
}
var wg sync.WaitGroup
for _, tc := range tests {
wg.Add(1)
go func(frame *Frame) {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(frame, obfsBuf, 0)
obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(obfsBuf)
if err != nil {
t.Error(err)
numStreams := acceptBacklog
seqs := make([]*uint64, numStreams)
for i := range seqs {
seqs[i] = new(uint64)
}
randFrame := func() *Frame {
id := rand.Intn(numStreams)
return &Frame{
uint32(id),
atomic.AddUint64(seqs[id], 1) - 1,
uint8(rand.Intn(2)),
[]byte{1, 2, 3, 4},
}
}
wg.Done()
}(tc.frame)
}
wg.Wait()
sc := int(sesh.streamCount())
var count int
sesh.streams.Range(func(_, s interface{}) bool {
if s != nil {
count++
}
return true
})
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
const numOfTests = 5000
tests := make([]struct {
name string
frame *Frame
}, numOfTests)
for i := range tests {
tests[i].name = strconv.Itoa(i)
tests[i].frame = randFrame()
}
var wg sync.WaitGroup
for _, tc := range tests {
wg.Add(1)
go func(frame *Frame) {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(frame, obfsBuf, 0)
obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(obfsBuf)
if err != nil {
t.Error(err)
}
wg.Done()
}(tc.frame)
}
wg.Wait()
sc := int(sesh.streamCount())
var count int
sesh.streams.Range(func(_, s interface{}) bool {
if s != nil {
count++
}
return true
})
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
}
})
}
}
func TestStream_SetReadDeadline(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
testReadDeadline := func(sesh *Session) {
t.Run("read after deadline set", func(t *testing.T) {
stream, _ := sesh.OpenStream()
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
_, err := stream.Read(make([]byte, 1))
if err != ErrTimeout {
t.Errorf("expecting error %v, got %v", ErrTimeout, err)
}
})
t.Run("read after deadline set", func(t *testing.T) {
stream, _ := sesh.OpenStream()
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
_, err := stream.Read(make([]byte, 1))
if err != ErrTimeout {
t.Errorf("expecting error %v, got %v", ErrTimeout, err)
}
})
t.Run("unblock when deadline passed", func(t *testing.T) {
stream, _ := sesh.OpenStream()
t.Run("unblock when deadline passed", func(t *testing.T) {
stream, _ := sesh.OpenStream()
done := make(chan struct{})
go func() {
_, _ = stream.Read(make([]byte, 1))
done <- struct{}{}
}()
done := make(chan struct{})
go func() {
_, _ = stream.Read(make([]byte, 1))
done <- struct{}{}
}()
_ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
select {
case <-done:
return
case <-time.After(500 * time.Millisecond):
t.Error("Read did not unblock after deadline has passed")
}
select {
case <-done:
return
case <-time.After(500 * time.Millisecond):
t.Error("Read did not unblock after deadline has passed")
}
})
})
}
sesh := MakeSession(0, seshConfigOrdered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
sesh = MakeSession(0, seshConfigUnordered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
}
func TestSession_timeoutAfter(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
seshConfigOrdered.InactivityTimeout = 100 * time.Millisecond
sesh := MakeSession(0, seshConfigOrdered)
assert.Eventually(t, func() bool {
return sesh.IsClosed()
}, 5*seshConfigOrdered.InactivityTimeout, seshConfigOrdered.InactivityTimeout, "session should have timed out")
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
seshConfig.InactivityTimeout = 100 * time.Millisecond
sesh := MakeSession(0, seshConfig)
assert.Eventually(t, func() bool {
return sesh.IsClosed()
}, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out")
})
}
}
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
@ -429,42 +418,60 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
b.Run("plain", func(b *testing.B) {
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
table := map[string]byte{
"plain": EncryptionMethodPlain,
"aes-gcm": EncryptionMethodAESGCM,
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
}
for name, ep := range table {
seshConfig := seshConfigs["ordered"]
obfuscator, _ := MakeObfuscator(ep, sessionKey)
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
n, _ := sesh.Obfs(f, obfsBuf, 0)
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
}
})
b.Run("aes-gcm", func(b *testing.B) {
obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf, 0)
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
}
})
b.Run("chacha20-poly1305", func(b *testing.B) {
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf, 0)
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
}
})
b.Run(name, func(b *testing.B) {
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
}
})
}
}
func BenchmarkMultiStreamWrite(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
table := map[string]byte{
"plain": EncryptionMethodPlain,
"aes-gcm": EncryptionMethodAESGCM,
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
}
testPayload := make([]byte, testPayloadLen)
for name, ep := range table {
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
b.Run(seshType, func(b *testing.B) {
obfuscator, _ := MakeObfuscator(ep, sessionKey)
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
b.ResetTimer()
b.SetBytes(testPayloadLen)
b.RunParallel(func(pb *testing.PB) {
stream, _ := sesh.OpenStream()
for pb.Next() {
stream.Write(testPayload)
}
})
})
}
})
}
}