Implement stream SetReadDeadline

This commit is contained in:
Andy Wang 2020-04-09 18:56:17 +01:00
parent e41394c83c
commit 86214a1df0
8 changed files with 104 additions and 21 deletions

View File

@ -69,14 +69,10 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int
//}
for {
if srcReadTimeout != 0 {
src.SetReadDeadline(time.Now().Add(srcReadTimeout))
/*
err =
if err != nil {
break
}
*/
err = src.SetReadDeadline(time.Now().Add(srcReadTimeout))
if err != nil {
break
}
}
nr, er := src.Read(buf)
if nr > 0 {

View File

@ -4,18 +4,23 @@ package multiplex
import (
"bytes"
"errors"
"io"
"sync"
"sync/atomic"
"time"
)
const BUF_SIZE_LIMIT = 1 << 20 * 500
var ErrTimeout = errors.New("deadline exceeded")
// The point of a bufferedPipe is that Read() will block until data is available
type bufferedPipe struct {
buf *bytes.Buffer
closed uint32
rwCond *sync.Cond
buf *bytes.Buffer
closed uint32
rwCond *sync.Cond
rDeadline time.Time
}
func NewBufferedPipe() *bufferedPipe {
@ -33,7 +38,13 @@ func (p *bufferedPipe) Read(target []byte) (int, error) {
if atomic.LoadUint32(&p.closed) == 1 && p.buf.Len() == 0 {
return 0, io.EOF
}
if !p.rDeadline.IsZero() {
d := time.Until(p.rDeadline)
if d <= 0 {
return 0, ErrTimeout
}
time.AfterFunc(d, p.rwCond.Broadcast)
}
if p.buf.Len() > 0 {
break
}
@ -75,3 +86,11 @@ func (p *bufferedPipe) Len() int {
defer p.rwCond.L.Unlock()
return p.buf.Len()
}
func (p *bufferedPipe) SetReadDeadline(t time.Time) {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
p.rDeadline = t
p.rwCond.Broadcast()
}

View File

@ -6,6 +6,7 @@ import (
"io"
"sync"
"sync/atomic"
"time"
)
const DATAGRAM_NUMBER_LIMIT = 1024
@ -14,9 +15,10 @@ const DATAGRAM_NUMBER_LIMIT = 1024
// instead of byte-oriented. The integrity of datagrams written into this buffer is preserved.
// it won't get chopped up into individual bytes
type datagramBuffer struct {
buf [][]byte
closed uint32
rwCond *sync.Cond
buf [][]byte
closed uint32
rwCond *sync.Cond
rDeadline time.Time
}
func NewDatagramBuffer() *datagramBuffer {
@ -35,6 +37,14 @@ func (d *datagramBuffer) Read(target []byte) (int, error) {
return 0, io.EOF
}
if !d.rDeadline.IsZero() {
delta := time.Until(d.rDeadline)
if delta <= 0 {
return 0, ErrTimeout
}
time.AfterFunc(delta, d.rwCond.Broadcast)
}
if len(d.buf) > 0 {
break
}
@ -84,3 +94,11 @@ func (d *datagramBuffer) Close() error {
d.rwCond.Broadcast()
return nil
}
func (d *datagramBuffer) SetReadDeadline(t time.Time) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.rDeadline = t
d.rwCond.Broadcast()
}

View File

@ -1,9 +1,13 @@
package multiplex
import "io"
import (
"io"
"time"
)
type recvBuffer interface {
// Read calls' err must be nil | io.EOF | io.ErrShortBuffer
io.ReadCloser
Write(Frame) (toBeClosed bool, err error)
SetReadDeadline(time time.Time)
}

View File

@ -2,10 +2,12 @@ package multiplex
import (
"bytes"
"github.com/cbeuw/connutil"
"math/rand"
"strconv"
"sync/atomic"
"testing"
"time"
)
var seshConfigOrdered = SessionConfig{
@ -398,6 +400,50 @@ func TestParallel(t *testing.T) {
}
}
func TestStream_SetReadDeadline(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(E_METHOD_PLAIN, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
testReadDeadline := func(sesh *Session) {
t.Run("read after deadline set", func(t *testing.T) {
stream, _ := sesh.OpenStream()
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
_, err := stream.Read(make([]byte, 1))
if err != ErrTimeout {
t.Errorf("expecting error %v, got %v", ErrTimeout, err)
}
})
t.Run("unblock when deadline passed", func(t *testing.T) {
stream, _ := sesh.OpenStream()
done := make(chan struct{})
go func() {
_, _ = stream.Read(make([]byte, 1))
done <- struct{}{}
}()
_ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
select {
case <-done:
return
case <-time.After(500 * time.Millisecond):
t.Error("Read did not unblock after deadline has passed")
}
})
}
sesh := MakeSession(0, seshConfigOrdered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
sesh = MakeSession(0, seshConfigUnordered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
}
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
testPayloadLen := 1024
testPayload := make([]byte, testPayloadLen)

View File

@ -146,5 +146,5 @@ func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Ad
// TODO: implement the following
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetReadDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }

View File

@ -7,13 +7,14 @@ package multiplex
// remote side before packet0. Cloak have to therefore sequence the packets so that they
// arrive in order as they were sent by the proxy software
//
// Cloak packets will have a 32-bit sequence number on them, so we know in which order
// Cloak packets will have a 64-bit sequence number on them, so we know in which order
// they should be sent to the proxy software. The code in this file provides buffering and sorting.
import (
"container/heap"
"fmt"
"sync"
"time"
)
type sorterHeap []*Frame
@ -57,8 +58,6 @@ func NewStreamBuffer() *streamBuffer {
return sb
}
// recvNewFrame is a forever running loop which receives frames unordered,
// cache and order them and send them into sortedBufCh
func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) {
sb.recvM.Lock()
defer sb.recvM.Unlock()
@ -100,3 +99,5 @@ func (sb *streamBuffer) Read(buf []byte) (int, error) {
func (sb *streamBuffer) Close() error {
return sb.buf.Close()
}
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }

View File

@ -173,7 +173,6 @@ func DispatchConnection(conn net.Conn, sta *State) {
}
log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod)
//TODO: stream timeout
go func() {
if _, err := common.Copy(localConn, newStream, sta.Timeout); err != nil {
log.Debugf("copying stream to proxy client: %v", err)