From 881f6e6f9d24bd307e41be5ad4926ece0755b38d Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Thu, 24 Dec 2020 11:35:29 +0000 Subject: [PATCH] Use sync.Pool for obfuscation buffer --- internal/multiplex/session.go | 24 +++++++++++++----------- internal/multiplex/stream.go | 32 ++++++++++++-------------------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 0a961f1..b918107 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -70,6 +70,8 @@ type Session struct { // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame recvFramePool sync.Pool + streamObfsBufPool sync.Pool + // Switchboard manages all connections to remote sb *switchboard @@ -117,6 +119,11 @@ func MakeSession(id uint32, config SessionConfig) *Session { // todo: validation. this must be smaller than StreamSendBufferSize sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead + sesh.streamObfsBufPool = sync.Pool{New: func() interface{} { + b := make([]byte, sesh.StreamSendBufferSize) + return &b + }} + sesh.sb = makeSwitchboard(sesh) time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout) return sesh @@ -180,25 +187,20 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { _ = s.getRecvBuf().Close() // recvBuf.Close should not return error if active { - tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead) + tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte) // Notify remote that this stream is closed - common.CryptoRandRead(tmpBuf[:1]) - padLen := int(tmpBuf[0]) + 1 - payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength] + common.CryptoRandRead((*tmpBuf)[:1]) + padLen := int((*tmpBuf)[0]) + 1 + payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength] common.CryptoRandRead(payload) // must be holding s.wirtingM on entry s.writingFrame.Closing = closingStream s.writingFrame.Payload = payload - cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength) - s.writingFrame.Seq++ - if err != nil { - return err - } - - _, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId) + err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength) + sesh.streamObfsBufPool.Put(tmpBuf) if err != nil { return err } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 84f106a..b29359f 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -34,11 +34,6 @@ type Stream struct { // atomic closed uint32 - // obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from - // the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste - // memory - obfsBuf []byte - // When we want order guarantee (i.e. session.Unordered is false), // we assign each stream a fixed underlying connection. // If the underlying connections the session uses provide ordering guarantee (most likely TCP), @@ -117,13 +112,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { return n, nil } -func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { - cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) +func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { + cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf) + s.writingFrame.Seq++ if err != nil { return err } - _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) + _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) @@ -142,9 +138,6 @@ func (s *Stream) Write(in []byte) (n int, err error) { return 0, ErrBrokenStream } - if s.obfsBuf == nil { - s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) - } for n < len(in) { var framePayload []byte if len(in)-n <= s.session.maxStreamUnitWrite { @@ -160,8 +153,9 @@ func (s *Stream) Write(in []byte) (n int, err error) { framePayload = in[n : s.session.maxStreamUnitWrite+n] } s.writingFrame.Payload = framePayload - err = s.obfuscateAndSend(0) - s.writingFrame.Seq++ + buf := s.session.streamObfsBufPool.Get().(*[]byte) + err = s.obfuscateAndSend(*buf, 0) + s.session.streamObfsBufPool.Put(buf) if err != nil { return } @@ -173,9 +167,6 @@ func (s *Stream) Write(in []byte) (n int, err error) { // ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // for readFromTimeout amount of time func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { - if s.obfsBuf == nil { - s.obfsBuf = make([]byte, s.session.StreamSendBufferSize) - } for { if s.readFromTimeout != 0 { if rder, ok := r.(net.Conn); !ok { @@ -184,7 +175,8 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { rder.SetReadDeadline(time.Now().Add(s.readFromTimeout)) } } - read, er := r.Read(s.obfsBuf[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite]) + buf := s.session.streamObfsBufPool.Get().(*[]byte) + read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite]) if er != nil { return n, er } @@ -196,10 +188,10 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { } s.writingM.Lock() - s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read] - err = s.obfuscateAndSend(frameHeaderLength) - s.writingFrame.Seq++ + s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read] + err = s.obfuscateAndSend(*buf, frameHeaderLength) s.writingM.Unlock() + s.session.streamObfsBufPool.Put(buf) if err != nil { return