Correctly assign payloadCipher to Obfuscator field, and add test for this issue

This commit is contained in:
Andy Wang 2020-12-28 12:04:32 +00:00
parent 5cb54aa3c9
commit 2c709f92df
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
2 changed files with 109 additions and 60 deletions

View File

@ -158,50 +158,49 @@ func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error {
return nil return nil
} }
func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) { func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (o Obfuscator, err error) {
obfuscator = Obfuscator{ o = Obfuscator{
SessionKey: sessionKey, SessionKey: sessionKey,
} }
var payloadCipher cipher.AEAD
switch encryptionMethod { switch encryptionMethod {
case EncryptionMethodPlain: case EncryptionMethodPlain:
payloadCipher = nil o.payloadCipher = nil
obfuscator.maxOverhead = salsa20NonceSize o.maxOverhead = salsa20NonceSize
case EncryptionMethodAES256GCM: case EncryptionMethodAES256GCM:
var c cipher.Block var c cipher.Block
c, err = aes.NewCipher(sessionKey[:]) c, err = aes.NewCipher(sessionKey[:])
if err != nil { if err != nil {
return return
} }
payloadCipher, err = cipher.NewGCM(c) o.payloadCipher, err = cipher.NewGCM(c)
if err != nil { if err != nil {
return return
} }
obfuscator.maxOverhead = payloadCipher.Overhead() o.maxOverhead = o.payloadCipher.Overhead()
case EncryptionMethodAES128GCM: case EncryptionMethodAES128GCM:
var c cipher.Block var c cipher.Block
c, err = aes.NewCipher(sessionKey[:16]) c, err = aes.NewCipher(sessionKey[:16])
if err != nil { if err != nil {
return return
} }
payloadCipher, err = cipher.NewGCM(c) o.payloadCipher, err = cipher.NewGCM(c)
if err != nil { if err != nil {
return return
} }
obfuscator.maxOverhead = payloadCipher.Overhead() o.maxOverhead = o.payloadCipher.Overhead()
case EncryptionMethodChaha20Poly1305: case EncryptionMethodChaha20Poly1305:
payloadCipher, err = chacha20poly1305.New(sessionKey[:]) o.payloadCipher, err = chacha20poly1305.New(sessionKey[:])
if err != nil { if err != nil {
return return
} }
obfuscator.maxOverhead = payloadCipher.Overhead() o.maxOverhead = o.payloadCipher.Overhead()
default: default:
return obfuscator, fmt.Errorf("unknown encryption method valued %v", encryptionMethod) return o, fmt.Errorf("unknown encryption method valued %v", encryptionMethod)
} }
if payloadCipher != nil { if o.payloadCipher != nil {
if payloadCipher.NonceSize() > frameHeaderLength { if o.payloadCipher.NonceSize() > frameHeaderLength {
return obfuscator, errors.New("payload AEAD's nonce size cannot be greater than size of frame header") return o, errors.New("payload AEAD's nonce size cannot be greater than size of frame header")
} }
} }

View File

@ -1,9 +1,9 @@
package multiplex package multiplex
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"math/rand" "math/rand"
"reflect" "reflect"
@ -15,69 +15,119 @@ func TestGenerateObfs(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
run := func(obfuscator Obfuscator, ct *testing.T) { run := func(o Obfuscator, t *testing.T) {
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
_testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42)))
testFrame := _testFrame.Interface().(*Frame) testFrame := _testFrame.Interface().(*Frame)
i, err := obfuscator.obfuscate(testFrame, obfsBuf, 0) i, err := o.obfuscate(testFrame, obfsBuf, 0)
if err != nil { assert.NoError(t, err)
ct.Error("failed to obfs ", err)
return
}
var resultFrame Frame var resultFrame Frame
err = obfuscator.deobfuscate(&resultFrame, obfsBuf[:i])
if err != nil { err = o.deobfuscate(&resultFrame, obfsBuf[:i])
ct.Error("failed to deobfs ", err) assert.NoError(t, err)
return assert.EqualValues(t, testFrame, resultFrame)
}
if !bytes.Equal(testFrame.Payload, resultFrame.Payload) || testFrame.StreamID != resultFrame.StreamID {
ct.Error("expecting", testFrame,
"got", resultFrame)
return
}
} }
t.Run("plain", func(t *testing.T) { t.Run("plain", func(t *testing.T) {
obfuscator, err := MakeObfuscator(EncryptionMethodPlain, sessionKey) o, err := MakeObfuscator(EncryptionMethodPlain, sessionKey)
if err != nil { assert.NoError(t, err)
t.Errorf("failed to generate obfuscator %v", err) run(o, t)
} else {
run(obfuscator, t)
}
}) })
t.Run("aes-256-gcm", func(t *testing.T) { t.Run("aes-256-gcm", func(t *testing.T) {
obfuscator, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey) o, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey)
if err != nil { assert.NoError(t, err)
t.Errorf("failed to generate obfuscator %v", err) run(o, t)
} else {
run(obfuscator, t)
}
}) })
t.Run("aes-128-gcm", func(t *testing.T) { t.Run("aes-128-gcm", func(t *testing.T) {
obfuscator, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey) o, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey)
if err != nil { assert.NoError(t, err)
t.Errorf("failed to generate obfuscator %v", err) run(o, t)
} else {
run(obfuscator, t)
}
}) })
t.Run("chacha20-poly1305", func(t *testing.T) { t.Run("chacha20-poly1305", func(t *testing.T) {
obfuscator, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) o, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
if err != nil { assert.NoError(t, err)
t.Errorf("failed to generate obfuscator %v", err) run(o, t)
} else {
run(obfuscator, t)
}
}) })
t.Run("unknown encryption method", func(t *testing.T) { t.Run("unknown encryption method", func(t *testing.T) {
_, err := MakeObfuscator(0xff, sessionKey) _, err := MakeObfuscator(0xff, sessionKey)
if err == nil { assert.Error(t, err)
t.Errorf("unknown encryption mehtod error expected")
}
}) })
} }
func TestObfuscate(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
const testPayloadLen = 1024
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
f := Frame{
StreamID: 0,
Seq: 0,
Closing: 0,
Payload: testPayload,
}
runTest := func(t *testing.T, o Obfuscator) {
obfsBuf := make([]byte, testPayloadLen*2)
n, err := o.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
resultFrame := Frame{}
err = o.deobfuscate(&resultFrame, obfsBuf[:n])
assert.NoError(t, err)
assert.EqualValues(t, f, resultFrame)
}
t.Run("plain", func(t *testing.T) {
o := Obfuscator{
payloadCipher: nil,
SessionKey: sessionKey,
maxOverhead: salsa20NonceSize,
}
runTest(t, o)
})
t.Run("aes-128-gcm", func(t *testing.T) {
c, err := aes.NewCipher(sessionKey[:16])
assert.NoError(t, err)
payloadCipher, err := cipher.NewGCM(c)
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
SessionKey: sessionKey,
maxOverhead: payloadCipher.Overhead(),
}
runTest(t, o)
})
t.Run("aes-256-gcm", func(t *testing.T) {
c, err := aes.NewCipher(sessionKey[:])
assert.NoError(t, err)
payloadCipher, err := cipher.NewGCM(c)
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
SessionKey: sessionKey,
maxOverhead: payloadCipher.Overhead(),
}
runTest(t, o)
})
t.Run("chacha20-poly1305", func(t *testing.T) {
payloadCipher, err := chacha20poly1305.New(sessionKey[:])
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
SessionKey: sessionKey,
maxOverhead: payloadCipher.Overhead(),
}
runTest(t, o)
})
}
func BenchmarkObfs(b *testing.B) { func BenchmarkObfs(b *testing.B) {
testPayload := make([]byte, 1024) testPayload := make([]byte, 1024)
rand.Read(testPayload) rand.Read(testPayload)