mirror of https://github.com/cbeuw/Cloak
Fix race condition in the use of assert.Eventually
This commit is contained in:
parent
9108794362
commit
53f0116c1d
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -167,10 +168,18 @@ func TestStream_Close(t *testing.T) {
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
readBuf := make([]byte, len(testPayload))
|
readBuf := make([]byte, len(testPayload))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
_, err = io.ReadFull(stream, readBuf)
|
_, err = io.ReadFull(stream, readBuf)
|
||||||
return err == nil
|
if err == nil {
|
||||||
|
wg.Done()
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}, time.Second, 10*time.Millisecond, "can't read residual data", err)
|
}, time.Second, 10*time.Millisecond, "can't read residual data", err)
|
||||||
|
wg.Wait()
|
||||||
if !bytes.Equal(readBuf, testPayload) {
|
if !bytes.Equal(readBuf, testPayload) {
|
||||||
t.Errorf("read wrong data")
|
t.Errorf("read wrong data")
|
||||||
}
|
}
|
||||||
|
|
@ -262,9 +271,6 @@ func TestStream_Read(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var streamID uint32
|
var streamID uint32
|
||||||
buf := make([]byte, 10)
|
|
||||||
|
|
||||||
obfsBuf := make([]byte, 512)
|
|
||||||
|
|
||||||
for name, unordered := range seshes {
|
for name, unordered := range seshes {
|
||||||
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
|
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
|
||||||
|
|
@ -272,6 +278,8 @@ func TestStream_Read(t *testing.T) {
|
||||||
sesh.AddConnection(common.NewTLSConn(rawConn))
|
sesh.AddConnection(common.NewTLSConn(rawConn))
|
||||||
writingEnd := common.NewTLSConn(rawWritingEnd)
|
writingEnd := common.NewTLSConn(rawWritingEnd)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
buf := make([]byte, 10)
|
||||||
|
obfsBuf := make([]byte, 512)
|
||||||
t.Run("Plain read", func(t *testing.T) {
|
t.Run("Plain read", func(t *testing.T) {
|
||||||
f.StreamID = streamID
|
f.StreamID = streamID
|
||||||
i, _ := sesh.Obfs(f, obfsBuf, 0)
|
i, _ := sesh.Obfs(f, obfsBuf, 0)
|
||||||
|
|
@ -318,10 +326,18 @@ func TestStream_Read(t *testing.T) {
|
||||||
stream.Close()
|
stream.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
i, err = stream.Read(buf)
|
i, err = stream.Read(buf)
|
||||||
return err == nil
|
if err == nil {
|
||||||
|
wg.Done()
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}, time.Second, 10*time.Millisecond, "failed to read", err)
|
}, time.Second, 10*time.Millisecond, "failed to read", err)
|
||||||
|
wg.Wait()
|
||||||
if i != smallPayloadLen {
|
if i != smallPayloadLen {
|
||||||
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
|
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
|
||||||
}
|
}
|
||||||
|
|
@ -343,10 +359,18 @@ func TestStream_Read(t *testing.T) {
|
||||||
stream, _ := sesh.Accept()
|
stream, _ := sesh.Accept()
|
||||||
sesh.Close()
|
sesh.Close()
|
||||||
var err error
|
var err error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
i, err = stream.Read(buf)
|
i, err = stream.Read(buf)
|
||||||
return err == nil
|
if err == nil {
|
||||||
|
wg.Done()
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}, time.Second, 10*time.Millisecond, "failed to read", err)
|
}, time.Second, 10*time.Millisecond, "failed to read", err)
|
||||||
|
wg.Wait()
|
||||||
if i != smallPayloadLen {
|
if i != smallPayloadLen {
|
||||||
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
|
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue