Fix race condition in the use of assert.Eventually

This commit is contained in:
Andy Wang 2020-12-23 17:43:18 +00:00
parent 9108794362
commit 53f0116c1d
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
1 changed files with 30 additions and 6 deletions

View File

@ -7,6 +7,7 @@ import (
"io"
"io/ioutil"
"math/rand"
"sync"
"testing"
"time"
@ -167,10 +168,18 @@ func TestStream_Close(t *testing.T) {
sesh.streamsM.Unlock()
readBuf := make([]byte, len(testPayload))
var wg sync.WaitGroup
wg.Add(1)
assert.Eventually(t, func() bool {
_, err = io.ReadFull(stream, readBuf)
return err == nil
if err == nil {
wg.Done()
return true
} else {
return false
}
}, time.Second, 10*time.Millisecond, "can't read residual data", err)
wg.Wait()
if !bytes.Equal(readBuf, testPayload) {
t.Errorf("read wrong data")
}
@ -262,9 +271,6 @@ func TestStream_Read(t *testing.T) {
}
var streamID uint32
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
for name, unordered := range seshes {
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
@ -272,6 +278,8 @@ func TestStream_Read(t *testing.T) {
sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd)
t.Run(name, func(t *testing.T) {
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
t.Run("Plain read", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf, 0)
@ -318,10 +326,18 @@ func TestStream_Read(t *testing.T) {
stream.Close()
var err error
var wg sync.WaitGroup
wg.Add(1)
assert.Eventually(t, func() bool {
i, err = stream.Read(buf)
return err == nil
if err == nil {
wg.Done()
return true
} else {
return false
}
}, time.Second, 10*time.Millisecond, "failed to read", err)
wg.Wait()
if i != smallPayloadLen {
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
}
@ -343,10 +359,18 @@ func TestStream_Read(t *testing.T) {
stream, _ := sesh.Accept()
sesh.Close()
var err error
var wg sync.WaitGroup
wg.Add(1)
assert.Eventually(t, func() bool {
i, err = stream.Read(buf)
return err == nil
if err == nil {
wg.Done()
return true
} else {
return false
}
}, time.Second, 10*time.Millisecond, "failed to read", err)
wg.Wait()
if i != smallPayloadLen {
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
}