Achieve zero allocation when writing data through stream

This commit is contained in:
Andy Wang 2020-12-22 13:16:48 +00:00
parent 3633c9a03c
commit 42f36b94d3
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
2 changed files with 22 additions and 30 deletions

View File

@ -165,6 +165,7 @@ func (sesh *Session) Accept() (net.Conn, error) {
} }
func (sesh *Session) closeStream(s *Stream, active bool) error { func (sesh *Session) closeStream(s *Stream, active bool) error {
// must be holding s.wirtingM on entry
if atomic.SwapUint32(&s.closed, 1) == 1 { if atomic.SwapUint32(&s.closed, 1) == 1 {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
} }
@ -173,16 +174,13 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
if active { if active {
// Notify remote that this stream is closed // Notify remote that this stream is closed
padding := genRandomPadding() padding := genRandomPadding()
f := &Frame{ s.writingFrame.Closing = closingStream
StreamID: s.id, s.writingFrame.Payload = padding
Seq: s.nextSendSeq,
Closing: closingStream,
Payload: padding,
}
s.nextSendSeq++
obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead) obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
i, err := sesh.Obfs(f, obfsBuf, 0)
i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0)
s.writingFrame.Seq++
if err != nil { if err != nil {
return err return err
} }
@ -190,7 +188,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
if err != nil { if err != nil {
return err return err
} }
log.Tracef("stream %v actively closed. seq %v", s.id, f.Seq) log.Tracef("stream %v actively closed.", s.id)
} else { } else {
log.Tracef("stream %v passively closed", s.id) log.Tracef("stream %v passively closed", s.id)
} }

View File

@ -28,7 +28,7 @@ type Stream struct {
recvBuf recvBuffer recvBuf recvBuffer
writingM sync.Mutex writingM sync.Mutex
nextSendSeq uint64 writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom
// atomic // atomic
closed uint32 closed uint32
@ -63,6 +63,11 @@ func makeStream(sesh *Session, id uint32) *Stream {
id: id, id: id,
session: sesh, session: sesh,
recvBuf: recvBuf, recvBuf: recvBuf,
writingFrame: Frame{
StreamID: id,
Seq: 0,
Closing: closingNothing,
},
} }
return stream return stream
@ -110,15 +115,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
return n, nil return n, nil
} }
func (s *Stream) obfuscateAndSend(f *Frame, payloadOffsetInObfsBuf int) error { func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
var cipherTextLen int var cipherTextLen int
cipherTextLen, err := s.session.Obfs(f, s.obfsBuf, payloadOffsetInObfsBuf) cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
if err != nil { if err != nil {
return err return err
} }
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v. seq: %v", len(f.Payload), s.id, err, f.Seq)
if err != nil { if err != nil {
if err == errBrokenSwitchboard { if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error()) s.session.SetTerminalMsg(err.Error())
@ -154,14 +158,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
} }
framePayload = in[n : s.session.maxStreamUnitWrite+n] framePayload = in[n : s.session.maxStreamUnitWrite+n]
} }
f := &Frame{ s.writingFrame.Payload = framePayload
StreamID: s.id, err = s.obfuscateAndSend(0)
Seq: s.nextSendSeq, s.writingFrame.Seq++
Closing: closingNothing,
Payload: framePayload,
}
s.nextSendSeq++
err = s.obfuscateAndSend(f, 0)
if err != nil { if err != nil {
return return
} }
@ -193,14 +192,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
} }
s.writingM.Lock() s.writingM.Lock()
f := &Frame{ s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read]
StreamID: s.id, err = s.obfuscateAndSend(frameHeaderLength)
Seq: s.nextSendSeq, s.writingFrame.Seq++
Closing: closingNothing,
Payload: s.obfsBuf[frameHeaderLength : frameHeaderLength+read],
}
s.nextSendSeq++
err = s.obfuscateAndSend(f, frameHeaderLength)
s.writingM.Unlock() s.writingM.Unlock()
if err != nil { if err != nil {