mirror of https://github.com/cbeuw/Cloak
Change behaviour of stream.Write(nil)
This commit is contained in:
parent
d3bc3b5a13
commit
a461059b4a
|
|
@ -70,12 +70,8 @@ func (s *Stream) writeFrame(frame Frame) error {
|
||||||
func (s *Stream) Read(buf []byte) (n int, err error) {
|
func (s *Stream) Read(buf []byte) (n int, err error) {
|
||||||
//log.Tracef("attempting to read from stream %v", s.id)
|
//log.Tracef("attempting to read from stream %v", s.id)
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
if s.isClosed() {
|
|
||||||
return 0, ErrBrokenStream
|
|
||||||
} else {
|
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
n, err = s.recvBuf.Read(buf)
|
n, err = s.recvBuf.Read(buf)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package multiplex
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/cbeuw/Cloak/internal/util"
|
"github.com/cbeuw/Cloak/internal/util"
|
||||||
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
@ -121,10 +122,9 @@ func TestStream_Write(t *testing.T) {
|
||||||
func TestStream_Close(t *testing.T) {
|
func TestStream_Close(t *testing.T) {
|
||||||
sesh := setupSesh(false)
|
sesh := setupSesh(false)
|
||||||
testPayload := []byte{42, 42, 42}
|
testPayload := []byte{42, 42, 42}
|
||||||
streamID := uint32(1)
|
|
||||||
|
|
||||||
f := &Frame{
|
f := &Frame{
|
||||||
streamID,
|
1,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
testPayload,
|
testPayload,
|
||||||
|
|
@ -147,10 +147,19 @@ func TestStream_Close(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if sI, _ := sesh.streams.Load(streamID); sI != nil {
|
if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil {
|
||||||
t.Error("stream still exists")
|
t.Error("stream still exists")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, len(testPayload))
|
||||||
|
_, err = io.ReadFull(stream, readBuf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("can't read residual data %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(readBuf, testPayload) {
|
||||||
|
t.Errorf("read wrong data")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStream_Read(t *testing.T) {
|
func TestStream_Read(t *testing.T) {
|
||||||
|
|
@ -210,14 +219,6 @@ func TestStream_Read(t *testing.T) {
|
||||||
t.Error("expecting", 0, nil,
|
t.Error("expecting", 0, nil,
|
||||||
"got", i, err)
|
"got", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.Close()
|
|
||||||
i, err = stream.Read(nil)
|
|
||||||
if i != 0 || err != ErrBrokenStream {
|
|
||||||
t.Error("expecting", 0, ErrBrokenStream,
|
|
||||||
"got", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
})
|
||||||
t.Run("Read after stream close", func(t *testing.T) {
|
t.Run("Read after stream close", func(t *testing.T) {
|
||||||
f.StreamID = streamID
|
f.StreamID = streamID
|
||||||
|
|
@ -325,14 +326,6 @@ func TestStream_UnorderedRead(t *testing.T) {
|
||||||
t.Error("expecting", 0, nil,
|
t.Error("expecting", 0, nil,
|
||||||
"got", i, err)
|
"got", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.Close()
|
|
||||||
i, err = stream.Read(nil)
|
|
||||||
if i != 0 || err != ErrBrokenStream {
|
|
||||||
t.Error("expecting", 0, ErrBrokenStream,
|
|
||||||
"got", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
})
|
||||||
t.Run("Read after stream close", func(t *testing.T) {
|
t.Run("Read after stream close", func(t *testing.T) {
|
||||||
f.StreamID = streamID
|
f.StreamID = streamID
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue