mirror of https://github.com/cbeuw/Cloak
Optimise session closing
This commit is contained in:
parent
059a222394
commit
6af97e2c22
|
|
@ -224,7 +224,7 @@ func main() {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if sesh == nil || sesh.IsBroken() {
|
if sesh == nil || sesh.IsClosed() {
|
||||||
sesh = makeSession(sta)
|
sesh = makeSession(sta)
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
||||||
|
|
@ -40,9 +40,7 @@ type Session struct {
|
||||||
// For accepting new streams
|
// For accepting new streams
|
||||||
acceptCh chan *Stream
|
acceptCh chan *Stream
|
||||||
|
|
||||||
broken uint32
|
closed uint32
|
||||||
die chan struct{}
|
|
||||||
suicide sync.Once
|
|
||||||
|
|
||||||
terminalMsg atomic.Value
|
terminalMsg atomic.Value
|
||||||
}
|
}
|
||||||
|
|
@ -56,7 +54,6 @@ func MakeSession(id uint32, valve *Valve, obfs Obfser, deobfs Deobfser, obfsedRe
|
||||||
nextStreamID: 1,
|
nextStreamID: 1,
|
||||||
streams: make(map[uint32]*Stream),
|
streams: make(map[uint32]*Stream),
|
||||||
acceptCh: make(chan *Stream, acceptBacklog),
|
acceptCh: make(chan *Stream, acceptBacklog),
|
||||||
die: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
sesh.addrs.Store([]net.Addr{nil, nil})
|
sesh.addrs.Store([]net.Addr{nil, nil})
|
||||||
sesh.sb = makeSwitchboard(sesh, valve)
|
sesh.sb = makeSwitchboard(sesh, valve)
|
||||||
|
|
@ -71,10 +68,8 @@ func (sesh *Session) AddConnection(conn net.Conn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) OpenStream() (*Stream, error) {
|
func (sesh *Session) OpenStream() (*Stream, error) {
|
||||||
select {
|
if sesh.IsClosed() {
|
||||||
case <-sesh.die:
|
|
||||||
return nil, ErrBrokenSession
|
return nil, ErrBrokenSession
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
|
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
|
||||||
// Because atomic.AddUint32 returns the value after incrementation
|
// Because atomic.AddUint32 returns the value after incrementation
|
||||||
|
|
@ -87,13 +82,14 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) Accept() (net.Conn, error) {
|
func (sesh *Session) Accept() (net.Conn, error) {
|
||||||
select {
|
if sesh.IsClosed() {
|
||||||
case <-sesh.die:
|
|
||||||
return nil, ErrBrokenSession
|
return nil, ErrBrokenSession
|
||||||
case stream := <-sesh.acceptCh:
|
|
||||||
return stream, nil
|
|
||||||
}
|
}
|
||||||
|
stream := <-sesh.acceptCh
|
||||||
|
if stream == nil {
|
||||||
|
return nil, ErrBrokenSession
|
||||||
|
}
|
||||||
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) delStream(id uint32) {
|
func (sesh *Session) delStream(id uint32) {
|
||||||
|
|
@ -143,10 +139,9 @@ func (sesh *Session) TerminalMsg() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) Close() error {
|
func (sesh *Session) Close() error {
|
||||||
// Because closing a closed channel causes panic
|
atomic.StoreUint32(&sesh.closed, 1)
|
||||||
sesh.suicide.Do(func() { close(sesh.die) })
|
|
||||||
atomic.StoreUint32(&sesh.broken, 1)
|
|
||||||
sesh.streamsM.Lock()
|
sesh.streamsM.Lock()
|
||||||
|
sesh.acceptCh <- nil
|
||||||
for id, stream := range sesh.streams {
|
for id, stream := range sesh.streams {
|
||||||
// If we call stream.Close() here, streamsM will result in a deadlock
|
// If we call stream.Close() here, streamsM will result in a deadlock
|
||||||
// because stream.Close calls sesh.delStream, which locks the mutex.
|
// because stream.Close calls sesh.delStream, which locks the mutex.
|
||||||
|
|
@ -162,14 +157,14 @@ func (sesh *Session) Close() error {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) IsBroken() bool {
|
func (sesh *Session) IsClosed() bool {
|
||||||
return atomic.LoadUint32(&sesh.broken) == 1
|
return atomic.LoadUint32(&sesh.closed) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) timeoutAfter(to time.Duration) {
|
func (sesh *Session) timeoutAfter(to time.Duration) {
|
||||||
time.Sleep(to)
|
time.Sleep(to)
|
||||||
sesh.streamsM.Lock()
|
sesh.streamsM.Lock()
|
||||||
if len(sesh.streams) == 0 && !sesh.IsBroken() {
|
if len(sesh.streams) == 0 && !sesh.IsClosed() {
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
sesh.SetTerminalMsg("timeout")
|
sesh.SetTerminalMsg("timeout")
|
||||||
sesh.Close()
|
sesh.Close()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue