mirror of https://github.com/cbeuw/Cloak
Optimise session closing
This commit is contained in:
parent
059a222394
commit
6af97e2c22
|
|
@ -224,7 +224,7 @@ func main() {
|
|||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
if sesh == nil || sesh.IsBroken() {
|
||||
if sesh == nil || sesh.IsClosed() {
|
||||
sesh = makeSession(sta)
|
||||
}
|
||||
go func() {
|
||||
|
|
|
|||
|
|
@ -40,9 +40,7 @@ type Session struct {
|
|||
// For accepting new streams
|
||||
acceptCh chan *Stream
|
||||
|
||||
broken uint32
|
||||
die chan struct{}
|
||||
suicide sync.Once
|
||||
closed uint32
|
||||
|
||||
terminalMsg atomic.Value
|
||||
}
|
||||
|
|
@ -56,7 +54,6 @@ func MakeSession(id uint32, valve *Valve, obfs Obfser, deobfs Deobfser, obfsedRe
|
|||
nextStreamID: 1,
|
||||
streams: make(map[uint32]*Stream),
|
||||
acceptCh: make(chan *Stream, acceptBacklog),
|
||||
die: make(chan struct{}),
|
||||
}
|
||||
sesh.addrs.Store([]net.Addr{nil, nil})
|
||||
sesh.sb = makeSwitchboard(sesh, valve)
|
||||
|
|
@ -71,10 +68,8 @@ func (sesh *Session) AddConnection(conn net.Conn) {
|
|||
}
|
||||
|
||||
func (sesh *Session) OpenStream() (*Stream, error) {
|
||||
select {
|
||||
case <-sesh.die:
|
||||
if sesh.IsClosed() {
|
||||
return nil, ErrBrokenSession
|
||||
default:
|
||||
}
|
||||
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
|
||||
// Because atomic.AddUint32 returns the value after incrementation
|
||||
|
|
@ -87,13 +82,14 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
|||
}
|
||||
|
||||
func (sesh *Session) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-sesh.die:
|
||||
if sesh.IsClosed() {
|
||||
return nil, ErrBrokenSession
|
||||
case stream := <-sesh.acceptCh:
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
stream := <-sesh.acceptCh
|
||||
if stream == nil {
|
||||
return nil, ErrBrokenSession
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (sesh *Session) delStream(id uint32) {
|
||||
|
|
@ -143,10 +139,9 @@ func (sesh *Session) TerminalMsg() string {
|
|||
}
|
||||
|
||||
func (sesh *Session) Close() error {
|
||||
// Because closing a closed channel causes panic
|
||||
sesh.suicide.Do(func() { close(sesh.die) })
|
||||
atomic.StoreUint32(&sesh.broken, 1)
|
||||
atomic.StoreUint32(&sesh.closed, 1)
|
||||
sesh.streamsM.Lock()
|
||||
sesh.acceptCh <- nil
|
||||
for id, stream := range sesh.streams {
|
||||
// If we call stream.Close() here, streamsM will result in a deadlock
|
||||
// because stream.Close calls sesh.delStream, which locks the mutex.
|
||||
|
|
@ -162,14 +157,14 @@ func (sesh *Session) Close() error {
|
|||
|
||||
}
|
||||
|
||||
func (sesh *Session) IsBroken() bool {
|
||||
return atomic.LoadUint32(&sesh.broken) == 1
|
||||
func (sesh *Session) IsClosed() bool {
|
||||
return atomic.LoadUint32(&sesh.closed) == 1
|
||||
}
|
||||
|
||||
func (sesh *Session) timeoutAfter(to time.Duration) {
|
||||
time.Sleep(to)
|
||||
sesh.streamsM.Lock()
|
||||
if len(sesh.streams) == 0 && !sesh.IsBroken() {
|
||||
if len(sesh.streams) == 0 && !sesh.IsClosed() {
|
||||
sesh.streamsM.Unlock()
|
||||
sesh.SetTerminalMsg("timeout")
|
||||
sesh.Close()
|
||||
|
|
|
|||
Loading…
Reference in New Issue