diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 0c1f8c6..1379072 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -12,7 +12,7 @@ import ( ) type Obfser func(*Frame, []byte, int) (int, error) -type Deobfser func([]byte) (*Frame, error) +type Deobfser func(*Frame, []byte) error var u32 = binary.BigEndian.Uint32 var u64 = binary.BigEndian.Uint64 @@ -135,9 +135,9 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { // frame header length + minimum data size (i.e. nonce size of salsa20) const minInputLen = frameHeaderLength + salsa20NonceSize - deobfs := func(in []byte) (*Frame, error) { + deobfs := func(f *Frame, in []byte) error { if len(in) < minInputLen { - return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) + return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) } header := in[:frameHeaderLength] @@ -153,7 +153,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { usefulPayloadLen := len(pldWithOverHead) - int(extraLen) if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { - return nil, errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") + return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") } var outputPayload []byte @@ -167,18 +167,16 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { } else { _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) if err != nil { - return nil, err + return err } outputPayload = pldWithOverHead[:usefulPayloadLen] } - ret := &Frame{ - StreamID: streamID, - Seq: seq, - Closing: closing, - Payload: outputPayload, - } - return ret, nil + f.StreamID = streamID + f.Seq = seq + f.Closing = closing + f.Payload = outputPayload + return nil } return deobfs } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 6cbbb5b..99f4f5f 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -17,8 +17,7 @@ func TestGenerateObfs(t *testing.T) { run := func(obfuscator Obfuscator, ct *testing.T) { obfsBuf := make([]byte, 512) - f := &Frame{} - _testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42))) + _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) if err != nil { @@ -26,7 +25,8 @@ func TestGenerateObfs(t *testing.T) { return } - resultFrame, err := obfuscator.Deobfs(obfsBuf[:i]) + var resultFrame Frame + err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) if err != nil { ct.Error("failed to deobfs ", err) return @@ -148,10 +148,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.SetBytes(int64(n)) b.ResetTimer() for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("AES128GCM", func(b *testing.B) { @@ -162,10 +163,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("plain", func(b *testing.B) { @@ -173,10 +175,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, nil) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("chacha20Poly1305", func(b *testing.B) { @@ -186,10 +189,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) } diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 9964f74..6a88aa3 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -65,6 +65,9 @@ type Session struct { activeStreamCount uint32 streams sync.Map + // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame + recvFramePool sync.Pool + // Switchboard manages all connections to remote sb *switchboard @@ -89,6 +92,7 @@ func MakeSession(id uint32, config SessionConfig) *Session { SessionConfig: config, nextStreamID: 1, acceptCh: make(chan *Stream, acceptBacklog), + recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }}, } sesh.addrs.Store([]net.Addr{nil, nil}) @@ -212,7 +216,10 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { // to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new // stream and then writes to the stream buffer func (sesh *Session) recvDataFromRemote(data []byte) error { - frame, err := sesh.Deobfs(data) + frame := sesh.recvFramePool.Get().(*Frame) + defer sesh.recvFramePool.Put(frame) + + err := sesh.Deobfs(frame, data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) }