diff --git a/internal/client/piper.go b/internal/client/piper.go index 28e84aa..e610799 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -128,13 +128,14 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu stream.Close() return } + + stream.SetReadFromTimeout(streamTimeout) // if localConn hasn't sent anything to stream to a period of time, stream closes go func() { - if _, err := common.Copy(localConn, stream, 0); err != nil { + if _, err := common.Copy(localConn, stream); err != nil { log.Tracef("copying stream to proxy client: %v", err) } }() - //util.Pipe(stream, localConn, localConfig.Timeout) - if _, err = common.Copy(stream, localConn, streamTimeout); err != nil { + if _, err = common.Copy(stream, localConn); err != nil { log.Tracef("copying proxy client to stream: %v", err) } }() diff --git a/internal/common/copy.go b/internal/common/copy.go index e09e837..07ab8ae 100644 --- a/internal/common/copy.go +++ b/internal/common/copy.go @@ -35,12 +35,9 @@ package common import ( "io" "net" - "time" ) -// copyBuffer is the actual implementation of Copy and CopyBuffer. -// if buf is nil, one is allocated. -func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int64, err error) { +func Copy(dst net.Conn, src net.Conn) (written int64, err error) { defer func() { src.Close(); dst.Close() }() // If the reader has a WriteTo method, use it to do the copy. @@ -56,13 +53,6 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int size := 32 * 1024 buf := make([]byte, size) for { - if srcReadTimeout != 0 { - // TODO: don't rely on setreaddeadline - err = src.SetReadDeadline(time.Now().Add(srcReadTimeout)) - if err != nil { - break - } - } nr, er := src.Read(buf) if nr > 0 { nw, ew := dst.Write(buf[0:nr]) diff --git a/internal/multiplex/bufferedPipe.go b/internal/multiplex/bufferedPipe.go index 4d2c4eb..7652cc9 100644 --- a/internal/multiplex/bufferedPipe.go +++ b/internal/multiplex/bufferedPipe.go @@ -22,6 +22,7 @@ type bufferedPipe struct { closed bool rwCond *sync.Cond rDeadline time.Time + wtTimeout time.Duration } func NewBufferedPipe() *bufferedPipe { @@ -74,7 +75,14 @@ func (p *bufferedPipe) WriteTo(w io.Writer) (n int64, err error) { if d <= 0 { return 0, ErrTimeout } - time.AfterFunc(d, p.rwCond.Broadcast) + if p.wtTimeout == 0 { + // if there hasn't been a scheduled broadcast + time.AfterFunc(d, p.rwCond.Broadcast) + } + } + if p.wtTimeout != 0 { + p.rDeadline = time.Now().Add(p.wtTimeout) + time.AfterFunc(p.wtTimeout, p.rwCond.Broadcast) } if p.buf.Len() > 0 { written, er := p.buf.WriteTo(w) @@ -127,3 +135,11 @@ func (p *bufferedPipe) SetReadDeadline(t time.Time) { p.rDeadline = t p.rwCond.Broadcast() } + +func (p *bufferedPipe) SetWriteToTimeout(d time.Duration) { + p.rwCond.L.Lock() + defer p.rwCond.L.Unlock() + + p.wtTimeout = d + p.rwCond.Broadcast() +} diff --git a/internal/multiplex/datagramBuffer.go b/internal/multiplex/datagramBuffer.go index 3b8d784..c3ea771 100644 --- a/internal/multiplex/datagramBuffer.go +++ b/internal/multiplex/datagramBuffer.go @@ -17,6 +17,7 @@ type datagramBuffer struct { buf *bytes.Buffer closed bool rwCond *sync.Cond + wtTimeout time.Duration rDeadline time.Time } @@ -72,13 +73,19 @@ func (d *datagramBuffer) WriteTo(w io.Writer) (n int64, err error) { if d.closed && len(d.pLens) == 0 { return 0, io.EOF } - if !d.rDeadline.IsZero() { delta := time.Until(d.rDeadline) if delta <= 0 { return 0, ErrTimeout } - time.AfterFunc(delta, d.rwCond.Broadcast) + if d.wtTimeout == 0 { + // if there hasn't been a scheduled broadcast + time.AfterFunc(delta, d.rwCond.Broadcast) + } + } + if d.wtTimeout != 0 { + d.rDeadline = time.Now().Add(d.wtTimeout) + time.AfterFunc(d.wtTimeout, d.rwCond.Broadcast) } if len(d.pLens) > 0 { @@ -143,3 +150,11 @@ func (d *datagramBuffer) SetReadDeadline(t time.Time) { d.rDeadline = t d.rwCond.Broadcast() } + +func (d *datagramBuffer) SetWriteToTimeout(t time.Duration) { + d.rwCond.L.Lock() + defer d.rwCond.L.Unlock() + + d.wtTimeout = t + d.rwCond.Broadcast() +} diff --git a/internal/multiplex/recvBuffer.go b/internal/multiplex/recvBuffer.go index 9697867..414cae2 100644 --- a/internal/multiplex/recvBuffer.go +++ b/internal/multiplex/recvBuffer.go @@ -11,4 +11,5 @@ type recvBuffer interface { io.WriterTo Write(Frame) (toBeClosed bool, err error) SetReadDeadline(time time.Time) + SetWriteToTimeout(d time.Duration) } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 69dde45..253cb2b 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -36,6 +36,8 @@ type Stream struct { // overall the streams in a session should be uniformly distributed across all connections // This is not used in unordered connection mode assignedConnId uint32 + + rfTimeout time.Duration } func makeStream(sesh *Session, id uint32) *Stream { @@ -152,6 +154,13 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { s.obfsBuf = make([]byte, s.session.SendBufferSize) } for { + if s.rfTimeout != 0 { + if rder, ok := r.(net.Conn); !ok { + log.Warn("ReadFrom timeout is set but reader doesn't implement SetReadDeadline") + } else { + rder.SetReadDeadline(time.Now().Add(s.rfTimeout)) + } + } read, er := r.Read(s.obfsBuf[HEADER_LEN : HEADER_LEN+s.session.maxStreamUnitWrite]) if er != nil { return n, er @@ -199,5 +208,7 @@ func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Ad // TODO: implement the following func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented } +func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } +func (s *Stream) SetReadFromTimeout(d time.Duration) { s.rfTimeout = d } func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented } diff --git a/internal/multiplex/streamBuffer.go b/internal/multiplex/streamBuffer.go index 1004cc5..31c5937 100644 --- a/internal/multiplex/streamBuffer.go +++ b/internal/multiplex/streamBuffer.go @@ -106,4 +106,5 @@ func (sb *streamBuffer) Close() error { return sb.buf.Close() } -func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) } +func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) } +func (sb *streamBuffer) SetWriteToTimeout(d time.Duration) { sb.buf.SetWriteToTimeout(d) } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 77a1463..639ed91 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/cbeuw/Cloak/internal/common" "io" + "io/ioutil" "math/rand" "testing" "time" @@ -13,6 +14,8 @@ import ( const payloadLen = 1000 +var emptyKey [32]byte + func setupSesh(unordered bool, key [32]byte, encryptionMethod byte) *Session { obfuscator, _ := MakeObfuscator(encryptionMethod, key) @@ -433,5 +436,53 @@ func TestStream_UnorderedRead(t *testing.T) { "got nil error") } }) - +} + +func TestStream_SetWriteToTimeout(t *testing.T) { + seshes := map[string]*Session{ + "ordered": setupSesh(false, emptyKey, E_METHOD_PLAIN), + "unordered": setupSesh(true, emptyKey, E_METHOD_PLAIN), + } + for name, sesh := range seshes { + t.Run(name, func(t *testing.T) { + stream, _ := sesh.OpenStream() + stream.SetWriteToTimeout(100 * time.Millisecond) + done := make(chan struct{}) + go func() { + stream.WriteTo(ioutil.Discard) + done <- struct{}{} + }() + + select { + case <-done: + return + case <-time.After(500 * time.Millisecond): + t.Error("didn't timeout") + } + }) + } +} + +func TestStream_SetReadFromTimeout(t *testing.T) { + seshes := map[string]*Session{ + "ordered": setupSesh(false, emptyKey, E_METHOD_PLAIN), + "unordered": setupSesh(true, emptyKey, E_METHOD_PLAIN), + } + for name, sesh := range seshes { + t.Run(name, func(t *testing.T) { + stream, _ := sesh.OpenStream() + stream.SetReadFromTimeout(100 * time.Millisecond) + done := make(chan struct{}) + go func() { + stream.ReadFrom(connutil.Discard()) + done <- struct{}{} + }() + select { + case <-done: + return + case <-time.After(500 * time.Millisecond): + t.Error("didn't timeout") + } + }) + } } diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 4f17039..f3179b1 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -184,13 +184,16 @@ func dispatchConnection(conn net.Conn, sta *State) { } log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod) + // if stream has nothing to send to proxy server for sta.Timeout period of time, stream will return error + newStream.(*mux.Stream).SetWriteToTimeout(sta.Timeout) go func() { - if _, err := common.Copy(localConn, newStream, sta.Timeout); err != nil { + if _, err := common.Copy(localConn, newStream); err != nil { log.Tracef("copying stream to proxy server: %v", err) } }() + go func() { - if _, err := common.Copy(newStream, localConn, 0); err != nil { + if _, err := common.Copy(newStream, localConn); err != nil { log.Tracef("copying proxy server to stream: %v", err) } }()