From e9243a2e9fa55851cdba96295d0be0817bb6ab0a Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Fri, 10 Apr 2020 18:48:36 +0100 Subject: [PATCH] Framing in Stream.Write to prevent silent short write --- internal/common/copy.go | 22 ++++++------- internal/multiplex/session.go | 3 ++ internal/multiplex/stream.go | 60 ++++++++++++++++++----------------- 3 files changed, 45 insertions(+), 40 deletions(-) diff --git a/internal/common/copy.go b/internal/common/copy.go index 921d921..1d57b97 100644 --- a/internal/common/copy.go +++ b/internal/common/copy.go @@ -77,17 +77,17 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int } nr, er := src.Read(buf) if nr > 0 { - var offset int - for offset < nr { - nw, ew := dst.Write(buf[offset:nr]) - if nw > 0 { - written += int64(nw) - } - if ew != nil { - err = ew - break - } - offset += nw + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break } } if er != nil { diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index fd0096a..7fecfab 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -59,6 +59,8 @@ type Session struct { closed uint32 terminalMsg atomic.Value + + maxStreamUnitWrite int // the max size passed to Write calls before it splits it into multiple frames } func MakeSession(id uint32, config SessionConfig) *Session { @@ -82,6 +84,7 @@ func MakeSession(id uint32, config SessionConfig) *Session { if config.MaxFrameSize <= 0 { sesh.MaxFrameSize = defaultSendRecvBufSize - 1024 } + sesh.maxStreamUnitWrite = sesh.MaxFrameSize - HEADER_LEN - sesh.Obfuscator.minOverhead sbConfig := switchboardConfig{ valve: sesh.Valve, diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 14a7728..87f0f14 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -96,37 +96,39 @@ func (s *Stream) Write(in []byte) (n int, err error) { return 0, ErrBrokenStream } - var payload []byte - maxDataLen := s.session.MaxFrameSize - HEADER_LEN - s.session.minOverhead - if len(in) <= maxDataLen { - payload = in - } else { - //TODO: short write isn't the correct behaviour - payload = in[:maxDataLen] - } - - f := &Frame{ - StreamID: s.id, - Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, - Closing: C_NOOP, - Payload: payload, - } - - i, err := s.session.Obfs(f, s.obfsBuf) - if err != nil { - return i, err - } - n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) - log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err) - if err != nil { - if err == errBrokenSwitchboard { - s.session.SetTerminalMsg(err.Error()) - s.session.passiveClose() + for n < len(in) { + var framePayload []byte + if len(in)-n <= s.session.maxStreamUnitWrite { + framePayload = in[n:] + } else { + framePayload = in[n : s.session.maxStreamUnitWrite+n] } - return - } - return len(payload), nil + f := &Frame{ + StreamID: s.id, + Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, + Closing: C_NOOP, + Payload: framePayload, + } + + var cipherTextLen int + cipherTextLen, err = s.session.Obfs(f, s.obfsBuf) + if err != nil { + return 0, err + } + + _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) + log.Tracef("%v sent to remote through stream %v with err %v", len(framePayload), s.id, err) + if err != nil { + if err == errBrokenSwitchboard { + s.session.SetTerminalMsg(err.Error()) + s.session.passiveClose() + } + return + } + n += len(framePayload) + } + return } func (s *Stream) passiveClose() error {