mirror of https://github.com/cbeuw/Cloak
Fix goroutine leak
This commit is contained in:
parent
0db52a8a26
commit
077eb16dba
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
//"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -115,7 +115,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
||||||
newStream, err := sesh.AcceptStream()
|
newStream, err := sesh.AcceptStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get new stream: %v", err)
|
log.Printf("Failed to get new stream: %v", err)
|
||||||
continue
|
if err == mux.ErrBrokenSession {
|
||||||
|
sta.DelSession(arrSID)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
|
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -131,6 +136,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
runtime.SetBlockProfileRate(5)
|
||||||
go func() {
|
go func() {
|
||||||
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
|
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
|
||||||
}()
|
}()
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,12 @@ func (sh *sorterHeap) Pop() interface{} {
|
||||||
|
|
||||||
func (s *Stream) recvNewFrame() {
|
func (s *Stream) recvNewFrame() {
|
||||||
for {
|
for {
|
||||||
f := <-s.newFrameCh
|
var f *Frame
|
||||||
|
select {
|
||||||
|
case <-s.die:
|
||||||
|
return
|
||||||
|
case f = <-s.newFrameCh:
|
||||||
|
}
|
||||||
if f == nil {
|
if f == nil {
|
||||||
log.Println("nil frame")
|
log.Println("nil frame")
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -9,14 +9,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errBrokenSession = "broken session"
|
|
||||||
errRepeatSessionClosing = "trying to close a closed session"
|
|
||||||
// Copied from smux
|
// Copied from smux
|
||||||
acceptBacklog = 1024
|
acceptBacklog = 1024
|
||||||
|
|
||||||
closeBacklog = 512
|
closeBacklog = 512
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrBrokenSession = errors.New("broken session")
|
||||||
|
var errRepeatSessionClosing = errors.New("trying to close a closed session")
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
id int
|
id int
|
||||||
|
|
||||||
|
|
@ -58,6 +59,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([]
|
||||||
streams: make(map[uint32]*Stream),
|
streams: make(map[uint32]*Stream),
|
||||||
acceptCh: make(chan *Stream, acceptBacklog),
|
acceptCh: make(chan *Stream, acceptBacklog),
|
||||||
closeQCh: make(chan uint32, closeBacklog),
|
closeQCh: make(chan uint32, closeBacklog),
|
||||||
|
die: make(chan struct{}),
|
||||||
}
|
}
|
||||||
sesh.sb = makeSwitchboard(conn, sesh)
|
sesh.sb = makeSwitchboard(conn, sesh)
|
||||||
return sesh
|
return sesh
|
||||||
|
|
@ -80,7 +82,7 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
||||||
func (sesh *Session) AcceptStream() (*Stream, error) {
|
func (sesh *Session) AcceptStream() (*Stream, error) {
|
||||||
select {
|
select {
|
||||||
case <-sesh.die:
|
case <-sesh.die:
|
||||||
return nil, errors.New(errBrokenSession)
|
return nil, ErrBrokenSession
|
||||||
case stream := <-sesh.acceptCh:
|
case stream := <-sesh.acceptCh:
|
||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
@ -122,7 +124,7 @@ func (sesh *Session) Close() error {
|
||||||
sesh.closingM.Lock()
|
sesh.closingM.Lock()
|
||||||
defer sesh.closingM.Unlock()
|
defer sesh.closingM.Unlock()
|
||||||
if sesh.closing {
|
if sesh.closing {
|
||||||
return errors.New(errRepeatSessionClosing)
|
return errRepeatSessionClosing
|
||||||
}
|
}
|
||||||
sesh.closing = true
|
sesh.closing = true
|
||||||
close(sesh.die)
|
close(sesh.die)
|
||||||
|
|
@ -138,6 +140,7 @@ func (sesh *Session) Close() error {
|
||||||
}
|
}
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
|
close(sesh.sb.die)
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,8 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var errBrokenStream = errors.New("broken stream")
|
||||||
errBrokenStream = "broken stream"
|
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
||||||
errRepeatStreamClosing = "trying to close a closed stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Stream struct {
|
type Stream struct {
|
||||||
id uint32
|
id uint32
|
||||||
|
|
@ -23,14 +21,18 @@ type Stream struct {
|
||||||
sh sorterHeap
|
sh sorterHeap
|
||||||
wrapMode bool
|
wrapMode bool
|
||||||
|
|
||||||
newFrameCh chan *Frame
|
// New frames are received through newFrameCh by frameSorter
|
||||||
|
newFrameCh chan *Frame
|
||||||
|
// sortedBufCh are order-sorted data ready to be read raw
|
||||||
sortedBufCh chan []byte
|
sortedBufCh chan []byte
|
||||||
|
|
||||||
nextSendSeq uint32
|
nextSendSeq uint32
|
||||||
|
|
||||||
closingM sync.Mutex
|
closingM sync.Mutex
|
||||||
die chan struct{}
|
// close(die) is used to notify different goroutines that this stream is closing
|
||||||
closing bool
|
die chan struct{}
|
||||||
|
// to prevent closing a closed channel
|
||||||
|
closing bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeStream(id uint32, sesh *Session) *Stream {
|
func makeStream(id uint32, sesh *Session) *Stream {
|
||||||
|
|
@ -51,7 +53,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
||||||
select {
|
select {
|
||||||
case <-stream.die:
|
case <-stream.die:
|
||||||
log.Printf("Stream %v dying\n", stream.id)
|
log.Printf("Stream %v dying\n", stream.id)
|
||||||
return 0, errors.New(errBrokenStream)
|
return 0, errBrokenStream
|
||||||
default:
|
default:
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
@ -59,7 +61,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
||||||
select {
|
select {
|
||||||
case <-stream.die:
|
case <-stream.die:
|
||||||
log.Printf("Stream %v dying\n", stream.id)
|
log.Printf("Stream %v dying\n", stream.id)
|
||||||
return 0, errors.New(errBrokenStream)
|
return 0, errBrokenStream
|
||||||
case data := <-stream.sortedBufCh:
|
case data := <-stream.sortedBufCh:
|
||||||
if len(buf) < len(data) {
|
if len(buf) < len(data) {
|
||||||
log.Println(len(data))
|
log.Println(len(data))
|
||||||
|
|
@ -75,7 +77,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
|
||||||
select {
|
select {
|
||||||
case <-stream.die:
|
case <-stream.die:
|
||||||
log.Printf("Stream %v dying\n", stream.id)
|
log.Printf("Stream %v dying\n", stream.id)
|
||||||
return 0, errors.New(errBrokenStream)
|
return 0, errBrokenStream
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -109,7 +111,7 @@ func (stream *Stream) Close() error {
|
||||||
stream.closingM.Lock()
|
stream.closingM.Lock()
|
||||||
defer stream.closingM.Unlock()
|
defer stream.closingM.Unlock()
|
||||||
if stream.closing {
|
if stream.closing {
|
||||||
return errors.New(errRepeatStreamClosing)
|
return errRepeatStreamClosing
|
||||||
}
|
}
|
||||||
stream.closing = true
|
stream.closing = true
|
||||||
close(stream.die)
|
close(stream.die)
|
||||||
|
|
@ -127,7 +129,7 @@ func (stream *Stream) closeNoDelMap() error {
|
||||||
stream.closingM.Lock()
|
stream.closingM.Lock()
|
||||||
defer stream.closingM.Unlock()
|
defer stream.closingM.Unlock()
|
||||||
if stream.closing {
|
if stream.closing {
|
||||||
return errors.New(errRepeatStreamClosing)
|
return errRepeatStreamClosing
|
||||||
}
|
}
|
||||||
stream.closing = true
|
stream.closing = true
|
||||||
close(stream.die)
|
close(stream.die)
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,12 @@ type switchboard struct {
|
||||||
|
|
||||||
// For telling dispatcher how many bytes have been sent after Connection.send.
|
// For telling dispatcher how many bytes have been sent after Connection.send.
|
||||||
sentNotifyCh chan *sentNotifier
|
sentNotifyCh chan *sentNotifier
|
||||||
dispatCh chan []byte
|
// dispatCh is used by streams to send new data to remote
|
||||||
newConnCh chan net.Conn
|
dispatCh chan []byte
|
||||||
closingCECh chan *connEnclave
|
newConnCh chan net.Conn
|
||||||
|
closingCECh chan *connEnclave
|
||||||
|
die chan struct{}
|
||||||
|
closing bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some data comes from a Stream to be sent through one of the many
|
// Some data comes from a Stream to be sent through one of the many
|
||||||
|
|
@ -57,6 +60,7 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
|
||||||
dispatCh: make(chan []byte, dispatchBacklog),
|
dispatCh: make(chan []byte, dispatchBacklog),
|
||||||
newConnCh: make(chan net.Conn, newConnBacklog),
|
newConnCh: make(chan net.Conn, newConnBacklog),
|
||||||
closingCECh: make(chan *connEnclave, 5),
|
closingCECh: make(chan *connEnclave, 5),
|
||||||
|
die: make(chan struct{}),
|
||||||
}
|
}
|
||||||
ce := &connEnclave{
|
ce := &connEnclave{
|
||||||
sb: sb,
|
sb: sb,
|
||||||
|
|
@ -97,6 +101,7 @@ func (ce *connEnclave) send(data []byte) {
|
||||||
// Dispatcher sends data coming from a stream to a remote connection
|
// Dispatcher sends data coming from a stream to a remote connection
|
||||||
// I used channels here because I didn't want to use mutex
|
// I used channels here because I didn't want to use mutex
|
||||||
func (sb *switchboard) dispatch() {
|
func (sb *switchboard) dispatch() {
|
||||||
|
var dying bool
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// dispatCh receives data from stream.Write
|
// dispatCh receives data from stream.Write
|
||||||
|
|
@ -123,6 +128,15 @@ func (sb *switchboard) dispatch() {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(sb.ces) == 0 && !dying {
|
||||||
|
sb.session.Close()
|
||||||
|
}
|
||||||
|
case <-sb.die:
|
||||||
|
dying = true
|
||||||
|
for _, ce := range sb.ces {
|
||||||
|
ce.remoteConn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -116,8 +116,8 @@ func (sta *State) ParseConfig(conf string) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sta *State) GetSession(SID [32]byte) *mux.Session {
|
func (sta *State) GetSession(SID [32]byte) *mux.Session {
|
||||||
sta.sessionsM.Lock()
|
sta.sessionsM.RLock()
|
||||||
defer sta.sessionsM.Unlock()
|
defer sta.sessionsM.RUnlock()
|
||||||
if sesh, ok := sta.sessions[SID]; ok {
|
if sesh, ok := sta.sessions[SID]; ok {
|
||||||
return sesh
|
return sesh
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -131,6 +131,12 @@ func (sta *State) PutSession(SID [32]byte, sesh *mux.Session) {
|
||||||
sta.sessionsM.Unlock()
|
sta.sessionsM.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sta *State) DelSession(SID [32]byte) {
|
||||||
|
sta.sessionsM.Lock()
|
||||||
|
delete(sta.sessions, SID)
|
||||||
|
sta.sessionsM.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (sta *State) getUsedRandom(random [32]byte) int {
|
func (sta *State) getUsedRandom(random [32]byte) int {
|
||||||
sta.usedRandomM.Lock()
|
sta.usedRandomM.Lock()
|
||||||
defer sta.usedRandomM.Unlock()
|
defer sta.usedRandomM.Unlock()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue