Reduce allocation of frame objects on receiving data

This commit is contained in:
Andy Wang 2020-12-22 14:45:29 +00:00
parent 104117cafb
commit 5a3f63f101
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
3 changed files with 29 additions and 20 deletions

View File

@ -12,7 +12,7 @@ import (
) )
type Obfser func(*Frame, []byte, int) (int, error) type Obfser func(*Frame, []byte, int) (int, error)
type Deobfser func([]byte) (*Frame, error) type Deobfser func(*Frame, []byte) error
var u32 = binary.BigEndian.Uint32 var u32 = binary.BigEndian.Uint32
var u64 = binary.BigEndian.Uint64 var u64 = binary.BigEndian.Uint64
@ -135,9 +135,9 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
// frame header length + minimum data size (i.e. nonce size of salsa20) // frame header length + minimum data size (i.e. nonce size of salsa20)
const minInputLen = frameHeaderLength + salsa20NonceSize const minInputLen = frameHeaderLength + salsa20NonceSize
deobfs := func(in []byte) (*Frame, error) { deobfs := func(f *Frame, in []byte) error {
if len(in) < minInputLen { if len(in) < minInputLen {
return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen)
} }
header := in[:frameHeaderLength] header := in[:frameHeaderLength]
@ -153,7 +153,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
usefulPayloadLen := len(pldWithOverHead) - int(extraLen) usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
return nil, errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
} }
var outputPayload []byte var outputPayload []byte
@ -167,18 +167,16 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
} else { } else {
_, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil)
if err != nil { if err != nil {
return nil, err return err
} }
outputPayload = pldWithOverHead[:usefulPayloadLen] outputPayload = pldWithOverHead[:usefulPayloadLen]
} }
ret := &Frame{ f.StreamID = streamID
StreamID: streamID, f.Seq = seq
Seq: seq, f.Closing = closing
Closing: closing, f.Payload = outputPayload
Payload: outputPayload, return nil
}
return ret, nil
} }
return deobfs return deobfs
} }

View File

@ -17,8 +17,7 @@ func TestGenerateObfs(t *testing.T) {
run := func(obfuscator Obfuscator, ct *testing.T) { run := func(obfuscator Obfuscator, ct *testing.T) {
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
f := &Frame{} _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42)))
_testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42)))
testFrame := _testFrame.Interface().(*Frame) testFrame := _testFrame.Interface().(*Frame)
i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) i, err := obfuscator.Obfs(testFrame, obfsBuf, 0)
if err != nil { if err != nil {
@ -26,7 +25,8 @@ func TestGenerateObfs(t *testing.T) {
return return
} }
resultFrame, err := obfuscator.Deobfs(obfsBuf[:i]) var resultFrame Frame
err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i])
if err != nil { if err != nil {
ct.Error("failed to deobfs ", err) ct.Error("failed to deobfs ", err)
return return
@ -148,10 +148,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, payloadCipher) deobfs := MakeDeobfs(key, payloadCipher)
frame := new(Frame)
b.SetBytes(int64(n)) b.SetBytes(int64(n))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
b.Run("AES128GCM", func(b *testing.B) { b.Run("AES128GCM", func(b *testing.B) {
@ -162,10 +163,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, payloadCipher) deobfs := MakeDeobfs(key, payloadCipher)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
b.Run("plain", func(b *testing.B) { b.Run("plain", func(b *testing.B) {
@ -173,10 +175,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, nil) deobfs := MakeDeobfs(key, nil)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
b.Run("chacha20Poly1305", func(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) {
@ -186,10 +189,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, payloadCipher) deobfs := MakeDeobfs(key, payloadCipher)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
} }

View File

@ -65,6 +65,9 @@ type Session struct {
activeStreamCount uint32 activeStreamCount uint32
streams sync.Map streams sync.Map
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool
// Switchboard manages all connections to remote // Switchboard manages all connections to remote
sb *switchboard sb *switchboard
@ -89,6 +92,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
SessionConfig: config, SessionConfig: config,
nextStreamID: 1, nextStreamID: 1,
acceptCh: make(chan *Stream, acceptBacklog), acceptCh: make(chan *Stream, acceptBacklog),
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
} }
sesh.addrs.Store([]net.Addr{nil, nil}) sesh.addrs.Store([]net.Addr{nil, nil})
@ -212,7 +216,10 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new // to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
// stream and then writes to the stream buffer // stream and then writes to the stream buffer
func (sesh *Session) recvDataFromRemote(data []byte) error { func (sesh *Session) recvDataFromRemote(data []byte) error {
frame, err := sesh.Deobfs(data) frame := sesh.recvFramePool.Get().(*Frame)
defer sesh.recvFramePool.Put(frame)
err := sesh.Deobfs(frame, data)
if err != nil { if err != nil {
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
} }