Optimise session closing

This commit is contained in:
Qian Wang 2019-07-28 23:27:59 +01:00
parent 059a222394
commit 6af97e2c22
2 changed files with 14 additions and 19 deletions

View File

@ -224,7 +224,7 @@ func main() {
log.Println(err)
continue
}
if sesh == nil || sesh.IsBroken() {
if sesh == nil || sesh.IsClosed() {
sesh = makeSession(sta)
}
go func() {

View File

@ -40,9 +40,7 @@ type Session struct {
// For accepting new streams
acceptCh chan *Stream
broken uint32
die chan struct{}
suicide sync.Once
closed uint32
terminalMsg atomic.Value
}
@ -56,7 +54,6 @@ func MakeSession(id uint32, valve *Valve, obfs Obfser, deobfs Deobfser, obfsedRe
nextStreamID: 1,
streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog),
die: make(chan struct{}),
}
sesh.addrs.Store([]net.Addr{nil, nil})
sesh.sb = makeSwitchboard(sesh, valve)
@ -71,10 +68,8 @@ func (sesh *Session) AddConnection(conn net.Conn) {
}
func (sesh *Session) OpenStream() (*Stream, error) {
select {
case <-sesh.die:
if sesh.IsClosed() {
return nil, ErrBrokenSession
default:
}
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
// Because atomic.AddUint32 returns the value after incrementation
@ -87,13 +82,14 @@ func (sesh *Session) OpenStream() (*Stream, error) {
}
func (sesh *Session) Accept() (net.Conn, error) {
select {
case <-sesh.die:
if sesh.IsClosed() {
return nil, ErrBrokenSession
case stream := <-sesh.acceptCh:
return stream, nil
}
stream := <-sesh.acceptCh
if stream == nil {
return nil, ErrBrokenSession
}
return stream, nil
}
func (sesh *Session) delStream(id uint32) {
@ -143,10 +139,9 @@ func (sesh *Session) TerminalMsg() string {
}
func (sesh *Session) Close() error {
// Because closing a closed channel causes panic
sesh.suicide.Do(func() { close(sesh.die) })
atomic.StoreUint32(&sesh.broken, 1)
atomic.StoreUint32(&sesh.closed, 1)
sesh.streamsM.Lock()
sesh.acceptCh <- nil
for id, stream := range sesh.streams {
// If we call stream.Close() here, streamsM will result in a deadlock
// because stream.Close calls sesh.delStream, which locks the mutex.
@ -162,14 +157,14 @@ func (sesh *Session) Close() error {
}
func (sesh *Session) IsBroken() bool {
return atomic.LoadUint32(&sesh.broken) == 1
func (sesh *Session) IsClosed() bool {
return atomic.LoadUint32(&sesh.closed) == 1
}
func (sesh *Session) timeoutAfter(to time.Duration) {
time.Sleep(to)
sesh.streamsM.Lock()
if len(sesh.streams) == 0 && !sesh.IsBroken() {
if len(sesh.streams) == 0 && !sesh.IsClosed() {
sesh.streamsM.Unlock()
sesh.SetTerminalMsg("timeout")
sesh.Close()