diff --git a/internal/client/state.go b/internal/client/state.go index c26f839..0ee914c 100644 --- a/internal/client/state.go +++ b/internal/client/state.go @@ -163,8 +163,7 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca switch strings.ToLower(raw.EncryptionMethod) { case "plain": auth.EncryptionMethod = mux.EncryptionMethodPlain - case "aes-gcm": - case "aes-256-gcm": + case "aes-gcm", "aes-256-gcm": auth.EncryptionMethod = mux.EncryptionMethodAES256GCM case "aes-128-gcm": auth.EncryptionMethod = mux.EncryptionMethodAES128GCM diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 91c9a76..9de2270 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -11,6 +11,9 @@ import ( "golang.org/x/crypto/salsa20" ) +type Obfser func(*Frame, []byte, int) (int, error) +type Deobfser func(*Frame, []byte) error + const frameHeaderLength = 14 const salsa20NonceSize = 8 @@ -23,15 +26,21 @@ const ( // Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames. type Obfuscator struct { - payloadCipher cipher.AEAD - + // Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header + Obfs Obfser + // Remove TLS header, decrypt and unmarshall frames + Deobfs Deobfser SessionKey [32]byte maxOverhead int } -// obfuscate adds multiplexing headers, encrypt and add TLS header -func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { +// MakeObfs returns a function of type Obfser. An Obfser takes three arguments: +// a *Frame with all the field set correctly, a []byte as buffer to put encrypted +// message in, and an int called payloadOffsetInBuf to be used when *Frame.payload +// is in the byte slice used as buffer (2nd argument). payloadOffsetInBuf specifies +// the index at which data belonging to *Frame.Payload starts in the buffer. +func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { // The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header // as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use // the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead()) @@ -63,99 +72,109 @@ func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (in // We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input // is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends // encrypted data. - payloadLen := len(f.Payload) - if payloadLen == 0 { - return 0, errors.New("payload cannot be empty") - } - var extraLen int - if o.payloadCipher == nil { - extraLen = salsa20NonceSize - payloadLen - if extraLen < 0 { - // if our payload is already greater than 8 bytes - extraLen = 0 + obfs := func(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { + payloadLen := len(f.Payload) + if payloadLen == 0 { + return 0, errors.New("payload cannot be empty") } - } else { - extraLen = o.payloadCipher.Overhead() - if extraLen < salsa20NonceSize { - return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") + var extraLen int + if payloadCipher == nil { + extraLen = salsa20NonceSize - payloadLen + if extraLen < 0 { + // if our payload is already greater than 8 bytes + extraLen = 0 + } + } else { + extraLen = payloadCipher.Overhead() + if extraLen < salsa20NonceSize { + return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") + } } - } - usefulLen := frameHeaderLength + payloadLen + extraLen - if len(buf) < usefulLen { - return 0, errors.New("obfs buffer too small") - } - // we do as much in-place as possible to save allocation - payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] - if payloadOffsetInBuf != frameHeaderLength { - // if payload is not at the correct location in buffer - copy(payload, f.Payload) - } - - header := buf[:frameHeaderLength] - binary.BigEndian.PutUint32(header[0:4], f.StreamID) - binary.BigEndian.PutUint64(header[4:12], f.Seq) - header[12] = f.Closing - header[13] = byte(extraLen) - - if o.payloadCipher == nil { - if extraLen != 0 { // read nonce - extra := buf[usefulLen-extraLen : usefulLen] - common.CryptoRandRead(extra) + usefulLen := frameHeaderLength + payloadLen + extraLen + if len(buf) < usefulLen { + return 0, errors.New("obfs buffer too small") } - } else { - o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil) + // we do as much in-place as possible to save allocation + payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] + if payloadOffsetInBuf != frameHeaderLength { + // if payload is not at the correct location in buffer + copy(payload, f.Payload) + } + + header := buf[:frameHeaderLength] + binary.BigEndian.PutUint32(header[0:4], f.StreamID) + binary.BigEndian.PutUint64(header[4:12], f.Seq) + header[12] = f.Closing + header[13] = byte(extraLen) + + if payloadCipher == nil { + if extraLen != 0 { // read nonce + extra := buf[usefulLen-extraLen : usefulLen] + common.CryptoRandRead(extra) + } + } else { + payloadCipher.Seal(payload[:0], header[:payloadCipher.NonceSize()], payload, nil) + } + + nonce := buf[usefulLen-salsa20NonceSize : usefulLen] + salsa20.XORKeyStream(header, header, nonce, &salsaKey) + + return usefulLen, nil } - - nonce := buf[usefulLen-salsa20NonceSize : usefulLen] - salsa20.XORKeyStream(header, header, nonce, &o.SessionKey) - - return usefulLen, nil + return obfs } -// deobfuscate removes TLS header, decrypt and unmarshall frames -func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error { - if len(in) < frameHeaderLength+salsa20NonceSize { - return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize) - } +// MakeDeobfs returns a function Deobfser. A Deobfser takes in a single byte slice, +// containing the message to be decrypted, and returns a *Frame containing the frame +// information and plaintext +func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { + // frame header length + minimum data size (i.e. nonce size of salsa20) + const minInputLen = frameHeaderLength + salsa20NonceSize + deobfs := func(f *Frame, in []byte) error { + if len(in) < minInputLen { + return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) + } - header := in[:frameHeaderLength] - pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead + header := in[:frameHeaderLength] + pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead - nonce := in[len(in)-salsa20NonceSize:] - salsa20.XORKeyStream(header, header, nonce, &o.SessionKey) + nonce := in[len(in)-salsa20NonceSize:] + salsa20.XORKeyStream(header, header, nonce, &salsaKey) - streamID := binary.BigEndian.Uint32(header[0:4]) - seq := binary.BigEndian.Uint64(header[4:12]) - closing := header[12] - extraLen := header[13] + streamID := binary.BigEndian.Uint32(header[0:4]) + seq := binary.BigEndian.Uint64(header[4:12]) + closing := header[12] + extraLen := header[13] - usefulPayloadLen := len(pldWithOverHead) - int(extraLen) - if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { - return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") - } + usefulPayloadLen := len(pldWithOverHead) - int(extraLen) + if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { + return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") + } - var outputPayload []byte + var outputPayload []byte - if o.payloadCipher == nil { - if extraLen == 0 { - outputPayload = pldWithOverHead + if payloadCipher == nil { + if extraLen == 0 { + outputPayload = pldWithOverHead + } else { + outputPayload = pldWithOverHead[:usefulPayloadLen] + } } else { + _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) + if err != nil { + return err + } outputPayload = pldWithOverHead[:usefulPayloadLen] } - } else { - _, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil) - if err != nil { - return err - } - outputPayload = pldWithOverHead[:usefulPayloadLen] - } - f.StreamID = streamID - f.Seq = seq - f.Closing = closing - f.Payload = outputPayload - return nil + f.StreamID = streamID + f.Seq = seq + f.Closing = closing + f.Payload = outputPayload + return nil + } + return deobfs } func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) { @@ -196,7 +215,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu } obfuscator.maxOverhead = payloadCipher.Overhead() default: - return obfuscator, fmt.Errorf("unknown encryption method valued %v", encryptionMethod) + return obfuscator, errors.New("Unknown encryption method") } if payloadCipher != nil { @@ -205,5 +224,7 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu } } + obfuscator.Obfs = MakeObfs(sessionKey, payloadCipher) + obfuscator.Deobfs = MakeDeobfs(sessionKey, payloadCipher) return } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 78a760d..6fd9916 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -19,14 +19,14 @@ func TestGenerateObfs(t *testing.T) { obfsBuf := make([]byte, 512) _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) - i, err := obfuscator.obfuscate(testFrame, obfsBuf, 0) + i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) if err != nil { ct.Error("failed to obfs ", err) return } var resultFrame Frame - err = obfuscator.deobfuscate(&resultFrame, obfsBuf[:i]) + err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) if err != nil { ct.Error("failed to deobfs ", err) return @@ -96,57 +96,40 @@ func BenchmarkObfs(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) - obfuscator := Obfuscator{ - payloadCipher: payloadCipher, - SessionKey: key, - maxOverhead: payloadCipher.Overhead(), - } - + obfs := MakeObfs(key, payloadCipher) b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs(testFrame, obfsBuf, 0) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfuscator := Obfuscator{ - payloadCipher: payloadCipher, - SessionKey: key, - maxOverhead: payloadCipher.Overhead(), - } + obfs := MakeObfs(key, payloadCipher) b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs(testFrame, obfsBuf, 0) } }) b.Run("plain", func(b *testing.B) { - obfuscator := Obfuscator{ - payloadCipher: nil, - SessionKey: key, - maxOverhead: salsa20NonceSize, - } + obfs := MakeObfs(key, nil) b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs(testFrame, obfsBuf, 0) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:]) + payloadCipher, _ := chacha20poly1305.New(key[:16]) - obfuscator := Obfuscator{ - payloadCipher: payloadCipher, - SessionKey: key, - maxOverhead: payloadCipher.Overhead(), - } + obfs := MakeObfs(key, payloadCipher) b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs(testFrame, obfsBuf, 0) } }) } @@ -168,70 +151,57 @@ func BenchmarkDeobfs(b *testing.B) { b.Run("AES256GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) - obfuscator := Obfuscator{ - payloadCipher: payloadCipher, - SessionKey: key, - maxOverhead: payloadCipher.Overhead(), - } - n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs := MakeObfs(key, payloadCipher) + n, _ := obfs(testFrame, obfsBuf, 0) + deobfs := MakeDeobfs(key, payloadCipher) frame := new(Frame) b.SetBytes(int64(n)) b.ResetTimer() for i := 0; i < b.N; i++ { - obfuscator.deobfuscate(frame, obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfuscator := Obfuscator{ - payloadCipher: payloadCipher, - SessionKey: key, - maxOverhead: payloadCipher.Overhead(), - } - n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs := MakeObfs(key, payloadCipher) + n, _ := obfs(testFrame, obfsBuf, 0) + deobfs := MakeDeobfs(key, payloadCipher) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - obfuscator.deobfuscate(frame, obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("plain", func(b *testing.B) { - obfuscator := Obfuscator{ - payloadCipher: nil, - SessionKey: key, - maxOverhead: salsa20NonceSize, - } - n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs := MakeObfs(key, nil) + n, _ := obfs(testFrame, obfsBuf, 0) + deobfs := MakeDeobfs(key, nil) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - obfuscator.deobfuscate(frame, obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:]) + payloadCipher, _ := chacha20poly1305.New(key[:16]) - obfuscator := Obfuscator{ - payloadCipher: nil, - SessionKey: key, - maxOverhead: payloadCipher.Overhead(), - } - - n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) + obfs := MakeObfs(key, payloadCipher) + n, _ := obfs(testFrame, obfsBuf, 0) + deobfs := MakeDeobfs(key, payloadCipher) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - obfuscator.deobfuscate(frame, obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index e05e399..0113afa 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -232,7 +232,7 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { frame := sesh.recvFramePool.Get().(*Frame) defer sesh.recvFramePool.Put(frame) - err := sesh.deobfuscate(frame, data) + err := sesh.Deobfs(frame, data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } @@ -331,7 +331,7 @@ func (sesh *Session) Close() error { Closing: closingSession, Payload: payload, } - i, err := sesh.obfuscate(f, *buf, frameHeaderLength) + i, err := sesh.Obfs(f, *buf, frameHeaderLength) if err != nil { return err } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index dfa3dbb..c942b10 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -4,10 +4,7 @@ import ( "bytes" "github.com/cbeuw/connutil" "github.com/stretchr/testify/assert" - "io" - "io/ioutil" "math/rand" - "net" "strconv" "sync" "sync/atomic" @@ -19,188 +16,73 @@ var seshConfigs = map[string]SessionConfig{ "ordered": {}, "unordered": {Unordered: true}, } -var encryptionMethods = map[string]byte{ - "plain": EncryptionMethodPlain, - "aes-256-gcm": EncryptionMethodAES256GCM, - "aes-128-gcm": EncryptionMethodAES128GCM, - "chacha20poly1305": EncryptionMethodChaha20Poly1305, -} const testPayloadLen = 1024 const obfsBufLen = testPayloadLen * 2 func TestRecvDataFromRemote(t *testing.T) { + testPayload := make([]byte, testPayloadLen) + rand.Read(testPayload) + f := &Frame{ + 1, + 0, + 0, + testPayload, + } + obfsBuf := make([]byte, obfsBufLen) + var sessionKey [32]byte rand.Read(sessionKey[:]) + MakeObfuscatorUnwrap := func(method byte, sessionKey [32]byte) Obfuscator { + ret, err := MakeObfuscator(method, sessionKey) + if err != nil { + t.Fatalf("failed to make an obfuscator: %v", err) + } + return ret + } + + encryptionMethods := map[string]Obfuscator{ + "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), + "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAES256GCM, sessionKey), + "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), + } + for seshType, seshConfig := range seshConfigs { seshConfig := seshConfig t.Run(seshType, func(t *testing.T) { - var err error - seshConfig.Obfuscator, err = MakeObfuscator(EncryptionMethodPlain, sessionKey) - if err != nil { - t.Fatalf("failed to make obfuscator: %v", err) - } - t.Run("initial frame", func(t *testing.T) { - sesh := MakeSession(0, seshConfig) - obfsBuf := make([]byte, obfsBufLen) - f := Frame{ - 1, - 0, - 0, - make([]byte, testPayloadLen), - } - rand.Read(f.Payload) - n, err := sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - stream, err := sesh.Accept() - assert.NoError(t, err) - - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - assert.NoError(t, err) - - assert.EqualValues(t, f.Payload, resultPayload) - }) - - t.Run("two frames in order", func(t *testing.T) { - sesh := MakeSession(0, seshConfig) - obfsBuf := make([]byte, obfsBufLen) - f := Frame{ - 1, - 0, - 0, - make([]byte, testPayloadLen), - } - rand.Read(f.Payload) - n, err := sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - stream, err := sesh.Accept() - assert.NoError(t, err) - - resultPayload := make([]byte, testPayloadLen) - _, err = io.ReadFull(stream, resultPayload) - assert.NoError(t, err) - - assert.EqualValues(t, f.Payload, resultPayload) - - f.Seq += 1 - rand.Read(f.Payload) - n, err = sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - - _, err = io.ReadFull(stream, resultPayload) - assert.NoError(t, err) - - assert.EqualValues(t, f.Payload, resultPayload) - }) - - t.Run("two frames in order", func(t *testing.T) { - sesh := MakeSession(0, seshConfig) - obfsBuf := make([]byte, obfsBufLen) - f := Frame{ - 1, - 0, - 0, - make([]byte, testPayloadLen), - } - rand.Read(f.Payload) - n, err := sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - stream, err := sesh.Accept() - assert.NoError(t, err) - - resultPayload := make([]byte, testPayloadLen) - _, err = io.ReadFull(stream, resultPayload) - assert.NoError(t, err) - - assert.EqualValues(t, f.Payload, resultPayload) - - f.Seq += 1 - rand.Read(f.Payload) - n, err = sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - - _, err = io.ReadFull(stream, resultPayload) - assert.NoError(t, err) - - assert.EqualValues(t, f.Payload, resultPayload) - }) - - if seshType == "ordered" { - t.Run("frames out of order", func(t *testing.T) { + for method, obfuscator := range encryptionMethods { + obfuscator := obfuscator + t.Run(method, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig) - obfsBuf := make([]byte, obfsBufLen) - f := Frame{ - 1, - 0, - 0, - nil, + n, err := sesh.Obfs(f, obfsBuf, 0) + if err != nil { + t.Error(err) + return + } + err = sesh.recvDataFromRemote(obfsBuf[:n]) + if err != nil { + t.Error(err) + return + } + stream, err := sesh.Accept() + if err != nil { + t.Error(err) + return } - // First frame - seq0 := make([]byte, testPayloadLen) - rand.Read(seq0) - f.Seq = 0 - f.Payload = seq0 - n, err := sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - - // Third frame - seq2 := make([]byte, testPayloadLen) - rand.Read(seq2) - f.Seq = 2 - f.Payload = seq2 - n, err = sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - - // Second frame - seq1 := make([]byte, testPayloadLen) - rand.Read(seq1) - f.Seq = 1 - f.Payload = seq1 - n, err = sesh.obfuscate(&f, obfsBuf, 0) - assert.NoError(t, err) - err = sesh.recvDataFromRemote(obfsBuf[:n]) - assert.NoError(t, err) - - // Expect things to receive in order - stream, err := sesh.Accept() - assert.NoError(t, err) - resultPayload := make([]byte, testPayloadLen) - - // First - _, err = io.ReadFull(stream, resultPayload) - assert.NoError(t, err) - assert.EqualValues(t, seq0, resultPayload) - - // Second - _, err = io.ReadFull(stream, resultPayload) - assert.NoError(t, err) - assert.EqualValues(t, seq1, resultPayload) - - // Third - _, err = io.ReadFull(stream, resultPayload) - assert.NoError(t, err) - assert.EqualValues(t, seq2, resultPayload) + _, err = stream.Read(resultPayload) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(testPayload, resultPayload) { + t.Errorf("Expecting %x, got %x", testPayload, resultPayload) + } }) } - }) } } @@ -212,9 +94,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) + obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) seshConfig := seshConfigs["ordered"] - seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey) + seshConfig.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig) f1 := &Frame{ @@ -224,7 +107,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { testPayload, } // create stream 1 - n, _ := sesh.obfuscate(f1, obfsBuf, 0) + n, _ := sesh.Obfs(f1, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -246,7 +129,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.obfuscate(f2, obfsBuf, 0) + n, _ = sesh.Obfs(f2, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 2: %v", err) @@ -268,7 +151,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingStream, testPayload, } - n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) + n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) @@ -297,7 +180,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { } // close stream 1 again - n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) + n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1 %v", err) @@ -320,7 +203,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { Closing: closingSession, Payload: testPayload, } - n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0) + n, _ = sesh.Obfs(fCloseSession, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving session closing frame: %v", err) @@ -350,9 +233,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) + obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) seshConfig := seshConfigs["ordered"] - seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey) + seshConfig.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig) // receive stream 1 closing first @@ -362,7 +246,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingStream, testPayload, } - n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0) + n, _ := sesh.Obfs(f1CloseStream, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) @@ -384,7 +268,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.obfuscate(f1, obfsBuf, 0) + n, _ = sesh.Obfs(f1, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -446,7 +330,7 @@ func TestParallelStreams(t *testing.T) { wg.Add(1) go func(frame *Frame) { obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.obfuscate(frame, obfsBuf, 0) + n, _ := sesh.Obfs(frame, obfsBuf, 0) obfsBuf = obfsBuf[0:n] err := sesh.recvDataFromRemote(obfsBuf) @@ -531,7 +415,7 @@ func TestSession_timeoutAfter(t *testing.T) { } } -func BenchmarkRecvDataFromRemote(b *testing.B) { +func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) f := &Frame{ @@ -544,34 +428,33 @@ func BenchmarkRecvDataFromRemote(b *testing.B) { var sessionKey [32]byte rand.Read(sessionKey[:]) + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAES256GCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } + const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic - for name, ep := range encryptionMethods { + for name, ep := range table { ep := ep b.Run(name, func(b *testing.B) { - for seshType, seshConfig := range seshConfigs { - b.Run(seshType, func(b *testing.B) { - seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) - sesh := MakeSession(0, seshConfig) + seshConfig := seshConfigs["ordered"] + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) - go func() { - stream, _ := sesh.Accept() - stream.(*Stream).WriteTo(ioutil.Discard) - }() + binaryFrames := [maxIter][]byte{} + for i := 0; i < maxIter; i++ { + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(f, obfsBuf, 0) + binaryFrames[i] = obfsBuf[:n] + f.Seq++ + } - binaryFrames := [maxIter][]byte{} - for i := 0; i < maxIter; i++ { - obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.obfuscate(f, obfsBuf, 0) - binaryFrames[i] = obfsBuf[:n] - f.Seq++ - } - - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(binaryFrames[i]) - } - }) + b.SetBytes(int64(len(f.Payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + sesh.recvDataFromRemote(binaryFrames[i]) } }) } @@ -581,13 +464,21 @@ func BenchmarkMultiStreamWrite(b *testing.B) { var sessionKey [32]byte rand.Read(sessionKey[:]) + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAES256GCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } + testPayload := make([]byte, testPayloadLen) - for name, ep := range encryptionMethods { + for name, ep := range table { b.Run(name, func(b *testing.B) { for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig b.Run(seshType, func(b *testing.B) { - seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig) sesh.AddConnection(connutil.Discard()) b.ResetTimer() @@ -603,36 +494,3 @@ func BenchmarkMultiStreamWrite(b *testing.B) { }) } } - -func BenchmarkLatency(b *testing.B) { - var sessionKey [32]byte - rand.Read(sessionKey[:]) - - for name, ep := range encryptionMethods { - b.Run(name, func(b *testing.B) { - for seshType, seshConfig := range seshConfigs { - b.Run(seshType, func(b *testing.B) { - seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) - clientSesh := MakeSession(0, seshConfig) - serverSesh := MakeSession(0, seshConfig) - - c, s := net.Pipe() - clientSesh.AddConnection(c) - serverSesh.AddConnection(s) - - buf := make([]byte, 64) - clientStream, _ := clientSesh.OpenStream() - clientStream.Write(buf) - serverStream, _ := serverSesh.Accept() - io.ReadFull(serverStream, buf) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - clientStream.Write(buf) - io.ReadFull(serverStream, buf) - } - }) - } - }) - } -} diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index ffd7e23..b29359f 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -113,7 +113,7 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { } func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { - cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf) + cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf) s.writingFrame.Seq++ if err != nil { return err diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index eb13bc8..895f35e 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -142,7 +142,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0) + i, _ := sesh.Obfs(dataFrame, obfsBuf, 0) _, err := writingEnd.Write(obfsBuf[:i]) if err != nil { t.Error("failed to write from remote end") @@ -185,7 +185,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, err := sesh.obfuscate(dataFrame, obfsBuf, 0) + i, err := sesh.Obfs(dataFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -207,7 +207,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.obfuscate(closingFrame, obfsBuf, 0) + i, err = sesh.Obfs(closingFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -223,7 +223,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0) + i, err = sesh.Obfs(closingFrameDup, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -275,7 +275,7 @@ func TestStream_Read(t *testing.T) { obfsBuf := make([]byte, 512) t.Run("Plain read", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.obfuscate(f, obfsBuf, 0) + i, _ := sesh.Obfs(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, err := sesh.Accept() @@ -300,7 +300,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Nil buf", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.obfuscate(f, obfsBuf, 0) + i, _ := sesh.Obfs(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -312,7 +312,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after stream close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.obfuscate(f, obfsBuf, 0) + i, _ := sesh.Obfs(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -337,7 +337,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after session close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.obfuscate(f, obfsBuf, 0) + i, _ := sesh.Obfs(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept()