From 783d016a2970cfc6adc6d8f561348b177a38fc9b Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 12 Apr 2020 01:35:17 +0100 Subject: [PATCH] Fix a race on closing stream --- internal/multiplex/session.go | 3 +-- internal/multiplex/session_test.go | 11 ++++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 868ebab..86451af 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -147,10 +147,9 @@ func (sesh *Session) Accept() (net.Conn, error) { } func (sesh *Session) closeStream(s *Stream, active bool) error { - if s.isClosed() { + if atomic.SwapUint32(&s.closed, 1) == 1 { return fmt.Errorf("stream %v is already closed", s.id) } - atomic.StoreUint32(&s.closed, 1) _ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() if active { diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index 5f24d94..3af9069 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -5,6 +5,7 @@ import ( "github.com/cbeuw/connutil" "math/rand" "strconv" + "sync" "sync/atomic" "testing" "time" @@ -344,7 +345,7 @@ func TestParallel(t *testing.T) { seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) - numStreams := 10 + numStreams := acceptBacklog seqs := make([]*uint64, numStreams) for i := range seqs { seqs[i] = new(uint64) @@ -359,7 +360,7 @@ func TestParallel(t *testing.T) { } } - numOfTests := 100 + numOfTests := 5000 tests := make([]struct { name string frame *Frame @@ -369,7 +370,9 @@ func TestParallel(t *testing.T) { tests[i].frame = randFrame() } + var wg sync.WaitGroup for _, tc := range tests { + wg.Add(1) go func(frame *Frame) { data := make([]byte, 1000) n, _ := sesh.Obfs(frame, data) @@ -379,9 +382,12 @@ func TestParallel(t *testing.T) { if err != nil { t.Error(err) } + wg.Done() }(tc.frame) } + wg.Wait() + sc := int(sesh.streamCount()) var count int sesh.streams.Range(func(_, s interface{}) bool { if s != nil { @@ -389,7 +395,6 @@ func TestParallel(t *testing.T) { } return true }) - sc := int(sesh.streamCount()) if sc != count { t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) }