mirror of https://github.com/cbeuw/Cloak
Fix a race on closing stream
This commit is contained in:
parent
58e0797578
commit
783d016a29
|
|
@ -147,10 +147,9 @@ func (sesh *Session) Accept() (net.Conn, error) {
|
|||
}
|
||||
|
||||
func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||
if s.isClosed() {
|
||||
if atomic.SwapUint32(&s.closed, 1) == 1 {
|
||||
return fmt.Errorf("stream %v is already closed", s.id)
|
||||
}
|
||||
atomic.StoreUint32(&s.closed, 1)
|
||||
_ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
|
||||
|
||||
if active {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"github.com/cbeuw/connutil"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -344,7 +345,7 @@ func TestParallel(t *testing.T) {
|
|||
seshConfigOrdered.Obfuscator = obfuscator
|
||||
sesh := MakeSession(0, seshConfigOrdered)
|
||||
|
||||
numStreams := 10
|
||||
numStreams := acceptBacklog
|
||||
seqs := make([]*uint64, numStreams)
|
||||
for i := range seqs {
|
||||
seqs[i] = new(uint64)
|
||||
|
|
@ -359,7 +360,7 @@ func TestParallel(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
numOfTests := 100
|
||||
numOfTests := 5000
|
||||
tests := make([]struct {
|
||||
name string
|
||||
frame *Frame
|
||||
|
|
@ -369,7 +370,9 @@ func TestParallel(t *testing.T) {
|
|||
tests[i].frame = randFrame()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, tc := range tests {
|
||||
wg.Add(1)
|
||||
go func(frame *Frame) {
|
||||
data := make([]byte, 1000)
|
||||
n, _ := sesh.Obfs(frame, data)
|
||||
|
|
@ -379,9 +382,12 @@ func TestParallel(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wg.Done()
|
||||
}(tc.frame)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
sc := int(sesh.streamCount())
|
||||
var count int
|
||||
sesh.streams.Range(func(_, s interface{}) bool {
|
||||
if s != nil {
|
||||
|
|
@ -389,7 +395,6 @@ func TestParallel(t *testing.T) {
|
|||
}
|
||||
return true
|
||||
})
|
||||
sc := int(sesh.streamCount())
|
||||
if sc != count {
|
||||
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue