mirror of https://github.com/cbeuw/Cloak
Use a buffered pipe to buffer sorted data
This commit is contained in:
parent
38f3a4a522
commit
0e08683828
|
|
@ -0,0 +1,78 @@
|
||||||
|
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
|
||||||
|
|
||||||
|
package multiplex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const BUF_SIZE_LIMIT = 1 << 20 * 500
|
||||||
|
|
||||||
|
type bufferedPipe struct {
|
||||||
|
buf *bytes.Buffer
|
||||||
|
closed bool
|
||||||
|
rwCond *sync.Cond
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBufferedPipe() *bufferedPipe {
|
||||||
|
p := &bufferedPipe{
|
||||||
|
buf: new(bytes.Buffer),
|
||||||
|
rwCond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bufferedPipe) Read(target []byte) (int, error) {
|
||||||
|
p.rwCond.L.Lock()
|
||||||
|
defer p.rwCond.L.Unlock()
|
||||||
|
for {
|
||||||
|
if p.closed && p.buf.Len() == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.buf.Len() > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
p.rwCond.Wait()
|
||||||
|
}
|
||||||
|
n, err := p.buf.Read(target)
|
||||||
|
// err will always be nil because we have already verified that buf.Len() != 0
|
||||||
|
p.rwCond.Broadcast()
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bufferedPipe) Write(input []byte) (int, error) {
|
||||||
|
p.rwCond.L.Lock()
|
||||||
|
defer p.rwCond.L.Unlock()
|
||||||
|
for {
|
||||||
|
if p.closed {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if p.buf.Len() <= BUF_SIZE_LIMIT {
|
||||||
|
// if p.buf gets too large, write() will panic. We don't want this to happen
|
||||||
|
break
|
||||||
|
}
|
||||||
|
p.rwCond.Wait()
|
||||||
|
}
|
||||||
|
n, err := p.buf.Write(input)
|
||||||
|
// err will always be nil
|
||||||
|
p.rwCond.Broadcast()
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bufferedPipe) Close() error {
|
||||||
|
p.rwCond.L.Lock()
|
||||||
|
defer p.rwCond.L.Unlock()
|
||||||
|
|
||||||
|
p.closed = true
|
||||||
|
p.rwCond.Broadcast()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bufferedPipe) Len() int {
|
||||||
|
p.rwCond.L.Lock()
|
||||||
|
defer p.rwCond.L.Unlock()
|
||||||
|
return p.buf.Len()
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,166 @@
|
||||||
|
package multiplex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPipeRW(t *testing.T) {
|
||||||
|
pipe := NewBufferedPipe()
|
||||||
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
|
n, err := pipe.Write(b)
|
||||||
|
if n != len(b) {
|
||||||
|
t.Error(
|
||||||
|
"For", "number of bytes written",
|
||||||
|
"expecting", len(b),
|
||||||
|
"got", n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Error(
|
||||||
|
"For", "simple write",
|
||||||
|
"expecting", "nil error",
|
||||||
|
"got", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
b2 := make([]byte, len(b))
|
||||||
|
n, err = pipe.Read(b2)
|
||||||
|
if n != len(b) {
|
||||||
|
t.Error(
|
||||||
|
"For", "number of bytes read",
|
||||||
|
"expecting", len(b),
|
||||||
|
"got", n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Error(
|
||||||
|
"For", "simple read",
|
||||||
|
"expecting", "nil error",
|
||||||
|
"got", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, b2) {
|
||||||
|
t.Error(
|
||||||
|
"For", "simple read",
|
||||||
|
"expecting", b,
|
||||||
|
"got", b2,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadBlock(t *testing.T) {
|
||||||
|
pipe := NewBufferedPipe()
|
||||||
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
|
go func() {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
pipe.Write(b)
|
||||||
|
}()
|
||||||
|
b2 := make([]byte, len(b))
|
||||||
|
n, err := pipe.Read(b2)
|
||||||
|
if n != len(b) {
|
||||||
|
t.Error(
|
||||||
|
"For", "number of bytes read after block",
|
||||||
|
"expecting", len(b),
|
||||||
|
"got", n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Error(
|
||||||
|
"For", "blocked read",
|
||||||
|
"expecting", "nil error",
|
||||||
|
"got", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, b2) {
|
||||||
|
t.Error(
|
||||||
|
"For", "blocked read",
|
||||||
|
"expecting", b,
|
||||||
|
"got", b2,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPartialRead(t *testing.T) {
|
||||||
|
pipe := NewBufferedPipe()
|
||||||
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
|
pipe.Write(b)
|
||||||
|
b1 := make([]byte, 1)
|
||||||
|
n, err := pipe.Read(b1)
|
||||||
|
if n != len(b1) {
|
||||||
|
t.Error(
|
||||||
|
"For", "number of bytes in partial read of 1",
|
||||||
|
"expecting", len(b1),
|
||||||
|
"got", n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Error(
|
||||||
|
"For", "partial read of 1",
|
||||||
|
"expecting", "nil error",
|
||||||
|
"got", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if b1[0] != b[0] {
|
||||||
|
t.Error(
|
||||||
|
"For", "partial read of 1",
|
||||||
|
"expecting", b[0],
|
||||||
|
"got", b1[0],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
b2 := make([]byte, 2)
|
||||||
|
n, err = pipe.Read(b2)
|
||||||
|
if n != len(b2) {
|
||||||
|
t.Error(
|
||||||
|
"For", "number of bytes in partial read of 2",
|
||||||
|
"expecting", len(b2),
|
||||||
|
"got", n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Error(
|
||||||
|
"For", "partial read of 2",
|
||||||
|
"expecting", "nil error",
|
||||||
|
"got", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b[1:], b2) {
|
||||||
|
t.Error(
|
||||||
|
"For", "partial read of 2",
|
||||||
|
"expecting", b[1:],
|
||||||
|
"got", b2,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadAfterClose(t *testing.T) {
|
||||||
|
pipe := NewBufferedPipe()
|
||||||
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
|
pipe.Write(b)
|
||||||
|
b2 := make([]byte, len(b))
|
||||||
|
pipe.Close()
|
||||||
|
n, err := pipe.Read(b2)
|
||||||
|
if n != len(b) {
|
||||||
|
t.Error(
|
||||||
|
"For", "number of bytes read",
|
||||||
|
"expecting", len(b),
|
||||||
|
"got", n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Error(
|
||||||
|
"For", "simple read",
|
||||||
|
"expecting", "nil error",
|
||||||
|
"got", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, b2) {
|
||||||
|
t.Error(
|
||||||
|
"For", "simple read",
|
||||||
|
"expecting", b,
|
||||||
|
"got", b2,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -73,10 +73,10 @@ func (s *Stream) recvNewFrame() {
|
||||||
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
|
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
|
||||||
if f.Closing == 1 {
|
if f.Closing == 1 {
|
||||||
// empty data indicates closing signal
|
// empty data indicates closing signal
|
||||||
s.sortedBufCh <- []byte{}
|
s.passiveClose()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
s.sortedBufCh <- f.Payload
|
s.sortedBuf.Write(f.Payload)
|
||||||
s.nextRecvSeq += 1
|
s.nextRecvSeq += 1
|
||||||
if s.nextRecvSeq == 0 { // getting wrapped
|
if s.nextRecvSeq == 0 { // getting wrapped
|
||||||
s.rev += 1
|
s.rev += 1
|
||||||
|
|
@ -115,10 +115,10 @@ func (s *Stream) recvNewFrame() {
|
||||||
f = heap.Pop(&s.sh).(*frameNode).frame
|
f = heap.Pop(&s.sh).(*frameNode).frame
|
||||||
if f.Closing == 1 {
|
if f.Closing == 1 {
|
||||||
// empty data indicates closing signal
|
// empty data indicates closing signal
|
||||||
s.sortedBufCh <- []byte{}
|
s.passiveClose()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
s.sortedBufCh <- f.Payload
|
s.sortedBuf.Write(f.Payload)
|
||||||
s.nextRecvSeq += 1
|
s.nextRecvSeq += 1
|
||||||
if s.nextRecvSeq == 0 { // getting wrapped
|
if s.nextRecvSeq == 0 { // getting wrapped
|
||||||
s.rev += 1
|
s.rev += 1
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,6 @@ package multiplex
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -29,10 +27,8 @@ type Stream struct {
|
||||||
|
|
||||||
// New frames are received through newFrameCh by frameSorter
|
// New frames are received through newFrameCh by frameSorter
|
||||||
newFrameCh chan *Frame
|
newFrameCh chan *Frame
|
||||||
// sortedBufCh are order-sorted data ready to be read raw
|
|
||||||
sortedBufCh chan []byte
|
sortedBuf *bufferedPipe
|
||||||
feederR *io.PipeReader
|
|
||||||
feederW *io.PipeWriter
|
|
||||||
|
|
||||||
// atomic
|
// atomic
|
||||||
nextSendSeq uint32
|
nextSendSeq uint32
|
||||||
|
|
@ -45,45 +41,18 @@ type Stream struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeStream(id uint32, sesh *Session) *Stream {
|
func makeStream(id uint32, sesh *Session) *Stream {
|
||||||
r, w := io.Pipe()
|
|
||||||
stream := &Stream{
|
stream := &Stream{
|
||||||
id: id,
|
id: id,
|
||||||
session: sesh,
|
session: sesh,
|
||||||
die: make(chan struct{}),
|
die: make(chan struct{}),
|
||||||
sh: []*frameNode{},
|
sh: []*frameNode{},
|
||||||
newFrameCh: make(chan *Frame, 1024),
|
newFrameCh: make(chan *Frame, 1024),
|
||||||
sortedBufCh: make(chan []byte, 1024),
|
sortedBuf: NewBufferedPipe(),
|
||||||
feederR: r,
|
|
||||||
feederW: w,
|
|
||||||
}
|
}
|
||||||
go stream.recvNewFrame()
|
go stream.recvNewFrame()
|
||||||
go stream.feed()
|
|
||||||
return stream
|
return stream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *Stream) feed() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stream.die:
|
|
||||||
return
|
|
||||||
case data := <-stream.sortedBufCh:
|
|
||||||
if len(data) == 0 {
|
|
||||||
stream.passiveClose()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err := stream.feederW.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
if err == io.ErrClosedPipe {
|
|
||||||
stream.Close()
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (stream *Stream) Read(buf []byte) (n int, err error) {
|
func (stream *Stream) Read(buf []byte) (n int, err error) {
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
select {
|
select {
|
||||||
|
|
@ -95,9 +64,13 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-stream.die:
|
case <-stream.die:
|
||||||
|
if stream.sortedBuf.Len() == 0 {
|
||||||
return 0, ErrBrokenStream
|
return 0, ErrBrokenStream
|
||||||
|
} else {
|
||||||
|
return stream.sortedBuf.Read(buf)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return stream.feederR.Read(buf)
|
return stream.sortedBuf.Read(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -168,9 +141,8 @@ func (stream *Stream) Close() error {
|
||||||
tlsRecord, _ := stream.session.obfs(f)
|
tlsRecord, _ := stream.session.obfs(f)
|
||||||
stream.session.sb.send(tlsRecord)
|
stream.session.sb.send(tlsRecord)
|
||||||
|
|
||||||
|
stream.sortedBuf.Close()
|
||||||
stream.session.delStream(stream.id)
|
stream.session.delStream(stream.id)
|
||||||
stream.feederW.Close()
|
|
||||||
stream.feederR.Close()
|
|
||||||
//log.Printf("%v actively closed\n", stream.id)
|
//log.Printf("%v actively closed\n", stream.id)
|
||||||
stream.writingM.Unlock()
|
stream.writingM.Unlock()
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue