Merge branch 'master' into notsure2

This commit is contained in:
notsure2 2020-12-27 22:09:46 +02:00
commit 7be088e7c1
15 changed files with 390 additions and 354 deletions

View File

@ -126,11 +126,11 @@ instead a CDN is used, use `CDN`.
`ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the `ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the
server's `ProxyBook` exactly. server's `ProxyBook` exactly.
`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-gcm` `EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-256-gcm` (
and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport security. The point of encryption is to hide synonymous to `aes-gcm`), `aes-128-gcm`, and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport
fingerprints of proxy protocols and render the payload statistically random-like. **You may only leave it as `plain` if security. The point of encryption is to hide fingerprints of proxy protocols and render the payload statistically
you are certain that your underlying proxy tool already provides BOTH encryption and authentication (via AEAD or similar random-like. **You may only leave it as `plain` if you are certain that your underlying proxy tool already provides BOTH
techniques).** encryption and authentication (via AEAD or similar techniques).**
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should `ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to. match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to.

View File

@ -164,7 +164,10 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca
case "plain": case "plain":
auth.EncryptionMethod = mux.EncryptionMethodPlain auth.EncryptionMethod = mux.EncryptionMethodPlain
case "aes-gcm": case "aes-gcm":
auth.EncryptionMethod = mux.EncryptionMethodAESGCM case "aes-256-gcm":
auth.EncryptionMethod = mux.EncryptionMethodAES256GCM
case "aes-128-gcm":
auth.EncryptionMethod = mux.EncryptionMethodAES128GCM
case "chacha20-poly1305": case "chacha20-poly1305":
auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305 auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305
default: default:

View File

@ -1,7 +1,6 @@
package common package common
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
@ -44,7 +43,9 @@ func NewTLSConn(conn net.Conn) *TLSConn {
return &TLSConn{ return &TLSConn{
Conn: conn, Conn: conn,
writeBufPool: sync.Pool{New: func() interface{} { writeBufPool: sync.Pool{New: func() interface{} {
return new(bytes.Buffer) b := make([]byte, 0, initialWriteBufSize)
b = append(b, ApplicationData, byte(VersionTLS13>>8), byte(VersionTLS13&0xFF))
return &b
}}, }},
} }
} }
@ -93,16 +94,13 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
func (tls *TLSConn) Write(in []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) {
msgLen := len(in) msgLen := len(in)
writeBuf := tls.writeBufPool.Get().(*bytes.Buffer) writeBuf := tls.writeBufPool.Get().(*[]byte)
writeBuf.WriteByte(ApplicationData) *writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
writeBuf.WriteByte(byte(VersionTLS13 >> 8)) *writeBuf = append(*writeBuf, in...)
writeBuf.WriteByte(byte(VersionTLS13 & 0xFF)) n, err = tls.Conn.Write(*writeBuf)
writeBuf.WriteByte(byte(msgLen >> 8)) *writeBuf = (*writeBuf)[:3]
writeBuf.WriteByte(byte(msgLen & 0xFF))
writeBuf.Write(in)
i, err := writeBuf.WriteTo(tls.Conn)
tls.writeBufPool.Put(writeBuf) tls.writeBufPool.Put(writeBuf)
return int(i - recordLayerLength), err return n - recordLayerLength, err
} }
func (tls *TLSConn) Close() error { func (tls *TLSConn) Close() error {

View File

@ -14,7 +14,6 @@ import (
// it won't get chopped up into individual bytes // it won't get chopped up into individual bytes
type datagramBufferedPipe struct { type datagramBufferedPipe struct {
pLens []int pLens []int
// lazily allocated
buf *bytes.Buffer buf *bytes.Buffer
closed bool closed bool
rwCond *sync.Cond rwCond *sync.Cond
@ -27,6 +26,7 @@ type datagramBufferedPipe struct {
func NewDatagramBufferedPipe() *datagramBufferedPipe { func NewDatagramBufferedPipe() *datagramBufferedPipe {
d := &datagramBufferedPipe{ d := &datagramBufferedPipe{
rwCond: sync.NewCond(&sync.Mutex{}), rwCond: sync.NewCond(&sync.Mutex{}),
buf: new(bytes.Buffer),
} }
return d return d
} }
@ -34,9 +34,6 @@ func NewDatagramBufferedPipe() *datagramBufferedPipe {
func (d *datagramBufferedPipe) Read(target []byte) (int, error) { func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
d.rwCond.L.Lock() d.rwCond.L.Lock()
defer d.rwCond.L.Unlock() defer d.rwCond.L.Unlock()
if d.buf == nil {
d.buf = new(bytes.Buffer)
}
for { for {
if d.closed && len(d.pLens) == 0 { if d.closed && len(d.pLens) == 0 {
return 0, io.EOF return 0, io.EOF
@ -72,9 +69,6 @@ func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
d.rwCond.L.Lock() d.rwCond.L.Lock()
defer d.rwCond.L.Unlock() defer d.rwCond.L.Unlock()
if d.buf == nil {
d.buf = new(bytes.Buffer)
}
for { for {
if d.closed && len(d.pLens) == 0 { if d.closed && len(d.pLens) == 0 {
return 0, io.EOF return 0, io.EOF
@ -115,9 +109,6 @@ func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) { func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
d.rwCond.L.Lock() d.rwCond.L.Lock()
defer d.rwCond.L.Unlock() defer d.rwCond.L.Unlock()
if d.buf == nil {
d.buf = new(bytes.Buffer)
}
for { for {
if d.closed { if d.closed {
return true, io.ErrClosedPipe return true, io.ErrClosedPipe

View File

@ -19,13 +19,13 @@ func serveEcho(l net.Listener) {
// TODO: pass the error back // TODO: pass the error back
return return
} }
go func() { go func(conn net.Conn) {
_, err := io.Copy(conn, conn) _, err := io.Copy(conn, conn)
if err != nil { if err != nil {
// TODO: pass the error back // TODO: pass the error back
return return
} }
}() }(conn)
} }
} }
@ -65,27 +65,32 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) { func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
var wg sync.WaitGroup var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func(conn net.Conn) {
testData := make([]byte, msgLen) testData := make([]byte, msgLen)
rand.Read(testData) rand.Read(testData)
for _, conn := range conns {
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
// we cannot call t.Fatalf in concurrent contexts
n, err := conn.Write(testData) n, err := conn.Write(testData)
if n != msgLen { if n != msgLen {
t.Fatalf("written only %v, err %v", n, err) t.Errorf("written only %v, err %v", n, err)
return
} }
recvBuf := make([]byte, msgLen) recvBuf := make([]byte, msgLen)
_, err = io.ReadFull(conn, recvBuf) _, err = io.ReadFull(conn, recvBuf)
if err != nil { if err != nil {
t.Fatalf("failed to read back: %v", err) t.Errorf("failed to read back: %v", err)
return
} }
if !bytes.Equal(testData, recvBuf) { if !bytes.Equal(testData, recvBuf) {
t.Fatalf("echoed data not correct") t.Errorf("echoed data not correct")
return
} }
wg.Done()
}(conn) }(conn)
} }
wg.Wait() wg.Wait()

View File

@ -11,9 +11,6 @@ import (
"golang.org/x/crypto/salsa20" "golang.org/x/crypto/salsa20"
) )
type Obfser func(*Frame, []byte, int) (int, error)
type Deobfser func(*Frame, []byte) error
var u32 = binary.BigEndian.Uint32 var u32 = binary.BigEndian.Uint32
var u64 = binary.BigEndian.Uint64 var u64 = binary.BigEndian.Uint64
var putU32 = binary.BigEndian.PutUint32 var putU32 = binary.BigEndian.PutUint32
@ -24,27 +21,22 @@ const salsa20NonceSize = 8
const ( const (
EncryptionMethodPlain = iota EncryptionMethodPlain = iota
EncryptionMethodAESGCM EncryptionMethodAES256GCM
EncryptionMethodChaha20Poly1305 EncryptionMethodChaha20Poly1305
EncryptionMethodAES128GCM
) )
// Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames. // Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames.
type Obfuscator struct { type Obfuscator struct {
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header payloadCipher cipher.AEAD
Obfs Obfser
// Remove TLS header, decrypt and unmarshall frames
Deobfs Deobfser
SessionKey [32]byte SessionKey [32]byte
maxOverhead int maxOverhead int
} }
// MakeObfs returns a function of type Obfser. An Obfser takes three arguments: // obfuscate adds multiplexing headers, encrypt and add TLS header
// a *Frame with all the field set correctly, a []byte as buffer to put encrypted func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) {
// message in, and an int called payloadOffsetInBuf to be used when *Frame.payload
// is in the byte slice used as buffer (2nd argument). payloadOffsetInBuf specifies
// the index at which data belonging to *Frame.Payload starts in the buffer.
func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
// The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header // The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header
// as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use // as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use
// the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead()) // the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead())
@ -76,20 +68,19 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
// We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input // We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input
// is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends // is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends
// encrypted data. // encrypted data.
obfs := func(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) {
payloadLen := len(f.Payload) payloadLen := len(f.Payload)
if payloadLen == 0 { if payloadLen == 0 {
return 0, errors.New("payload cannot be empty") return 0, errors.New("payload cannot be empty")
} }
var extraLen int var extraLen int
if payloadCipher == nil { if o.payloadCipher == nil {
extraLen = salsa20NonceSize - payloadLen extraLen = salsa20NonceSize - payloadLen
if extraLen < 0 { if extraLen < 0 {
// if our payload is already greater than 8 bytes // if our payload is already greater than 8 bytes
extraLen = 0 extraLen = 0
} }
} else { } else {
extraLen = payloadCipher.Overhead() extraLen = o.payloadCipher.Overhead()
if extraLen < salsa20NonceSize { if extraLen < salsa20NonceSize {
return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes")
} }
@ -112,39 +103,32 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
header[12] = f.Closing header[12] = f.Closing
header[13] = byte(extraLen) header[13] = byte(extraLen)
if payloadCipher == nil { if o.payloadCipher == nil {
if extraLen != 0 { // read nonce if extraLen != 0 { // read nonce
extra := buf[usefulLen-extraLen : usefulLen] extra := buf[usefulLen-extraLen : usefulLen]
common.CryptoRandRead(extra) common.CryptoRandRead(extra)
} }
} else { } else {
payloadCipher.Seal(payload[:0], header[:payloadCipher.NonceSize()], payload, nil) o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil)
} }
nonce := buf[usefulLen-salsa20NonceSize : usefulLen] nonce := buf[usefulLen-salsa20NonceSize : usefulLen]
salsa20.XORKeyStream(header, header, nonce, &salsaKey) salsa20.XORKeyStream(header, header, nonce, &o.SessionKey)
return usefulLen, nil return usefulLen, nil
}
return obfs
} }
// MakeDeobfs returns a function Deobfser. A Deobfser takes in a single byte slice, // deobfuscate removes TLS header, decrypt and unmarshall frames
// containing the message to be decrypted, and returns a *Frame containing the frame func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error {
// information and plaintext if len(in) < frameHeaderLength+salsa20NonceSize {
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize)
// frame header length + minimum data size (i.e. nonce size of salsa20)
const minInputLen = frameHeaderLength + salsa20NonceSize
deobfs := func(f *Frame, in []byte) error {
if len(in) < minInputLen {
return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen)
} }
header := in[:frameHeaderLength] header := in[:frameHeaderLength]
pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead
nonce := in[len(in)-salsa20NonceSize:] nonce := in[len(in)-salsa20NonceSize:]
salsa20.XORKeyStream(header, header, nonce, &salsaKey) salsa20.XORKeyStream(header, header, nonce, &o.SessionKey)
streamID := u32(header[0:4]) streamID := u32(header[0:4])
seq := u64(header[4:12]) seq := u64(header[4:12])
@ -158,14 +142,14 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
var outputPayload []byte var outputPayload []byte
if payloadCipher == nil { if o.payloadCipher == nil {
if extraLen == 0 { if extraLen == 0 {
outputPayload = pldWithOverHead outputPayload = pldWithOverHead
} else { } else {
outputPayload = pldWithOverHead[:usefulPayloadLen] outputPayload = pldWithOverHead[:usefulPayloadLen]
} }
} else { } else {
_, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) _, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil)
if err != nil { if err != nil {
return err return err
} }
@ -177,8 +161,6 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
f.Closing = closing f.Closing = closing
f.Payload = outputPayload f.Payload = outputPayload
return nil return nil
}
return deobfs
} }
func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) { func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) {
@ -190,7 +172,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
case EncryptionMethodPlain: case EncryptionMethodPlain:
payloadCipher = nil payloadCipher = nil
obfuscator.maxOverhead = salsa20NonceSize obfuscator.maxOverhead = salsa20NonceSize
case EncryptionMethodAESGCM: case EncryptionMethodAES256GCM:
var c cipher.Block var c cipher.Block
c, err = aes.NewCipher(sessionKey[:]) c, err = aes.NewCipher(sessionKey[:])
if err != nil { if err != nil {
@ -201,6 +183,17 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
return return
} }
obfuscator.maxOverhead = payloadCipher.Overhead() obfuscator.maxOverhead = payloadCipher.Overhead()
case EncryptionMethodAES128GCM:
var c cipher.Block
c, err = aes.NewCipher(sessionKey[:16])
if err != nil {
return
}
payloadCipher, err = cipher.NewGCM(c)
if err != nil {
return
}
obfuscator.maxOverhead = payloadCipher.Overhead()
case EncryptionMethodChaha20Poly1305: case EncryptionMethodChaha20Poly1305:
payloadCipher, err = chacha20poly1305.New(sessionKey[:]) payloadCipher, err = chacha20poly1305.New(sessionKey[:])
if err != nil { if err != nil {
@ -208,7 +201,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
} }
obfuscator.maxOverhead = payloadCipher.Overhead() obfuscator.maxOverhead = payloadCipher.Overhead()
default: default:
return obfuscator, errors.New("Unknown encryption method") return obfuscator, fmt.Errorf("unknown encryption method valued %v", encryptionMethod)
} }
if payloadCipher != nil { if payloadCipher != nil {
@ -217,7 +210,5 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu
} }
} }
obfuscator.Obfs = MakeObfs(sessionKey, payloadCipher)
obfuscator.Deobfs = MakeDeobfs(sessionKey, payloadCipher)
return return
} }

View File

@ -19,14 +19,14 @@ func TestGenerateObfs(t *testing.T) {
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
_testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42)))
testFrame := _testFrame.Interface().(*Frame) testFrame := _testFrame.Interface().(*Frame)
i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) i, err := obfuscator.obfuscate(testFrame, obfsBuf, 0)
if err != nil { if err != nil {
ct.Error("failed to obfs ", err) ct.Error("failed to obfs ", err)
return return
} }
var resultFrame Frame var resultFrame Frame
err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) err = obfuscator.deobfuscate(&resultFrame, obfsBuf[:i])
if err != nil { if err != nil {
ct.Error("failed to deobfs ", err) ct.Error("failed to deobfs ", err)
return return
@ -46,8 +46,16 @@ func TestGenerateObfs(t *testing.T) {
run(obfuscator, t) run(obfuscator, t)
} }
}) })
t.Run("aes-gcm", func(t *testing.T) { t.Run("aes-256-gcm", func(t *testing.T) {
obfuscator, err := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) obfuscator, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey)
if err != nil {
t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
})
t.Run("aes-128-gcm", func(t *testing.T) {
obfuscator, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey)
if err != nil { if err != nil {
t.Errorf("failed to generate obfuscator %v", err) t.Errorf("failed to generate obfuscator %v", err)
} else { } else {
@ -88,40 +96,57 @@ func BenchmarkObfs(b *testing.B) {
c, _ := aes.NewCipher(key[:]) c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfs := MakeObfs(key, payloadCipher) obfuscator := Obfuscator{
payloadCipher: payloadCipher,
SessionKey: key,
maxOverhead: payloadCipher.Overhead(),
}
b.SetBytes(int64(len(testFrame.Payload))) b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfs(testFrame, obfsBuf, 0) obfuscator.obfuscate(testFrame, obfsBuf, 0)
} }
}) })
b.Run("AES128GCM", func(b *testing.B) { b.Run("AES128GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:16]) c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfs := MakeObfs(key, payloadCipher) obfuscator := Obfuscator{
payloadCipher: payloadCipher,
SessionKey: key,
maxOverhead: payloadCipher.Overhead(),
}
b.SetBytes(int64(len(testFrame.Payload))) b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfs(testFrame, obfsBuf, 0) obfuscator.obfuscate(testFrame, obfsBuf, 0)
} }
}) })
b.Run("plain", func(b *testing.B) { b.Run("plain", func(b *testing.B) {
obfs := MakeObfs(key, nil) obfuscator := Obfuscator{
payloadCipher: nil,
SessionKey: key,
maxOverhead: salsa20NonceSize,
}
b.SetBytes(int64(len(testFrame.Payload))) b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfs(testFrame, obfsBuf, 0) obfuscator.obfuscate(testFrame, obfsBuf, 0)
} }
}) })
b.Run("chacha20Poly1305", func(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:16]) payloadCipher, _ := chacha20poly1305.New(key[:])
obfs := MakeObfs(key, payloadCipher) obfuscator := Obfuscator{
payloadCipher: payloadCipher,
SessionKey: key,
maxOverhead: payloadCipher.Overhead(),
}
b.SetBytes(int64(len(testFrame.Payload))) b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfs(testFrame, obfsBuf, 0) obfuscator.obfuscate(testFrame, obfsBuf, 0)
} }
}) })
} }
@ -143,57 +168,70 @@ func BenchmarkDeobfs(b *testing.B) {
b.Run("AES256GCM", func(b *testing.B) { b.Run("AES256GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:]) c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
SessionKey: key,
maxOverhead: payloadCipher.Overhead(),
}
obfs := MakeObfs(key, payloadCipher) n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, payloadCipher)
frame := new(Frame) frame := new(Frame)
b.SetBytes(int64(n)) b.SetBytes(int64(n))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(frame, obfsBuf[:n]) obfuscator.deobfuscate(frame, obfsBuf[:n])
} }
}) })
b.Run("AES128GCM", func(b *testing.B) { b.Run("AES128GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:16]) c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfs := MakeObfs(key, payloadCipher) obfuscator := Obfuscator{
n, _ := obfs(testFrame, obfsBuf, 0) payloadCipher: payloadCipher,
deobfs := MakeDeobfs(key, payloadCipher) SessionKey: key,
maxOverhead: payloadCipher.Overhead(),
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame) frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(frame, obfsBuf[:n]) obfuscator.deobfuscate(frame, obfsBuf[:n])
} }
}) })
b.Run("plain", func(b *testing.B) { b.Run("plain", func(b *testing.B) {
obfs := MakeObfs(key, nil) obfuscator := Obfuscator{
n, _ := obfs(testFrame, obfsBuf, 0) payloadCipher: nil,
deobfs := MakeDeobfs(key, nil) SessionKey: key,
maxOverhead: salsa20NonceSize,
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame) frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(frame, obfsBuf[:n]) obfuscator.deobfuscate(frame, obfsBuf[:n])
} }
}) })
b.Run("chacha20Poly1305", func(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:16]) payloadCipher, _ := chacha20poly1305.New(key[:])
obfs := MakeObfs(key, payloadCipher) obfuscator := Obfuscator{
n, _ := obfs(testFrame, obfsBuf, 0) payloadCipher: nil,
deobfs := MakeDeobfs(key, payloadCipher) SessionKey: key,
maxOverhead: payloadCipher.Overhead(),
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame) frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(frame, obfsBuf[:n]) obfuscator.deobfuscate(frame, obfsBuf[:n])
} }
}) })
} }

View File

@ -66,19 +66,20 @@ type Session struct {
streamsM sync.Mutex streamsM sync.Mutex
streams map[uint32]*Stream streams map[uint32]*Stream
// For accepting new streams
acceptCh chan *Stream
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool recvFramePool sync.Pool
streamObfsBufPool sync.Pool
// Switchboard manages all connections to remote // Switchboard manages all connections to remote
sb *switchboard sb *switchboard
// Used for LocalAddr() and RemoteAddr() etc. // Used for LocalAddr() and RemoteAddr() etc.
addrs atomic.Value addrs atomic.Value
// For accepting new streams
acceptCh chan *Stream
closed uint32 closed uint32
terminalMsg atomic.Value terminalMsg atomic.Value
@ -117,6 +118,11 @@ func MakeSession(id uint32, config SessionConfig) *Session {
// todo: validation. this must be smaller than StreamSendBufferSize // todo: validation. this must be smaller than StreamSendBufferSize
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
b := make([]byte, sesh.StreamSendBufferSize)
return &b
}}
sesh.sb = makeSwitchboard(sesh) sesh.sb = makeSwitchboard(sesh)
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout) time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
return sesh return sesh
@ -174,26 +180,26 @@ func (sesh *Session) Accept() (net.Conn, error) {
} }
func (sesh *Session) closeStream(s *Stream, active bool) error { func (sesh *Session) closeStream(s *Stream, active bool) error {
// must be holding s.wirtingM on entry if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
if atomic.SwapUint32(&s.closed, 1) == 1 {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
} }
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error _ = s.getRecvBuf().Close() // recvBuf.Close should not return error
if active { if active {
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
// Notify remote that this stream is closed // Notify remote that this stream is closed
padding := genRandomPadding() common.CryptoRandRead((*tmpBuf)[:1])
padLen := int((*tmpBuf)[0]) + 1
payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
// must be holding s.wirtingM on entry
s.writingFrame.Closing = closingStream s.writingFrame.Closing = closingStream
s.writingFrame.Payload = padding s.writingFrame.Payload = payload
obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead) err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength)
sesh.streamObfsBufPool.Put(tmpBuf)
i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0)
s.writingFrame.Seq++
if err != nil {
return err
}
_, err = sesh.sb.send(obfsBuf[:i], &s.assignedConnId)
if err != nil { if err != nil {
return err return err
} }
@ -226,7 +232,7 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
frame := sesh.recvFramePool.Get().(*Frame) frame := sesh.recvFramePool.Get().(*Frame)
defer sesh.recvFramePool.Put(frame) defer sesh.recvFramePool.Put(frame)
err := sesh.Deobfs(frame, data) err := sesh.deobfuscate(frame, data)
if err != nil { if err != nil {
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
} }
@ -237,6 +243,10 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
} }
sesh.streamsM.Lock() sesh.streamsM.Lock()
if sesh.IsClosed() {
sesh.streamsM.Unlock()
return ErrBrokenSession
}
existingStream, existing := sesh.streams[frame.StreamID] existingStream, existing := sesh.streams[frame.StreamID]
if existing { if existing {
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
@ -248,10 +258,10 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
} else { } else {
newStream := makeStream(sesh, frame.StreamID) newStream := makeStream(sesh, frame.StreamID)
sesh.streams[frame.StreamID] = newStream sesh.streams[frame.StreamID] = newStream
sesh.acceptCh <- newStream
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
// new stream // new stream
sesh.streamCountIncr() sesh.streamCountIncr()
sesh.acceptCh <- newStream
return newStream.recvFrame(frame) return newStream.recvFrame(frame)
} }
} }
@ -269,14 +279,14 @@ func (sesh *Session) TerminalMsg() string {
} }
} }
func (sesh *Session) closeSession(closeSwitchboard bool) error { func (sesh *Session) closeSession() error {
if atomic.SwapUint32(&sesh.closed, 1) == 1 { if !atomic.CompareAndSwapUint32(&sesh.closed, 0, 1) {
log.Debugf("session %v has already been closed", sesh.id) log.Debugf("session %v has already been closed", sesh.id)
return errRepeatSessionClosing return errRepeatSessionClosing
} }
sesh.acceptCh <- nil
sesh.streamsM.Lock() sesh.streamsM.Lock()
close(sesh.acceptCh)
for id, stream := range sesh.streams { for id, stream := range sesh.streams {
if stream == nil { if stream == nil {
continue continue
@ -287,55 +297,48 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
sesh.streamCountDecr() sesh.streamCountDecr()
} }
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
if closeSwitchboard {
sesh.sb.closeAll()
}
return nil return nil
} }
func (sesh *Session) passiveClose() error { func (sesh *Session) passiveClose() error {
log.Debugf("attempting to passively close session %v", sesh.id) log.Debugf("attempting to passively close session %v", sesh.id)
err := sesh.closeSession(true) err := sesh.closeSession()
if err != nil { if err != nil {
return err return err
} }
sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id) log.Debugf("session %v closed gracefully", sesh.id)
return nil return nil
} }
func genRandomPadding() []byte {
lenB := make([]byte, 1)
common.CryptoRandRead(lenB)
pad := make([]byte, int(lenB[0])+1)
common.CryptoRandRead(pad)
return pad
}
func (sesh *Session) Close() error { func (sesh *Session) Close() error {
log.Debugf("attempting to actively close session %v", sesh.id) log.Debugf("attempting to actively close session %v", sesh.id)
err := sesh.closeSession(false) err := sesh.closeSession()
if err != nil { if err != nil {
return err return err
} }
// we send a notice frame telling remote to close the session // we send a notice frame telling remote to close the session
pad := genRandomPadding()
buf := sesh.streamObfsBufPool.Get().(*[]byte)
common.CryptoRandRead((*buf)[:1])
padLen := int((*buf)[0]) + 1
payload := (*buf)[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
f := &Frame{ f := &Frame{
StreamID: 0xffffffff, StreamID: 0xffffffff,
Seq: 0, Seq: 0,
Closing: closingSession, Closing: closingSession,
Payload: pad, Payload: payload,
} }
obfsBuf := make([]byte, len(pad)+frameHeaderLength+sesh.Obfuscator.maxOverhead) i, err := sesh.obfuscate(f, *buf, frameHeaderLength)
i, err := sesh.Obfs(f, obfsBuf, 0)
if err != nil { if err != nil {
return err return err
} }
_, err = sesh.sb.send(obfsBuf[:i], new(uint32)) _, err = sesh.sb.send((*buf)[:i], new(uint32))
if err != nil { if err != nil {
return err return err
} }
sesh.sb.closeAll() sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id) log.Debugf("session %v closed gracefully", sesh.id)
return nil return nil

View File

@ -44,7 +44,7 @@ func TestRecvDataFromRemote(t *testing.T) {
encryptionMethods := map[string]Obfuscator{ encryptionMethods := map[string]Obfuscator{
"plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
"aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAES256GCM, sessionKey),
"chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
} }
@ -56,7 +56,7 @@ func TestRecvDataFromRemote(t *testing.T) {
t.Run(method, func(t *testing.T) { t.Run(method, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig) sesh := MakeSession(0, seshConfig)
n, err := sesh.Obfs(f, obfsBuf, 0) n, err := sesh.obfuscate(f, obfsBuf, 0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -107,7 +107,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
testPayload, testPayload,
} }
// create stream 1 // create stream 1
n, _ := sesh.Obfs(f1, obfsBuf, 0) n, _ := sesh.obfuscate(f1, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n]) err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil { if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err) t.Fatalf("receiving normal frame for stream 1: %v", err)
@ -129,7 +129,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
closingNothing, closingNothing,
testPayload, testPayload,
} }
n, _ = sesh.Obfs(f2, obfsBuf, 0) n, _ = sesh.obfuscate(f2, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n]) err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil { if err != nil {
t.Fatalf("receiving normal frame for stream 2: %v", err) t.Fatalf("receiving normal frame for stream 2: %v", err)
@ -151,7 +151,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
closingStream, closingStream,
testPayload, testPayload,
} }
n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n]) err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil { if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err) t.Fatalf("receiving stream closing frame for stream 1: %v", err)
@ -180,7 +180,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
} }
// close stream 1 again // close stream 1 again
n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n]) err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil { if err != nil {
t.Fatalf("receiving stream closing frame for stream 1 %v", err) t.Fatalf("receiving stream closing frame for stream 1 %v", err)
@ -203,7 +203,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
Closing: closingSession, Closing: closingSession,
Payload: testPayload, Payload: testPayload,
} }
n, _ = sesh.Obfs(fCloseSession, obfsBuf, 0) n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n]) err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil { if err != nil {
t.Fatalf("receiving session closing frame: %v", err) t.Fatalf("receiving session closing frame: %v", err)
@ -246,7 +246,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
closingStream, closingStream,
testPayload, testPayload,
} }
n, _ := sesh.Obfs(f1CloseStream, obfsBuf, 0) n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n]) err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil { if err != nil {
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
@ -268,7 +268,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
closingNothing, closingNothing,
testPayload, testPayload,
} }
n, _ = sesh.Obfs(f1, obfsBuf, 0) n, _ = sesh.obfuscate(f1, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n]) err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil { if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err) t.Fatalf("receiving normal frame for stream 1: %v", err)
@ -330,7 +330,7 @@ func TestParallelStreams(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(frame *Frame) { go func(frame *Frame) {
obfsBuf := make([]byte, obfsBufLen) obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(frame, obfsBuf, 0) n, _ := sesh.obfuscate(frame, obfsBuf, 0)
obfsBuf = obfsBuf[0:n] obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(obfsBuf) err := sesh.recvDataFromRemote(obfsBuf)
@ -430,7 +430,8 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
table := map[string]byte{ table := map[string]byte{
"plain": EncryptionMethodPlain, "plain": EncryptionMethodPlain,
"aes-gcm": EncryptionMethodAESGCM, "aes-256-gcm": EncryptionMethodAES256GCM,
"aes-128-gcm": EncryptionMethodAES128GCM,
"chacha20poly1305": EncryptionMethodChaha20Poly1305, "chacha20poly1305": EncryptionMethodChaha20Poly1305,
} }
@ -446,7 +447,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
binaryFrames := [maxIter][]byte{} binaryFrames := [maxIter][]byte{}
for i := 0; i < maxIter; i++ { for i := 0; i < maxIter; i++ {
obfsBuf := make([]byte, obfsBufLen) obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(f, obfsBuf, 0) n, _ := sesh.obfuscate(f, obfsBuf, 0)
binaryFrames[i] = obfsBuf[:n] binaryFrames[i] = obfsBuf[:n]
f.Seq++ f.Seq++
} }
@ -466,7 +467,8 @@ func BenchmarkMultiStreamWrite(b *testing.B) {
table := map[string]byte{ table := map[string]byte{
"plain": EncryptionMethodPlain, "plain": EncryptionMethodPlain,
"aes-gcm": EncryptionMethodAESGCM, "aes-256-gcm": EncryptionMethodAES256GCM,
"aes-128-gcm": EncryptionMethodAES128GCM,
"chacha20poly1305": EncryptionMethodChaha20Poly1305, "chacha20poly1305": EncryptionMethodChaha20Poly1305,
} }

View File

@ -34,11 +34,6 @@ type Stream struct {
// atomic // atomic
closed uint32 closed uint32
// obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
// memory
obfsBuf []byte
// When we want order guarantee (i.e. session.Unordered is false), // When we want order guarantee (i.e. session.Unordered is false),
// we assign each stream a fixed underlying connection. // we assign each stream a fixed underlying connection.
// If the underlying connections the session uses provide ordering guarantee (most likely TCP), // If the underlying connections the session uses provide ordering guarantee (most likely TCP),
@ -117,14 +112,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
return n, nil return n, nil
} }
func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
var cipherTextLen int cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf)
cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) s.writingFrame.Seq++
if err != nil { if err != nil {
return err return err
} }
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId)
if err != nil { if err != nil {
if err == errBrokenSwitchboard { if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error()) s.session.SetTerminalMsg(err.Error())
@ -143,9 +138,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
return 0, ErrBrokenStream return 0, ErrBrokenStream
} }
if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
}
for n < len(in) { for n < len(in) {
var framePayload []byte var framePayload []byte
if len(in)-n <= s.session.maxStreamUnitWrite { if len(in)-n <= s.session.maxStreamUnitWrite {
@ -161,8 +153,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
framePayload = in[n : s.session.maxStreamUnitWrite+n] framePayload = in[n : s.session.maxStreamUnitWrite+n]
} }
s.writingFrame.Payload = framePayload s.writingFrame.Payload = framePayload
err = s.obfuscateAndSend(0) buf := s.session.streamObfsBufPool.Get().(*[]byte)
s.writingFrame.Seq++ err = s.obfuscateAndSend(*buf, 0)
s.session.streamObfsBufPool.Put(buf)
if err != nil { if err != nil {
return return
} }
@ -174,9 +167,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
// for readFromTimeout amount of time // for readFromTimeout amount of time
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
}
for { for {
if s.readFromTimeout != 0 { if s.readFromTimeout != 0 {
if rder, ok := r.(net.Conn); !ok { if rder, ok := r.(net.Conn); !ok {
@ -185,19 +175,23 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout)) rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
} }
} }
read, er := r.Read(s.obfsBuf[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite]) buf := s.session.streamObfsBufPool.Get().(*[]byte)
read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
if er != nil { if er != nil {
return n, er return n, er
} }
// the above read may have been unblocked by another goroutine calling stream.Close(), so we need
// to check that here
if s.isClosed() { if s.isClosed() {
return n, ErrBrokenStream return n, ErrBrokenStream
} }
s.writingM.Lock() s.writingM.Lock()
s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read] s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read]
err = s.obfuscateAndSend(frameHeaderLength) err = s.obfuscateAndSend(*buf, frameHeaderLength)
s.writingFrame.Seq++
s.writingM.Unlock() s.writingM.Unlock()
s.session.streamObfsBufPool.Put(buf)
if err != nil { if err != nil {
return return

View File

@ -11,7 +11,6 @@ import (
// The point of a streamBufferedPipe is that Read() will block until data is available // The point of a streamBufferedPipe is that Read() will block until data is available
type streamBufferedPipe struct { type streamBufferedPipe struct {
// only alloc when on first Read or Write
buf *bytes.Buffer buf *bytes.Buffer
closed bool closed bool
@ -25,6 +24,7 @@ type streamBufferedPipe struct {
func NewStreamBufferedPipe() *streamBufferedPipe { func NewStreamBufferedPipe() *streamBufferedPipe {
p := &streamBufferedPipe{ p := &streamBufferedPipe{
rwCond: sync.NewCond(&sync.Mutex{}), rwCond: sync.NewCond(&sync.Mutex{}),
buf: new(bytes.Buffer),
} }
return p return p
} }
@ -32,9 +32,6 @@ func NewStreamBufferedPipe() *streamBufferedPipe {
func (p *streamBufferedPipe) Read(target []byte) (int, error) { func (p *streamBufferedPipe) Read(target []byte) (int, error) {
p.rwCond.L.Lock() p.rwCond.L.Lock()
defer p.rwCond.L.Unlock() defer p.rwCond.L.Unlock()
if p.buf == nil {
p.buf = new(bytes.Buffer)
}
for { for {
if p.closed && p.buf.Len() == 0 { if p.closed && p.buf.Len() == 0 {
return 0, io.EOF return 0, io.EOF
@ -64,9 +61,6 @@ func (p *streamBufferedPipe) Read(target []byte) (int, error) {
func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
p.rwCond.L.Lock() p.rwCond.L.Lock()
defer p.rwCond.L.Unlock() defer p.rwCond.L.Unlock()
if p.buf == nil {
p.buf = new(bytes.Buffer)
}
for { for {
if p.closed && p.buf.Len() == 0 { if p.closed && p.buf.Len() == 0 {
return 0, io.EOF return 0, io.EOF
@ -104,9 +98,6 @@ func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
func (p *streamBufferedPipe) Write(input []byte) (int, error) { func (p *streamBufferedPipe) Write(input []byte) (int, error) {
p.rwCond.L.Lock() p.rwCond.L.Lock()
defer p.rwCond.L.Unlock() defer p.rwCond.L.Unlock()
if p.buf == nil {
p.buf = new(bytes.Buffer)
}
for { for {
if p.closed { if p.closed {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe

View File

@ -39,7 +39,8 @@ func BenchmarkStream_Write_Ordered(b *testing.B) {
eMethods := map[string]byte{ eMethods := map[string]byte{
"plain": EncryptionMethodPlain, "plain": EncryptionMethodPlain,
"chacha20-poly1305": EncryptionMethodChaha20Poly1305, "chacha20-poly1305": EncryptionMethodChaha20Poly1305,
"aes-gcm": EncryptionMethodAESGCM, "aes-256-gcm": EncryptionMethodAES256GCM,
"aes-128-gcm": EncryptionMethodAES128GCM,
} }
for name, method := range eMethods { for name, method := range eMethods {
@ -141,7 +142,7 @@ func TestStream_Close(t *testing.T) {
writingEnd := common.NewTLSConn(rawWritingEnd) writingEnd := common.NewTLSConn(rawWritingEnd)
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
i, _ := sesh.Obfs(dataFrame, obfsBuf, 0) i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0)
_, err := writingEnd.Write(obfsBuf[:i]) _, err := writingEnd.Write(obfsBuf[:i])
if err != nil { if err != nil {
t.Error("failed to write from remote end") t.Error("failed to write from remote end")
@ -151,16 +152,7 @@ func TestStream_Close(t *testing.T) {
t.Error("failed to accept stream", err) t.Error("failed to accept stream", err)
return return
} }
time.Sleep(500 * time.Millisecond)
// we read something to wait for the test frame to reach our recvBuffer.
// if it's empty by the point we call stream.Close(), the incoming
// frame will be dropped
readBuf := make([]byte, len(testPayload))
_, err = io.ReadFull(stream, readBuf[:1])
if err != nil {
t.Errorf("can't read any data before active closing")
}
err = stream.Close() err = stream.Close()
if err != nil { if err != nil {
t.Error("failed to actively close stream", err) t.Error("failed to actively close stream", err)
@ -175,10 +167,12 @@ func TestStream_Close(t *testing.T) {
} }
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
_, err = io.ReadFull(stream, readBuf[1:]) readBuf := make([]byte, len(testPayload))
_, err = io.ReadFull(stream, readBuf)
if err != nil { if err != nil {
t.Errorf("can't read residual data %v", err) t.Errorf("cannot read resiual data: %v", err)
} }
if !bytes.Equal(readBuf, testPayload) { if !bytes.Equal(readBuf, testPayload) {
t.Errorf("read wrong data") t.Errorf("read wrong data")
} }
@ -191,7 +185,7 @@ func TestStream_Close(t *testing.T) {
writingEnd := common.NewTLSConn(rawWritingEnd) writingEnd := common.NewTLSConn(rawWritingEnd)
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
i, err := sesh.Obfs(dataFrame, obfsBuf, 0) i, err := sesh.obfuscate(dataFrame, obfsBuf, 0)
if err != nil { if err != nil {
t.Errorf("failed to obfuscate frame %v", err) t.Errorf("failed to obfuscate frame %v", err)
} }
@ -213,7 +207,7 @@ func TestStream_Close(t *testing.T) {
testPayload, testPayload,
} }
i, err = sesh.Obfs(closingFrame, obfsBuf, 0) i, err = sesh.obfuscate(closingFrame, obfsBuf, 0)
if err != nil { if err != nil {
t.Errorf("failed to obfuscate frame %v", err) t.Errorf("failed to obfuscate frame %v", err)
} }
@ -229,7 +223,7 @@ func TestStream_Close(t *testing.T) {
testPayload, testPayload,
} }
i, err = sesh.Obfs(closingFrameDup, obfsBuf, 0) i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0)
if err != nil { if err != nil {
t.Errorf("failed to obfuscate frame %v", err) t.Errorf("failed to obfuscate frame %v", err)
} }
@ -270,9 +264,6 @@ func TestStream_Read(t *testing.T) {
} }
var streamID uint32 var streamID uint32
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
for name, unordered := range seshes { for name, unordered := range seshes {
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain) sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
@ -280,9 +271,11 @@ func TestStream_Read(t *testing.T) {
sesh.AddConnection(common.NewTLSConn(rawConn)) sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd) writingEnd := common.NewTLSConn(rawWritingEnd)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
t.Run("Plain read", func(t *testing.T) { t.Run("Plain read", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf, 0) i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) writingEnd.Write(obfsBuf[:i])
stream, err := sesh.Accept() stream, err := sesh.Accept()
@ -307,7 +300,7 @@ func TestStream_Read(t *testing.T) {
}) })
t.Run("Nil buf", func(t *testing.T) { t.Run("Nil buf", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf, 0) i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) writingEnd.Write(obfsBuf[:i])
stream, _ := sesh.Accept() stream, _ := sesh.Accept()
@ -319,21 +312,22 @@ func TestStream_Read(t *testing.T) {
}) })
t.Run("Read after stream close", func(t *testing.T) { t.Run("Read after stream close", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf, 0) i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) writingEnd.Write(obfsBuf[:i])
stream, _ := sesh.Accept() stream, _ := sesh.Accept()
time.Sleep(500 * time.Millisecond)
stream.Close() stream.Close()
i, err := stream.Read(buf)
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
if err != nil { if err != nil {
t.Error("failed to read", err) t.Errorf("cannot read residual data: %v", err)
} }
if i != smallPayloadLen { if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload, t.Error("expected", testPayload,
"got", buf[:i]) "got", buf[:smallPayloadLen])
} }
_, err = stream.Read(buf) _, err = stream.Read(buf)
if err == nil { if err == nil {
@ -343,21 +337,21 @@ func TestStream_Read(t *testing.T) {
}) })
t.Run("Read after session close", func(t *testing.T) { t.Run("Read after session close", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf, 0) i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) writingEnd.Write(obfsBuf[:i])
stream, _ := sesh.Accept() stream, _ := sesh.Accept()
time.Sleep(500 * time.Millisecond)
sesh.Close() sesh.Close()
i, err := stream.Read(buf) _, err := io.ReadFull(stream, buf[:smallPayloadLen])
if err != nil { if err != nil {
t.Error("failed to read", err) t.Errorf("cannot read resiual data: %v", err)
} }
if i != smallPayloadLen { if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload, t.Error("expected", testPayload,
"got", buf[:i]) "got", buf[:smallPayloadLen])
} }
_, err = stream.Read(buf) _, err = stream.Read(buf)
if err == nil { if err == nil {

View File

@ -7,6 +7,7 @@ import (
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
) )
const ( const (
@ -31,6 +32,7 @@ type switchboard struct {
conns sync.Map conns sync.Map
numConns uint32 numConns uint32
nextConnId uint32 nextConnId uint32
randPool sync.Pool
broken uint32 broken uint32
} }
@ -48,6 +50,9 @@ func makeSwitchboard(sesh *Session) *switchboard {
strategy: strategy, strategy: strategy,
valve: sesh.Valve, valve: sesh.Valve,
nextConnId: 1, nextConnId: 1,
randPool: sync.Pool{New: func() interface{} {
return rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
}},
} }
return sb return sb
} }
@ -67,45 +72,43 @@ func (sb *switchboard) addConn(conn net.Conn) {
// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable // a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable
func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) {
writeAndRegUsage := func(conn net.Conn, d []byte) (int, error) {
n, err = conn.Write(d)
if err != nil {
sb.conns.Delete(*connId)
sb.close("failed to write to remote " + err.Error())
return n, err
}
sb.valve.AddTx(int64(n))
return n, nil
}
sb.valve.txWait(len(data)) sb.valve.txWait(len(data))
if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 { if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 {
return 0, errBrokenSwitchboard return 0, errBrokenSwitchboard
} }
var conn net.Conn
switch sb.strategy { switch sb.strategy {
case UNIFORM_SPREAD: case UNIFORM_SPREAD:
_, conn, err := sb.pickRandConn() _, conn, err = sb.pickRandConn()
if err != nil { if err != nil {
return 0, errBrokenSwitchboard return 0, errBrokenSwitchboard
} }
return writeAndRegUsage(conn, data)
case FIXED_CONN_MAPPING: case FIXED_CONN_MAPPING:
connI, ok := sb.conns.Load(*connId) connI, ok := sb.conns.Load(*connId)
if ok { if ok {
conn := connI.(net.Conn) conn = connI.(net.Conn)
return writeAndRegUsage(conn, data)
} else { } else {
newConnId, conn, err := sb.pickRandConn() var newConnId uint32
newConnId, conn, err = sb.pickRandConn()
if err != nil { if err != nil {
return 0, errBrokenSwitchboard return 0, errBrokenSwitchboard
} }
*connId = newConnId *connId = newConnId
return writeAndRegUsage(conn, data)
} }
default: default:
return 0, errors.New("unsupported traffic distribution strategy") return 0, errors.New("unsupported traffic distribution strategy")
} }
n, err = conn.Write(data)
if err != nil {
sb.conns.Delete(*connId)
sb.session.SetTerminalMsg("failed to write to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
sb.valve.AddTx(int64(n))
return n, nil
} }
// returns a random connId // returns a random connId
@ -120,7 +123,9 @@ func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) {
// so if the r > len(sb.conns) at the point of range call, the last visited element is picked // so if the r > len(sb.conns) at the point of range call, the last visited element is picked
var id uint32 var id uint32
var conn net.Conn var conn net.Conn
r := rand.Intn(connCount) randReader := sb.randPool.Get().(*rand.Rand)
r := randReader.Intn(connCount)
sb.randPool.Put(randReader)
var c int var c int
sb.conns.Range(func(connIdI, connI interface{}) bool { sb.conns.Range(func(connIdI, connI interface{}) bool {
if r == c { if r == c {
@ -138,16 +143,11 @@ func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) {
return id, conn, nil return id, conn, nil
} }
func (sb *switchboard) close(terminalMsg string) {
atomic.StoreUint32(&sb.broken, 1)
if !sb.session.IsClosed() {
sb.session.SetTerminalMsg(terminalMsg)
sb.session.passiveClose()
}
}
// actively triggered by session.Close() // actively triggered by session.Close()
func (sb *switchboard) closeAll() { func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
return
}
sb.conns.Range(func(key, connI interface{}) bool { sb.conns.Range(func(key, connI interface{}) bool {
conn := connI.(net.Conn) conn := connI.(net.Conn)
conn.Close() conn.Close()
@ -168,7 +168,8 @@ func (sb *switchboard) deplex(connId uint32, conn net.Conn) {
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
sb.conns.Delete(connId) sb.conns.Delete(connId)
atomic.AddUint32(&sb.numConns, ^uint32(0)) atomic.AddUint32(&sb.numConns, ^uint32(0))
sb.close("a connection has dropped unexpectedly") sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
sb.session.passiveClose()
return return
} }

View File

@ -175,7 +175,13 @@ func dispatchConnection(conn net.Conn, sta *State) {
common.RandRead(sta.WorldState.Rand, sessionKey[:]) common.RandRead(sta.WorldState.Rand, sessionKey[:])
obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey) obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey)
if err != nil { if err != nil {
log.Error(err) log.WithFields(log.Fields{
"remoteAddr": conn.RemoteAddr(),
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Error(err)
goWeb() goWeb()
return return
} }

View File

@ -30,15 +30,14 @@ func serveTCPEcho(l net.Listener) {
log.Error(err) log.Error(err)
return return
} }
go func() { go func(conn net.Conn) {
conn := conn
_, err := io.Copy(conn, conn) _, err := io.Copy(conn, conn)
if err != nil { if err != nil {
conn.Close() conn.Close()
log.Error(err) log.Error(err)
return return
} }
}() }(conn)
} }
} }
@ -50,8 +49,7 @@ func serveUDPEcho(listener *connutil.PipeListener) {
return return
} }
const bufSize = 32 * 1024 const bufSize = 32 * 1024
go func() { go func(conn net.PacketConn) {
conn := conn
defer conn.Close() defer conn.Close()
buf := make([]byte, bufSize) buf := make([]byte, bufSize)
for { for {
@ -70,7 +68,7 @@ func serveUDPEcho(listener *connutil.PipeListener) {
return return
} }
} }
}() }(conn)
} }
} }
@ -222,30 +220,34 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a
return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil
} }
func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
var wg sync.WaitGroup var wg sync.WaitGroup
testData := make([]byte, msgLen)
rand.Read(testData)
for _, conn := range conns { for _, conn := range conns {
wg.Add(1) wg.Add(1)
go func(conn net.Conn) { go func(conn net.Conn) {
testDataLen := rand.Intn(maxMsgLen) defer wg.Done()
testData := make([]byte, testDataLen)
rand.Read(testData)
// we cannot call t.Fatalf in concurrent contexts
n, err := conn.Write(testData) n, err := conn.Write(testData)
if n != testDataLen { if n != msgLen {
t.Fatalf("written only %v, err %v", n, err) t.Errorf("written only %v, err %v", n, err)
return
} }
recvBuf := make([]byte, testDataLen) recvBuf := make([]byte, msgLen)
_, err = io.ReadFull(conn, recvBuf) _, err = io.ReadFull(conn, recvBuf)
if err != nil { if err != nil {
t.Fatalf("failed to read back: %v", err) t.Errorf("failed to read back: %v", err)
return
} }
if !bytes.Equal(testData, recvBuf) { if !bytes.Equal(testData, recvBuf) {
t.Fatalf("echoed data not correct") t.Errorf("echoed data not correct")
return
} }
wg.Done()
}(conn) }(conn)
} }
wg.Wait() wg.Wait()
@ -294,6 +296,7 @@ func TestUDP(t *testing.T) {
} }
}) })
const echoMsgLen = 1024
t.Run("user echo", func(t *testing.T) { t.Run("user echo", func(t *testing.T) {
go serveUDPEcho(proxyFromCkServerL) go serveUDPEcho(proxyFromCkServerL)
var conn [1]net.Conn var conn [1]net.Conn
@ -302,7 +305,7 @@ func TestUDP(t *testing.T) {
t.Error(err) t.Error(err)
} }
runEchoTest(t, conn[:], 1024) runEchoTest(t, conn[:], echoMsgLen)
}) })
} }
@ -317,13 +320,14 @@ func TestTCPSingleplex(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
const echoMsgLen = 16384
go serveTCPEcho(proxyFromCkServerL) go serveTCPEcho(proxyFromCkServerL)
proxyConn1, err := proxyToCkClientD.Dial("", "") proxyConn1, err := proxyToCkClientD.Dial("", "")
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
runEchoTest(t, []net.Conn{proxyConn1}, 65536) runEchoTest(t, []net.Conn{proxyConn1}, echoMsgLen)
user, err := sta.Panel.GetUser(ai.UID[:]) user, err := sta.Panel.GetUser(ai.UID[:])
if err != nil { if err != nil {
t.Fatalf("failed to fetch user: %v", err) t.Fatalf("failed to fetch user: %v", err)
@ -335,15 +339,15 @@ func TestTCPSingleplex(t *testing.T) {
proxyConn2, err := proxyToCkClientD.Dial("", "") proxyConn2, err := proxyToCkClientD.Dial("", "")
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
runEchoTest(t, []net.Conn{proxyConn2}, 65536) runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
if user.NumSession() != 2 { if user.NumSession() != 2 {
t.Error("no extra session were made on second connection establishment") t.Error("no extra session were made on second connection establishment")
} }
// Both conns should work // Both conns should work
runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, 65536) runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, echoMsgLen)
proxyConn1.Close() proxyConn1.Close()
@ -352,17 +356,17 @@ func TestTCPSingleplex(t *testing.T) {
}, time.Second, 10*time.Millisecond, "first session was not closed on connection close") }, time.Second, 10*time.Millisecond, "first session was not closed on connection close")
// conn2 should still work // conn2 should still work
runEchoTest(t, []net.Conn{proxyConn2}, 65536) runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
var conns [numConns]net.Conn var conns [numConns]net.Conn
for i := 0; i < numConns; i++ { for i := 0; i < numConns; i++ {
conns[i], err = proxyToCkClientD.Dial("", "") conns[i], err = proxyToCkClientD.Dial("", "")
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
} }
runEchoTest(t, conns[:], 65536) runEchoTest(t, conns[:], echoMsgLen)
} }
@ -410,6 +414,7 @@ func TestTCPMultiplex(t *testing.T) {
} }
}) })
const echoMsgLen = 16384
t.Run("user echo", func(t *testing.T) { t.Run("user echo", func(t *testing.T) {
go serveTCPEcho(proxyFromCkServerL) go serveTCPEcho(proxyFromCkServerL)
var conns [numConns]net.Conn var conns [numConns]net.Conn
@ -420,7 +425,7 @@ func TestTCPMultiplex(t *testing.T) {
} }
} }
runEchoTest(t, conns[:], 65536) runEchoTest(t, conns[:], echoMsgLen)
}) })
t.Run("redir echo", func(t *testing.T) { t.Run("redir echo", func(t *testing.T) {
@ -432,7 +437,7 @@ func TestTCPMultiplex(t *testing.T) {
t.Error(err) t.Error(err)
} }
} }
runEchoTest(t, conns[:], 65536) runEchoTest(t, conns[:], echoMsgLen)
}) })
} }
@ -503,7 +508,7 @@ func TestClosingStreamsFromProxy(t *testing.T) {
} }
} }
func BenchmarkThroughput(b *testing.B) { func BenchmarkIntegration(b *testing.B) {
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
@ -513,7 +518,8 @@ func BenchmarkThroughput(b *testing.B) {
encryptionMethods := map[string]byte{ encryptionMethods := map[string]byte{
"plain": mux.EncryptionMethodPlain, "plain": mux.EncryptionMethodPlain,
"chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305, "chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305,
"aes-gcm": mux.EncryptionMethodAESGCM, "aes-256-gcm": mux.EncryptionMethodAES256GCM,
"aes-128-gcm": mux.EncryptionMethodAES128GCM,
} }
for name, method := range encryptionMethods { for name, method := range encryptionMethods {
@ -524,7 +530,7 @@ func BenchmarkThroughput(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
b.Run("single stream", func(b *testing.B) { b.Run("single stream bandwidth", func(b *testing.B) {
more := make(chan int, 10) more := make(chan int, 10)
go func() { go func() {
// sender // sender
@ -548,6 +554,19 @@ func BenchmarkThroughput(b *testing.B) {
} }
}) })
b.Run("single stream latency", func(b *testing.B) {
clientConn, _ := proxyToCkClientD.Dial("", "")
buf := []byte{1}
clientConn.Write(buf)
serverConn, _ := proxyFromCkServerL.Accept()
serverConn.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
clientConn.Write(buf)
serverConn.Read(buf)
}
})
}) })
} }