Refactor session_test.go

This commit is contained in:
Andy Wang 2020-12-21 20:38:28 +00:00
parent de0daac123
commit c9ac93b0b9
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
1 changed files with 174 additions and 167 deletions

View File

@ -12,10 +12,9 @@ import (
"time" "time"
) )
var seshConfigOrdered = SessionConfig{} var seshConfigs = map[string]SessionConfig{
"ordered": {},
var seshConfigUnordered = SessionConfig{ "unordered": {Unordered: true},
Unordered: true,
} }
const testPayloadLen = 1024 const testPayloadLen = 1024
@ -43,40 +42,20 @@ func TestRecvDataFromRemote(t *testing.T) {
return ret return ret
} }
sessionTypes := []struct { encryptionMethods := map[string]Obfuscator{
name string "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
config SessionConfig "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
}{ "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
{"ordered",
SessionConfig{}},
{"unordered",
SessionConfig{Unordered: true}},
} }
encryptionMethods := []struct { for seshType, seshConfig := range seshConfigs {
name string seshConfig := seshConfig
obfuscator Obfuscator t.Run(seshType, func(t *testing.T) {
}{ for method, obfuscator := range encryptionMethods {
{ obfuscator := obfuscator
"plain", t.Run(method, func(t *testing.T) {
MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), seshConfig.Obfuscator = obfuscator
}, sesh := MakeSession(0, seshConfig)
{
"aes-gcm",
MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
},
{
"chacha20-poly1305",
MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
},
}
for _, st := range sessionTypes {
t.Run(st.name, func(t *testing.T) {
for _, em := range encryptionMethods {
t.Run(em.name, func(t *testing.T) {
st.config.Obfuscator = em.obfuscator
sesh := MakeSession(0, st.config)
n, err := sesh.Obfs(f, obfsBuf, 0) n, err := sesh.Obfs(f, obfsBuf, 0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -116,8 +95,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered) seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
f1 := &Frame{ f1 := &Frame{
1, 1,
@ -245,8 +226,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered) seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
// receive stream 1 closing first // receive stream 1 closing first
f1CloseStream := &Frame{ f1CloseStream := &Frame{
@ -300,119 +283,125 @@ func TestParallelStreams(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
numStreams := acceptBacklog for seshType, seshConfig := range seshConfigs {
seqs := make([]*uint64, numStreams) seshConfig := seshConfig
for i := range seqs { t.Run(seshType, func(t *testing.T) {
seqs[i] = new(uint64) seshConfig.Obfuscator = obfuscator
} sesh := MakeSession(0, seshConfig)
randFrame := func() *Frame {
id := rand.Intn(numStreams)
return &Frame{
uint32(id),
atomic.AddUint64(seqs[id], 1) - 1,
uint8(rand.Intn(2)),
[]byte{1, 2, 3, 4},
}
}
const numOfTests = 5000 numStreams := acceptBacklog
tests := make([]struct { seqs := make([]*uint64, numStreams)
name string for i := range seqs {
frame *Frame seqs[i] = new(uint64)
}, numOfTests) }
for i := range tests { randFrame := func() *Frame {
tests[i].name = strconv.Itoa(i) id := rand.Intn(numStreams)
tests[i].frame = randFrame() return &Frame{
} uint32(id),
atomic.AddUint64(seqs[id], 1) - 1,
var wg sync.WaitGroup uint8(rand.Intn(2)),
for _, tc := range tests { []byte{1, 2, 3, 4},
wg.Add(1) }
go func(frame *Frame) {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(frame, obfsBuf, 0)
obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(obfsBuf)
if err != nil {
t.Error(err)
} }
wg.Done()
}(tc.frame)
}
wg.Wait() const numOfTests = 5000
sc := int(sesh.streamCount()) tests := make([]struct {
var count int name string
sesh.streams.Range(func(_, s interface{}) bool { frame *Frame
if s != nil { }, numOfTests)
count++ for i := range tests {
} tests[i].name = strconv.Itoa(i)
return true tests[i].frame = randFrame()
}) }
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) var wg sync.WaitGroup
for _, tc := range tests {
wg.Add(1)
go func(frame *Frame) {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(frame, obfsBuf, 0)
obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(obfsBuf)
if err != nil {
t.Error(err)
}
wg.Done()
}(tc.frame)
}
wg.Wait()
sc := int(sesh.streamCount())
var count int
sesh.streams.Range(func(_, s interface{}) bool {
if s != nil {
count++
}
return true
})
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
}
})
} }
} }
func TestStream_SetReadDeadline(t *testing.T) { func TestStream_SetReadDeadline(t *testing.T) {
var sessionKey [32]byte for seshType, seshConfig := range seshConfigs {
rand.Read(sessionKey[:]) seshConfig := seshConfig
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) t.Run(seshType, func(t *testing.T) {
seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
testReadDeadline := func(sesh *Session) { t.Run("read after deadline set", func(t *testing.T) {
t.Run("read after deadline set", func(t *testing.T) { stream, _ := sesh.OpenStream()
stream, _ := sesh.OpenStream() _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) _, err := stream.Read(make([]byte, 1))
_, err := stream.Read(make([]byte, 1)) if err != ErrTimeout {
if err != ErrTimeout { t.Errorf("expecting error %v, got %v", ErrTimeout, err)
t.Errorf("expecting error %v, got %v", ErrTimeout, err) }
} })
})
t.Run("unblock when deadline passed", func(t *testing.T) { t.Run("unblock when deadline passed", func(t *testing.T) {
stream, _ := sesh.OpenStream() stream, _ := sesh.OpenStream()
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
_, _ = stream.Read(make([]byte, 1)) _, _ = stream.Read(make([]byte, 1))
done <- struct{}{} done <- struct{}{}
}() }()
_ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) _ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
select { select {
case <-done: case <-done:
return return
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):
t.Error("Read did not unblock after deadline has passed") t.Error("Read did not unblock after deadline has passed")
} }
})
}) })
} }
sesh := MakeSession(0, seshConfigOrdered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
sesh = MakeSession(0, seshConfigUnordered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
} }
func TestSession_timeoutAfter(t *testing.T) { func TestSession_timeoutAfter(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
seshConfigOrdered.InactivityTimeout = 100 * time.Millisecond
sesh := MakeSession(0, seshConfigOrdered)
assert.Eventually(t, func() bool { for seshType, seshConfig := range seshConfigs {
return sesh.IsClosed() seshConfig := seshConfig
}, 5*seshConfigOrdered.InactivityTimeout, seshConfigOrdered.InactivityTimeout, "session should have timed out") t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
seshConfig.InactivityTimeout = 100 * time.Millisecond
sesh := MakeSession(0, seshConfig)
assert.Eventually(t, func() bool {
return sesh.IsClosed()
}, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out")
})
}
} }
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
@ -429,42 +418,60 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
b.Run("plain", func(b *testing.B) { table := map[string]byte{
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) "plain": EncryptionMethodPlain,
seshConfigOrdered.Obfuscator = obfuscator "aes-gcm": EncryptionMethodAESGCM,
sesh := MakeSession(0, seshConfigOrdered) "chacha20poly1305": EncryptionMethodChaha20Poly1305,
}
for name, ep := range table {
seshConfig := seshConfigs["ordered"]
obfuscator, _ := MakeObfuscator(ep, sessionKey)
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
n, _ := sesh.Obfs(f, obfsBuf, 0) n, _ := sesh.Obfs(f, obfsBuf, 0)
b.SetBytes(int64(len(f.Payload))) b.Run(name, func(b *testing.B) {
b.ResetTimer() b.SetBytes(int64(len(f.Payload)))
for i := 0; i < b.N; i++ { b.ResetTimer()
sesh.recvDataFromRemote(obfsBuf[:n]) for i := 0; i < b.N; i++ {
} sesh.recvDataFromRemote(obfsBuf[:n])
}) }
})
b.Run("aes-gcm", func(b *testing.B) { }
obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) }
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered) func BenchmarkMultiStreamWrite(b *testing.B) {
n, _ := sesh.Obfs(f, obfsBuf, 0) var sessionKey [32]byte
rand.Read(sessionKey[:])
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer() table := map[string]byte{
for i := 0; i < b.N; i++ { "plain": EncryptionMethodPlain,
sesh.recvDataFromRemote(obfsBuf[:n]) "aes-gcm": EncryptionMethodAESGCM,
} "chacha20poly1305": EncryptionMethodChaha20Poly1305,
}) }
b.Run("chacha20-poly1305", func(b *testing.B) { testPayload := make([]byte, testPayloadLen)
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator for name, ep := range table {
sesh := MakeSession(0, seshConfigOrdered) b.Run(name, func(b *testing.B) {
n, _ := sesh.Obfs(f, obfsBuf, 0) for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
b.SetBytes(int64(len(f.Payload))) b.Run(seshType, func(b *testing.B) {
b.ResetTimer() obfuscator, _ := MakeObfuscator(ep, sessionKey)
for i := 0; i < b.N; i++ { seshConfig.Obfuscator = obfuscator
sesh.recvDataFromRemote(obfsBuf[:n]) sesh := MakeSession(0, seshConfig)
} sesh.AddConnection(connutil.Discard())
}) b.ResetTimer()
b.SetBytes(testPayloadLen)
b.RunParallel(func(pb *testing.PB) {
stream, _ := sesh.OpenStream()
for pb.Next() {
stream.Write(testPayload)
}
})
})
}
})
}
} }