diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 9de2270..9fa614d 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -11,9 +11,6 @@ import ( "golang.org/x/crypto/salsa20" ) -type Obfser func(*Frame, []byte, int) (int, error) -type Deobfser func(*Frame, []byte) error - const frameHeaderLength = 14 const salsa20NonceSize = 8 @@ -26,25 +23,19 @@ const ( // Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames. type Obfuscator struct { - // Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header - Obfs Obfser - // Remove TLS header, decrypt and unmarshall frames - Deobfs Deobfser - SessionKey [32]byte + payloadCipher cipher.AEAD + + sessionKey [32]byte maxOverhead int } -// MakeObfs returns a function of type Obfser. An Obfser takes three arguments: -// a *Frame with all the field set correctly, a []byte as buffer to put encrypted -// message in, and an int called payloadOffsetInBuf to be used when *Frame.payload -// is in the byte slice used as buffer (2nd argument). payloadOffsetInBuf specifies -// the index at which data belonging to *Frame.Payload starts in the buffer. -func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { +// obfuscate adds multiplexing headers, encrypt and add TLS header +func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { // The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header // as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use // the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead()) - // as nonce for Salsa20 to encrypt the frame header. Both with SessionKey as keys. + // as nonce for Salsa20 to encrypt the frame header. Both with sessionKey as keys. // // Several cryptographic guarantees we have made here: that payloadCipher, as an AEAD, is given a unique // iv/nonce each time, relative to its key; that the frame header encryptor Salsa20 is given a unique @@ -72,159 +63,146 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { // We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input // is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends // encrypted data. - obfs := func(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { - payloadLen := len(f.Payload) - if payloadLen == 0 { - return 0, errors.New("payload cannot be empty") - } - var extraLen int - if payloadCipher == nil { - extraLen = salsa20NonceSize - payloadLen - if extraLen < 0 { - // if our payload is already greater than 8 bytes - extraLen = 0 - } - } else { - extraLen = payloadCipher.Overhead() - if extraLen < salsa20NonceSize { - return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") - } - } - - usefulLen := frameHeaderLength + payloadLen + extraLen - if len(buf) < usefulLen { - return 0, errors.New("obfs buffer too small") - } - // we do as much in-place as possible to save allocation - payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] - if payloadOffsetInBuf != frameHeaderLength { - // if payload is not at the correct location in buffer - copy(payload, f.Payload) - } - - header := buf[:frameHeaderLength] - binary.BigEndian.PutUint32(header[0:4], f.StreamID) - binary.BigEndian.PutUint64(header[4:12], f.Seq) - header[12] = f.Closing - header[13] = byte(extraLen) - - if payloadCipher == nil { - if extraLen != 0 { // read nonce - extra := buf[usefulLen-extraLen : usefulLen] - common.CryptoRandRead(extra) - } - } else { - payloadCipher.Seal(payload[:0], header[:payloadCipher.NonceSize()], payload, nil) - } - - nonce := buf[usefulLen-salsa20NonceSize : usefulLen] - salsa20.XORKeyStream(header, header, nonce, &salsaKey) - - return usefulLen, nil + payloadLen := len(f.Payload) + if payloadLen == 0 { + return 0, errors.New("payload cannot be empty") } - return obfs + var extraLen int + if o.payloadCipher == nil { + extraLen = salsa20NonceSize - payloadLen + if extraLen < 0 { + // if our payload is already greater than 8 bytes + extraLen = 0 + } + } else { + extraLen = o.payloadCipher.Overhead() + if extraLen < salsa20NonceSize { + return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") + } + } + + usefulLen := frameHeaderLength + payloadLen + extraLen + if len(buf) < usefulLen { + return 0, errors.New("obfs buffer too small") + } + // we do as much in-place as possible to save allocation + payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] + if payloadOffsetInBuf != frameHeaderLength { + // if payload is not at the correct location in buffer + copy(payload, f.Payload) + } + + header := buf[:frameHeaderLength] + binary.BigEndian.PutUint32(header[0:4], f.StreamID) + binary.BigEndian.PutUint64(header[4:12], f.Seq) + header[12] = f.Closing + header[13] = byte(extraLen) + + if o.payloadCipher == nil { + if extraLen != 0 { // read nonce + extra := buf[usefulLen-extraLen : usefulLen] + common.CryptoRandRead(extra) + } + } else { + o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil) + } + + nonce := buf[usefulLen-salsa20NonceSize : usefulLen] + salsa20.XORKeyStream(header, header, nonce, &o.sessionKey) + + return usefulLen, nil } -// MakeDeobfs returns a function Deobfser. A Deobfser takes in a single byte slice, -// containing the message to be decrypted, and returns a *Frame containing the frame -// information and plaintext -func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { - // frame header length + minimum data size (i.e. nonce size of salsa20) - const minInputLen = frameHeaderLength + salsa20NonceSize - deobfs := func(f *Frame, in []byte) error { - if len(in) < minInputLen { - return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) - } +// deobfuscate removes TLS header, decrypt and unmarshall frames +func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error { + if len(in) < frameHeaderLength+salsa20NonceSize { + return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize) + } - header := in[:frameHeaderLength] - pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead + header := in[:frameHeaderLength] + pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead - nonce := in[len(in)-salsa20NonceSize:] - salsa20.XORKeyStream(header, header, nonce, &salsaKey) + nonce := in[len(in)-salsa20NonceSize:] + salsa20.XORKeyStream(header, header, nonce, &o.sessionKey) - streamID := binary.BigEndian.Uint32(header[0:4]) - seq := binary.BigEndian.Uint64(header[4:12]) - closing := header[12] - extraLen := header[13] + streamID := binary.BigEndian.Uint32(header[0:4]) + seq := binary.BigEndian.Uint64(header[4:12]) + closing := header[12] + extraLen := header[13] - usefulPayloadLen := len(pldWithOverHead) - int(extraLen) - if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { - return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") - } + usefulPayloadLen := len(pldWithOverHead) - int(extraLen) + if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { + return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") + } - var outputPayload []byte + var outputPayload []byte - if payloadCipher == nil { - if extraLen == 0 { - outputPayload = pldWithOverHead - } else { - outputPayload = pldWithOverHead[:usefulPayloadLen] - } + if o.payloadCipher == nil { + if extraLen == 0 { + outputPayload = pldWithOverHead } else { - _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) - if err != nil { - return err - } outputPayload = pldWithOverHead[:usefulPayloadLen] } - - f.StreamID = streamID - f.Seq = seq - f.Closing = closing - f.Payload = outputPayload - return nil + } else { + _, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil) + if err != nil { + return err + } + outputPayload = pldWithOverHead[:usefulPayloadLen] } - return deobfs + + f.StreamID = streamID + f.Seq = seq + f.Closing = closing + f.Payload = outputPayload + return nil } -func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) { - obfuscator = Obfuscator{ - SessionKey: sessionKey, +func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (o Obfuscator, err error) { + o = Obfuscator{ + sessionKey: sessionKey, } - var payloadCipher cipher.AEAD switch encryptionMethod { case EncryptionMethodPlain: - payloadCipher = nil - obfuscator.maxOverhead = salsa20NonceSize + o.payloadCipher = nil + o.maxOverhead = salsa20NonceSize case EncryptionMethodAES256GCM: var c cipher.Block c, err = aes.NewCipher(sessionKey[:]) if err != nil { return } - payloadCipher, err = cipher.NewGCM(c) + o.payloadCipher, err = cipher.NewGCM(c) if err != nil { return } - obfuscator.maxOverhead = payloadCipher.Overhead() + o.maxOverhead = o.payloadCipher.Overhead() case EncryptionMethodAES128GCM: var c cipher.Block c, err = aes.NewCipher(sessionKey[:16]) if err != nil { return } - payloadCipher, err = cipher.NewGCM(c) + o.payloadCipher, err = cipher.NewGCM(c) if err != nil { return } - obfuscator.maxOverhead = payloadCipher.Overhead() + o.maxOverhead = o.payloadCipher.Overhead() case EncryptionMethodChaha20Poly1305: - payloadCipher, err = chacha20poly1305.New(sessionKey[:]) + o.payloadCipher, err = chacha20poly1305.New(sessionKey[:]) if err != nil { return } - obfuscator.maxOverhead = payloadCipher.Overhead() + o.maxOverhead = o.payloadCipher.Overhead() default: - return obfuscator, errors.New("Unknown encryption method") + return o, fmt.Errorf("unknown encryption method valued %v", encryptionMethod) } - if payloadCipher != nil { - if payloadCipher.NonceSize() > frameHeaderLength { - return obfuscator, errors.New("payload AEAD's nonce size cannot be greater than size of frame header") + if o.payloadCipher != nil { + if o.payloadCipher.NonceSize() > frameHeaderLength { + return o, errors.New("payload AEAD's nonce size cannot be greater than size of frame header") } } - obfuscator.Obfs = MakeObfs(sessionKey, payloadCipher) - obfuscator.Deobfs = MakeDeobfs(sessionKey, payloadCipher) return } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 6fd9916..2dad728 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -1,9 +1,9 @@ package multiplex import ( - "bytes" "crypto/aes" "crypto/cipher" + "github.com/stretchr/testify/assert" "golang.org/x/crypto/chacha20poly1305" "math/rand" "reflect" @@ -15,69 +15,119 @@ func TestGenerateObfs(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) - run := func(obfuscator Obfuscator, ct *testing.T) { + run := func(o Obfuscator, t *testing.T) { obfsBuf := make([]byte, 512) - _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) - testFrame := _testFrame.Interface().(*Frame) - i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) - if err != nil { - ct.Error("failed to obfs ", err) - return - } - + _testFrame, _ := quick.Value(reflect.TypeOf(Frame{}), rand.New(rand.NewSource(42))) + testFrame := _testFrame.Interface().(Frame) + i, err := o.obfuscate(&testFrame, obfsBuf, 0) + assert.NoError(t, err) var resultFrame Frame - err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) - if err != nil { - ct.Error("failed to deobfs ", err) - return - } - if !bytes.Equal(testFrame.Payload, resultFrame.Payload) || testFrame.StreamID != resultFrame.StreamID { - ct.Error("expecting", testFrame, - "got", resultFrame) - return - } + + err = o.deobfuscate(&resultFrame, obfsBuf[:i]) + assert.NoError(t, err) + assert.EqualValues(t, testFrame, resultFrame) } t.Run("plain", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodPlain, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodPlain, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("aes-256-gcm", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("aes-128-gcm", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("chacha20-poly1305", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("unknown encryption method", func(t *testing.T) { _, err := MakeObfuscator(0xff, sessionKey) - if err == nil { - t.Errorf("unknown encryption mehtod error expected") - } + assert.Error(t, err) }) } +func TestObfuscate(t *testing.T) { + var sessionKey [32]byte + rand.Read(sessionKey[:]) + + const testPayloadLen = 1024 + testPayload := make([]byte, testPayloadLen) + rand.Read(testPayload) + f := Frame{ + StreamID: 0, + Seq: 0, + Closing: 0, + Payload: testPayload, + } + + runTest := func(t *testing.T, o Obfuscator) { + obfsBuf := make([]byte, testPayloadLen*2) + n, err := o.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + + resultFrame := Frame{} + err = o.deobfuscate(&resultFrame, obfsBuf[:n]) + assert.NoError(t, err) + + assert.EqualValues(t, f, resultFrame) + } + + t.Run("plain", func(t *testing.T) { + o := Obfuscator{ + payloadCipher: nil, + sessionKey: sessionKey, + maxOverhead: salsa20NonceSize, + } + runTest(t, o) + }) + + t.Run("aes-128-gcm", func(t *testing.T) { + c, err := aes.NewCipher(sessionKey[:16]) + assert.NoError(t, err) + payloadCipher, err := cipher.NewGCM(c) + assert.NoError(t, err) + o := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: sessionKey, + maxOverhead: payloadCipher.Overhead(), + } + runTest(t, o) + }) + + t.Run("aes-256-gcm", func(t *testing.T) { + c, err := aes.NewCipher(sessionKey[:]) + assert.NoError(t, err) + payloadCipher, err := cipher.NewGCM(c) + assert.NoError(t, err) + o := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: sessionKey, + maxOverhead: payloadCipher.Overhead(), + } + runTest(t, o) + }) + + t.Run("chacha20-poly1305", func(t *testing.T) { + payloadCipher, err := chacha20poly1305.New(sessionKey[:]) + assert.NoError(t, err) + o := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: sessionKey, + maxOverhead: payloadCipher.Overhead(), + } + runTest(t, o) + }) + +} + func BenchmarkObfs(b *testing.B) { testPayload := make([]byte, 1024) rand.Read(testPayload) @@ -96,40 +146,57 @@ func BenchmarkObfs(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) + obfuscator := Obfuscator{ + payloadCipher: nil, + sessionKey: key, + maxOverhead: salsa20NonceSize, + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:16]) + payloadCipher, _ := chacha20poly1305.New(key[:]) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) } @@ -151,57 +218,70 @@ func BenchmarkDeobfs(b *testing.B) { b.Run("AES256GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.SetBytes(int64(n)) b.ResetTimer() for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + sessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, nil) + obfuscator := Obfuscator{ + payloadCipher: nil, + sessionKey: key, + maxOverhead: salsa20NonceSize, + } + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:16]) + payloadCipher, _ := chacha20poly1305.New(key[:]) - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: nil, + sessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 0113afa..6abc90e 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -128,6 +128,10 @@ func MakeSession(id uint32, config SessionConfig) *Session { return sesh } +func (sesh *Session) GetSessionKey() [32]byte { + return sesh.sessionKey +} + func (sesh *Session) streamCountIncr() uint32 { return atomic.AddUint32(&sesh.activeStreamCount, 1) } @@ -232,7 +236,7 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { frame := sesh.recvFramePool.Get().(*Frame) defer sesh.recvFramePool.Put(frame) - err := sesh.Deobfs(frame, data) + err := sesh.deobfuscate(frame, data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } @@ -331,7 +335,7 @@ func (sesh *Session) Close() error { Closing: closingSession, Payload: payload, } - i, err := sesh.Obfs(f, *buf, frameHeaderLength) + i, err := sesh.obfuscate(f, *buf, frameHeaderLength) if err != nil { return err } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index c942b10..dfa3dbb 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -4,7 +4,10 @@ import ( "bytes" "github.com/cbeuw/connutil" "github.com/stretchr/testify/assert" + "io" + "io/ioutil" "math/rand" + "net" "strconv" "sync" "sync/atomic" @@ -16,73 +19,188 @@ var seshConfigs = map[string]SessionConfig{ "ordered": {}, "unordered": {Unordered: true}, } +var encryptionMethods = map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-256-gcm": EncryptionMethodAES256GCM, + "aes-128-gcm": EncryptionMethodAES128GCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, +} const testPayloadLen = 1024 const obfsBufLen = testPayloadLen * 2 func TestRecvDataFromRemote(t *testing.T) { - testPayload := make([]byte, testPayloadLen) - rand.Read(testPayload) - f := &Frame{ - 1, - 0, - 0, - testPayload, - } - obfsBuf := make([]byte, obfsBufLen) - var sessionKey [32]byte rand.Read(sessionKey[:]) - MakeObfuscatorUnwrap := func(method byte, sessionKey [32]byte) Obfuscator { - ret, err := MakeObfuscator(method, sessionKey) - if err != nil { - t.Fatalf("failed to make an obfuscator: %v", err) - } - return ret - } - - encryptionMethods := map[string]Obfuscator{ - "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), - "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAES256GCM, sessionKey), - "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), - } - for seshType, seshConfig := range seshConfigs { seshConfig := seshConfig t.Run(seshType, func(t *testing.T) { - for method, obfuscator := range encryptionMethods { - obfuscator := obfuscator - t.Run(method, func(t *testing.T) { - seshConfig.Obfuscator = obfuscator + var err error + seshConfig.Obfuscator, err = MakeObfuscator(EncryptionMethodPlain, sessionKey) + if err != nil { + t.Fatalf("failed to make obfuscator: %v", err) + } + t.Run("initial frame", func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + make([]byte, testPayloadLen), + } + rand.Read(f.Payload) + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + stream, err := sesh.Accept() + assert.NoError(t, err) + + resultPayload := make([]byte, testPayloadLen) + _, err = stream.Read(resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + }) + + t.Run("two frames in order", func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + make([]byte, testPayloadLen), + } + rand.Read(f.Payload) + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + stream, err := sesh.Accept() + assert.NoError(t, err) + + resultPayload := make([]byte, testPayloadLen) + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + + f.Seq += 1 + rand.Read(f.Payload) + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + }) + + t.Run("two frames in order", func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + make([]byte, testPayloadLen), + } + rand.Read(f.Payload) + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + stream, err := sesh.Accept() + assert.NoError(t, err) + + resultPayload := make([]byte, testPayloadLen) + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + + f.Seq += 1 + rand.Read(f.Payload) + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + }) + + if seshType == "ordered" { + t.Run("frames out of order", func(t *testing.T) { sesh := MakeSession(0, seshConfig) - n, err := sesh.Obfs(f, obfsBuf, 0) - if err != nil { - t.Error(err) - return - } - err = sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + nil, } + // First frame + seq0 := make([]byte, testPayloadLen) + rand.Read(seq0) + f.Seq = 0 + f.Payload = seq0 + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + // Third frame + seq2 := make([]byte, testPayloadLen) + rand.Read(seq2) + f.Seq = 2 + f.Payload = seq2 + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + // Second frame + seq1 := make([]byte, testPayloadLen) + rand.Read(seq1) + f.Seq = 1 + f.Payload = seq1 + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + // Expect things to receive in order + stream, err := sesh.Accept() + assert.NoError(t, err) + resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) - } + + // First + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + assert.EqualValues(t, seq0, resultPayload) + + // Second + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + assert.EqualValues(t, seq1, resultPayload) + + // Third + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + assert.EqualValues(t, seq2, resultPayload) }) } + }) } } @@ -94,10 +212,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) seshConfig := seshConfigs["ordered"] - seshConfig.Obfuscator = obfuscator + seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey) sesh := MakeSession(0, seshConfig) f1 := &Frame{ @@ -107,7 +224,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { testPayload, } // create stream 1 - n, _ := sesh.Obfs(f1, obfsBuf, 0) + n, _ := sesh.obfuscate(f1, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -129,7 +246,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.Obfs(f2, obfsBuf, 0) + n, _ = sesh.obfuscate(f2, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 2: %v", err) @@ -151,7 +268,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingStream, testPayload, } - n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) @@ -180,7 +297,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { } // close stream 1 again - n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1 %v", err) @@ -203,7 +320,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { Closing: closingSession, Payload: testPayload, } - n, _ = sesh.Obfs(fCloseSession, obfsBuf, 0) + n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving session closing frame: %v", err) @@ -233,10 +350,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) seshConfig := seshConfigs["ordered"] - seshConfig.Obfuscator = obfuscator + seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey) sesh := MakeSession(0, seshConfig) // receive stream 1 closing first @@ -246,7 +362,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingStream, testPayload, } - n, _ := sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) @@ -268,7 +384,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.Obfs(f1, obfsBuf, 0) + n, _ = sesh.obfuscate(f1, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -330,7 +446,7 @@ func TestParallelStreams(t *testing.T) { wg.Add(1) go func(frame *Frame) { obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(frame, obfsBuf, 0) + n, _ := sesh.obfuscate(frame, obfsBuf, 0) obfsBuf = obfsBuf[0:n] err := sesh.recvDataFromRemote(obfsBuf) @@ -415,7 +531,7 @@ func TestSession_timeoutAfter(t *testing.T) { } } -func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { +func BenchmarkRecvDataFromRemote(b *testing.B) { testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) f := &Frame{ @@ -428,33 +544,34 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { var sessionKey [32]byte rand.Read(sessionKey[:]) - table := map[string]byte{ - "plain": EncryptionMethodPlain, - "aes-gcm": EncryptionMethodAES256GCM, - "chacha20poly1305": EncryptionMethodChaha20Poly1305, - } - const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic - for name, ep := range table { + for name, ep := range encryptionMethods { ep := ep b.Run(name, func(b *testing.B) { - seshConfig := seshConfigs["ordered"] - obfuscator, _ := MakeObfuscator(ep, sessionKey) - seshConfig.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfig) + for seshType, seshConfig := range seshConfigs { + b.Run(seshType, func(b *testing.B) { + seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) + sesh := MakeSession(0, seshConfig) - binaryFrames := [maxIter][]byte{} - for i := 0; i < maxIter; i++ { - obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(f, obfsBuf, 0) - binaryFrames[i] = obfsBuf[:n] - f.Seq++ - } + go func() { + stream, _ := sesh.Accept() + stream.(*Stream).WriteTo(ioutil.Discard) + }() - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(binaryFrames[i]) + binaryFrames := [maxIter][]byte{} + for i := 0; i < maxIter; i++ { + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.obfuscate(f, obfsBuf, 0) + binaryFrames[i] = obfsBuf[:n] + f.Seq++ + } + + b.SetBytes(int64(len(f.Payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + sesh.recvDataFromRemote(binaryFrames[i]) + } + }) } }) } @@ -464,21 +581,13 @@ func BenchmarkMultiStreamWrite(b *testing.B) { var sessionKey [32]byte rand.Read(sessionKey[:]) - table := map[string]byte{ - "plain": EncryptionMethodPlain, - "aes-gcm": EncryptionMethodAES256GCM, - "chacha20poly1305": EncryptionMethodChaha20Poly1305, - } - testPayload := make([]byte, testPayloadLen) - for name, ep := range table { + for name, ep := range encryptionMethods { b.Run(name, func(b *testing.B) { for seshType, seshConfig := range seshConfigs { - seshConfig := seshConfig b.Run(seshType, func(b *testing.B) { - obfuscator, _ := MakeObfuscator(ep, sessionKey) - seshConfig.Obfuscator = obfuscator + seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) sesh := MakeSession(0, seshConfig) sesh.AddConnection(connutil.Discard()) b.ResetTimer() @@ -494,3 +603,36 @@ func BenchmarkMultiStreamWrite(b *testing.B) { }) } } + +func BenchmarkLatency(b *testing.B) { + var sessionKey [32]byte + rand.Read(sessionKey[:]) + + for name, ep := range encryptionMethods { + b.Run(name, func(b *testing.B) { + for seshType, seshConfig := range seshConfigs { + b.Run(seshType, func(b *testing.B) { + seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) + clientSesh := MakeSession(0, seshConfig) + serverSesh := MakeSession(0, seshConfig) + + c, s := net.Pipe() + clientSesh.AddConnection(c) + serverSesh.AddConnection(s) + + buf := make([]byte, 64) + clientStream, _ := clientSesh.OpenStream() + clientStream.Write(buf) + serverStream, _ := serverSesh.Accept() + io.ReadFull(serverStream, buf) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + clientStream.Write(buf) + io.ReadFull(serverStream, buf) + } + }) + } + }) + } +} diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index b29359f..ffd7e23 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -113,7 +113,7 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { } func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { - cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf) + cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf) s.writingFrame.Seq++ if err != nil { return err diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 895f35e..eb13bc8 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -142,7 +142,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, _ := sesh.Obfs(dataFrame, obfsBuf, 0) + i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0) _, err := writingEnd.Write(obfsBuf[:i]) if err != nil { t.Error("failed to write from remote end") @@ -185,7 +185,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, err := sesh.Obfs(dataFrame, obfsBuf, 0) + i, err := sesh.obfuscate(dataFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -207,7 +207,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.Obfs(closingFrame, obfsBuf, 0) + i, err = sesh.obfuscate(closingFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -223,7 +223,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.Obfs(closingFrameDup, obfsBuf, 0) + i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -275,7 +275,7 @@ func TestStream_Read(t *testing.T) { obfsBuf := make([]byte, 512) t.Run("Plain read", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, err := sesh.Accept() @@ -300,7 +300,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Nil buf", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -312,7 +312,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after stream close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -337,7 +337,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after session close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 9daa772..287e363 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -236,7 +236,7 @@ func dispatchConnection(conn net.Conn, sta *State) { return } - preparedConn, err := finishHandshake(conn, sesh.SessionKey, sta.WorldState.Rand) + preparedConn, err := finishHandshake(conn, sesh.GetSessionKey(), sta.WorldState.Rand) if err != nil { log.Error(err) return