From 42f36b94d3c5704b8f2a3a9b13d65366ac13e7e6 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 22 Dec 2020 13:16:48 +0000 Subject: [PATCH] Achieve zero allocation when writing data through stream --- internal/multiplex/session.go | 16 +++++++--------- internal/multiplex/stream.go | 36 +++++++++++++++-------------------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index c808e2b..a32b6bf 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -165,6 +165,7 @@ func (sesh *Session) Accept() (net.Conn, error) { } func (sesh *Session) closeStream(s *Stream, active bool) error { + // must be holding s.wirtingM on entry if atomic.SwapUint32(&s.closed, 1) == 1 { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } @@ -173,16 +174,13 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if active { // Notify remote that this stream is closed padding := genRandomPadding() - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingStream, - Payload: padding, - } - s.nextSendSeq++ + s.writingFrame.Closing = closingStream + s.writingFrame.Payload = padding obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead) - i, err := sesh.Obfs(f, obfsBuf, 0) + + i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0) + s.writingFrame.Seq++ if err != nil { return err } @@ -190,7 +188,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if err != nil { return err } - log.Tracef("stream %v actively closed. seq %v", s.id, f.Seq) + log.Tracef("stream %v actively closed.", s.id) } else { log.Tracef("stream %v passively closed", s.id) } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index beee2b8..d64628f 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -27,8 +27,8 @@ type Stream struct { // been read by the consumer through Read or WriteTo recvBuf recvBuffer - writingM sync.Mutex - nextSendSeq uint64 + writingM sync.Mutex + writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom // atomic closed uint32 @@ -63,6 +63,11 @@ func makeStream(sesh *Session, id uint32) *Stream { id: id, session: sesh, recvBuf: recvBuf, + writingFrame: Frame{ + StreamID: id, + Seq: 0, + Closing: closingNothing, + }, } return stream @@ -110,15 +115,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { return n, nil } -func (s *Stream) obfuscateAndSend(f *Frame, payloadOffsetInObfsBuf int) error { +func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { var cipherTextLen int - cipherTextLen, err := s.session.Obfs(f, s.obfsBuf, payloadOffsetInObfsBuf) + cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) if err != nil { return err } _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) - log.Tracef("%v sent to remote through stream %v with err %v. seq: %v", len(f.Payload), s.id, err, f.Seq) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) @@ -154,14 +158,9 @@ func (s *Stream) Write(in []byte) (n int, err error) { } framePayload = in[n : s.session.maxStreamUnitWrite+n] } - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingNothing, - Payload: framePayload, - } - s.nextSendSeq++ - err = s.obfuscateAndSend(f, 0) + s.writingFrame.Payload = framePayload + err = s.obfuscateAndSend(0) + s.writingFrame.Seq++ if err != nil { return } @@ -193,14 +192,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { } s.writingM.Lock() - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingNothing, - Payload: s.obfsBuf[frameHeaderLength : frameHeaderLength+read], - } - s.nextSendSeq++ - err = s.obfuscateAndSend(f, frameHeaderLength) + s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read] + err = s.obfuscateAndSend(frameHeaderLength) + s.writingFrame.Seq++ s.writingM.Unlock() if err != nil {