Framing in Stream.Write to prevent silent short write

This commit is contained in:
Andy Wang 2020-04-10 18:48:36 +01:00
parent 17d57d9369
commit e9243a2e9f
3 changed files with 45 additions and 40 deletions

View File

@ -77,9 +77,7 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int
} }
nr, er := src.Read(buf) nr, er := src.Read(buf)
if nr > 0 { if nr > 0 {
var offset int nw, ew := dst.Write(buf[0:nr])
for offset < nr {
nw, ew := dst.Write(buf[offset:nr])
if nw > 0 { if nw > 0 {
written += int64(nw) written += int64(nw)
} }
@ -87,7 +85,9 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int
err = ew err = ew
break break
} }
offset += nw if nr != nw {
err = io.ErrShortWrite
break
} }
} }
if er != nil { if er != nil {

View File

@ -59,6 +59,8 @@ type Session struct {
closed uint32 closed uint32
terminalMsg atomic.Value terminalMsg atomic.Value
maxStreamUnitWrite int // the max size passed to Write calls before it splits it into multiple frames
} }
func MakeSession(id uint32, config SessionConfig) *Session { func MakeSession(id uint32, config SessionConfig) *Session {
@ -82,6 +84,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
if config.MaxFrameSize <= 0 { if config.MaxFrameSize <= 0 {
sesh.MaxFrameSize = defaultSendRecvBufSize - 1024 sesh.MaxFrameSize = defaultSendRecvBufSize - 1024
} }
sesh.maxStreamUnitWrite = sesh.MaxFrameSize - HEADER_LEN - sesh.Obfuscator.minOverhead
sbConfig := switchboardConfig{ sbConfig := switchboardConfig{
valve: sesh.Valve, valve: sesh.Valve,

View File

@ -96,28 +96,29 @@ func (s *Stream) Write(in []byte) (n int, err error) {
return 0, ErrBrokenStream return 0, ErrBrokenStream
} }
var payload []byte for n < len(in) {
maxDataLen := s.session.MaxFrameSize - HEADER_LEN - s.session.minOverhead var framePayload []byte
if len(in) <= maxDataLen { if len(in)-n <= s.session.maxStreamUnitWrite {
payload = in framePayload = in[n:]
} else { } else {
//TODO: short write isn't the correct behaviour framePayload = in[n : s.session.maxStreamUnitWrite+n]
payload = in[:maxDataLen]
} }
f := &Frame{ f := &Frame{
StreamID: s.id, StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: C_NOOP, Closing: C_NOOP,
Payload: payload, Payload: framePayload,
} }
i, err := s.session.Obfs(f, s.obfsBuf) var cipherTextLen int
cipherTextLen, err = s.session.Obfs(f, s.obfsBuf)
if err != nil { if err != nil {
return i, err return 0, err
} }
n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err) _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v", len(framePayload), s.id, err)
if err != nil { if err != nil {
if err == errBrokenSwitchboard { if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error()) s.session.SetTerminalMsg(err.Error())
@ -125,8 +126,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
} }
return return
} }
return len(payload), nil n += len(framePayload)
}
return
} }
func (s *Stream) passiveClose() error { func (s *Stream) passiveClose() error {