diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 1379072..97b2cd9 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -11,9 +11,6 @@ import ( "golang.org/x/crypto/salsa20" ) -type Obfser func(*Frame, []byte, int) (int, error) -type Deobfser func(*Frame, []byte) error - var u32 = binary.BigEndian.Uint32 var u64 = binary.BigEndian.Uint64 var putU32 = binary.BigEndian.PutUint32 @@ -30,21 +27,15 @@ const ( // Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames. type Obfuscator struct { - // Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header - Obfs Obfser - // Remove TLS header, decrypt and unmarshall frames - Deobfs Deobfser + payloadCipher cipher.AEAD + SessionKey [32]byte maxOverhead int } -// MakeObfs returns a function of type Obfser. An Obfser takes three arguments: -// a *Frame with all the field set correctly, a []byte as buffer to put encrypted -// message in, and an int called payloadOffsetInBuf to be used when *Frame.payload -// is in the byte slice used as buffer (2nd argument). payloadOffsetInBuf specifies -// the index at which data belonging to *Frame.Payload starts in the buffer. -func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { +// obfuscate adds multiplexing headers, encrypt and add TLS header +func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { // The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header // as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use // the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead()) @@ -76,109 +67,99 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { // We can't ensure its uniqueness ourselves, which is why plaintext mode must only be used when the user input // is already random-like. For Cloak it would normally mean that the user is using a proxy protocol that sends // encrypted data. - obfs := func(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { - payloadLen := len(f.Payload) - if payloadLen == 0 { - return 0, errors.New("payload cannot be empty") - } - var extraLen int - if payloadCipher == nil { - extraLen = salsa20NonceSize - payloadLen - if extraLen < 0 { - // if our payload is already greater than 8 bytes - extraLen = 0 - } - } else { - extraLen = payloadCipher.Overhead() - if extraLen < salsa20NonceSize { - return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") - } - } - - usefulLen := frameHeaderLength + payloadLen + extraLen - if len(buf) < usefulLen { - return 0, errors.New("obfs buffer too small") - } - // we do as much in-place as possible to save allocation - payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] - if payloadOffsetInBuf != frameHeaderLength { - // if payload is not at the correct location in buffer - copy(payload, f.Payload) - } - - header := buf[:frameHeaderLength] - putU32(header[0:4], f.StreamID) - putU64(header[4:12], f.Seq) - header[12] = f.Closing - header[13] = byte(extraLen) - - if payloadCipher == nil { - if extraLen != 0 { // read nonce - extra := buf[usefulLen-extraLen : usefulLen] - common.CryptoRandRead(extra) - } - } else { - payloadCipher.Seal(payload[:0], header[:payloadCipher.NonceSize()], payload, nil) - } - - nonce := buf[usefulLen-salsa20NonceSize : usefulLen] - salsa20.XORKeyStream(header, header, nonce, &salsaKey) - - return usefulLen, nil + payloadLen := len(f.Payload) + if payloadLen == 0 { + return 0, errors.New("payload cannot be empty") } - return obfs + var extraLen int + if o.payloadCipher == nil { + extraLen = salsa20NonceSize - payloadLen + if extraLen < 0 { + // if our payload is already greater than 8 bytes + extraLen = 0 + } + } else { + extraLen = o.payloadCipher.Overhead() + if extraLen < salsa20NonceSize { + return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes") + } + } + + usefulLen := frameHeaderLength + payloadLen + extraLen + if len(buf) < usefulLen { + return 0, errors.New("obfs buffer too small") + } + // we do as much in-place as possible to save allocation + payload := buf[frameHeaderLength : frameHeaderLength+payloadLen] + if payloadOffsetInBuf != frameHeaderLength { + // if payload is not at the correct location in buffer + copy(payload, f.Payload) + } + + header := buf[:frameHeaderLength] + putU32(header[0:4], f.StreamID) + putU64(header[4:12], f.Seq) + header[12] = f.Closing + header[13] = byte(extraLen) + + if o.payloadCipher == nil { + if extraLen != 0 { // read nonce + extra := buf[usefulLen-extraLen : usefulLen] + common.CryptoRandRead(extra) + } + } else { + o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil) + } + + nonce := buf[usefulLen-salsa20NonceSize : usefulLen] + salsa20.XORKeyStream(header, header, nonce, &o.SessionKey) + + return usefulLen, nil } -// MakeDeobfs returns a function Deobfser. A Deobfser takes in a single byte slice, -// containing the message to be decrypted, and returns a *Frame containing the frame -// information and plaintext -func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { - // frame header length + minimum data size (i.e. nonce size of salsa20) - const minInputLen = frameHeaderLength + salsa20NonceSize - deobfs := func(f *Frame, in []byte) error { - if len(in) < minInputLen { - return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) - } +// deobfuscate removes TLS header, decrypt and unmarshall frames +func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error { + if len(in) < frameHeaderLength+salsa20NonceSize { + return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize) + } - header := in[:frameHeaderLength] - pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead + header := in[:frameHeaderLength] + pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead - nonce := in[len(in)-salsa20NonceSize:] - salsa20.XORKeyStream(header, header, nonce, &salsaKey) + nonce := in[len(in)-salsa20NonceSize:] + salsa20.XORKeyStream(header, header, nonce, &o.SessionKey) - streamID := u32(header[0:4]) - seq := u64(header[4:12]) - closing := header[12] - extraLen := header[13] + streamID := u32(header[0:4]) + seq := u64(header[4:12]) + closing := header[12] + extraLen := header[13] - usefulPayloadLen := len(pldWithOverHead) - int(extraLen) - if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { - return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") - } + usefulPayloadLen := len(pldWithOverHead) - int(extraLen) + if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { + return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") + } - var outputPayload []byte + var outputPayload []byte - if payloadCipher == nil { - if extraLen == 0 { - outputPayload = pldWithOverHead - } else { - outputPayload = pldWithOverHead[:usefulPayloadLen] - } + if o.payloadCipher == nil { + if extraLen == 0 { + outputPayload = pldWithOverHead } else { - _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) - if err != nil { - return err - } outputPayload = pldWithOverHead[:usefulPayloadLen] } - - f.StreamID = streamID - f.Seq = seq - f.Closing = closing - f.Payload = outputPayload - return nil + } else { + _, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil) + if err != nil { + return err + } + outputPayload = pldWithOverHead[:usefulPayloadLen] } - return deobfs + + f.StreamID = streamID + f.Seq = seq + f.Closing = closing + f.Payload = outputPayload + return nil } func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfuscator, err error) { @@ -217,7 +198,5 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator Obfu } } - obfuscator.Obfs = MakeObfs(sessionKey, payloadCipher) - obfuscator.Deobfs = MakeDeobfs(sessionKey, payloadCipher) return } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 99f4f5f..2bee633 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -19,14 +19,14 @@ func TestGenerateObfs(t *testing.T) { obfsBuf := make([]byte, 512) _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) - i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) + i, err := obfuscator.obfuscate(testFrame, obfsBuf, 0) if err != nil { ct.Error("failed to obfs ", err) return } var resultFrame Frame - err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) + err = obfuscator.deobfuscate(&resultFrame, obfsBuf[:i]) if err != nil { ct.Error("failed to deobfs ", err) return @@ -88,40 +88,57 @@ func BenchmarkObfs(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) + obfuscator := Obfuscator{ + payloadCipher: nil, + SessionKey: key, + maxOverhead: salsa20NonceSize, + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:16]) + payloadCipher, _ := chacha20poly1305.New(key[:]) - obfs := MakeObfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } b.SetBytes(int64(len(testFrame.Payload))) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame, obfsBuf, 0) + obfuscator.obfuscate(testFrame, obfsBuf, 0) } }) } @@ -143,57 +160,70 @@ func BenchmarkDeobfs(b *testing.B) { b.Run("AES256GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.SetBytes(int64(n)) b.ResetTimer() for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("AES128GCM", func(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: payloadCipher, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, nil) + obfuscator := Obfuscator{ + payloadCipher: nil, + SessionKey: key, + maxOverhead: salsa20NonceSize, + } + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) b.Run("chacha20Poly1305", func(b *testing.B) { - payloadCipher, _ := chacha20poly1305.New(key[:16]) + payloadCipher, _ := chacha20poly1305.New(key[:]) - obfs := MakeObfs(key, payloadCipher) - n, _ := obfs(testFrame, obfsBuf, 0) - deobfs := MakeDeobfs(key, payloadCipher) + obfuscator := Obfuscator{ + payloadCipher: nil, + SessionKey: key, + maxOverhead: payloadCipher.Overhead(), + } + + n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(frame, obfsBuf[:n]) + obfuscator.deobfuscate(frame, obfsBuf[:n]) } }) } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 0113afa..e05e399 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -232,7 +232,7 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { frame := sesh.recvFramePool.Get().(*Frame) defer sesh.recvFramePool.Put(frame) - err := sesh.Deobfs(frame, data) + err := sesh.deobfuscate(frame, data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } @@ -331,7 +331,7 @@ func (sesh *Session) Close() error { Closing: closingSession, Payload: payload, } - i, err := sesh.Obfs(f, *buf, frameHeaderLength) + i, err := sesh.obfuscate(f, *buf, frameHeaderLength) if err != nil { return err } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index f4b32bb..990437a 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -56,7 +56,7 @@ func TestRecvDataFromRemote(t *testing.T) { t.Run(method, func(t *testing.T) { seshConfig.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig) - n, err := sesh.Obfs(f, obfsBuf, 0) + n, err := sesh.obfuscate(f, obfsBuf, 0) if err != nil { t.Error(err) return @@ -107,7 +107,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { testPayload, } // create stream 1 - n, _ := sesh.Obfs(f1, obfsBuf, 0) + n, _ := sesh.obfuscate(f1, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -129,7 +129,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.Obfs(f2, obfsBuf, 0) + n, _ = sesh.obfuscate(f2, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 2: %v", err) @@ -151,7 +151,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { closingStream, testPayload, } - n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) @@ -180,7 +180,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { } // close stream 1 again - n, _ = sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving stream closing frame for stream 1 %v", err) @@ -203,7 +203,7 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { Closing: closingSession, Payload: testPayload, } - n, _ = sesh.Obfs(fCloseSession, obfsBuf, 0) + n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving session closing frame: %v", err) @@ -246,7 +246,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingStream, testPayload, } - n, _ := sesh.Obfs(f1CloseStream, obfsBuf, 0) + n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0) err := sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) @@ -268,7 +268,7 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { closingNothing, testPayload, } - n, _ = sesh.Obfs(f1, obfsBuf, 0) + n, _ = sesh.obfuscate(f1, obfsBuf, 0) err = sesh.recvDataFromRemote(obfsBuf[:n]) if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) @@ -330,7 +330,7 @@ func TestParallelStreams(t *testing.T) { wg.Add(1) go func(frame *Frame) { obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(frame, obfsBuf, 0) + n, _ := sesh.obfuscate(frame, obfsBuf, 0) obfsBuf = obfsBuf[0:n] err := sesh.recvDataFromRemote(obfsBuf) @@ -446,7 +446,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { binaryFrames := [maxIter][]byte{} for i := 0; i < maxIter; i++ { obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(f, obfsBuf, 0) + n, _ := sesh.obfuscate(f, obfsBuf, 0) binaryFrames[i] = obfsBuf[:n] f.Seq++ } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index b29359f..ffd7e23 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -113,7 +113,7 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { } func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { - cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf) + cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf) s.writingFrame.Seq++ if err != nil { return err diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 893aa46..0435557 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -141,7 +141,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, _ := sesh.Obfs(dataFrame, obfsBuf, 0) + i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0) _, err := writingEnd.Write(obfsBuf[:i]) if err != nil { t.Error("failed to write from remote end") @@ -184,7 +184,7 @@ func TestStream_Close(t *testing.T) { writingEnd := common.NewTLSConn(rawWritingEnd) obfsBuf := make([]byte, 512) - i, err := sesh.Obfs(dataFrame, obfsBuf, 0) + i, err := sesh.obfuscate(dataFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -206,7 +206,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.Obfs(closingFrame, obfsBuf, 0) + i, err = sesh.obfuscate(closingFrame, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -222,7 +222,7 @@ func TestStream_Close(t *testing.T) { testPayload, } - i, err = sesh.Obfs(closingFrameDup, obfsBuf, 0) + i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0) if err != nil { t.Errorf("failed to obfuscate frame %v", err) } @@ -274,7 +274,7 @@ func TestStream_Read(t *testing.T) { obfsBuf := make([]byte, 512) t.Run("Plain read", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, err := sesh.Accept() @@ -299,7 +299,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Nil buf", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -311,7 +311,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after stream close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept() @@ -336,7 +336,7 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after session close", func(t *testing.T) { f.StreamID = streamID - i, _ := sesh.Obfs(f, obfsBuf, 0) + i, _ := sesh.obfuscate(f, obfsBuf, 0) streamID++ writingEnd.Write(obfsBuf[:i]) stream, _ := sesh.Accept()