From 6af97e2c22423ae51c8ae2d801ff3f50a6554b58 Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Sun, 28 Jul 2019 23:27:59 +0100 Subject: [PATCH] Optimise session closing --- cmd/ck-client/ck-client.go | 2 +- internal/multiplex/session.go | 31 +++++++++++++------------------ 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index d8272ee..ff5ca5e 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -224,7 +224,7 @@ func main() { log.Println(err) continue } - if sesh == nil || sesh.IsBroken() { + if sesh == nil || sesh.IsClosed() { sesh = makeSession(sta) } go func() { diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 6ecc6f2..76c2747 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -40,9 +40,7 @@ type Session struct { // For accepting new streams acceptCh chan *Stream - broken uint32 - die chan struct{} - suicide sync.Once + closed uint32 terminalMsg atomic.Value } @@ -56,7 +54,6 @@ func MakeSession(id uint32, valve *Valve, obfs Obfser, deobfs Deobfser, obfsedRe nextStreamID: 1, streams: make(map[uint32]*Stream), acceptCh: make(chan *Stream, acceptBacklog), - die: make(chan struct{}), } sesh.addrs.Store([]net.Addr{nil, nil}) sesh.sb = makeSwitchboard(sesh, valve) @@ -71,10 +68,8 @@ func (sesh *Session) AddConnection(conn net.Conn) { } func (sesh *Session) OpenStream() (*Stream, error) { - select { - case <-sesh.die: + if sesh.IsClosed() { return nil, ErrBrokenSession - default: } id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1 // Because atomic.AddUint32 returns the value after incrementation @@ -87,13 +82,14 @@ func (sesh *Session) OpenStream() (*Stream, error) { } func (sesh *Session) Accept() (net.Conn, error) { - select { - case <-sesh.die: + if sesh.IsClosed() { return nil, ErrBrokenSession - case stream := <-sesh.acceptCh: - return stream, nil } - + stream := <-sesh.acceptCh + if stream == nil { + return nil, ErrBrokenSession + } + return stream, nil } func (sesh *Session) delStream(id uint32) { @@ -143,10 +139,9 @@ func (sesh *Session) TerminalMsg() string { } func (sesh *Session) Close() error { - // Because closing a closed channel causes panic - sesh.suicide.Do(func() { close(sesh.die) }) - atomic.StoreUint32(&sesh.broken, 1) + atomic.StoreUint32(&sesh.closed, 1) sesh.streamsM.Lock() + sesh.acceptCh <- nil for id, stream := range sesh.streams { // If we call stream.Close() here, streamsM will result in a deadlock // because stream.Close calls sesh.delStream, which locks the mutex. @@ -162,14 +157,14 @@ func (sesh *Session) Close() error { } -func (sesh *Session) IsBroken() bool { - return atomic.LoadUint32(&sesh.broken) == 1 +func (sesh *Session) IsClosed() bool { + return atomic.LoadUint32(&sesh.closed) == 1 } func (sesh *Session) timeoutAfter(to time.Duration) { time.Sleep(to) sesh.streamsM.Lock() - if len(sesh.streams) == 0 && !sesh.IsBroken() { + if len(sesh.streams) == 0 && !sesh.IsClosed() { sesh.streamsM.Unlock() sesh.SetTerminalMsg("timeout") sesh.Close()