From bf8d373f79c8426eb8ef0c156dd57eb025375b0e Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Sun, 28 Jul 2019 11:58:45 +0100 Subject: [PATCH] Stream optimisations --- go.sum | 5 ++ internal/multiplex/frameSorter.go | 10 +-- internal/multiplex/frameSorter_test.go | 2 +- internal/multiplex/stream.go | 101 ++++++++++++------------ internal/multiplex/stream_test.go | 103 +++++++++++++++++++++++++ 5 files changed, 161 insertions(+), 60 deletions(-) create mode 100644 internal/multiplex/stream_test.go diff --git a/go.sum b/go.sum index 34a9d2b..ae0515a 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,11 @@ github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/juju/ratelimit v1.0.1 h1:+7AIFJVQ0EQgq/K9+0Krm7m530Du7tIz0METWzN0RgY= github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b h1:Elez2XeF2p9uyVj0yEUDqQ56NFcDtcBNkYP7yv8YbUE= golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/sys v0.0.0-20190124100055-b90733256f2e h1:3GIlrlVLfkoipSReOMNAgApI0ajnalyLa/EZHHca/XI= diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go index 051aa48..6738893 100644 --- a/internal/multiplex/frameSorter.go +++ b/internal/multiplex/frameSorter.go @@ -58,15 +58,9 @@ func (s *Stream) writeNewFrame(f *Frame) { // cache and order them and send them into sortedBufCh func (s *Stream) recvNewFrame() { for { - var f *Frame - select { - case <-s.die: + f := <-s.newFrameCh + if f == nil { return - case f = <-s.newFrameCh: - } - if f == nil { // This shouldn't happen - //log.Println("nil frame") - continue } // when there's no ooo packages in heap and we receive the next package in order diff --git a/internal/multiplex/frameSorter_test.go b/internal/multiplex/frameSorter_test.go index d448e6c..cfceded 100644 --- a/internal/multiplex/frameSorter_test.go +++ b/internal/multiplex/frameSorter_test.go @@ -50,6 +50,6 @@ func TestRecvNewFrame(t *testing.T) { ) } } - close(stream.die) + stream.newFrameCh <- nil } } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index b56de88..4620562 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -36,15 +36,13 @@ type Stream struct { writingM sync.RWMutex // close(die) is used to notify different goroutines that this stream is closing - die chan struct{} - heliumMask sync.Once // my personal fav + closed uint32 } func makeStream(id uint32, sesh *Session) *Stream { stream := &Stream{ id: id, session: sesh, - die: make(chan struct{}), sh: []*frameNode{}, newFrameCh: make(chan *Frame, 1024), sortedBuf: NewBufferedPipe(), @@ -53,99 +51,100 @@ func makeStream(id uint32, sesh *Session) *Stream { return stream } -func (stream *Stream) Read(buf []byte) (n int, err error) { +func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } + +func (s *Stream) Read(buf []byte) (n int, err error) { if len(buf) == 0 { - select { - case <-stream.die: + if s.isClosed() { return 0, ErrBrokenStream - default: + } else { return 0, nil } } - select { - case <-stream.die: - if stream.sortedBuf.Len() == 0 { + if s.isClosed() { + if s.sortedBuf.Len() == 0 { return 0, ErrBrokenStream } else { - return stream.sortedBuf.Read(buf) + return s.sortedBuf.Read(buf) } - default: - return stream.sortedBuf.Read(buf) + } else { + return s.sortedBuf.Read(buf) } } -func (stream *Stream) Write(in []byte) (n int, err error) { +func (s *Stream) Write(in []byte) (n int, err error) { // RWMutex used here isn't really for RW. // we use it to exploit the fact that RLock doesn't create contention. // The use of RWMutex is so that the stream will not actively close // in the middle of the execution of Write. This may cause the closing frame // to be sent before the data frame and cause loss of packet. - stream.writingM.RLock() - select { - case <-stream.die: - stream.writingM.RUnlock() + s.writingM.RLock() + if s.isClosed() { + s.writingM.RUnlock() return 0, ErrBrokenStream - default: } f := &Frame{ - StreamID: stream.id, - Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1, + StreamID: s.id, + Seq: atomic.AddUint32(&s.nextSendSeq, 1) - 1, Closing: 0, Payload: in, } - tlsRecord, err := stream.session.obfs(f) + tlsRecord, err := s.session.obfs(f) if err != nil { - stream.writingM.RUnlock() + s.writingM.RUnlock() return 0, err } - n, err = stream.session.sb.send(tlsRecord) - stream.writingM.RUnlock() + n, err = s.session.sb.send(tlsRecord) + s.writingM.RUnlock() return } +// the necessary steps to mark the stream as closed and to release resources +func (s *Stream) _close() { + atomic.StoreUint32(&s.closed, 1) + s.newFrameCh <- nil // this will trigger frameSorter to return + s.sortedBuf.Close() +} + // only close locally. Used when the stream close is notified by the remote -func (stream *Stream) passiveClose() { - stream.heliumMask.Do(func() { close(stream.die) }) - stream.session.delStream(stream.id) - stream.sortedBuf.Close() +func (s *Stream) passiveClose() { + s._close() + s.session.delStream(s.id) //log.Printf("%v passive closing\n", stream.id) } // active close. Close locally and tell the remote that this stream is being closed -func (stream *Stream) Close() error { +func (s *Stream) Close() error { - stream.writingM.Lock() - select { - case <-stream.die: - stream.writingM.Unlock() + s.writingM.Lock() + if s.isClosed() { + s.writingM.Unlock() return errors.New("Already Closed") - default: } - stream.heliumMask.Do(func() { close(stream.die) }) // Notify remote that this stream is closed - prand.Seed(int64(stream.id)) + prand.Seed(int64(s.id)) padLen := int(math.Floor(prand.Float64()*200 + 300)) pad := make([]byte, padLen) prand.Read(pad) f := &Frame{ - StreamID: stream.id, - Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1, + StreamID: s.id, + Seq: atomic.AddUint32(&s.nextSendSeq, 1) - 1, Closing: 1, Payload: pad, } - tlsRecord, _ := stream.session.obfs(f) - stream.session.sb.send(tlsRecord) + tlsRecord, _ := s.session.obfs(f) + s.session.sb.send(tlsRecord) - stream.sortedBuf.Close() - stream.session.delStream(stream.id) + s._close() + s.session.delStream(s.id) //log.Printf("%v actively closed\n", stream.id) - stream.writingM.Unlock() + s.writingM.Unlock() return nil } @@ -153,18 +152,18 @@ func (stream *Stream) Close() error { // This is called in session.Close() to avoid mutex deadlock // We don't notify the remote because session.Close() is always // called when the session is passively closed -func (stream *Stream) closeNoDelMap() { - stream.heliumMask.Do(func() { close(stream.die) }) +func (s *Stream) closeNoDelMap() { + s._close() } // the following functions are purely for implementing net.Conn interface. // they are not used var errNotImplemented = errors.New("Not implemented") -func (stream *Stream) LocalAddr() net.Addr { return stream.session.addrs.Load().([]net.Addr)[0] } -func (stream *Stream) RemoteAddr() net.Addr { return stream.session.addrs.Load().([]net.Addr)[1] } +func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } +func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } // TODO: implement the following -func (stream *Stream) SetDeadline(t time.Time) error { return errNotImplemented } -func (stream *Stream) SetReadDeadline(t time.Time) error { return errNotImplemented } -func (stream *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented } +func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented } +func (s *Stream) SetReadDeadline(t time.Time) error { return errNotImplemented } +func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go new file mode 100644 index 0000000..dd2e0f2 --- /dev/null +++ b/internal/multiplex/stream_test.go @@ -0,0 +1,103 @@ +package multiplex + +import ( + "bufio" + "github.com/cbeuw/Cloak/internal/util" + "io/ioutil" + "math/rand" + "net" + "testing" + "time" +) + +func setupSesh() *Session { + UID := make([]byte, 16) + rand.Read(UID) + tthKey := make([]byte, 32) + rand.Read(tthKey) + crypto := &Plain{} + obfs := MakeObfs(tthKey, crypto) + deobfs := MakeDeobfs(tthKey, crypto) + return MakeSession(0, UNLIMITED_VALVE, obfs, deobfs, util.ReadTLS) +} + +type blackhole struct { + hole *bufio.Writer +} + +func newBlackHole() *blackhole { return &blackhole{hole: bufio.NewWriter(ioutil.Discard)} } +func (b *blackhole) Read([]byte) (int, error) { + time.Sleep(1 * time.Hour) + return 0, nil +} +func (b *blackhole) Write(in []byte) (int, error) { return b.hole.Write(in) } +func (b *blackhole) Close() error { return nil } +func (b *blackhole) LocalAddr() net.Addr { + ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") + return ret +} +func (b *blackhole) RemoteAddr() net.Addr { + ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") + return ret +} +func (b *blackhole) SetDeadline(t time.Time) error { return nil } +func (b *blackhole) SetReadDeadline(t time.Time) error { return nil } +func (b *blackhole) SetWriteDeadline(t time.Time) error { return nil } + +const PAYLOAD_LEN = 1 << 20 * 100 + +func BenchmarkStream_Write(b *testing.B) { + hole := newBlackHole() + sesh := setupSesh() + sesh.AddConnection(hole) + testData := make([]byte, PAYLOAD_LEN) + rand.Read(testData) + + stream, _ := sesh.OpenStream() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := stream.Write(testData) + if err != nil { + b.Error( + "For", "stream write", + "got", err, + ) + } + b.SetBytes(PAYLOAD_LEN) + } +} + +/* +func BenchmarkStream_Write(b *testing.B) { + mc := mock_conn.NewConn() + go func(){ + w := bufio.NewWriter(ioutil.Discard) + for { + _, err := w.ReadFrom(mc.Server) + if err != nil { + log.Println(err) + return + } + } + }() + + sesh := setupSesh() + sesh.AddConnection(mc.Client) + testData := make([]byte,PAYLOAD_LEN) + rand.Read(testData) + + stream,_ := sesh.OpenStream() + b.ResetTimer() + for i:=0;i