diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 35fae9d..a28cfc6 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "math/rand" + "sync" "testing" "time" @@ -167,10 +168,18 @@ func TestStream_Close(t *testing.T) { sesh.streamsM.Unlock() readBuf := make([]byte, len(testPayload)) + var wg sync.WaitGroup + wg.Add(1) assert.Eventually(t, func() bool { _, err = io.ReadFull(stream, readBuf) - return err == nil + if err == nil { + wg.Done() + return true + } else { + return false + } }, time.Second, 10*time.Millisecond, "can't read residual data", err) + wg.Wait() if !bytes.Equal(readBuf, testPayload) { t.Errorf("read wrong data") } @@ -262,9 +271,6 @@ func TestStream_Read(t *testing.T) { } var streamID uint32 - buf := make([]byte, 10) - - obfsBuf := make([]byte, 512) for name, unordered := range seshes { sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain) @@ -272,6 +278,8 @@ func TestStream_Read(t *testing.T) { sesh.AddConnection(common.NewTLSConn(rawConn)) writingEnd := common.NewTLSConn(rawWritingEnd) t.Run(name, func(t *testing.T) { + buf := make([]byte, 10) + obfsBuf := make([]byte, 512) t.Run("Plain read", func(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf, 0) @@ -318,10 +326,18 @@ func TestStream_Read(t *testing.T) { stream.Close() var err error + var wg sync.WaitGroup + wg.Add(1) assert.Eventually(t, func() bool { i, err = stream.Read(buf) - return err == nil + if err == nil { + wg.Done() + return true + } else { + return false + } }, time.Second, 10*time.Millisecond, "failed to read", err) + wg.Wait() if i != smallPayloadLen { t.Errorf("expected read %v, got %v", smallPayloadLen, i) } @@ -343,10 +359,18 @@ func TestStream_Read(t *testing.T) { stream, _ := sesh.Accept() sesh.Close() var err error + var wg sync.WaitGroup + wg.Add(1) assert.Eventually(t, func() bool { i, err = stream.Read(buf) - return err == nil + if err == nil { + wg.Done() + return true + } else { + return false + } }, time.Second, 10*time.Millisecond, "failed to read", err) + wg.Wait() if i != smallPayloadLen { t.Errorf("expected read %v, got %v", smallPayloadLen, i) }