Use sync.Once to close die ch

This commit is contained in:
Qian Wang 2018-11-23 23:57:35 +00:00
parent 85e0e95a4b
commit 3b656c9360
3 changed files with 23 additions and 45 deletions

View File

@ -152,8 +152,9 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
for { for {
newStream, err := sesh.AcceptStream() newStream, err := sesh.AcceptStream()
if err != nil { if err != nil {
log.Printf("Failed to get new stream: %v", err) log.Printf("Failed to get new stream: %v\n", err)
if err == mux.ErrBrokenSession { if err == mux.ErrBrokenSession {
log.Printf("Session closed: %x:%v\n", UID, sessionID)
user.DelSession(sessionID) user.DelSession(sessionID)
return return
} else { } else {
@ -162,7 +163,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
} }
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil { if err != nil {
log.Printf("Failed to connect to ssserver: %v", err) log.Printf("Failed to connect to ssserver: %v\n", err)
continue continue
} }
go pipe(ssConn, newStream) go pipe(ssConn, newStream)

View File

@ -38,10 +38,8 @@ type Session struct {
// For accepting new streams // For accepting new streams
acceptCh chan *Stream acceptCh chan *Stream
// TODO: use sync.Once for this
closingM sync.Mutex
die chan struct{} die chan struct{}
closing bool overdose sync.Once // fentanyl? beware of respiratory depression
} }
// 1 conn is needed to make a session // 1 conn is needed to make a session
@ -123,13 +121,7 @@ func (sesh *Session) addStream(id uint32) *Stream {
func (sesh *Session) Close() error { func (sesh *Session) Close() error {
// Because closing a closed channel causes panic // Because closing a closed channel causes panic
sesh.closingM.Lock() sesh.overdose.Do(func() { close(sesh.die) })
if sesh.closing {
sesh.closingM.Unlock()
return errRepeatSessionClosing
}
sesh.closing = true
close(sesh.die)
sesh.streamsM.Lock() sesh.streamsM.Lock()
for id, stream := range sesh.streams { for id, stream := range sesh.streams {
// If we call stream.Close() here, streamsM will result in a deadlock // If we call stream.Close() here, streamsM will result in a deadlock

View File

@ -10,7 +10,6 @@ import (
) )
var errBrokenStream = errors.New("broken stream") var errBrokenStream = errors.New("broken stream")
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
type Stream struct { type Stream struct {
id uint32 id uint32
@ -31,11 +30,11 @@ type Stream struct {
// atomic // atomic
nextSendSeq uint32 nextSendSeq uint32
closingM sync.RWMutex writingM sync.RWMutex
// close(die) is used to notify different goroutines that this stream is closing // close(die) is used to notify different goroutines that this stream is closing
die chan struct{} die chan struct{}
// to prevent closing a closed channel heliumMask sync.Once // my personal fav
closing bool
} }
func makeStream(id uint32, sesh *Session) *Stream { func makeStream(id uint32, sesh *Session) *Stream {
@ -84,10 +83,10 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
// The use of RWMutex is so that the stream will not actively close // The use of RWMutex is so that the stream will not actively close
// in the middle of the execution of Write. This may cause the closing frame // in the middle of the execution of Write. This may cause the closing frame
// to be sent before the data frame and cause loss of packet. // to be sent before the data frame and cause loss of packet.
stream.closingM.RLock() stream.writingM.RLock()
select { select {
case <-stream.die: case <-stream.die:
stream.closingM.RUnlock() stream.writingM.RUnlock()
return 0, errBrokenStream return 0, errBrokenStream
default: default:
} }
@ -101,43 +100,26 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
tlsRecord := stream.session.obfs(f) tlsRecord := stream.session.obfs(f)
n, err = stream.session.sb.send(tlsRecord) n, err = stream.session.sb.send(tlsRecord)
stream.closingM.RUnlock() stream.writingM.RUnlock()
return return
} }
func (stream *Stream) shutdown() error {
// Lock here because closing a closed channel causes panic
stream.closingM.Lock()
if stream.closing {
stream.closingM.Unlock()
return errRepeatStreamClosing
}
stream.closing = true
close(stream.die)
stream.closingM.Unlock()
return nil
}
// only close locally. Used when the stream close is notified by the remote // only close locally. Used when the stream close is notified by the remote
func (stream *Stream) passiveClose() error { func (stream *Stream) passiveClose() error {
err := stream.shutdown() stream.heliumMask.Do(func() { close(stream.die) })
if err != nil {
return err
}
stream.session.delStream(stream.id) stream.session.delStream(stream.id)
log.Printf("%v passive closing\n", stream.id) log.Printf("%v passive closing\n", stream.id)
// TODO: really need to return an error?
return nil return nil
} }
// active close. Close locally and tell the remote that this stream is being closed // active close. Close locally and tell the remote that this stream is being closed
func (stream *Stream) Close() error { func (stream *Stream) Close() error {
err := stream.shutdown() stream.writingM.Lock()
if err != nil { stream.heliumMask.Do(func() { close(stream.die) })
return err
}
// Notify remote that this stream is closed // Notify remote that this stream is closed
prand.Seed(int64(stream.id)) prand.Seed(int64(stream.id))
@ -151,17 +133,20 @@ func (stream *Stream) Close() error {
Payload: pad, Payload: pad,
} }
tlsRecord := stream.session.obfs(f) tlsRecord := stream.session.obfs(f)
// FIXME: despite sb.send being always called after Write(), the actual TCP sending
// may still be out of order
stream.session.sb.send(tlsRecord) stream.session.sb.send(tlsRecord)
stream.session.delStream(stream.id) stream.session.delStream(stream.id)
log.Printf("%v actively closed\n", stream.id) log.Printf("%v actively closed\n", stream.id)
stream.writingM.Unlock()
return nil return nil
} }
// Same as Close() but no call to session.delStream. // Same as passiveClose() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock // This is called in session.Close() to avoid mutex deadlock
// We don't notify the remote because session.Close() is always
// called when the session is passively closed
func (stream *Stream) closeNoDelMap() error { func (stream *Stream) closeNoDelMap() error {
return stream.shutdown() stream.heliumMask.Do(func() { close(stream.die) })
// TODO: really need to return an error?
return nil
} }