diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 91c9a76..8fdf9ce 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -158,50 +158,49 @@ func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error { return nil } -func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) { - obfuscator = Obfuscator{ +func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (o Obfuscator, err error) { + o = Obfuscator{ SessionKey: sessionKey, } - var payloadCipher cipher.AEAD switch encryptionMethod { case EncryptionMethodPlain: - payloadCipher = nil - obfuscator.maxOverhead = salsa20NonceSize + o.payloadCipher = nil + o.maxOverhead = salsa20NonceSize case EncryptionMethodAES256GCM: var c cipher.Block c, err = aes.NewCipher(sessionKey[:]) if err != nil { return } - payloadCipher, err = cipher.NewGCM(c) + o.payloadCipher, err = cipher.NewGCM(c) if err != nil { return } - obfuscator.maxOverhead = payloadCipher.Overhead() + o.maxOverhead = o.payloadCipher.Overhead() case EncryptionMethodAES128GCM: var c cipher.Block c, err = aes.NewCipher(sessionKey[:16]) if err != nil { return } - payloadCipher, err = cipher.NewGCM(c) + o.payloadCipher, err = cipher.NewGCM(c) if err != nil { return } - obfuscator.maxOverhead = payloadCipher.Overhead() + o.maxOverhead = o.payloadCipher.Overhead() case EncryptionMethodChaha20Poly1305: - payloadCipher, err = chacha20poly1305.New(sessionKey[:]) + o.payloadCipher, err = chacha20poly1305.New(sessionKey[:]) if err != nil { return } - obfuscator.maxOverhead = payloadCipher.Overhead() + o.maxOverhead = o.payloadCipher.Overhead() default: - return obfuscator, fmt.Errorf("unknown encryption method valued %v", encryptionMethod) + return o, fmt.Errorf("unknown encryption method valued %v", encryptionMethod) } - if payloadCipher != nil { - if payloadCipher.NonceSize() > frameHeaderLength { - return obfuscator, errors.New("payload AEAD's nonce size cannot be greater than size of frame header") + if o.payloadCipher != nil { + if o.payloadCipher.NonceSize() > frameHeaderLength { + return o, errors.New("payload AEAD's nonce size cannot be greater than size of frame header") } } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 78a760d..21c29fd 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -1,9 +1,9 @@ package multiplex import ( - "bytes" "crypto/aes" "crypto/cipher" + "github.com/stretchr/testify/assert" "golang.org/x/crypto/chacha20poly1305" "math/rand" "reflect" @@ -15,69 +15,119 @@ func TestGenerateObfs(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) - run := func(obfuscator Obfuscator, ct *testing.T) { + run := func(o Obfuscator, t *testing.T) { obfsBuf := make([]byte, 512) _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) - i, err := obfuscator.obfuscate(testFrame, obfsBuf, 0) - if err != nil { - ct.Error("failed to obfs ", err) - return - } - + i, err := o.obfuscate(testFrame, obfsBuf, 0) + assert.NoError(t, err) var resultFrame Frame - err = obfuscator.deobfuscate(&resultFrame, obfsBuf[:i]) - if err != nil { - ct.Error("failed to deobfs ", err) - return - } - if !bytes.Equal(testFrame.Payload, resultFrame.Payload) || testFrame.StreamID != resultFrame.StreamID { - ct.Error("expecting", testFrame, - "got", resultFrame) - return - } + + err = o.deobfuscate(&resultFrame, obfsBuf[:i]) + assert.NoError(t, err) + assert.EqualValues(t, testFrame, resultFrame) } t.Run("plain", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodPlain, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodPlain, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("aes-256-gcm", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("aes-128-gcm", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("chacha20-poly1305", func(t *testing.T) { - obfuscator, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) - if err != nil { - t.Errorf("failed to generate obfuscator %v", err) - } else { - run(obfuscator, t) - } + o, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) + assert.NoError(t, err) + run(o, t) }) t.Run("unknown encryption method", func(t *testing.T) { _, err := MakeObfuscator(0xff, sessionKey) - if err == nil { - t.Errorf("unknown encryption mehtod error expected") - } + assert.Error(t, err) }) } +func TestObfuscate(t *testing.T) { + var sessionKey [32]byte + rand.Read(sessionKey[:]) + + const testPayloadLen = 1024 + testPayload := make([]byte, testPayloadLen) + rand.Read(testPayload) + f := Frame{ + StreamID: 0, + Seq: 0, + Closing: 0, + Payload: testPayload, + } + + runTest := func(t *testing.T, o Obfuscator) { + obfsBuf := make([]byte, testPayloadLen*2) + n, err := o.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + + resultFrame := Frame{} + err = o.deobfuscate(&resultFrame, obfsBuf[:n]) + assert.NoError(t, err) + + assert.EqualValues(t, f, resultFrame) + } + + t.Run("plain", func(t *testing.T) { + o := Obfuscator{ + payloadCipher: nil, + SessionKey: sessionKey, + maxOverhead: salsa20NonceSize, + } + runTest(t, o) + }) + + t.Run("aes-128-gcm", func(t *testing.T) { + c, err := aes.NewCipher(sessionKey[:16]) + assert.NoError(t, err) + payloadCipher, err := cipher.NewGCM(c) + assert.NoError(t, err) + o := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: sessionKey, + maxOverhead: payloadCipher.Overhead(), + } + runTest(t, o) + }) + + t.Run("aes-256-gcm", func(t *testing.T) { + c, err := aes.NewCipher(sessionKey[:]) + assert.NoError(t, err) + payloadCipher, err := cipher.NewGCM(c) + assert.NoError(t, err) + o := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: sessionKey, + maxOverhead: payloadCipher.Overhead(), + } + runTest(t, o) + }) + + t.Run("chacha20-poly1305", func(t *testing.T) { + payloadCipher, err := chacha20poly1305.New(sessionKey[:]) + assert.NoError(t, err) + o := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: sessionKey, + maxOverhead: payloadCipher.Overhead(), + } + runTest(t, o) + }) + +} + func BenchmarkObfs(b *testing.B) { testPayload := make([]byte, 1024) rand.Read(testPayload)