Use a buffered pipe to buffer sorted data

This commit is contained in:
Qian Wang 2019-07-27 19:53:16 +01:00
parent 38f3a4a522
commit 0e08683828
4 changed files with 263 additions and 47 deletions

View File

@ -0,0 +1,78 @@
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
package multiplex
import (
"bytes"
"io"
"sync"
)
const BUF_SIZE_LIMIT = 1 << 20 * 500
type bufferedPipe struct {
buf *bytes.Buffer
closed bool
rwCond *sync.Cond
}
func NewBufferedPipe() *bufferedPipe {
p := &bufferedPipe{
buf: new(bytes.Buffer),
rwCond: sync.NewCond(&sync.Mutex{}),
}
return p
}
func (p *bufferedPipe) Read(target []byte) (int, error) {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
for {
if p.closed && p.buf.Len() == 0 {
return 0, io.EOF
}
if p.buf.Len() > 0 {
break
}
p.rwCond.Wait()
}
n, err := p.buf.Read(target)
// err will always be nil because we have already verified that buf.Len() != 0
p.rwCond.Broadcast()
return n, err
}
func (p *bufferedPipe) Write(input []byte) (int, error) {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
for {
if p.closed {
return 0, io.ErrClosedPipe
}
if p.buf.Len() <= BUF_SIZE_LIMIT {
// if p.buf gets too large, write() will panic. We don't want this to happen
break
}
p.rwCond.Wait()
}
n, err := p.buf.Write(input)
// err will always be nil
p.rwCond.Broadcast()
return n, err
}
func (p *bufferedPipe) Close() error {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
p.closed = true
p.rwCond.Broadcast()
return nil
}
func (p *bufferedPipe) Len() int {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
return p.buf.Len()
}

View File

@ -0,0 +1,166 @@
package multiplex
import (
"bytes"
"testing"
"time"
)
func TestPipeRW(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
n, err := pipe.Write(b)
if n != len(b) {
t.Error(
"For", "number of bytes written",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple write",
"expecting", "nil error",
"got", err,
)
}
b2 := make([]byte, len(b))
n, err = pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
}
}
func TestReadBlock(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(10 * time.Millisecond)
pipe.Write(b)
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read after block",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
}
}
func TestPartialRead(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b1 := make([]byte, 1)
n, err := pipe.Read(b1)
if n != len(b1) {
t.Error(
"For", "number of bytes in partial read of 1",
"expecting", len(b1),
"got", n,
)
}
if err != nil {
t.Error(
"For", "partial read of 1",
"expecting", "nil error",
"got", err,
)
}
if b1[0] != b[0] {
t.Error(
"For", "partial read of 1",
"expecting", b[0],
"got", b1[0],
)
}
b2 := make([]byte, 2)
n, err = pipe.Read(b2)
if n != len(b2) {
t.Error(
"For", "number of bytes in partial read of 2",
"expecting", len(b2),
"got", n,
)
}
if err != nil {
t.Error(
"For", "partial read of 2",
"expecting", "nil error",
"got", err,
)
}
if !bytes.Equal(b[1:], b2) {
t.Error(
"For", "partial read of 2",
"expecting", b[1:],
"got", b2,
)
}
}
func TestReadAfterClose(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
}
}

View File

@ -73,10 +73,10 @@ func (s *Stream) recvNewFrame() {
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq { if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
if f.Closing == 1 { if f.Closing == 1 {
// empty data indicates closing signal // empty data indicates closing signal
s.sortedBufCh <- []byte{} s.passiveClose()
return return
} else { } else {
s.sortedBufCh <- f.Payload s.sortedBuf.Write(f.Payload)
s.nextRecvSeq += 1 s.nextRecvSeq += 1
if s.nextRecvSeq == 0 { // getting wrapped if s.nextRecvSeq == 0 { // getting wrapped
s.rev += 1 s.rev += 1
@ -115,10 +115,10 @@ func (s *Stream) recvNewFrame() {
f = heap.Pop(&s.sh).(*frameNode).frame f = heap.Pop(&s.sh).(*frameNode).frame
if f.Closing == 1 { if f.Closing == 1 {
// empty data indicates closing signal // empty data indicates closing signal
s.sortedBufCh <- []byte{} s.passiveClose()
return return
} else { } else {
s.sortedBufCh <- f.Payload s.sortedBuf.Write(f.Payload)
s.nextRecvSeq += 1 s.nextRecvSeq += 1
if s.nextRecvSeq == 0 { // getting wrapped if s.nextRecvSeq == 0 { // getting wrapped
s.rev += 1 s.rev += 1

View File

@ -2,8 +2,6 @@ package multiplex
import ( import (
"errors" "errors"
"io"
"log"
"net" "net"
"time" "time"
@ -29,10 +27,8 @@ type Stream struct {
// New frames are received through newFrameCh by frameSorter // New frames are received through newFrameCh by frameSorter
newFrameCh chan *Frame newFrameCh chan *Frame
// sortedBufCh are order-sorted data ready to be read raw
sortedBufCh chan []byte sortedBuf *bufferedPipe
feederR *io.PipeReader
feederW *io.PipeWriter
// atomic // atomic
nextSendSeq uint32 nextSendSeq uint32
@ -45,45 +41,18 @@ type Stream struct {
} }
func makeStream(id uint32, sesh *Session) *Stream { func makeStream(id uint32, sesh *Session) *Stream {
r, w := io.Pipe()
stream := &Stream{ stream := &Stream{
id: id, id: id,
session: sesh, session: sesh,
die: make(chan struct{}), die: make(chan struct{}),
sh: []*frameNode{}, sh: []*frameNode{},
newFrameCh: make(chan *Frame, 1024), newFrameCh: make(chan *Frame, 1024),
sortedBufCh: make(chan []byte, 1024), sortedBuf: NewBufferedPipe(),
feederR: r,
feederW: w,
} }
go stream.recvNewFrame() go stream.recvNewFrame()
go stream.feed()
return stream return stream
} }
func (stream *Stream) feed() {
for {
select {
case <-stream.die:
return
case data := <-stream.sortedBufCh:
if len(data) == 0 {
stream.passiveClose()
return
}
_, err := stream.feederW.Write(data)
if err != nil {
if err == io.ErrClosedPipe {
stream.Close()
return
} else {
log.Println(err)
}
}
}
}
}
func (stream *Stream) Read(buf []byte) (n int, err error) { func (stream *Stream) Read(buf []byte) (n int, err error) {
if len(buf) == 0 { if len(buf) == 0 {
select { select {
@ -95,9 +64,13 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
} }
select { select {
case <-stream.die: case <-stream.die:
if stream.sortedBuf.Len() == 0 {
return 0, ErrBrokenStream return 0, ErrBrokenStream
} else {
return stream.sortedBuf.Read(buf)
}
default: default:
return stream.feederR.Read(buf) return stream.sortedBuf.Read(buf)
} }
} }
@ -168,9 +141,8 @@ func (stream *Stream) Close() error {
tlsRecord, _ := stream.session.obfs(f) tlsRecord, _ := stream.session.obfs(f)
stream.session.sb.send(tlsRecord) stream.session.sb.send(tlsRecord)
stream.sortedBuf.Close()
stream.session.delStream(stream.id) stream.session.delStream(stream.id)
stream.feederW.Close()
stream.feederR.Close()
//log.Printf("%v actively closed\n", stream.id) //log.Printf("%v actively closed\n", stream.id)
stream.writingM.Unlock() stream.writingM.Unlock()
return nil return nil