mirror of https://github.com/cbeuw/Cloak
Use sync.Once to close die ch
This commit is contained in:
parent
85e0e95a4b
commit
3b656c9360
|
|
@ -152,8 +152,9 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
||||||
for {
|
for {
|
||||||
newStream, err := sesh.AcceptStream()
|
newStream, err := sesh.AcceptStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get new stream: %v", err)
|
log.Printf("Failed to get new stream: %v\n", err)
|
||||||
if err == mux.ErrBrokenSession {
|
if err == mux.ErrBrokenSession {
|
||||||
|
log.Printf("Session closed: %x:%v\n", UID, sessionID)
|
||||||
user.DelSession(sessionID)
|
user.DelSession(sessionID)
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -162,7 +163,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
||||||
}
|
}
|
||||||
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
|
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to connect to ssserver: %v", err)
|
log.Printf("Failed to connect to ssserver: %v\n", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go pipe(ssConn, newStream)
|
go pipe(ssConn, newStream)
|
||||||
|
|
|
||||||
|
|
@ -38,10 +38,8 @@ type Session struct {
|
||||||
// For accepting new streams
|
// For accepting new streams
|
||||||
acceptCh chan *Stream
|
acceptCh chan *Stream
|
||||||
|
|
||||||
// TODO: use sync.Once for this
|
|
||||||
closingM sync.Mutex
|
|
||||||
die chan struct{}
|
die chan struct{}
|
||||||
closing bool
|
overdose sync.Once // fentanyl? beware of respiratory depression
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1 conn is needed to make a session
|
// 1 conn is needed to make a session
|
||||||
|
|
@ -123,13 +121,7 @@ func (sesh *Session) addStream(id uint32) *Stream {
|
||||||
|
|
||||||
func (sesh *Session) Close() error {
|
func (sesh *Session) Close() error {
|
||||||
// Because closing a closed channel causes panic
|
// Because closing a closed channel causes panic
|
||||||
sesh.closingM.Lock()
|
sesh.overdose.Do(func() { close(sesh.die) })
|
||||||
if sesh.closing {
|
|
||||||
sesh.closingM.Unlock()
|
|
||||||
return errRepeatSessionClosing
|
|
||||||
}
|
|
||||||
sesh.closing = true
|
|
||||||
close(sesh.die)
|
|
||||||
sesh.streamsM.Lock()
|
sesh.streamsM.Lock()
|
||||||
for id, stream := range sesh.streams {
|
for id, stream := range sesh.streams {
|
||||||
// If we call stream.Close() here, streamsM will result in a deadlock
|
// If we call stream.Close() here, streamsM will result in a deadlock
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var errBrokenStream = errors.New("broken stream")
|
var errBrokenStream = errors.New("broken stream")
|
||||||
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
|
||||||
|
|
||||||
type Stream struct {
|
type Stream struct {
|
||||||
id uint32
|
id uint32
|
||||||
|
|
@ -31,11 +30,11 @@ type Stream struct {
|
||||||
// atomic
|
// atomic
|
||||||
nextSendSeq uint32
|
nextSendSeq uint32
|
||||||
|
|
||||||
closingM sync.RWMutex
|
writingM sync.RWMutex
|
||||||
|
|
||||||
// close(die) is used to notify different goroutines that this stream is closing
|
// close(die) is used to notify different goroutines that this stream is closing
|
||||||
die chan struct{}
|
die chan struct{}
|
||||||
// to prevent closing a closed channel
|
heliumMask sync.Once // my personal fav
|
||||||
closing bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeStream(id uint32, sesh *Session) *Stream {
|
func makeStream(id uint32, sesh *Session) *Stream {
|
||||||
|
|
@ -84,10 +83,10 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
|
||||||
// The use of RWMutex is so that the stream will not actively close
|
// The use of RWMutex is so that the stream will not actively close
|
||||||
// in the middle of the execution of Write. This may cause the closing frame
|
// in the middle of the execution of Write. This may cause the closing frame
|
||||||
// to be sent before the data frame and cause loss of packet.
|
// to be sent before the data frame and cause loss of packet.
|
||||||
stream.closingM.RLock()
|
stream.writingM.RLock()
|
||||||
select {
|
select {
|
||||||
case <-stream.die:
|
case <-stream.die:
|
||||||
stream.closingM.RUnlock()
|
stream.writingM.RUnlock()
|
||||||
return 0, errBrokenStream
|
return 0, errBrokenStream
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
@ -101,43 +100,26 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
|
||||||
|
|
||||||
tlsRecord := stream.session.obfs(f)
|
tlsRecord := stream.session.obfs(f)
|
||||||
n, err = stream.session.sb.send(tlsRecord)
|
n, err = stream.session.sb.send(tlsRecord)
|
||||||
stream.closingM.RUnlock()
|
stream.writingM.RUnlock()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *Stream) shutdown() error {
|
|
||||||
// Lock here because closing a closed channel causes panic
|
|
||||||
stream.closingM.Lock()
|
|
||||||
if stream.closing {
|
|
||||||
stream.closingM.Unlock()
|
|
||||||
return errRepeatStreamClosing
|
|
||||||
}
|
|
||||||
stream.closing = true
|
|
||||||
close(stream.die)
|
|
||||||
stream.closingM.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// only close locally. Used when the stream close is notified by the remote
|
// only close locally. Used when the stream close is notified by the remote
|
||||||
func (stream *Stream) passiveClose() error {
|
func (stream *Stream) passiveClose() error {
|
||||||
err := stream.shutdown()
|
stream.heliumMask.Do(func() { close(stream.die) })
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
stream.session.delStream(stream.id)
|
stream.session.delStream(stream.id)
|
||||||
log.Printf("%v passive closing\n", stream.id)
|
log.Printf("%v passive closing\n", stream.id)
|
||||||
|
// TODO: really need to return an error?
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// active close. Close locally and tell the remote that this stream is being closed
|
// active close. Close locally and tell the remote that this stream is being closed
|
||||||
func (stream *Stream) Close() error {
|
func (stream *Stream) Close() error {
|
||||||
|
|
||||||
err := stream.shutdown()
|
stream.writingM.Lock()
|
||||||
if err != nil {
|
stream.heliumMask.Do(func() { close(stream.die) })
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify remote that this stream is closed
|
// Notify remote that this stream is closed
|
||||||
prand.Seed(int64(stream.id))
|
prand.Seed(int64(stream.id))
|
||||||
|
|
@ -151,17 +133,20 @@ func (stream *Stream) Close() error {
|
||||||
Payload: pad,
|
Payload: pad,
|
||||||
}
|
}
|
||||||
tlsRecord := stream.session.obfs(f)
|
tlsRecord := stream.session.obfs(f)
|
||||||
// FIXME: despite sb.send being always called after Write(), the actual TCP sending
|
|
||||||
// may still be out of order
|
|
||||||
stream.session.sb.send(tlsRecord)
|
stream.session.sb.send(tlsRecord)
|
||||||
|
|
||||||
stream.session.delStream(stream.id)
|
stream.session.delStream(stream.id)
|
||||||
log.Printf("%v actively closed\n", stream.id)
|
log.Printf("%v actively closed\n", stream.id)
|
||||||
|
stream.writingM.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as Close() but no call to session.delStream.
|
// Same as passiveClose() but no call to session.delStream.
|
||||||
// This is called in session.Close() to avoid mutex deadlock
|
// This is called in session.Close() to avoid mutex deadlock
|
||||||
|
// We don't notify the remote because session.Close() is always
|
||||||
|
// called when the session is passively closed
|
||||||
func (stream *Stream) closeNoDelMap() error {
|
func (stream *Stream) closeNoDelMap() error {
|
||||||
return stream.shutdown()
|
stream.heliumMask.Do(func() { close(stream.die) })
|
||||||
|
// TODO: really need to return an error?
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue