Fix race condition in the use of assert.Eventually

This commit is contained in:
Andy Wang 2020-12-23 17:43:18 +00:00
parent 9108794362
commit 53f0116c1d
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
1 changed files with 30 additions and 6 deletions

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"sync"
"testing" "testing"
"time" "time"
@ -167,10 +168,18 @@ func TestStream_Close(t *testing.T) {
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
readBuf := make([]byte, len(testPayload)) readBuf := make([]byte, len(testPayload))
var wg sync.WaitGroup
wg.Add(1)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
_, err = io.ReadFull(stream, readBuf) _, err = io.ReadFull(stream, readBuf)
return err == nil if err == nil {
wg.Done()
return true
} else {
return false
}
}, time.Second, 10*time.Millisecond, "can't read residual data", err) }, time.Second, 10*time.Millisecond, "can't read residual data", err)
wg.Wait()
if !bytes.Equal(readBuf, testPayload) { if !bytes.Equal(readBuf, testPayload) {
t.Errorf("read wrong data") t.Errorf("read wrong data")
} }
@ -262,9 +271,6 @@ func TestStream_Read(t *testing.T) {
} }
var streamID uint32 var streamID uint32
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
for name, unordered := range seshes { for name, unordered := range seshes {
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain) sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
@ -272,6 +278,8 @@ func TestStream_Read(t *testing.T) {
sesh.AddConnection(common.NewTLSConn(rawConn)) sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd) writingEnd := common.NewTLSConn(rawWritingEnd)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
t.Run("Plain read", func(t *testing.T) { t.Run("Plain read", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf, 0) i, _ := sesh.Obfs(f, obfsBuf, 0)
@ -318,10 +326,18 @@ func TestStream_Read(t *testing.T) {
stream.Close() stream.Close()
var err error var err error
var wg sync.WaitGroup
wg.Add(1)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
i, err = stream.Read(buf) i, err = stream.Read(buf)
return err == nil if err == nil {
wg.Done()
return true
} else {
return false
}
}, time.Second, 10*time.Millisecond, "failed to read", err) }, time.Second, 10*time.Millisecond, "failed to read", err)
wg.Wait()
if i != smallPayloadLen { if i != smallPayloadLen {
t.Errorf("expected read %v, got %v", smallPayloadLen, i) t.Errorf("expected read %v, got %v", smallPayloadLen, i)
} }
@ -343,10 +359,18 @@ func TestStream_Read(t *testing.T) {
stream, _ := sesh.Accept() stream, _ := sesh.Accept()
sesh.Close() sesh.Close()
var err error var err error
var wg sync.WaitGroup
wg.Add(1)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
i, err = stream.Read(buf) i, err = stream.Read(buf)
return err == nil if err == nil {
wg.Done()
return true
} else {
return false
}
}, time.Second, 10*time.Millisecond, "failed to read", err) }, time.Second, 10*time.Millisecond, "failed to read", err)
wg.Wait()
if i != smallPayloadLen { if i != smallPayloadLen {
t.Errorf("expected read %v, got %v", smallPayloadLen, i) t.Errorf("expected read %v, got %v", smallPayloadLen, i)
} }