Fix goroutine leak

This commit is contained in:
Qian Wang 2018-10-27 15:27:43 +01:00
parent 0db52a8a26
commit 077eb16dba
6 changed files with 60 additions and 24 deletions

View File

@ -9,7 +9,7 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
//"runtime" "runtime"
"strings" "strings"
"time" "time"
@ -115,7 +115,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
newStream, err := sesh.AcceptStream() newStream, err := sesh.AcceptStream()
if err != nil { if err != nil {
log.Printf("Failed to get new stream: %v", err) log.Printf("Failed to get new stream: %v", err)
continue if err == mux.ErrBrokenSession {
sta.DelSession(arrSID)
return
} else {
continue
}
} }
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil { if err != nil {
@ -131,6 +136,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
} }
func main() { func main() {
runtime.SetBlockProfileRate(5)
go func() { go func() {
log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
}() }()

View File

@ -56,7 +56,12 @@ func (sh *sorterHeap) Pop() interface{} {
func (s *Stream) recvNewFrame() { func (s *Stream) recvNewFrame() {
for { for {
f := <-s.newFrameCh var f *Frame
select {
case <-s.die:
return
case f = <-s.newFrameCh:
}
if f == nil { if f == nil {
log.Println("nil frame") log.Println("nil frame")
continue continue

View File

@ -9,14 +9,15 @@ import (
) )
const ( const (
errBrokenSession = "broken session"
errRepeatSessionClosing = "trying to close a closed session"
// Copied from smux // Copied from smux
acceptBacklog = 1024 acceptBacklog = 1024
closeBacklog = 512 closeBacklog = 512
) )
var ErrBrokenSession = errors.New("broken session")
var errRepeatSessionClosing = errors.New("trying to close a closed session")
type Session struct { type Session struct {
id int id int
@ -58,6 +59,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([]
streams: make(map[uint32]*Stream), streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog), acceptCh: make(chan *Stream, acceptBacklog),
closeQCh: make(chan uint32, closeBacklog), closeQCh: make(chan uint32, closeBacklog),
die: make(chan struct{}),
} }
sesh.sb = makeSwitchboard(conn, sesh) sesh.sb = makeSwitchboard(conn, sesh)
return sesh return sesh
@ -80,7 +82,7 @@ func (sesh *Session) OpenStream() (*Stream, error) {
func (sesh *Session) AcceptStream() (*Stream, error) { func (sesh *Session) AcceptStream() (*Stream, error) {
select { select {
case <-sesh.die: case <-sesh.die:
return nil, errors.New(errBrokenSession) return nil, ErrBrokenSession
case stream := <-sesh.acceptCh: case stream := <-sesh.acceptCh:
return stream, nil return stream, nil
} }
@ -122,7 +124,7 @@ func (sesh *Session) Close() error {
sesh.closingM.Lock() sesh.closingM.Lock()
defer sesh.closingM.Unlock() defer sesh.closingM.Unlock()
if sesh.closing { if sesh.closing {
return errors.New(errRepeatSessionClosing) return errRepeatSessionClosing
} }
sesh.closing = true sesh.closing = true
close(sesh.die) close(sesh.die)
@ -138,6 +140,7 @@ func (sesh *Session) Close() error {
} }
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
close(sesh.sb.die)
return nil return nil
} }

View File

@ -7,10 +7,8 @@ import (
"sync/atomic" "sync/atomic"
) )
const ( var errBrokenStream = errors.New("broken stream")
errBrokenStream = "broken stream" var errRepeatStreamClosing = errors.New("trying to close a closed stream")
errRepeatStreamClosing = "trying to close a closed stream"
)
type Stream struct { type Stream struct {
id uint32 id uint32
@ -23,14 +21,18 @@ type Stream struct {
sh sorterHeap sh sorterHeap
wrapMode bool wrapMode bool
newFrameCh chan *Frame // New frames are received through newFrameCh by frameSorter
newFrameCh chan *Frame
// sortedBufCh are order-sorted data ready to be read raw
sortedBufCh chan []byte sortedBufCh chan []byte
nextSendSeq uint32 nextSendSeq uint32
closingM sync.Mutex closingM sync.Mutex
die chan struct{} // close(die) is used to notify different goroutines that this stream is closing
closing bool die chan struct{}
// to prevent closing a closed channel
closing bool
} }
func makeStream(id uint32, sesh *Session) *Stream { func makeStream(id uint32, sesh *Session) *Stream {
@ -51,7 +53,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
select { select {
case <-stream.die: case <-stream.die:
log.Printf("Stream %v dying\n", stream.id) log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenStream) return 0, errBrokenStream
default: default:
return 0, nil return 0, nil
} }
@ -59,7 +61,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
select { select {
case <-stream.die: case <-stream.die:
log.Printf("Stream %v dying\n", stream.id) log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenStream) return 0, errBrokenStream
case data := <-stream.sortedBufCh: case data := <-stream.sortedBufCh:
if len(buf) < len(data) { if len(buf) < len(data) {
log.Println(len(data)) log.Println(len(data))
@ -75,7 +77,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
select { select {
case <-stream.die: case <-stream.die:
log.Printf("Stream %v dying\n", stream.id) log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenStream) return 0, errBrokenStream
default: default:
} }
@ -109,7 +111,7 @@ func (stream *Stream) Close() error {
stream.closingM.Lock() stream.closingM.Lock()
defer stream.closingM.Unlock() defer stream.closingM.Unlock()
if stream.closing { if stream.closing {
return errors.New(errRepeatStreamClosing) return errRepeatStreamClosing
} }
stream.closing = true stream.closing = true
close(stream.die) close(stream.die)
@ -127,7 +129,7 @@ func (stream *Stream) closeNoDelMap() error {
stream.closingM.Lock() stream.closingM.Lock()
defer stream.closingM.Unlock() defer stream.closingM.Unlock()
if stream.closing { if stream.closing {
return errors.New(errRepeatStreamClosing) return errRepeatStreamClosing
} }
stream.closing = true stream.closing = true
close(stream.die) close(stream.die)

View File

@ -20,9 +20,12 @@ type switchboard struct {
// For telling dispatcher how many bytes have been sent after Connection.send. // For telling dispatcher how many bytes have been sent after Connection.send.
sentNotifyCh chan *sentNotifier sentNotifyCh chan *sentNotifier
dispatCh chan []byte // dispatCh is used by streams to send new data to remote
newConnCh chan net.Conn dispatCh chan []byte
closingCECh chan *connEnclave newConnCh chan net.Conn
closingCECh chan *connEnclave
die chan struct{}
closing bool
} }
// Some data comes from a Stream to be sent through one of the many // Some data comes from a Stream to be sent through one of the many
@ -57,6 +60,7 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
dispatCh: make(chan []byte, dispatchBacklog), dispatCh: make(chan []byte, dispatchBacklog),
newConnCh: make(chan net.Conn, newConnBacklog), newConnCh: make(chan net.Conn, newConnBacklog),
closingCECh: make(chan *connEnclave, 5), closingCECh: make(chan *connEnclave, 5),
die: make(chan struct{}),
} }
ce := &connEnclave{ ce := &connEnclave{
sb: sb, sb: sb,
@ -97,6 +101,7 @@ func (ce *connEnclave) send(data []byte) {
// Dispatcher sends data coming from a stream to a remote connection // Dispatcher sends data coming from a stream to a remote connection
// I used channels here because I didn't want to use mutex // I used channels here because I didn't want to use mutex
func (sb *switchboard) dispatch() { func (sb *switchboard) dispatch() {
var dying bool
for { for {
select { select {
// dispatCh receives data from stream.Write // dispatCh receives data from stream.Write
@ -123,6 +128,15 @@ func (sb *switchboard) dispatch() {
break break
} }
} }
if len(sb.ces) == 0 && !dying {
sb.session.Close()
}
case <-sb.die:
dying = true
for _, ce := range sb.ces {
ce.remoteConn.Close()
}
return
} }
} }
} }

View File

@ -116,8 +116,8 @@ func (sta *State) ParseConfig(conf string) (err error) {
} }
func (sta *State) GetSession(SID [32]byte) *mux.Session { func (sta *State) GetSession(SID [32]byte) *mux.Session {
sta.sessionsM.Lock() sta.sessionsM.RLock()
defer sta.sessionsM.Unlock() defer sta.sessionsM.RUnlock()
if sesh, ok := sta.sessions[SID]; ok { if sesh, ok := sta.sessions[SID]; ok {
return sesh return sesh
} else { } else {
@ -131,6 +131,12 @@ func (sta *State) PutSession(SID [32]byte, sesh *mux.Session) {
sta.sessionsM.Unlock() sta.sessionsM.Unlock()
} }
func (sta *State) DelSession(SID [32]byte) {
sta.sessionsM.Lock()
delete(sta.sessions, SID)
sta.sessionsM.Unlock()
}
func (sta *State) getUsedRandom(random [32]byte) int { func (sta *State) getUsedRandom(random [32]byte) int {
sta.usedRandomM.Lock() sta.usedRandomM.Lock()
defer sta.usedRandomM.Unlock() defer sta.usedRandomM.Unlock()