diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index c4d76ff..2485fe2 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -9,7 +9,7 @@ import ( "net/http" _ "net/http/pprof" "os" - //"runtime" + "runtime" "strings" "time" @@ -115,7 +115,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) { newStream, err := sesh.AcceptStream() if err != nil { log.Printf("Failed to get new stream: %v", err) - continue + if err == mux.ErrBrokenSession { + sta.DelSession(arrSID) + return + } else { + continue + } } ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) if err != nil { @@ -131,6 +136,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } func main() { + runtime.SetBlockProfileRate(5) go func() { log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) }() diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go index e4a658b..d858dc3 100644 --- a/internal/multiplex/frameSorter.go +++ b/internal/multiplex/frameSorter.go @@ -56,7 +56,12 @@ func (sh *sorterHeap) Pop() interface{} { func (s *Stream) recvNewFrame() { for { - f := <-s.newFrameCh + var f *Frame + select { + case <-s.die: + return + case f = <-s.newFrameCh: + } if f == nil { log.Println("nil frame") continue diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 49a029e..af7adb7 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -9,14 +9,15 @@ import ( ) const ( - errBrokenSession = "broken session" - errRepeatSessionClosing = "trying to close a closed session" // Copied from smux acceptBacklog = 1024 closeBacklog = 512 ) +var ErrBrokenSession = errors.New("broken session") +var errRepeatSessionClosing = errors.New("trying to close a closed session") + type Session struct { id int @@ -58,6 +59,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([] streams: make(map[uint32]*Stream), acceptCh: make(chan *Stream, acceptBacklog), closeQCh: make(chan uint32, closeBacklog), + die: make(chan struct{}), } sesh.sb = makeSwitchboard(conn, sesh) return sesh @@ -80,7 +82,7 @@ func (sesh *Session) OpenStream() (*Stream, error) { func (sesh *Session) AcceptStream() (*Stream, error) { select { case <-sesh.die: - return nil, errors.New(errBrokenSession) + return nil, ErrBrokenSession case stream := <-sesh.acceptCh: return stream, nil } @@ -122,7 +124,7 @@ func (sesh *Session) Close() error { sesh.closingM.Lock() defer sesh.closingM.Unlock() if sesh.closing { - return errors.New(errRepeatSessionClosing) + return errRepeatSessionClosing } sesh.closing = true close(sesh.die) @@ -138,6 +140,7 @@ func (sesh *Session) Close() error { } sesh.streamsM.Unlock() + close(sesh.sb.die) return nil } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 768e895..5479fed 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -7,10 +7,8 @@ import ( "sync/atomic" ) -const ( - errBrokenStream = "broken stream" - errRepeatStreamClosing = "trying to close a closed stream" -) +var errBrokenStream = errors.New("broken stream") +var errRepeatStreamClosing = errors.New("trying to close a closed stream") type Stream struct { id uint32 @@ -23,14 +21,18 @@ type Stream struct { sh sorterHeap wrapMode bool - newFrameCh chan *Frame + // New frames are received through newFrameCh by frameSorter + newFrameCh chan *Frame + // sortedBufCh are order-sorted data ready to be read raw sortedBufCh chan []byte nextSendSeq uint32 closingM sync.Mutex - die chan struct{} - closing bool + // close(die) is used to notify different goroutines that this stream is closing + die chan struct{} + // to prevent closing a closed channel + closing bool } func makeStream(id uint32, sesh *Session) *Stream { @@ -51,7 +53,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { select { case <-stream.die: log.Printf("Stream %v dying\n", stream.id) - return 0, errors.New(errBrokenStream) + return 0, errBrokenStream default: return 0, nil } @@ -59,7 +61,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { select { case <-stream.die: log.Printf("Stream %v dying\n", stream.id) - return 0, errors.New(errBrokenStream) + return 0, errBrokenStream case data := <-stream.sortedBufCh: if len(buf) < len(data) { log.Println(len(data)) @@ -75,7 +77,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) { select { case <-stream.die: log.Printf("Stream %v dying\n", stream.id) - return 0, errors.New(errBrokenStream) + return 0, errBrokenStream default: } @@ -109,7 +111,7 @@ func (stream *Stream) Close() error { stream.closingM.Lock() defer stream.closingM.Unlock() if stream.closing { - return errors.New(errRepeatStreamClosing) + return errRepeatStreamClosing } stream.closing = true close(stream.die) @@ -127,7 +129,7 @@ func (stream *Stream) closeNoDelMap() error { stream.closingM.Lock() defer stream.closingM.Unlock() if stream.closing { - return errors.New(errRepeatStreamClosing) + return errRepeatStreamClosing } stream.closing = true close(stream.die) diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 3d5808c..a1c7f99 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -20,9 +20,12 @@ type switchboard struct { // For telling dispatcher how many bytes have been sent after Connection.send. sentNotifyCh chan *sentNotifier - dispatCh chan []byte - newConnCh chan net.Conn - closingCECh chan *connEnclave + // dispatCh is used by streams to send new data to remote + dispatCh chan []byte + newConnCh chan net.Conn + closingCECh chan *connEnclave + die chan struct{} + closing bool } // Some data comes from a Stream to be sent through one of the many @@ -57,6 +60,7 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard { dispatCh: make(chan []byte, dispatchBacklog), newConnCh: make(chan net.Conn, newConnBacklog), closingCECh: make(chan *connEnclave, 5), + die: make(chan struct{}), } ce := &connEnclave{ sb: sb, @@ -97,6 +101,7 @@ func (ce *connEnclave) send(data []byte) { // Dispatcher sends data coming from a stream to a remote connection // I used channels here because I didn't want to use mutex func (sb *switchboard) dispatch() { + var dying bool for { select { // dispatCh receives data from stream.Write @@ -123,6 +128,15 @@ func (sb *switchboard) dispatch() { break } } + if len(sb.ces) == 0 && !dying { + sb.session.Close() + } + case <-sb.die: + dying = true + for _, ce := range sb.ces { + ce.remoteConn.Close() + } + return } } } diff --git a/internal/server/state.go b/internal/server/state.go index 8f564bb..7971b8b 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -116,8 +116,8 @@ func (sta *State) ParseConfig(conf string) (err error) { } func (sta *State) GetSession(SID [32]byte) *mux.Session { - sta.sessionsM.Lock() - defer sta.sessionsM.Unlock() + sta.sessionsM.RLock() + defer sta.sessionsM.RUnlock() if sesh, ok := sta.sessions[SID]; ok { return sesh } else { @@ -131,6 +131,12 @@ func (sta *State) PutSession(SID [32]byte, sesh *mux.Session) { sta.sessionsM.Unlock() } +func (sta *State) DelSession(SID [32]byte) { + sta.sessionsM.Lock() + delete(sta.sessions, SID) + sta.sessionsM.Unlock() +} + func (sta *State) getUsedRandom(random [32]byte) int { sta.usedRandomM.Lock() defer sta.usedRandomM.Unlock()