From 02fa0729646da7081b044dd9342930441530e7d7 Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Tue, 16 Oct 2018 21:13:19 +0100 Subject: [PATCH] Fix infinite loop. Baseline --- cmd/ck-server/ck-server.go | 14 ++++++++++--- config/ckserver.json | 2 +- internal/multiplex/frameSorter.go | 3 +++ internal/multiplex/session.go | 6 +++--- internal/multiplex/stream.go | 32 +++++++++++++++++----------- internal/multiplex/switchboard.go | 35 ++++++++++++++++++------------- internal/util/util.go | 9 ++++---- 7 files changed, 64 insertions(+), 37 deletions(-) diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index e9137d7..2536b12 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -6,7 +6,10 @@ import ( "io" "log" "net" + "net/http" + _ "net/http/pprof" "os" + "runtime" "strings" "time" @@ -19,9 +22,8 @@ var version string func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) { for { - _, err := io.Copy(dst, src) - if err != nil { - log.Println(err) + i, err := io.Copy(dst, src) + if err != nil || i == 0 { go dst.Close() go src.Close() return @@ -102,10 +104,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) { newStream, err := sesh.AcceptStream() if err != nil { log.Printf("Failed to get new stream: %v", err) + continue } ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) if err != nil { log.Printf("Failed to connect to ssserver: %v", err) + continue } go pipe(ssConn, newStream) go pipe(newStream, ssConn) @@ -116,6 +120,10 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } func main() { + runtime.SetBlockProfileRate(2) + go func() { + log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) + }() // Should be 127.0.0.1 to listen to ss-server on this machine var localHost string // server_port in ss config, same as remotePort in plugin mode diff --git a/config/ckserver.json b/config/ckserver.json index 001a8a7..9d358c3 100644 --- a/config/ckserver.json +++ b/config/ckserver.json @@ -1,4 +1,4 @@ { "WebServerAddr":"204.79.197.200:443", - "Key":"CN+VRP9OqZR0+Im2X/1y6FvaK7+GBnX6qCiovbo+eVo=" + "Key":"UGUmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ=" } diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go index de74021..eac1df3 100644 --- a/internal/multiplex/frameSorter.go +++ b/internal/multiplex/frameSorter.go @@ -2,6 +2,7 @@ package multiplex import ( "container/heap" + "log" ) // The data is multiplexed through several TCP connections, therefore the @@ -57,8 +58,10 @@ func (s *Stream) recvNewFrame() { for { f := <-s.newFrameCh if f == nil { + log.Println("nil frame") continue } + // For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255 fs := &frameNode{ f.Seq, diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index cb6d8d7..c8805ab 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -8,7 +8,7 @@ import ( const ( // Copied from smux - errBrokenPipe = "broken pipe" + errBrokenPipe = "broken stream" errRepeatStreamClosing = "trying to close a closed stream" acceptBacklog = 1024 @@ -84,9 +84,9 @@ func (sesh *Session) AcceptStream() (*Stream, error) { } func (sesh *Session) delStream(id uint32) { - sesh.streamsM.RLock() + sesh.streamsM.Lock() delete(sesh.streams, id) - sesh.streamsM.RUnlock() + sesh.streamsM.Unlock() } func (sesh *Session) isStream(id uint32) bool { diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 272975c..9169c5a 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -8,7 +8,7 @@ import ( ) const ( - readBuffer = 102400 + readBuffer = 20480 ) type Stream struct { @@ -50,21 +50,29 @@ func makeStream(id uint32, sesh *Session) *Stream { } func (stream *Stream) Read(buf []byte) (n int, err error) { - if len(buf) != 0 { + if len(buf) == 0 { select { case <-stream.die: + log.Printf("Stream %v dying\n", stream.id) return 0, errors.New(errBrokenPipe) - case data := <-stream.sortedBufCh: - if len(data) > 0 { - copy(buf, data) - return len(data), nil - } else { - // TODO: close stream here or not? - return 0, io.EOF - } + default: + return 0, nil } } - return 0, errors.New(errBrokenPipe) + select { + case <-stream.die: + log.Printf("Stream %v dying\n", stream.id) + return 0, errors.New(errBrokenPipe) + default: + } + data := <-stream.sortedBufCh + if len(data) > 0 { + copy(buf, data) + return len(data), nil + } else { + // TODO: close stream here or not? + return 0, io.EOF + } } @@ -111,8 +119,8 @@ func (stream *Stream) Close() error { return errors.New(errRepeatStreamClosing) } stream.closing = true - stream.session.delStream(stream.id) close(stream.die) + stream.session.delStream(stream.id) stream.session.closeQCh <- stream.id return nil } diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index cf9449a..3d7ed60 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -3,7 +3,7 @@ package multiplex import ( "log" "net" - "sort" + //"sort" ) const ( @@ -79,30 +79,36 @@ type sentNotifier struct { func (ce *connEnclave) send(data []byte) { // TODO: error handling - n, err := ce.remoteConn.Write(data) + _, err := ce.remoteConn.Write(data) if err != nil { + ce.sb.closingCECh <- ce log.Println(err) } - sn := &sentNotifier{ - ce, - n, - } - ce.sb.sentNotifyCh <- sn + /* + sn := &sentNotifier{ + ce, + n, + } + ce.sb.sentNotifyCh <- sn + */ } // Dispatcher sends data coming from a stream to a remote connection // I used channels here because I didn't want to use mutex func (sb *switchboard) dispatch() { + var nextCE int for { select { // dispatCh receives data from stream.Write case data := <-sb.dispatCh: - go sb.ces[0].send(data) - sb.ces[0].sendQueue += len(data) - case notified := <-sb.sentNotifyCh: - notified.ce.sendQueue -= notified.sent - sort.Sort(byQ(sb.ces)) + go sb.ces[nextCE%len(sb.ces)].send(data) + //sb.ces[0].sendQueue += len(data) + nextCE += 1 + /*case notified := <-sb.sentNotifyCh: + notified.ce.sendQueue -= notified.sent + sort.Sort(byQ(sb.ces))*/ case conn := <-sb.newConnCh: + log.Println("newConn") newCe := &connEnclave{ sb: sb, remoteConn: conn, @@ -110,8 +116,9 @@ func (sb *switchboard) dispatch() { } sb.ces = append(sb.ces, newCe) go sb.deplex(newCe) - sort.Sort(byQ(sb.ces)) + //sort.Sort(byQ(sb.ces)) case closing := <-sb.closingCECh: + log.Println("Closing conn") for i, ce := range sb.ces { if closing == ce { sb.ces = append(sb.ces[:i], sb.ces[i+1:]...) @@ -124,7 +131,7 @@ func (sb *switchboard) dispatch() { } func (sb *switchboard) deplex(ce *connEnclave) { - buf := make([]byte, 204800) + buf := make([]byte, 20480) for { i, err := sb.session.obfsedReader(ce.remoteConn, buf) if err != nil { diff --git a/internal/util/util.go b/internal/util/util.go index 4960e79..49cbb7f 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -6,6 +6,7 @@ import ( "io" prand "math/rand" "net" + "strconv" "time" ) @@ -45,15 +46,15 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) { } dataLength := BtoInt(buffer[3:5]) + if dataLength > len(buffer) { + err = errors.New("Reading TLS message: message size greater than buffer. message size: " + strconv.Itoa(dataLength)) + return + } left := dataLength readPtr := 5 conn.SetReadDeadline(time.Now().Add(3 * time.Second)) for left != 0 { - if readPtr > len(buffer) || readPtr+left > len(buffer) { - err = errors.New("Reading TLS message: message size greater than buffer") - return - } // If left > buffer size (i.e. our message got segmented), the entire MTU is read // if left = buffer size, the entire buffer is all there left to read // if left < buffer size (i.e. multiple messages came together),