mirror of https://github.com/cbeuw/Cloak
Merge branch 'master' into notsure2
This commit is contained in:
commit
7be088e7c1
10
README.md
10
README.md
|
|
@ -126,11 +126,11 @@ instead a CDN is used, use `CDN`.
|
||||||
`ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the
|
`ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the
|
||||||
server's `ProxyBook` exactly.
|
server's `ProxyBook` exactly.
|
||||||
|
|
||||||
`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-gcm`
|
`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-256-gcm` (
|
||||||
and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport security. The point of encryption is to hide
|
synonymous to `aes-gcm`), `aes-128-gcm`, and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport
|
||||||
fingerprints of proxy protocols and render the payload statistically random-like. **You may only leave it as `plain` if
|
security. The point of encryption is to hide fingerprints of proxy protocols and render the payload statistically
|
||||||
you are certain that your underlying proxy tool already provides BOTH encryption and authentication (via AEAD or similar
|
random-like. **You may only leave it as `plain` if you are certain that your underlying proxy tool already provides BOTH
|
||||||
techniques).**
|
encryption and authentication (via AEAD or similar techniques).**
|
||||||
|
|
||||||
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
|
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
|
||||||
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to.
|
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to.
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,10 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca
|
||||||
case "plain":
|
case "plain":
|
||||||
auth.EncryptionMethod = mux.EncryptionMethodPlain
|
auth.EncryptionMethod = mux.EncryptionMethodPlain
|
||||||
case "aes-gcm":
|
case "aes-gcm":
|
||||||
auth.EncryptionMethod = mux.EncryptionMethodAESGCM
|
case "aes-256-gcm":
|
||||||
|
auth.EncryptionMethod = mux.EncryptionMethodAES256GCM
|
||||||
|
case "aes-128-gcm":
|
||||||
|
auth.EncryptionMethod = mux.EncryptionMethodAES128GCM
|
||||||
case "chacha20-poly1305":
|
case "chacha20-poly1305":
|
||||||
auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305
|
auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
|
@ -44,7 +43,9 @@ func NewTLSConn(conn net.Conn) *TLSConn {
|
||||||
return &TLSConn{
|
return &TLSConn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
writeBufPool: sync.Pool{New: func() interface{} {
|
writeBufPool: sync.Pool{New: func() interface{} {
|
||||||
return new(bytes.Buffer)
|
b := make([]byte, 0, initialWriteBufSize)
|
||||||
|
b = append(b, ApplicationData, byte(VersionTLS13>>8), byte(VersionTLS13&0xFF))
|
||||||
|
return &b
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -93,16 +94,13 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
|
||||||
|
|
||||||
func (tls *TLSConn) Write(in []byte) (n int, err error) {
|
func (tls *TLSConn) Write(in []byte) (n int, err error) {
|
||||||
msgLen := len(in)
|
msgLen := len(in)
|
||||||
writeBuf := tls.writeBufPool.Get().(*bytes.Buffer)
|
writeBuf := tls.writeBufPool.Get().(*[]byte)
|
||||||
writeBuf.WriteByte(ApplicationData)
|
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
|
||||||
writeBuf.WriteByte(byte(VersionTLS13 >> 8))
|
*writeBuf = append(*writeBuf, in...)
|
||||||
writeBuf.WriteByte(byte(VersionTLS13 & 0xFF))
|
n, err = tls.Conn.Write(*writeBuf)
|
||||||
writeBuf.WriteByte(byte(msgLen >> 8))
|
*writeBuf = (*writeBuf)[:3]
|
||||||
writeBuf.WriteByte(byte(msgLen & 0xFF))
|
|
||||||
writeBuf.Write(in)
|
|
||||||
i, err := writeBuf.WriteTo(tls.Conn)
|
|
||||||
tls.writeBufPool.Put(writeBuf)
|
tls.writeBufPool.Put(writeBuf)
|
||||||
return int(i - recordLayerLength), err
|
return n - recordLayerLength, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tls *TLSConn) Close() error {
|
func (tls *TLSConn) Close() error {
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,7 @@ import (
|
||||||
// instead of byte-oriented. The integrity of datagrams written into this buffer is preserved.
|
// instead of byte-oriented. The integrity of datagrams written into this buffer is preserved.
|
||||||
// it won't get chopped up into individual bytes
|
// it won't get chopped up into individual bytes
|
||||||
type datagramBufferedPipe struct {
|
type datagramBufferedPipe struct {
|
||||||
pLens []int
|
pLens []int
|
||||||
// lazily allocated
|
|
||||||
buf *bytes.Buffer
|
buf *bytes.Buffer
|
||||||
closed bool
|
closed bool
|
||||||
rwCond *sync.Cond
|
rwCond *sync.Cond
|
||||||
|
|
@ -27,6 +26,7 @@ type datagramBufferedPipe struct {
|
||||||
func NewDatagramBufferedPipe() *datagramBufferedPipe {
|
func NewDatagramBufferedPipe() *datagramBufferedPipe {
|
||||||
d := &datagramBufferedPipe{
|
d := &datagramBufferedPipe{
|
||||||
rwCond: sync.NewCond(&sync.Mutex{}),
|
rwCond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
buf: new(bytes.Buffer),
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
@ -34,9 +34,6 @@ func NewDatagramBufferedPipe() *datagramBufferedPipe {
|
||||||
func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
|
func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
|
||||||
d.rwCond.L.Lock()
|
d.rwCond.L.Lock()
|
||||||
defer d.rwCond.L.Unlock()
|
defer d.rwCond.L.Unlock()
|
||||||
if d.buf == nil {
|
|
||||||
d.buf = new(bytes.Buffer)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if d.closed && len(d.pLens) == 0 {
|
if d.closed && len(d.pLens) == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
|
|
@ -72,9 +69,6 @@ func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
|
||||||
func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
d.rwCond.L.Lock()
|
d.rwCond.L.Lock()
|
||||||
defer d.rwCond.L.Unlock()
|
defer d.rwCond.L.Unlock()
|
||||||
if d.buf == nil {
|
|
||||||
d.buf = new(bytes.Buffer)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if d.closed && len(d.pLens) == 0 {
|
if d.closed && len(d.pLens) == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
|
|
@ -115,9 +109,6 @@ func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
|
func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
|
||||||
d.rwCond.L.Lock()
|
d.rwCond.L.Lock()
|
||||||
defer d.rwCond.L.Unlock()
|
defer d.rwCond.L.Unlock()
|
||||||
if d.buf == nil {
|
|
||||||
d.buf = new(bytes.Buffer)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if d.closed {
|
if d.closed {
|
||||||
return true, io.ErrClosedPipe
|
return true, io.ErrClosedPipe
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,13 @@ func serveEcho(l net.Listener) {
|
||||||
// TODO: pass the error back
|
// TODO: pass the error back
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func(conn net.Conn) {
|
||||||
_, err := io.Copy(conn, conn)
|
_, err := io.Copy(conn, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: pass the error back
|
// TODO: pass the error back
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -65,27 +65,32 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
|
||||||
|
|
||||||
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
|
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
testData := make([]byte, msgLen)
|
||||||
|
rand.Read(testData)
|
||||||
|
|
||||||
for _, conn := range conns {
|
for _, conn := range conns {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(conn net.Conn) {
|
go func(conn net.Conn) {
|
||||||
testData := make([]byte, msgLen)
|
defer wg.Done()
|
||||||
rand.Read(testData)
|
|
||||||
|
|
||||||
|
// we cannot call t.Fatalf in concurrent contexts
|
||||||
n, err := conn.Write(testData)
|
n, err := conn.Write(testData)
|
||||||
if n != msgLen {
|
if n != msgLen {
|
||||||
t.Fatalf("written only %v, err %v", n, err)
|
t.Errorf("written only %v, err %v", n, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
recvBuf := make([]byte, msgLen)
|
recvBuf := make([]byte, msgLen)
|
||||||
_, err = io.ReadFull(conn, recvBuf)
|
_, err = io.ReadFull(conn, recvBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read back: %v", err)
|
t.Errorf("failed to read back: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(testData, recvBuf) {
|
if !bytes.Equal(testData, recvBuf) {
|
||||||
t.Fatalf("echoed data not correct")
|
t.Errorf("echoed data not correct")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
wg.Done()
|
|
||||||
}(conn)
|
}(conn)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,6 @@ import (
|
||||||
"golang.org/x/crypto/salsa20"
|
"golang.org/x/crypto/salsa20"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Obfser func(*Frame, []byte, int) (int, error)
|
|
||||||
type Deobfser func(*Frame, []byte) error
|
|
||||||
|
|
||||||
var u32 = binary.BigEndian.Uint32
|
var u32 = binary.BigEndian.Uint32
|
||||||
var u64 = binary.BigEndian.Uint64
|
var u64 = binary.BigEndian.Uint64
|
||||||
var putU32 = binary.BigEndian.PutUint32
|
var putU32 = binary.BigEndian.PutUint32
|
||||||
|
|
@ -24,27 +21,22 @@ const salsa20NonceSize = 8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EncryptionMethodPlain = iota
|
EncryptionMethodPlain = iota
|
||||||
EncryptionMethodAESGCM
|
EncryptionMethodAES256GCM
|
||||||
EncryptionMethodChaha20Poly1305
|
EncryptionMethodChaha20Poly1305
|
||||||
|
EncryptionMethodAES128GCM
|
||||||
)
|
)
|
||||||
|
|
||||||
// Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames.
|
// Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames.
|
||||||
type Obfuscator struct {
|
type Obfuscator struct {
|
||||||
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
|
payloadCipher cipher.AEAD
|
||||||
Obfs Obfser
|
|
||||||
// Remove TLS header, decrypt and unmarshall frames
|
|
||||||
Deobfs Deobfser
|
|
||||||
SessionKey [32]byte
|
SessionKey [32]byte
|
||||||
|
|
||||||
maxOverhead int
|
maxOverhead int
|
||||||
}
|
}
|
||||||
|
|
||||||
// MakeObfs returns a function of type Obfser. An Obfser takes three arguments:
|
// obfuscate adds multiplexing headers, encrypt and add TLS header
|
||||||
// a *Frame with all the field set correctly, a []byte as buffer to put encrypted
|
func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) {
|
||||||
// message in, and an int called payloadOffsetInBuf to be used when *Frame.payload
|
|
||||||
// is in the byte slice used as buffer (2nd argument). payloadOffsetInBuf specifies
|
|
||||||
// the index at which data belonging to *Frame.Payload starts in the buffer.
|
|
||||||
func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
|
|
||||||
// The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header
|
// The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header
|
||||||
// as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use
|
// as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use
|
||||||
// the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead())
|
// the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead())
|
||||||
|
|
@ -76,109 +68,99 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
|
||||||
// We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input
|
// We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input
|
||||||
// is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends
|
// is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends
|
||||||
// encrypted data.
|
// encrypted data.
|
||||||
obfs := func(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) {
|
payloadLen := len(f.Payload)
|
||||||
payloadLen := len(f.Payload)
|
if payloadLen == 0 {
|
||||||
if payloadLen == 0 {
|
return 0, errors.New("payload cannot be empty")
|
||||||
return 0, errors.New("payload cannot be empty")
|
|
||||||
}
|
|
||||||
var extraLen int
|
|
||||||
if payloadCipher == nil {
|
|
||||||
extraLen = salsa20NonceSize - payloadLen
|
|
||||||
if extraLen < 0 {
|
|
||||||
// if our payload is already greater than 8 bytes
|
|
||||||
extraLen = 0
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
extraLen = payloadCipher.Overhead()
|
|
||||||
if extraLen < salsa20NonceSize {
|
|
||||||
return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
usefulLen := frameHeaderLength + payloadLen + extraLen
|
|
||||||
if len(buf) < usefulLen {
|
|
||||||
return 0, errors.New("obfs buffer too small")
|
|
||||||
}
|
|
||||||
// we do as much in-place as possible to save allocation
|
|
||||||
payload := buf[frameHeaderLength : frameHeaderLength+payloadLen]
|
|
||||||
if payloadOffsetInBuf != frameHeaderLength {
|
|
||||||
// if payload is not at the correct location in buffer
|
|
||||||
copy(payload, f.Payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
header := buf[:frameHeaderLength]
|
|
||||||
putU32(header[0:4], f.StreamID)
|
|
||||||
putU64(header[4:12], f.Seq)
|
|
||||||
header[12] = f.Closing
|
|
||||||
header[13] = byte(extraLen)
|
|
||||||
|
|
||||||
if payloadCipher == nil {
|
|
||||||
if extraLen != 0 { // read nonce
|
|
||||||
extra := buf[usefulLen-extraLen : usefulLen]
|
|
||||||
common.CryptoRandRead(extra)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
payloadCipher.Seal(payload[:0], header[:payloadCipher.NonceSize()], payload, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
nonce := buf[usefulLen-salsa20NonceSize : usefulLen]
|
|
||||||
salsa20.XORKeyStream(header, header, nonce, &salsaKey)
|
|
||||||
|
|
||||||
return usefulLen, nil
|
|
||||||
}
|
}
|
||||||
return obfs
|
var extraLen int
|
||||||
|
if o.payloadCipher == nil {
|
||||||
|
extraLen = salsa20NonceSize - payloadLen
|
||||||
|
if extraLen < 0 {
|
||||||
|
// if our payload is already greater than 8 bytes
|
||||||
|
extraLen = 0
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
extraLen = o.payloadCipher.Overhead()
|
||||||
|
if extraLen < salsa20NonceSize {
|
||||||
|
return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usefulLen := frameHeaderLength + payloadLen + extraLen
|
||||||
|
if len(buf) < usefulLen {
|
||||||
|
return 0, errors.New("obfs buffer too small")
|
||||||
|
}
|
||||||
|
// we do as much in-place as possible to save allocation
|
||||||
|
payload := buf[frameHeaderLength : frameHeaderLength+payloadLen]
|
||||||
|
if payloadOffsetInBuf != frameHeaderLength {
|
||||||
|
// if payload is not at the correct location in buffer
|
||||||
|
copy(payload, f.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
header := buf[:frameHeaderLength]
|
||||||
|
putU32(header[0:4], f.StreamID)
|
||||||
|
putU64(header[4:12], f.Seq)
|
||||||
|
header[12] = f.Closing
|
||||||
|
header[13] = byte(extraLen)
|
||||||
|
|
||||||
|
if o.payloadCipher == nil {
|
||||||
|
if extraLen != 0 { // read nonce
|
||||||
|
extra := buf[usefulLen-extraLen : usefulLen]
|
||||||
|
common.CryptoRandRead(extra)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := buf[usefulLen-salsa20NonceSize : usefulLen]
|
||||||
|
salsa20.XORKeyStream(header, header, nonce, &o.SessionKey)
|
||||||
|
|
||||||
|
return usefulLen, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MakeDeobfs returns a function Deobfser. A Deobfser takes in a single byte slice,
|
// deobfuscate removes TLS header, decrypt and unmarshall frames
|
||||||
// containing the message to be decrypted, and returns a *Frame containing the frame
|
func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error {
|
||||||
// information and plaintext
|
if len(in) < frameHeaderLength+salsa20NonceSize {
|
||||||
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
|
return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize)
|
||||||
// frame header length + minimum data size (i.e. nonce size of salsa20)
|
}
|
||||||
const minInputLen = frameHeaderLength + salsa20NonceSize
|
|
||||||
deobfs := func(f *Frame, in []byte) error {
|
|
||||||
if len(in) < minInputLen {
|
|
||||||
return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
header := in[:frameHeaderLength]
|
header := in[:frameHeaderLength]
|
||||||
pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead
|
pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead
|
||||||
|
|
||||||
nonce := in[len(in)-salsa20NonceSize:]
|
nonce := in[len(in)-salsa20NonceSize:]
|
||||||
salsa20.XORKeyStream(header, header, nonce, &salsaKey)
|
salsa20.XORKeyStream(header, header, nonce, &o.SessionKey)
|
||||||
|
|
||||||
streamID := u32(header[0:4])
|
streamID := u32(header[0:4])
|
||||||
seq := u64(header[4:12])
|
seq := u64(header[4:12])
|
||||||
closing := header[12]
|
closing := header[12]
|
||||||
extraLen := header[13]
|
extraLen := header[13]
|
||||||
|
|
||||||
usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
|
usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
|
||||||
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
|
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
|
||||||
return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
|
return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputPayload []byte
|
var outputPayload []byte
|
||||||
|
|
||||||
if payloadCipher == nil {
|
if o.payloadCipher == nil {
|
||||||
if extraLen == 0 {
|
if extraLen == 0 {
|
||||||
outputPayload = pldWithOverHead
|
outputPayload = pldWithOverHead
|
||||||
} else {
|
|
||||||
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
_, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
f.StreamID = streamID
|
_, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil)
|
||||||
f.Seq = seq
|
if err != nil {
|
||||||
f.Closing = closing
|
return err
|
||||||
f.Payload = outputPayload
|
}
|
||||||
return nil
|
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
||||||
}
|
}
|
||||||
return deobfs
|
|
||||||
|
f.StreamID = streamID
|
||||||
|
f.Seq = seq
|
||||||
|
f.Closing = closing
|
||||||
|
f.Payload = outputPayload
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) {
|
func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) {
|
||||||
|
|
@ -190,7 +172,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
|
||||||
case EncryptionMethodPlain:
|
case EncryptionMethodPlain:
|
||||||
payloadCipher = nil
|
payloadCipher = nil
|
||||||
obfuscator.maxOverhead = salsa20NonceSize
|
obfuscator.maxOverhead = salsa20NonceSize
|
||||||
case EncryptionMethodAESGCM:
|
case EncryptionMethodAES256GCM:
|
||||||
var c cipher.Block
|
var c cipher.Block
|
||||||
c, err = aes.NewCipher(sessionKey[:])
|
c, err = aes.NewCipher(sessionKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -201,6 +183,17 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
obfuscator.maxOverhead = payloadCipher.Overhead()
|
obfuscator.maxOverhead = payloadCipher.Overhead()
|
||||||
|
case EncryptionMethodAES128GCM:
|
||||||
|
var c cipher.Block
|
||||||
|
c, err = aes.NewCipher(sessionKey[:16])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payloadCipher, err = cipher.NewGCM(c)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
obfuscator.maxOverhead = payloadCipher.Overhead()
|
||||||
case EncryptionMethodChaha20Poly1305:
|
case EncryptionMethodChaha20Poly1305:
|
||||||
payloadCipher, err = chacha20poly1305.New(sessionKey[:])
|
payloadCipher, err = chacha20poly1305.New(sessionKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -208,7 +201,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
|
||||||
}
|
}
|
||||||
obfuscator.maxOverhead = payloadCipher.Overhead()
|
obfuscator.maxOverhead = payloadCipher.Overhead()
|
||||||
default:
|
default:
|
||||||
return obfuscator, errors.New("Unknown encryption method")
|
return obfuscator, fmt.Errorf("unknown encryption method valued %v", encryptionMethod)
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadCipher != nil {
|
if payloadCipher != nil {
|
||||||
|
|
@ -217,7 +210,5 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
obfuscator.Obfs = MakeObfs(sessionKey, payloadCipher)
|
|
||||||
obfuscator.Deobfs = MakeDeobfs(sessionKey, payloadCipher)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,14 +19,14 @@ func TestGenerateObfs(t *testing.T) {
|
||||||
obfsBuf := make([]byte, 512)
|
obfsBuf := make([]byte, 512)
|
||||||
_testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42)))
|
_testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42)))
|
||||||
testFrame := _testFrame.Interface().(*Frame)
|
testFrame := _testFrame.Interface().(*Frame)
|
||||||
i, err := obfuscator.Obfs(testFrame, obfsBuf, 0)
|
i, err := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ct.Error("failed to obfs ", err)
|
ct.Error("failed to obfs ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var resultFrame Frame
|
var resultFrame Frame
|
||||||
err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i])
|
err = obfuscator.deobfuscate(&resultFrame, obfsBuf[:i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ct.Error("failed to deobfs ", err)
|
ct.Error("failed to deobfs ", err)
|
||||||
return
|
return
|
||||||
|
|
@ -46,8 +46,16 @@ func TestGenerateObfs(t *testing.T) {
|
||||||
run(obfuscator, t)
|
run(obfuscator, t)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("aes-gcm", func(t *testing.T) {
|
t.Run("aes-256-gcm", func(t *testing.T) {
|
||||||
obfuscator, err := MakeObfuscator(EncryptionMethodAESGCM, sessionKey)
|
obfuscator, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to generate obfuscator %v", err)
|
||||||
|
} else {
|
||||||
|
run(obfuscator, t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("aes-128-gcm", func(t *testing.T) {
|
||||||
|
obfuscator, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to generate obfuscator %v", err)
|
t.Errorf("failed to generate obfuscator %v", err)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -88,40 +96,57 @@ func BenchmarkObfs(b *testing.B) {
|
||||||
c, _ := aes.NewCipher(key[:])
|
c, _ := aes.NewCipher(key[:])
|
||||||
payloadCipher, _ := cipher.NewGCM(c)
|
payloadCipher, _ := cipher.NewGCM(c)
|
||||||
|
|
||||||
obfs := MakeObfs(key, payloadCipher)
|
obfuscator := Obfuscator{
|
||||||
|
payloadCipher: payloadCipher,
|
||||||
|
SessionKey: key,
|
||||||
|
maxOverhead: payloadCipher.Overhead(),
|
||||||
|
}
|
||||||
|
|
||||||
b.SetBytes(int64(len(testFrame.Payload)))
|
b.SetBytes(int64(len(testFrame.Payload)))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
obfs(testFrame, obfsBuf, 0)
|
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("AES128GCM", func(b *testing.B) {
|
b.Run("AES128GCM", func(b *testing.B) {
|
||||||
c, _ := aes.NewCipher(key[:16])
|
c, _ := aes.NewCipher(key[:16])
|
||||||
payloadCipher, _ := cipher.NewGCM(c)
|
payloadCipher, _ := cipher.NewGCM(c)
|
||||||
|
|
||||||
obfs := MakeObfs(key, payloadCipher)
|
obfuscator := Obfuscator{
|
||||||
|
payloadCipher: payloadCipher,
|
||||||
|
SessionKey: key,
|
||||||
|
maxOverhead: payloadCipher.Overhead(),
|
||||||
|
}
|
||||||
b.SetBytes(int64(len(testFrame.Payload)))
|
b.SetBytes(int64(len(testFrame.Payload)))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
obfs(testFrame, obfsBuf, 0)
|
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("plain", func(b *testing.B) {
|
b.Run("plain", func(b *testing.B) {
|
||||||
obfs := MakeObfs(key, nil)
|
obfuscator := Obfuscator{
|
||||||
|
payloadCipher: nil,
|
||||||
|
SessionKey: key,
|
||||||
|
maxOverhead: salsa20NonceSize,
|
||||||
|
}
|
||||||
b.SetBytes(int64(len(testFrame.Payload)))
|
b.SetBytes(int64(len(testFrame.Payload)))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
obfs(testFrame, obfsBuf, 0)
|
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("chacha20Poly1305", func(b *testing.B) {
|
b.Run("chacha20Poly1305", func(b *testing.B) {
|
||||||
payloadCipher, _ := chacha20poly1305.New(key[:16])
|
payloadCipher, _ := chacha20poly1305.New(key[:])
|
||||||
|
|
||||||
obfs := MakeObfs(key, payloadCipher)
|
obfuscator := Obfuscator{
|
||||||
|
payloadCipher: payloadCipher,
|
||||||
|
SessionKey: key,
|
||||||
|
maxOverhead: payloadCipher.Overhead(),
|
||||||
|
}
|
||||||
b.SetBytes(int64(len(testFrame.Payload)))
|
b.SetBytes(int64(len(testFrame.Payload)))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
obfs(testFrame, obfsBuf, 0)
|
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -143,57 +168,70 @@ func BenchmarkDeobfs(b *testing.B) {
|
||||||
b.Run("AES256GCM", func(b *testing.B) {
|
b.Run("AES256GCM", func(b *testing.B) {
|
||||||
c, _ := aes.NewCipher(key[:])
|
c, _ := aes.NewCipher(key[:])
|
||||||
payloadCipher, _ := cipher.NewGCM(c)
|
payloadCipher, _ := cipher.NewGCM(c)
|
||||||
|
obfuscator := Obfuscator{
|
||||||
|
payloadCipher: payloadCipher,
|
||||||
|
SessionKey: key,
|
||||||
|
maxOverhead: payloadCipher.Overhead(),
|
||||||
|
}
|
||||||
|
|
||||||
obfs := MakeObfs(key, payloadCipher)
|
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
|
||||||
deobfs := MakeDeobfs(key, payloadCipher)
|
|
||||||
|
|
||||||
frame := new(Frame)
|
frame := new(Frame)
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(frame, obfsBuf[:n])
|
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("AES128GCM", func(b *testing.B) {
|
b.Run("AES128GCM", func(b *testing.B) {
|
||||||
c, _ := aes.NewCipher(key[:16])
|
c, _ := aes.NewCipher(key[:16])
|
||||||
payloadCipher, _ := cipher.NewGCM(c)
|
payloadCipher, _ := cipher.NewGCM(c)
|
||||||
|
|
||||||
obfs := MakeObfs(key, payloadCipher)
|
obfuscator := Obfuscator{
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
payloadCipher: payloadCipher,
|
||||||
deobfs := MakeDeobfs(key, payloadCipher)
|
SessionKey: key,
|
||||||
|
maxOverhead: payloadCipher.Overhead(),
|
||||||
|
}
|
||||||
|
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
|
|
||||||
frame := new(Frame)
|
frame := new(Frame)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(frame, obfsBuf[:n])
|
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("plain", func(b *testing.B) {
|
b.Run("plain", func(b *testing.B) {
|
||||||
obfs := MakeObfs(key, nil)
|
obfuscator := Obfuscator{
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
payloadCipher: nil,
|
||||||
deobfs := MakeDeobfs(key, nil)
|
SessionKey: key,
|
||||||
|
maxOverhead: salsa20NonceSize,
|
||||||
|
}
|
||||||
|
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
|
|
||||||
frame := new(Frame)
|
frame := new(Frame)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(frame, obfsBuf[:n])
|
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("chacha20Poly1305", func(b *testing.B) {
|
b.Run("chacha20Poly1305", func(b *testing.B) {
|
||||||
payloadCipher, _ := chacha20poly1305.New(key[:16])
|
payloadCipher, _ := chacha20poly1305.New(key[:])
|
||||||
|
|
||||||
obfs := MakeObfs(key, payloadCipher)
|
obfuscator := Obfuscator{
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
payloadCipher: nil,
|
||||||
deobfs := MakeDeobfs(key, payloadCipher)
|
SessionKey: key,
|
||||||
|
maxOverhead: payloadCipher.Overhead(),
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||||
|
|
||||||
frame := new(Frame)
|
frame := new(Frame)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(frame, obfsBuf[:n])
|
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -66,19 +66,20 @@ type Session struct {
|
||||||
|
|
||||||
streamsM sync.Mutex
|
streamsM sync.Mutex
|
||||||
streams map[uint32]*Stream
|
streams map[uint32]*Stream
|
||||||
|
// For accepting new streams
|
||||||
|
acceptCh chan *Stream
|
||||||
|
|
||||||
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
||||||
recvFramePool sync.Pool
|
recvFramePool sync.Pool
|
||||||
|
|
||||||
|
streamObfsBufPool sync.Pool
|
||||||
|
|
||||||
// Switchboard manages all connections to remote
|
// Switchboard manages all connections to remote
|
||||||
sb *switchboard
|
sb *switchboard
|
||||||
|
|
||||||
// Used for LocalAddr() and RemoteAddr() etc.
|
// Used for LocalAddr() and RemoteAddr() etc.
|
||||||
addrs atomic.Value
|
addrs atomic.Value
|
||||||
|
|
||||||
// For accepting new streams
|
|
||||||
acceptCh chan *Stream
|
|
||||||
|
|
||||||
closed uint32
|
closed uint32
|
||||||
|
|
||||||
terminalMsg atomic.Value
|
terminalMsg atomic.Value
|
||||||
|
|
@ -117,6 +118,11 @@ func MakeSession(id uint32, config SessionConfig) *Session {
|
||||||
// todo: validation. this must be smaller than StreamSendBufferSize
|
// todo: validation. this must be smaller than StreamSendBufferSize
|
||||||
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
|
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
|
||||||
|
|
||||||
|
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
|
||||||
|
b := make([]byte, sesh.StreamSendBufferSize)
|
||||||
|
return &b
|
||||||
|
}}
|
||||||
|
|
||||||
sesh.sb = makeSwitchboard(sesh)
|
sesh.sb = makeSwitchboard(sesh)
|
||||||
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
|
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
|
||||||
return sesh
|
return sesh
|
||||||
|
|
@ -174,26 +180,26 @@ func (sesh *Session) Accept() (net.Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) closeStream(s *Stream, active bool) error {
|
func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
// must be holding s.wirtingM on entry
|
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||||
if atomic.SwapUint32(&s.closed, 1) == 1 {
|
|
||||||
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
|
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
|
||||||
}
|
}
|
||||||
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
|
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
|
||||||
|
|
||||||
if active {
|
if active {
|
||||||
|
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
|
||||||
|
|
||||||
// Notify remote that this stream is closed
|
// Notify remote that this stream is closed
|
||||||
padding := genRandomPadding()
|
common.CryptoRandRead((*tmpBuf)[:1])
|
||||||
|
padLen := int((*tmpBuf)[0]) + 1
|
||||||
|
payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength]
|
||||||
|
common.CryptoRandRead(payload)
|
||||||
|
|
||||||
|
// must be holding s.wirtingM on entry
|
||||||
s.writingFrame.Closing = closingStream
|
s.writingFrame.Closing = closingStream
|
||||||
s.writingFrame.Payload = padding
|
s.writingFrame.Payload = payload
|
||||||
|
|
||||||
obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
|
err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength)
|
||||||
|
sesh.streamObfsBufPool.Put(tmpBuf)
|
||||||
i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0)
|
|
||||||
s.writingFrame.Seq++
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = sesh.sb.send(obfsBuf[:i], &s.assignedConnId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -226,7 +232,7 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
frame := sesh.recvFramePool.Get().(*Frame)
|
frame := sesh.recvFramePool.Get().(*Frame)
|
||||||
defer sesh.recvFramePool.Put(frame)
|
defer sesh.recvFramePool.Put(frame)
|
||||||
|
|
||||||
err := sesh.Deobfs(frame, data)
|
err := sesh.deobfuscate(frame, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
|
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
|
||||||
}
|
}
|
||||||
|
|
@ -237,6 +243,10 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
sesh.streamsM.Lock()
|
sesh.streamsM.Lock()
|
||||||
|
if sesh.IsClosed() {
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
return ErrBrokenSession
|
||||||
|
}
|
||||||
existingStream, existing := sesh.streams[frame.StreamID]
|
existingStream, existing := sesh.streams[frame.StreamID]
|
||||||
if existing {
|
if existing {
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
|
|
@ -248,10 +258,10 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
} else {
|
} else {
|
||||||
newStream := makeStream(sesh, frame.StreamID)
|
newStream := makeStream(sesh, frame.StreamID)
|
||||||
sesh.streams[frame.StreamID] = newStream
|
sesh.streams[frame.StreamID] = newStream
|
||||||
|
sesh.acceptCh <- newStream
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
// new stream
|
// new stream
|
||||||
sesh.streamCountIncr()
|
sesh.streamCountIncr()
|
||||||
sesh.acceptCh <- newStream
|
|
||||||
return newStream.recvFrame(frame)
|
return newStream.recvFrame(frame)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -269,14 +279,14 @@ func (sesh *Session) TerminalMsg() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) closeSession(closeSwitchboard bool) error {
|
func (sesh *Session) closeSession() error {
|
||||||
if atomic.SwapUint32(&sesh.closed, 1) == 1 {
|
if !atomic.CompareAndSwapUint32(&sesh.closed, 0, 1) {
|
||||||
log.Debugf("session %v has already been closed", sesh.id)
|
log.Debugf("session %v has already been closed", sesh.id)
|
||||||
return errRepeatSessionClosing
|
return errRepeatSessionClosing
|
||||||
}
|
}
|
||||||
sesh.acceptCh <- nil
|
|
||||||
|
|
||||||
sesh.streamsM.Lock()
|
sesh.streamsM.Lock()
|
||||||
|
close(sesh.acceptCh)
|
||||||
for id, stream := range sesh.streams {
|
for id, stream := range sesh.streams {
|
||||||
if stream == nil {
|
if stream == nil {
|
||||||
continue
|
continue
|
||||||
|
|
@ -287,55 +297,48 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
|
||||||
sesh.streamCountDecr()
|
sesh.streamCountDecr()
|
||||||
}
|
}
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
if closeSwitchboard {
|
|
||||||
sesh.sb.closeAll()
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) passiveClose() error {
|
func (sesh *Session) passiveClose() error {
|
||||||
log.Debugf("attempting to passively close session %v", sesh.id)
|
log.Debugf("attempting to passively close session %v", sesh.id)
|
||||||
err := sesh.closeSession(true)
|
err := sesh.closeSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
sesh.sb.closeAll()
|
||||||
log.Debugf("session %v closed gracefully", sesh.id)
|
log.Debugf("session %v closed gracefully", sesh.id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRandomPadding() []byte {
|
|
||||||
lenB := make([]byte, 1)
|
|
||||||
common.CryptoRandRead(lenB)
|
|
||||||
pad := make([]byte, int(lenB[0])+1)
|
|
||||||
common.CryptoRandRead(pad)
|
|
||||||
return pad
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sesh *Session) Close() error {
|
func (sesh *Session) Close() error {
|
||||||
log.Debugf("attempting to actively close session %v", sesh.id)
|
log.Debugf("attempting to actively close session %v", sesh.id)
|
||||||
err := sesh.closeSession(false)
|
err := sesh.closeSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// we send a notice frame telling remote to close the session
|
// we send a notice frame telling remote to close the session
|
||||||
pad := genRandomPadding()
|
|
||||||
|
buf := sesh.streamObfsBufPool.Get().(*[]byte)
|
||||||
|
common.CryptoRandRead((*buf)[:1])
|
||||||
|
padLen := int((*buf)[0]) + 1
|
||||||
|
payload := (*buf)[frameHeaderLength : padLen+frameHeaderLength]
|
||||||
|
common.CryptoRandRead(payload)
|
||||||
|
|
||||||
f := &Frame{
|
f := &Frame{
|
||||||
StreamID: 0xffffffff,
|
StreamID: 0xffffffff,
|
||||||
Seq: 0,
|
Seq: 0,
|
||||||
Closing: closingSession,
|
Closing: closingSession,
|
||||||
Payload: pad,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
obfsBuf := make([]byte, len(pad)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
|
i, err := sesh.obfuscate(f, *buf, frameHeaderLength)
|
||||||
i, err := sesh.Obfs(f, obfsBuf, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = sesh.sb.send(obfsBuf[:i], new(uint32))
|
_, err = sesh.sb.send((*buf)[:i], new(uint32))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sesh.sb.closeAll()
|
sesh.sb.closeAll()
|
||||||
log.Debugf("session %v closed gracefully", sesh.id)
|
log.Debugf("session %v closed gracefully", sesh.id)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ func TestRecvDataFromRemote(t *testing.T) {
|
||||||
|
|
||||||
encryptionMethods := map[string]Obfuscator{
|
encryptionMethods := map[string]Obfuscator{
|
||||||
"plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
|
"plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
|
||||||
"aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
|
"aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAES256GCM, sessionKey),
|
||||||
"chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
|
"chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -56,7 +56,7 @@ func TestRecvDataFromRemote(t *testing.T) {
|
||||||
t.Run(method, func(t *testing.T) {
|
t.Run(method, func(t *testing.T) {
|
||||||
seshConfig.Obfuscator = obfuscator
|
seshConfig.Obfuscator = obfuscator
|
||||||
sesh := MakeSession(0, seshConfig)
|
sesh := MakeSession(0, seshConfig)
|
||||||
n, err := sesh.Obfs(f, obfsBuf, 0)
|
n, err := sesh.obfuscate(f, obfsBuf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
|
|
@ -107,7 +107,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
// create stream 1
|
// create stream 1
|
||||||
n, _ := sesh.Obfs(f1, obfsBuf, 0)
|
n, _ := sesh.obfuscate(f1, obfsBuf, 0)
|
||||||
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||||
|
|
@ -129,7 +129,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
closingNothing,
|
closingNothing,
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
n, _ = sesh.Obfs(f2, obfsBuf, 0)
|
n, _ = sesh.obfuscate(f2, obfsBuf, 0)
|
||||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||||
|
|
@ -151,7 +151,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
closingStream,
|
closingStream,
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0)
|
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
|
||||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||||
|
|
@ -180,7 +180,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// close stream 1 again
|
// close stream 1 again
|
||||||
n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0)
|
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
|
||||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
||||||
|
|
@ -203,7 +203,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
Closing: closingSession,
|
Closing: closingSession,
|
||||||
Payload: testPayload,
|
Payload: testPayload,
|
||||||
}
|
}
|
||||||
n, _ = sesh.Obfs(fCloseSession, obfsBuf, 0)
|
n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0)
|
||||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving session closing frame: %v", err)
|
t.Fatalf("receiving session closing frame: %v", err)
|
||||||
|
|
@ -246,7 +246,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
||||||
closingStream,
|
closingStream,
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
n, _ := sesh.Obfs(f1CloseStream, obfsBuf, 0)
|
n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0)
|
||||||
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
||||||
|
|
@ -268,7 +268,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
||||||
closingNothing,
|
closingNothing,
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
n, _ = sesh.Obfs(f1, obfsBuf, 0)
|
n, _ = sesh.obfuscate(f1, obfsBuf, 0)
|
||||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||||
|
|
@ -330,7 +330,7 @@ func TestParallelStreams(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(frame *Frame) {
|
go func(frame *Frame) {
|
||||||
obfsBuf := make([]byte, obfsBufLen)
|
obfsBuf := make([]byte, obfsBufLen)
|
||||||
n, _ := sesh.Obfs(frame, obfsBuf, 0)
|
n, _ := sesh.obfuscate(frame, obfsBuf, 0)
|
||||||
obfsBuf = obfsBuf[0:n]
|
obfsBuf = obfsBuf[0:n]
|
||||||
|
|
||||||
err := sesh.recvDataFromRemote(obfsBuf)
|
err := sesh.recvDataFromRemote(obfsBuf)
|
||||||
|
|
@ -430,7 +430,8 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
||||||
|
|
||||||
table := map[string]byte{
|
table := map[string]byte{
|
||||||
"plain": EncryptionMethodPlain,
|
"plain": EncryptionMethodPlain,
|
||||||
"aes-gcm": EncryptionMethodAESGCM,
|
"aes-256-gcm": EncryptionMethodAES256GCM,
|
||||||
|
"aes-128-gcm": EncryptionMethodAES128GCM,
|
||||||
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
|
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -446,7 +447,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
||||||
binaryFrames := [maxIter][]byte{}
|
binaryFrames := [maxIter][]byte{}
|
||||||
for i := 0; i < maxIter; i++ {
|
for i := 0; i < maxIter; i++ {
|
||||||
obfsBuf := make([]byte, obfsBufLen)
|
obfsBuf := make([]byte, obfsBufLen)
|
||||||
n, _ := sesh.Obfs(f, obfsBuf, 0)
|
n, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||||
binaryFrames[i] = obfsBuf[:n]
|
binaryFrames[i] = obfsBuf[:n]
|
||||||
f.Seq++
|
f.Seq++
|
||||||
}
|
}
|
||||||
|
|
@ -466,7 +467,8 @@ func BenchmarkMultiStreamWrite(b *testing.B) {
|
||||||
|
|
||||||
table := map[string]byte{
|
table := map[string]byte{
|
||||||
"plain": EncryptionMethodPlain,
|
"plain": EncryptionMethodPlain,
|
||||||
"aes-gcm": EncryptionMethodAESGCM,
|
"aes-256-gcm": EncryptionMethodAES256GCM,
|
||||||
|
"aes-128-gcm": EncryptionMethodAES128GCM,
|
||||||
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
|
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,11 +34,6 @@ type Stream struct {
|
||||||
// atomic
|
// atomic
|
||||||
closed uint32
|
closed uint32
|
||||||
|
|
||||||
// obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from
|
|
||||||
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
|
|
||||||
// memory
|
|
||||||
obfsBuf []byte
|
|
||||||
|
|
||||||
// When we want order guarantee (i.e. session.Unordered is false),
|
// When we want order guarantee (i.e. session.Unordered is false),
|
||||||
// we assign each stream a fixed underlying connection.
|
// we assign each stream a fixed underlying connection.
|
||||||
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
|
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
|
||||||
|
|
@ -117,14 +112,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
|
func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
|
||||||
var cipherTextLen int
|
cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf)
|
||||||
cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
|
s.writingFrame.Seq++
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
|
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == errBrokenSwitchboard {
|
if err == errBrokenSwitchboard {
|
||||||
s.session.SetTerminalMsg(err.Error())
|
s.session.SetTerminalMsg(err.Error())
|
||||||
|
|
@ -143,9 +138,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
return 0, ErrBrokenStream
|
return 0, ErrBrokenStream
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.obfsBuf == nil {
|
|
||||||
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
|
|
||||||
}
|
|
||||||
for n < len(in) {
|
for n < len(in) {
|
||||||
var framePayload []byte
|
var framePayload []byte
|
||||||
if len(in)-n <= s.session.maxStreamUnitWrite {
|
if len(in)-n <= s.session.maxStreamUnitWrite {
|
||||||
|
|
@ -161,8 +153,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
||||||
}
|
}
|
||||||
s.writingFrame.Payload = framePayload
|
s.writingFrame.Payload = framePayload
|
||||||
err = s.obfuscateAndSend(0)
|
buf := s.session.streamObfsBufPool.Get().(*[]byte)
|
||||||
s.writingFrame.Seq++
|
err = s.obfuscateAndSend(*buf, 0)
|
||||||
|
s.session.streamObfsBufPool.Put(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -174,9 +167,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
|
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
|
||||||
// for readFromTimeout amount of time
|
// for readFromTimeout amount of time
|
||||||
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
if s.obfsBuf == nil {
|
|
||||||
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if s.readFromTimeout != 0 {
|
if s.readFromTimeout != 0 {
|
||||||
if rder, ok := r.(net.Conn); !ok {
|
if rder, ok := r.(net.Conn); !ok {
|
||||||
|
|
@ -185,19 +175,23 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
|
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
read, er := r.Read(s.obfsBuf[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
|
buf := s.session.streamObfsBufPool.Get().(*[]byte)
|
||||||
|
read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
|
||||||
if er != nil {
|
if er != nil {
|
||||||
return n, er
|
return n, er
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the above read may have been unblocked by another goroutine calling stream.Close(), so we need
|
||||||
|
// to check that here
|
||||||
if s.isClosed() {
|
if s.isClosed() {
|
||||||
return n, ErrBrokenStream
|
return n, ErrBrokenStream
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writingM.Lock()
|
s.writingM.Lock()
|
||||||
s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read]
|
s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read]
|
||||||
err = s.obfuscateAndSend(frameHeaderLength)
|
err = s.obfuscateAndSend(*buf, frameHeaderLength)
|
||||||
s.writingFrame.Seq++
|
|
||||||
s.writingM.Unlock()
|
s.writingM.Unlock()
|
||||||
|
s.session.streamObfsBufPool.Put(buf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
// The point of a streamBufferedPipe is that Read() will block until data is available
|
// The point of a streamBufferedPipe is that Read() will block until data is available
|
||||||
type streamBufferedPipe struct {
|
type streamBufferedPipe struct {
|
||||||
// only alloc when on first Read or Write
|
|
||||||
buf *bytes.Buffer
|
buf *bytes.Buffer
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
|
|
@ -25,6 +24,7 @@ type streamBufferedPipe struct {
|
||||||
func NewStreamBufferedPipe() *streamBufferedPipe {
|
func NewStreamBufferedPipe() *streamBufferedPipe {
|
||||||
p := &streamBufferedPipe{
|
p := &streamBufferedPipe{
|
||||||
rwCond: sync.NewCond(&sync.Mutex{}),
|
rwCond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
buf: new(bytes.Buffer),
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
@ -32,9 +32,6 @@ func NewStreamBufferedPipe() *streamBufferedPipe {
|
||||||
func (p *streamBufferedPipe) Read(target []byte) (int, error) {
|
func (p *streamBufferedPipe) Read(target []byte) (int, error) {
|
||||||
p.rwCond.L.Lock()
|
p.rwCond.L.Lock()
|
||||||
defer p.rwCond.L.Unlock()
|
defer p.rwCond.L.Unlock()
|
||||||
if p.buf == nil {
|
|
||||||
p.buf = new(bytes.Buffer)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if p.closed && p.buf.Len() == 0 {
|
if p.closed && p.buf.Len() == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
|
|
@ -64,9 +61,6 @@ func (p *streamBufferedPipe) Read(target []byte) (int, error) {
|
||||||
func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
p.rwCond.L.Lock()
|
p.rwCond.L.Lock()
|
||||||
defer p.rwCond.L.Unlock()
|
defer p.rwCond.L.Unlock()
|
||||||
if p.buf == nil {
|
|
||||||
p.buf = new(bytes.Buffer)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if p.closed && p.buf.Len() == 0 {
|
if p.closed && p.buf.Len() == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
|
|
@ -104,9 +98,6 @@ func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
func (p *streamBufferedPipe) Write(input []byte) (int, error) {
|
func (p *streamBufferedPipe) Write(input []byte) (int, error) {
|
||||||
p.rwCond.L.Lock()
|
p.rwCond.L.Lock()
|
||||||
defer p.rwCond.L.Unlock()
|
defer p.rwCond.L.Unlock()
|
||||||
if p.buf == nil {
|
|
||||||
p.buf = new(bytes.Buffer)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,8 @@ func BenchmarkStream_Write_Ordered(b *testing.B) {
|
||||||
eMethods := map[string]byte{
|
eMethods := map[string]byte{
|
||||||
"plain": EncryptionMethodPlain,
|
"plain": EncryptionMethodPlain,
|
||||||
"chacha20-poly1305": EncryptionMethodChaha20Poly1305,
|
"chacha20-poly1305": EncryptionMethodChaha20Poly1305,
|
||||||
"aes-gcm": EncryptionMethodAESGCM,
|
"aes-256-gcm": EncryptionMethodAES256GCM,
|
||||||
|
"aes-128-gcm": EncryptionMethodAES128GCM,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, method := range eMethods {
|
for name, method := range eMethods {
|
||||||
|
|
@ -141,7 +142,7 @@ func TestStream_Close(t *testing.T) {
|
||||||
writingEnd := common.NewTLSConn(rawWritingEnd)
|
writingEnd := common.NewTLSConn(rawWritingEnd)
|
||||||
|
|
||||||
obfsBuf := make([]byte, 512)
|
obfsBuf := make([]byte, 512)
|
||||||
i, _ := sesh.Obfs(dataFrame, obfsBuf, 0)
|
i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0)
|
||||||
_, err := writingEnd.Write(obfsBuf[:i])
|
_, err := writingEnd.Write(obfsBuf[:i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to write from remote end")
|
t.Error("failed to write from remote end")
|
||||||
|
|
@ -151,16 +152,7 @@ func TestStream_Close(t *testing.T) {
|
||||||
t.Error("failed to accept stream", err)
|
t.Error("failed to accept stream", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
// we read something to wait for the test frame to reach our recvBuffer.
|
|
||||||
// if it's empty by the point we call stream.Close(), the incoming
|
|
||||||
// frame will be dropped
|
|
||||||
readBuf := make([]byte, len(testPayload))
|
|
||||||
_, err = io.ReadFull(stream, readBuf[:1])
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("can't read any data before active closing")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = stream.Close()
|
err = stream.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to actively close stream", err)
|
t.Error("failed to actively close stream", err)
|
||||||
|
|
@ -175,10 +167,12 @@ func TestStream_Close(t *testing.T) {
|
||||||
}
|
}
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
_, err = io.ReadFull(stream, readBuf[1:])
|
readBuf := make([]byte, len(testPayload))
|
||||||
|
_, err = io.ReadFull(stream, readBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("can't read residual data %v", err)
|
t.Errorf("cannot read resiual data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(readBuf, testPayload) {
|
if !bytes.Equal(readBuf, testPayload) {
|
||||||
t.Errorf("read wrong data")
|
t.Errorf("read wrong data")
|
||||||
}
|
}
|
||||||
|
|
@ -191,7 +185,7 @@ func TestStream_Close(t *testing.T) {
|
||||||
writingEnd := common.NewTLSConn(rawWritingEnd)
|
writingEnd := common.NewTLSConn(rawWritingEnd)
|
||||||
|
|
||||||
obfsBuf := make([]byte, 512)
|
obfsBuf := make([]byte, 512)
|
||||||
i, err := sesh.Obfs(dataFrame, obfsBuf, 0)
|
i, err := sesh.obfuscate(dataFrame, obfsBuf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to obfuscate frame %v", err)
|
t.Errorf("failed to obfuscate frame %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -213,7 +207,7 @@ func TestStream_Close(t *testing.T) {
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
|
|
||||||
i, err = sesh.Obfs(closingFrame, obfsBuf, 0)
|
i, err = sesh.obfuscate(closingFrame, obfsBuf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to obfuscate frame %v", err)
|
t.Errorf("failed to obfuscate frame %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -229,7 +223,7 @@ func TestStream_Close(t *testing.T) {
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
|
|
||||||
i, err = sesh.Obfs(closingFrameDup, obfsBuf, 0)
|
i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to obfuscate frame %v", err)
|
t.Errorf("failed to obfuscate frame %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -270,9 +264,6 @@ func TestStream_Read(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var streamID uint32
|
var streamID uint32
|
||||||
buf := make([]byte, 10)
|
|
||||||
|
|
||||||
obfsBuf := make([]byte, 512)
|
|
||||||
|
|
||||||
for name, unordered := range seshes {
|
for name, unordered := range seshes {
|
||||||
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
|
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
|
||||||
|
|
@ -280,9 +271,11 @@ func TestStream_Read(t *testing.T) {
|
||||||
sesh.AddConnection(common.NewTLSConn(rawConn))
|
sesh.AddConnection(common.NewTLSConn(rawConn))
|
||||||
writingEnd := common.NewTLSConn(rawWritingEnd)
|
writingEnd := common.NewTLSConn(rawWritingEnd)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
buf := make([]byte, 10)
|
||||||
|
obfsBuf := make([]byte, 512)
|
||||||
t.Run("Plain read", func(t *testing.T) {
|
t.Run("Plain read", func(t *testing.T) {
|
||||||
f.StreamID = streamID
|
f.StreamID = streamID
|
||||||
i, _ := sesh.Obfs(f, obfsBuf, 0)
|
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||||
streamID++
|
streamID++
|
||||||
writingEnd.Write(obfsBuf[:i])
|
writingEnd.Write(obfsBuf[:i])
|
||||||
stream, err := sesh.Accept()
|
stream, err := sesh.Accept()
|
||||||
|
|
@ -307,7 +300,7 @@ func TestStream_Read(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("Nil buf", func(t *testing.T) {
|
t.Run("Nil buf", func(t *testing.T) {
|
||||||
f.StreamID = streamID
|
f.StreamID = streamID
|
||||||
i, _ := sesh.Obfs(f, obfsBuf, 0)
|
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||||
streamID++
|
streamID++
|
||||||
writingEnd.Write(obfsBuf[:i])
|
writingEnd.Write(obfsBuf[:i])
|
||||||
stream, _ := sesh.Accept()
|
stream, _ := sesh.Accept()
|
||||||
|
|
@ -319,21 +312,22 @@ func TestStream_Read(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("Read after stream close", func(t *testing.T) {
|
t.Run("Read after stream close", func(t *testing.T) {
|
||||||
f.StreamID = streamID
|
f.StreamID = streamID
|
||||||
i, _ := sesh.Obfs(f, obfsBuf, 0)
|
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||||
streamID++
|
streamID++
|
||||||
writingEnd.Write(obfsBuf[:i])
|
writingEnd.Write(obfsBuf[:i])
|
||||||
stream, _ := sesh.Accept()
|
stream, _ := sesh.Accept()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
stream.Close()
|
stream.Close()
|
||||||
i, err := stream.Read(buf)
|
|
||||||
|
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to read", err)
|
t.Errorf("cannot read residual data: %v", err)
|
||||||
}
|
}
|
||||||
if i != smallPayloadLen {
|
if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
|
||||||
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(buf[:i], testPayload) {
|
|
||||||
t.Error("expected", testPayload,
|
t.Error("expected", testPayload,
|
||||||
"got", buf[:i])
|
"got", buf[:smallPayloadLen])
|
||||||
}
|
}
|
||||||
_, err = stream.Read(buf)
|
_, err = stream.Read(buf)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
@ -343,21 +337,21 @@ func TestStream_Read(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("Read after session close", func(t *testing.T) {
|
t.Run("Read after session close", func(t *testing.T) {
|
||||||
f.StreamID = streamID
|
f.StreamID = streamID
|
||||||
i, _ := sesh.Obfs(f, obfsBuf, 0)
|
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||||
streamID++
|
streamID++
|
||||||
writingEnd.Write(obfsBuf[:i])
|
writingEnd.Write(obfsBuf[:i])
|
||||||
stream, _ := sesh.Accept()
|
stream, _ := sesh.Accept()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
sesh.Close()
|
sesh.Close()
|
||||||
i, err := stream.Read(buf)
|
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to read", err)
|
t.Errorf("cannot read resiual data: %v", err)
|
||||||
}
|
}
|
||||||
if i != smallPayloadLen {
|
if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
|
||||||
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(buf[:i], testPayload) {
|
|
||||||
t.Error("expected", testPayload,
|
t.Error("expected", testPayload,
|
||||||
"got", buf[:i])
|
"got", buf[:smallPayloadLen])
|
||||||
}
|
}
|
||||||
_, err = stream.Read(buf)
|
_, err = stream.Read(buf)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -31,6 +32,7 @@ type switchboard struct {
|
||||||
conns sync.Map
|
conns sync.Map
|
||||||
numConns uint32
|
numConns uint32
|
||||||
nextConnId uint32
|
nextConnId uint32
|
||||||
|
randPool sync.Pool
|
||||||
|
|
||||||
broken uint32
|
broken uint32
|
||||||
}
|
}
|
||||||
|
|
@ -48,6 +50,9 @@ func makeSwitchboard(sesh *Session) *switchboard {
|
||||||
strategy: strategy,
|
strategy: strategy,
|
||||||
valve: sesh.Valve,
|
valve: sesh.Valve,
|
||||||
nextConnId: 1,
|
nextConnId: 1,
|
||||||
|
randPool: sync.Pool{New: func() interface{} {
|
||||||
|
return rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
return sb
|
return sb
|
||||||
}
|
}
|
||||||
|
|
@ -67,45 +72,43 @@ func (sb *switchboard) addConn(conn net.Conn) {
|
||||||
|
|
||||||
// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable
|
// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable
|
||||||
func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) {
|
func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) {
|
||||||
writeAndRegUsage := func(conn net.Conn, d []byte) (int, error) {
|
|
||||||
n, err = conn.Write(d)
|
|
||||||
if err != nil {
|
|
||||||
sb.conns.Delete(*connId)
|
|
||||||
sb.close("failed to write to remote " + err.Error())
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
sb.valve.AddTx(int64(n))
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.valve.txWait(len(data))
|
sb.valve.txWait(len(data))
|
||||||
if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 {
|
if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 {
|
||||||
return 0, errBrokenSwitchboard
|
return 0, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var conn net.Conn
|
||||||
switch sb.strategy {
|
switch sb.strategy {
|
||||||
case UNIFORM_SPREAD:
|
case UNIFORM_SPREAD:
|
||||||
_, conn, err := sb.pickRandConn()
|
_, conn, err = sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errBrokenSwitchboard
|
return 0, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
return writeAndRegUsage(conn, data)
|
|
||||||
case FIXED_CONN_MAPPING:
|
case FIXED_CONN_MAPPING:
|
||||||
connI, ok := sb.conns.Load(*connId)
|
connI, ok := sb.conns.Load(*connId)
|
||||||
if ok {
|
if ok {
|
||||||
conn := connI.(net.Conn)
|
conn = connI.(net.Conn)
|
||||||
return writeAndRegUsage(conn, data)
|
|
||||||
} else {
|
} else {
|
||||||
newConnId, conn, err := sb.pickRandConn()
|
var newConnId uint32
|
||||||
|
newConnId, conn, err = sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errBrokenSwitchboard
|
return 0, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
*connId = newConnId
|
*connId = newConnId
|
||||||
return writeAndRegUsage(conn, data)
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("unsupported traffic distribution strategy")
|
return 0, errors.New("unsupported traffic distribution strategy")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
n, err = conn.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
sb.conns.Delete(*connId)
|
||||||
|
sb.session.SetTerminalMsg("failed to write to remote " + err.Error())
|
||||||
|
sb.session.passiveClose()
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
sb.valve.AddTx(int64(n))
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns a random connId
|
// returns a random connId
|
||||||
|
|
@ -120,7 +123,9 @@ func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) {
|
||||||
// so if the r > len(sb.conns) at the point of range call, the last visited element is picked
|
// so if the r > len(sb.conns) at the point of range call, the last visited element is picked
|
||||||
var id uint32
|
var id uint32
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
r := rand.Intn(connCount)
|
randReader := sb.randPool.Get().(*rand.Rand)
|
||||||
|
r := randReader.Intn(connCount)
|
||||||
|
sb.randPool.Put(randReader)
|
||||||
var c int
|
var c int
|
||||||
sb.conns.Range(func(connIdI, connI interface{}) bool {
|
sb.conns.Range(func(connIdI, connI interface{}) bool {
|
||||||
if r == c {
|
if r == c {
|
||||||
|
|
@ -138,16 +143,11 @@ func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) {
|
||||||
return id, conn, nil
|
return id, conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sb *switchboard) close(terminalMsg string) {
|
|
||||||
atomic.StoreUint32(&sb.broken, 1)
|
|
||||||
if !sb.session.IsClosed() {
|
|
||||||
sb.session.SetTerminalMsg(terminalMsg)
|
|
||||||
sb.session.passiveClose()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// actively triggered by session.Close()
|
// actively triggered by session.Close()
|
||||||
func (sb *switchboard) closeAll() {
|
func (sb *switchboard) closeAll() {
|
||||||
|
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
sb.conns.Range(func(key, connI interface{}) bool {
|
sb.conns.Range(func(key, connI interface{}) bool {
|
||||||
conn := connI.(net.Conn)
|
conn := connI.(net.Conn)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
@ -168,7 +168,8 @@ func (sb *switchboard) deplex(connId uint32, conn net.Conn) {
|
||||||
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
|
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
|
||||||
sb.conns.Delete(connId)
|
sb.conns.Delete(connId)
|
||||||
atomic.AddUint32(&sb.numConns, ^uint32(0))
|
atomic.AddUint32(&sb.numConns, ^uint32(0))
|
||||||
sb.close("a connection has dropped unexpectedly")
|
sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
|
||||||
|
sb.session.passiveClose()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,13 @@ func dispatchConnection(conn net.Conn, sta *State) {
|
||||||
common.RandRead(sta.WorldState.Rand, sessionKey[:])
|
common.RandRead(sta.WorldState.Rand, sessionKey[:])
|
||||||
obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey)
|
obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.WithFields(log.Fields{
|
||||||
|
"remoteAddr": conn.RemoteAddr(),
|
||||||
|
"UID": b64(ci.UID),
|
||||||
|
"sessionId": ci.SessionId,
|
||||||
|
"proxyMethod": ci.ProxyMethod,
|
||||||
|
"encryptionMethod": ci.EncryptionMethod,
|
||||||
|
}).Error(err)
|
||||||
goWeb()
|
goWeb()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,15 +30,14 @@ func serveTCPEcho(l net.Listener) {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func(conn net.Conn) {
|
||||||
conn := conn
|
|
||||||
_, err := io.Copy(conn, conn)
|
_, err := io.Copy(conn, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -50,8 +49,7 @@ func serveUDPEcho(listener *connutil.PipeListener) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const bufSize = 32 * 1024
|
const bufSize = 32 * 1024
|
||||||
go func() {
|
go func(conn net.PacketConn) {
|
||||||
conn := conn
|
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
buf := make([]byte, bufSize)
|
buf := make([]byte, bufSize)
|
||||||
for {
|
for {
|
||||||
|
|
@ -70,7 +68,7 @@ func serveUDPEcho(listener *connutil.PipeListener) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -222,30 +220,34 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a
|
||||||
return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil
|
return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
|
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
testData := make([]byte, msgLen)
|
||||||
|
rand.Read(testData)
|
||||||
|
|
||||||
for _, conn := range conns {
|
for _, conn := range conns {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(conn net.Conn) {
|
go func(conn net.Conn) {
|
||||||
testDataLen := rand.Intn(maxMsgLen)
|
defer wg.Done()
|
||||||
testData := make([]byte, testDataLen)
|
|
||||||
rand.Read(testData)
|
|
||||||
|
|
||||||
|
// we cannot call t.Fatalf in concurrent contexts
|
||||||
n, err := conn.Write(testData)
|
n, err := conn.Write(testData)
|
||||||
if n != testDataLen {
|
if n != msgLen {
|
||||||
t.Fatalf("written only %v, err %v", n, err)
|
t.Errorf("written only %v, err %v", n, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
recvBuf := make([]byte, testDataLen)
|
recvBuf := make([]byte, msgLen)
|
||||||
_, err = io.ReadFull(conn, recvBuf)
|
_, err = io.ReadFull(conn, recvBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read back: %v", err)
|
t.Errorf("failed to read back: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(testData, recvBuf) {
|
if !bytes.Equal(testData, recvBuf) {
|
||||||
t.Fatalf("echoed data not correct")
|
t.Errorf("echoed data not correct")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
wg.Done()
|
|
||||||
}(conn)
|
}(conn)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
@ -294,6 +296,7 @@ func TestUDP(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const echoMsgLen = 1024
|
||||||
t.Run("user echo", func(t *testing.T) {
|
t.Run("user echo", func(t *testing.T) {
|
||||||
go serveUDPEcho(proxyFromCkServerL)
|
go serveUDPEcho(proxyFromCkServerL)
|
||||||
var conn [1]net.Conn
|
var conn [1]net.Conn
|
||||||
|
|
@ -302,7 +305,7 @@ func TestUDP(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
runEchoTest(t, conn[:], 1024)
|
runEchoTest(t, conn[:], echoMsgLen)
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -317,13 +320,14 @@ func TestTCPSingleplex(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const echoMsgLen = 16384
|
||||||
go serveTCPEcho(proxyFromCkServerL)
|
go serveTCPEcho(proxyFromCkServerL)
|
||||||
|
|
||||||
proxyConn1, err := proxyToCkClientD.Dial("", "")
|
proxyConn1, err := proxyToCkClientD.Dial("", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
runEchoTest(t, []net.Conn{proxyConn1}, 65536)
|
runEchoTest(t, []net.Conn{proxyConn1}, echoMsgLen)
|
||||||
user, err := sta.Panel.GetUser(ai.UID[:])
|
user, err := sta.Panel.GetUser(ai.UID[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to fetch user: %v", err)
|
t.Fatalf("failed to fetch user: %v", err)
|
||||||
|
|
@ -335,15 +339,15 @@ func TestTCPSingleplex(t *testing.T) {
|
||||||
|
|
||||||
proxyConn2, err := proxyToCkClientD.Dial("", "")
|
proxyConn2, err := proxyToCkClientD.Dial("", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
runEchoTest(t, []net.Conn{proxyConn2}, 65536)
|
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
|
||||||
if user.NumSession() != 2 {
|
if user.NumSession() != 2 {
|
||||||
t.Error("no extra session were made on second connection establishment")
|
t.Error("no extra session were made on second connection establishment")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Both conns should work
|
// Both conns should work
|
||||||
runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, 65536)
|
runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, echoMsgLen)
|
||||||
|
|
||||||
proxyConn1.Close()
|
proxyConn1.Close()
|
||||||
|
|
||||||
|
|
@ -352,17 +356,17 @@ func TestTCPSingleplex(t *testing.T) {
|
||||||
}, time.Second, 10*time.Millisecond, "first session was not closed on connection close")
|
}, time.Second, 10*time.Millisecond, "first session was not closed on connection close")
|
||||||
|
|
||||||
// conn2 should still work
|
// conn2 should still work
|
||||||
runEchoTest(t, []net.Conn{proxyConn2}, 65536)
|
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
|
||||||
|
|
||||||
var conns [numConns]net.Conn
|
var conns [numConns]net.Conn
|
||||||
for i := 0; i < numConns; i++ {
|
for i := 0; i < numConns; i++ {
|
||||||
conns[i], err = proxyToCkClientD.Dial("", "")
|
conns[i], err = proxyToCkClientD.Dial("", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
runEchoTest(t, conns[:], 65536)
|
runEchoTest(t, conns[:], echoMsgLen)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -410,6 +414,7 @@ func TestTCPMultiplex(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const echoMsgLen = 16384
|
||||||
t.Run("user echo", func(t *testing.T) {
|
t.Run("user echo", func(t *testing.T) {
|
||||||
go serveTCPEcho(proxyFromCkServerL)
|
go serveTCPEcho(proxyFromCkServerL)
|
||||||
var conns [numConns]net.Conn
|
var conns [numConns]net.Conn
|
||||||
|
|
@ -420,7 +425,7 @@ func TestTCPMultiplex(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
runEchoTest(t, conns[:], 65536)
|
runEchoTest(t, conns[:], echoMsgLen)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("redir echo", func(t *testing.T) {
|
t.Run("redir echo", func(t *testing.T) {
|
||||||
|
|
@ -432,7 +437,7 @@ func TestTCPMultiplex(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
runEchoTest(t, conns[:], 65536)
|
runEchoTest(t, conns[:], echoMsgLen)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -503,7 +508,7 @@ func TestClosingStreamsFromProxy(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkThroughput(b *testing.B) {
|
func BenchmarkIntegration(b *testing.B) {
|
||||||
log.SetLevel(log.ErrorLevel)
|
log.SetLevel(log.ErrorLevel)
|
||||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||||
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
||||||
|
|
@ -513,7 +518,8 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
encryptionMethods := map[string]byte{
|
encryptionMethods := map[string]byte{
|
||||||
"plain": mux.EncryptionMethodPlain,
|
"plain": mux.EncryptionMethodPlain,
|
||||||
"chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305,
|
"chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305,
|
||||||
"aes-gcm": mux.EncryptionMethodAESGCM,
|
"aes-256-gcm": mux.EncryptionMethodAES256GCM,
|
||||||
|
"aes-128-gcm": mux.EncryptionMethodAES128GCM,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, method := range encryptionMethods {
|
for name, method := range encryptionMethods {
|
||||||
|
|
@ -524,7 +530,7 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Run("single stream", func(b *testing.B) {
|
b.Run("single stream bandwidth", func(b *testing.B) {
|
||||||
more := make(chan int, 10)
|
more := make(chan int, 10)
|
||||||
go func() {
|
go func() {
|
||||||
// sender
|
// sender
|
||||||
|
|
@ -548,6 +554,19 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
b.Run("single stream latency", func(b *testing.B) {
|
||||||
|
clientConn, _ := proxyToCkClientD.Dial("", "")
|
||||||
|
buf := []byte{1}
|
||||||
|
clientConn.Write(buf)
|
||||||
|
serverConn, _ := proxyFromCkServerL.Accept()
|
||||||
|
serverConn.Read(buf)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
clientConn.Write(buf)
|
||||||
|
serverConn.Read(buf)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue