mirror of https://github.com/cbeuw/Cloak
Use a buffered pipe to buffer sorted data
This commit is contained in:
parent
38f3a4a522
commit
0e08683828
|
|
@ -0,0 +1,78 @@
|
|||
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
|
||||
|
||||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const BUF_SIZE_LIMIT = 1 << 20 * 500
|
||||
|
||||
type bufferedPipe struct {
|
||||
buf *bytes.Buffer
|
||||
closed bool
|
||||
rwCond *sync.Cond
|
||||
}
|
||||
|
||||
func NewBufferedPipe() *bufferedPipe {
|
||||
p := &bufferedPipe{
|
||||
buf: new(bytes.Buffer),
|
||||
rwCond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *bufferedPipe) Read(target []byte) (int, error) {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
for {
|
||||
if p.closed && p.buf.Len() == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if p.buf.Len() > 0 {
|
||||
break
|
||||
}
|
||||
p.rwCond.Wait()
|
||||
}
|
||||
n, err := p.buf.Read(target)
|
||||
// err will always be nil because we have already verified that buf.Len() != 0
|
||||
p.rwCond.Broadcast()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (p *bufferedPipe) Write(input []byte) (int, error) {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
for {
|
||||
if p.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if p.buf.Len() <= BUF_SIZE_LIMIT {
|
||||
// if p.buf gets too large, write() will panic. We don't want this to happen
|
||||
break
|
||||
}
|
||||
p.rwCond.Wait()
|
||||
}
|
||||
n, err := p.buf.Write(input)
|
||||
// err will always be nil
|
||||
p.rwCond.Broadcast()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (p *bufferedPipe) Close() error {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
|
||||
p.closed = true
|
||||
p.rwCond.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bufferedPipe) Len() int {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
return p.buf.Len()
|
||||
}
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPipeRW(t *testing.T) {
|
||||
pipe := NewBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
n, err := pipe.Write(b)
|
||||
if n != len(b) {
|
||||
t.Error(
|
||||
"For", "number of bytes written",
|
||||
"expecting", len(b),
|
||||
"got", n,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"For", "simple write",
|
||||
"expecting", "nil error",
|
||||
"got", err,
|
||||
)
|
||||
}
|
||||
|
||||
b2 := make([]byte, len(b))
|
||||
n, err = pipe.Read(b2)
|
||||
if n != len(b) {
|
||||
t.Error(
|
||||
"For", "number of bytes read",
|
||||
"expecting", len(b),
|
||||
"got", n,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"For", "simple read",
|
||||
"expecting", "nil error",
|
||||
"got", err,
|
||||
)
|
||||
}
|
||||
if !bytes.Equal(b, b2) {
|
||||
t.Error(
|
||||
"For", "simple read",
|
||||
"expecting", b,
|
||||
"got", b2,
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestReadBlock(t *testing.T) {
|
||||
pipe := NewBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
pipe.Write(b)
|
||||
}()
|
||||
b2 := make([]byte, len(b))
|
||||
n, err := pipe.Read(b2)
|
||||
if n != len(b) {
|
||||
t.Error(
|
||||
"For", "number of bytes read after block",
|
||||
"expecting", len(b),
|
||||
"got", n,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"For", "blocked read",
|
||||
"expecting", "nil error",
|
||||
"got", err,
|
||||
)
|
||||
}
|
||||
if !bytes.Equal(b, b2) {
|
||||
t.Error(
|
||||
"For", "blocked read",
|
||||
"expecting", b,
|
||||
"got", b2,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartialRead(t *testing.T) {
|
||||
pipe := NewBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
pipe.Write(b)
|
||||
b1 := make([]byte, 1)
|
||||
n, err := pipe.Read(b1)
|
||||
if n != len(b1) {
|
||||
t.Error(
|
||||
"For", "number of bytes in partial read of 1",
|
||||
"expecting", len(b1),
|
||||
"got", n,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"For", "partial read of 1",
|
||||
"expecting", "nil error",
|
||||
"got", err,
|
||||
)
|
||||
}
|
||||
if b1[0] != b[0] {
|
||||
t.Error(
|
||||
"For", "partial read of 1",
|
||||
"expecting", b[0],
|
||||
"got", b1[0],
|
||||
)
|
||||
}
|
||||
b2 := make([]byte, 2)
|
||||
n, err = pipe.Read(b2)
|
||||
if n != len(b2) {
|
||||
t.Error(
|
||||
"For", "number of bytes in partial read of 2",
|
||||
"expecting", len(b2),
|
||||
"got", n,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"For", "partial read of 2",
|
||||
"expecting", "nil error",
|
||||
"got", err,
|
||||
)
|
||||
}
|
||||
if !bytes.Equal(b[1:], b2) {
|
||||
t.Error(
|
||||
"For", "partial read of 2",
|
||||
"expecting", b[1:],
|
||||
"got", b2,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadAfterClose(t *testing.T) {
|
||||
pipe := NewBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
pipe.Write(b)
|
||||
b2 := make([]byte, len(b))
|
||||
pipe.Close()
|
||||
n, err := pipe.Read(b2)
|
||||
if n != len(b) {
|
||||
t.Error(
|
||||
"For", "number of bytes read",
|
||||
"expecting", len(b),
|
||||
"got", n,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"For", "simple read",
|
||||
"expecting", "nil error",
|
||||
"got", err,
|
||||
)
|
||||
}
|
||||
if !bytes.Equal(b, b2) {
|
||||
t.Error(
|
||||
"For", "simple read",
|
||||
"expecting", b,
|
||||
"got", b2,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -73,10 +73,10 @@ func (s *Stream) recvNewFrame() {
|
|||
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
|
||||
if f.Closing == 1 {
|
||||
// empty data indicates closing signal
|
||||
s.sortedBufCh <- []byte{}
|
||||
s.passiveClose()
|
||||
return
|
||||
} else {
|
||||
s.sortedBufCh <- f.Payload
|
||||
s.sortedBuf.Write(f.Payload)
|
||||
s.nextRecvSeq += 1
|
||||
if s.nextRecvSeq == 0 { // getting wrapped
|
||||
s.rev += 1
|
||||
|
|
@ -115,10 +115,10 @@ func (s *Stream) recvNewFrame() {
|
|||
f = heap.Pop(&s.sh).(*frameNode).frame
|
||||
if f.Closing == 1 {
|
||||
// empty data indicates closing signal
|
||||
s.sortedBufCh <- []byte{}
|
||||
s.passiveClose()
|
||||
return
|
||||
} else {
|
||||
s.sortedBufCh <- f.Payload
|
||||
s.sortedBuf.Write(f.Payload)
|
||||
s.nextRecvSeq += 1
|
||||
if s.nextRecvSeq == 0 { // getting wrapped
|
||||
s.rev += 1
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ package multiplex
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
|
@ -29,10 +27,8 @@ type Stream struct {
|
|||
|
||||
// New frames are received through newFrameCh by frameSorter
|
||||
newFrameCh chan *Frame
|
||||
// sortedBufCh are order-sorted data ready to be read raw
|
||||
sortedBufCh chan []byte
|
||||
feederR *io.PipeReader
|
||||
feederW *io.PipeWriter
|
||||
|
||||
sortedBuf *bufferedPipe
|
||||
|
||||
// atomic
|
||||
nextSendSeq uint32
|
||||
|
|
@ -45,45 +41,18 @@ type Stream struct {
|
|||
}
|
||||
|
||||
func makeStream(id uint32, sesh *Session) *Stream {
|
||||
r, w := io.Pipe()
|
||||
stream := &Stream{
|
||||
id: id,
|
||||
session: sesh,
|
||||
die: make(chan struct{}),
|
||||
sh: []*frameNode{},
|
||||
newFrameCh: make(chan *Frame, 1024),
|
||||
sortedBufCh: make(chan []byte, 1024),
|
||||
feederR: r,
|
||||
feederW: w,
|
||||
id: id,
|
||||
session: sesh,
|
||||
die: make(chan struct{}),
|
||||
sh: []*frameNode{},
|
||||
newFrameCh: make(chan *Frame, 1024),
|
||||
sortedBuf: NewBufferedPipe(),
|
||||
}
|
||||
go stream.recvNewFrame()
|
||||
go stream.feed()
|
||||
return stream
|
||||
}
|
||||
|
||||
func (stream *Stream) feed() {
|
||||
for {
|
||||
select {
|
||||
case <-stream.die:
|
||||
return
|
||||
case data := <-stream.sortedBufCh:
|
||||
if len(data) == 0 {
|
||||
stream.passiveClose()
|
||||
return
|
||||
}
|
||||
_, err := stream.feederW.Write(data)
|
||||
if err != nil {
|
||||
if err == io.ErrClosedPipe {
|
||||
stream.Close()
|
||||
return
|
||||
} else {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *Stream) Read(buf []byte) (n int, err error) {
|
||||
if len(buf) == 0 {
|
||||
select {
|
||||
|
|
@ -95,9 +64,13 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
|||
}
|
||||
select {
|
||||
case <-stream.die:
|
||||
return 0, ErrBrokenStream
|
||||
if stream.sortedBuf.Len() == 0 {
|
||||
return 0, ErrBrokenStream
|
||||
} else {
|
||||
return stream.sortedBuf.Read(buf)
|
||||
}
|
||||
default:
|
||||
return stream.feederR.Read(buf)
|
||||
return stream.sortedBuf.Read(buf)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -168,9 +141,8 @@ func (stream *Stream) Close() error {
|
|||
tlsRecord, _ := stream.session.obfs(f)
|
||||
stream.session.sb.send(tlsRecord)
|
||||
|
||||
stream.sortedBuf.Close()
|
||||
stream.session.delStream(stream.id)
|
||||
stream.feederW.Close()
|
||||
stream.feederR.Close()
|
||||
//log.Printf("%v actively closed\n", stream.id)
|
||||
stream.writingM.Unlock()
|
||||
return nil
|
||||
|
|
|
|||
Loading…
Reference in New Issue