mirror of https://github.com/cbeuw/Cloak
Implement WriteTo and ReadFrom timeouts
This commit is contained in:
parent
4a81683e44
commit
e202d8d03b
|
|
@ -128,13 +128,14 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu
|
|||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
stream.SetReadFromTimeout(streamTimeout) // if localConn hasn't sent anything to stream to a period of time, stream closes
|
||||
go func() {
|
||||
if _, err := common.Copy(localConn, stream, 0); err != nil {
|
||||
if _, err := common.Copy(localConn, stream); err != nil {
|
||||
log.Tracef("copying stream to proxy client: %v", err)
|
||||
}
|
||||
}()
|
||||
//util.Pipe(stream, localConn, localConfig.Timeout)
|
||||
if _, err = common.Copy(stream, localConn, streamTimeout); err != nil {
|
||||
if _, err = common.Copy(stream, localConn); err != nil {
|
||||
log.Tracef("copying proxy client to stream: %v", err)
|
||||
}
|
||||
}()
|
||||
|
|
|
|||
|
|
@ -35,12 +35,9 @@ package common
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// copyBuffer is the actual implementation of Copy and CopyBuffer.
|
||||
// if buf is nil, one is allocated.
|
||||
func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int64, err error) {
|
||||
func Copy(dst net.Conn, src net.Conn) (written int64, err error) {
|
||||
defer func() { src.Close(); dst.Close() }()
|
||||
|
||||
// If the reader has a WriteTo method, use it to do the copy.
|
||||
|
|
@ -56,13 +53,6 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int
|
|||
size := 32 * 1024
|
||||
buf := make([]byte, size)
|
||||
for {
|
||||
if srcReadTimeout != 0 {
|
||||
// TODO: don't rely on setreaddeadline
|
||||
err = src.SetReadDeadline(time.Now().Add(srcReadTimeout))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ type bufferedPipe struct {
|
|||
closed bool
|
||||
rwCond *sync.Cond
|
||||
rDeadline time.Time
|
||||
wtTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewBufferedPipe() *bufferedPipe {
|
||||
|
|
@ -74,7 +75,14 @@ func (p *bufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if d <= 0 {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
time.AfterFunc(d, p.rwCond.Broadcast)
|
||||
if p.wtTimeout == 0 {
|
||||
// if there hasn't been a scheduled broadcast
|
||||
time.AfterFunc(d, p.rwCond.Broadcast)
|
||||
}
|
||||
}
|
||||
if p.wtTimeout != 0 {
|
||||
p.rDeadline = time.Now().Add(p.wtTimeout)
|
||||
time.AfterFunc(p.wtTimeout, p.rwCond.Broadcast)
|
||||
}
|
||||
if p.buf.Len() > 0 {
|
||||
written, er := p.buf.WriteTo(w)
|
||||
|
|
@ -127,3 +135,11 @@ func (p *bufferedPipe) SetReadDeadline(t time.Time) {
|
|||
p.rDeadline = t
|
||||
p.rwCond.Broadcast()
|
||||
}
|
||||
|
||||
func (p *bufferedPipe) SetWriteToTimeout(d time.Duration) {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
|
||||
p.wtTimeout = d
|
||||
p.rwCond.Broadcast()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ type datagramBuffer struct {
|
|||
buf *bytes.Buffer
|
||||
closed bool
|
||||
rwCond *sync.Cond
|
||||
wtTimeout time.Duration
|
||||
rDeadline time.Time
|
||||
}
|
||||
|
||||
|
|
@ -72,13 +73,19 @@ func (d *datagramBuffer) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if d.closed && len(d.pLens) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if !d.rDeadline.IsZero() {
|
||||
delta := time.Until(d.rDeadline)
|
||||
if delta <= 0 {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
time.AfterFunc(delta, d.rwCond.Broadcast)
|
||||
if d.wtTimeout == 0 {
|
||||
// if there hasn't been a scheduled broadcast
|
||||
time.AfterFunc(delta, d.rwCond.Broadcast)
|
||||
}
|
||||
}
|
||||
if d.wtTimeout != 0 {
|
||||
d.rDeadline = time.Now().Add(d.wtTimeout)
|
||||
time.AfterFunc(d.wtTimeout, d.rwCond.Broadcast)
|
||||
}
|
||||
|
||||
if len(d.pLens) > 0 {
|
||||
|
|
@ -143,3 +150,11 @@ func (d *datagramBuffer) SetReadDeadline(t time.Time) {
|
|||
d.rDeadline = t
|
||||
d.rwCond.Broadcast()
|
||||
}
|
||||
|
||||
func (d *datagramBuffer) SetWriteToTimeout(t time.Duration) {
|
||||
d.rwCond.L.Lock()
|
||||
defer d.rwCond.L.Unlock()
|
||||
|
||||
d.wtTimeout = t
|
||||
d.rwCond.Broadcast()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,4 +11,5 @@ type recvBuffer interface {
|
|||
io.WriterTo
|
||||
Write(Frame) (toBeClosed bool, err error)
|
||||
SetReadDeadline(time time.Time)
|
||||
SetWriteToTimeout(d time.Duration)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ type Stream struct {
|
|||
// overall the streams in a session should be uniformly distributed across all connections
|
||||
// This is not used in unordered connection mode
|
||||
assignedConnId uint32
|
||||
|
||||
rfTimeout time.Duration
|
||||
}
|
||||
|
||||
func makeStream(sesh *Session, id uint32) *Stream {
|
||||
|
|
@ -152,6 +154,13 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
s.obfsBuf = make([]byte, s.session.SendBufferSize)
|
||||
}
|
||||
for {
|
||||
if s.rfTimeout != 0 {
|
||||
if rder, ok := r.(net.Conn); !ok {
|
||||
log.Warn("ReadFrom timeout is set but reader doesn't implement SetReadDeadline")
|
||||
} else {
|
||||
rder.SetReadDeadline(time.Now().Add(s.rfTimeout))
|
||||
}
|
||||
}
|
||||
read, er := r.Read(s.obfsBuf[HEADER_LEN : HEADER_LEN+s.session.maxStreamUnitWrite])
|
||||
if er != nil {
|
||||
return n, er
|
||||
|
|
@ -199,5 +208,7 @@ func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Ad
|
|||
|
||||
// TODO: implement the following
|
||||
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) }
|
||||
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
|
||||
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.rfTimeout = d }
|
||||
func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }
|
||||
|
|
|
|||
|
|
@ -106,4 +106,5 @@ func (sb *streamBuffer) Close() error {
|
|||
return sb.buf.Close()
|
||||
}
|
||||
|
||||
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }
|
||||
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }
|
||||
func (sb *streamBuffer) SetWriteToTimeout(d time.Duration) { sb.buf.SetWriteToTimeout(d) }
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -13,6 +14,8 @@ import (
|
|||
|
||||
const payloadLen = 1000
|
||||
|
||||
var emptyKey [32]byte
|
||||
|
||||
func setupSesh(unordered bool, key [32]byte, encryptionMethod byte) *Session {
|
||||
obfuscator, _ := MakeObfuscator(encryptionMethod, key)
|
||||
|
||||
|
|
@ -433,5 +436,53 @@ func TestStream_UnorderedRead(t *testing.T) {
|
|||
"got nil error")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestStream_SetWriteToTimeout(t *testing.T) {
|
||||
seshes := map[string]*Session{
|
||||
"ordered": setupSesh(false, emptyKey, E_METHOD_PLAIN),
|
||||
"unordered": setupSesh(true, emptyKey, E_METHOD_PLAIN),
|
||||
}
|
||||
for name, sesh := range seshes {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
stream, _ := sesh.OpenStream()
|
||||
stream.SetWriteToTimeout(100 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
stream.WriteTo(ioutil.Discard)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("didn't timeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_SetReadFromTimeout(t *testing.T) {
|
||||
seshes := map[string]*Session{
|
||||
"ordered": setupSesh(false, emptyKey, E_METHOD_PLAIN),
|
||||
"unordered": setupSesh(true, emptyKey, E_METHOD_PLAIN),
|
||||
}
|
||||
for name, sesh := range seshes {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
stream, _ := sesh.OpenStream()
|
||||
stream.SetReadFromTimeout(100 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
stream.ReadFrom(connutil.Discard())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("didn't timeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,13 +184,16 @@ func dispatchConnection(conn net.Conn, sta *State) {
|
|||
}
|
||||
log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod)
|
||||
|
||||
// if stream has nothing to send to proxy server for sta.Timeout period of time, stream will return error
|
||||
newStream.(*mux.Stream).SetWriteToTimeout(sta.Timeout)
|
||||
go func() {
|
||||
if _, err := common.Copy(localConn, newStream, sta.Timeout); err != nil {
|
||||
if _, err := common.Copy(localConn, newStream); err != nil {
|
||||
log.Tracef("copying stream to proxy server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := common.Copy(newStream, localConn, 0); err != nil {
|
||||
if _, err := common.Copy(newStream, localConn); err != nil {
|
||||
log.Tracef("copying proxy server to stream: %v", err)
|
||||
}
|
||||
}()
|
||||
|
|
|
|||
Loading…
Reference in New Issue