Eliminate some bounds check

This commit is contained in:
Andy Wang 2020-04-12 16:10:48 +01:00
parent f05cc19dbc
commit f0e8b4556e
1 changed files with 17 additions and 17 deletions

View File

@ -43,35 +43,35 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
// we need the encrypted data to be at least 8 bytes to be used as nonce for salsa20 stream header encryption // we need the encrypted data to be at least 8 bytes to be used as nonce for salsa20 stream header encryption
// this will be the case if the encryption method is an AEAD cipher, however for plain, it's well possible // this will be the case if the encryption method is an AEAD cipher, however for plain, it's well possible
// that the frame payload is smaller than 8 bytes, so we need to add on the difference // that the frame payload is smaller than 8 bytes, so we need to add on the difference
var extraLen uint8 var extraLen int
if payloadCipher == nil { if payloadCipher == nil {
if len(f.Payload) < 8 { if extraLen = 8 - len(f.Payload); extraLen < 0 {
extraLen = uint8(8 - len(f.Payload)) extraLen = 0
} }
} else { } else {
extraLen = uint8(payloadCipher.Overhead()) extraLen = payloadCipher.Overhead()
if extraLen < 8 {
return 0, errors.New("AEAD's Overhead cannot be fewer than 8 bytes")
}
} }
// usefulLen is the amount of bytes that will be eventually sent off // usefulLen is the amount of bytes that will be eventually sent off
usefulLen := HEADER_LEN + len(f.Payload) + int(extraLen) usefulLen := HEADER_LEN + len(f.Payload) + extraLen
if len(buf) < usefulLen { if usefulLen < HEADER_LEN || len(buf) < usefulLen { // compiler hint to eliminate bound check
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
// we do as much in-place as possible to save allocation // we do as much in-place as possible to save allocation
useful := buf[:usefulLen] // stream header + payload + potential overhead header := buf[:HEADER_LEN]
header := useful[:HEADER_LEN] encryptedPayloadWithExtra := buf[HEADER_LEN:usefulLen]
encryptedPayloadWithExtra := useful[HEADER_LEN:]
putU32(header[0:4], f.StreamID) putU32(header[0:4], f.StreamID)
putU64(header[4:12], f.Seq) putU64(header[4:12], f.Seq)
header[12] = f.Closing header[12] = f.Closing
header[13] = extraLen header[13] = byte(extraLen)
if payloadCipher == nil { if payloadCipher == nil {
copy(encryptedPayloadWithExtra, f.Payload) copy(encryptedPayloadWithExtra, f.Payload)
if extraLen != 0 { if extraLen != 0 { // read nonce
util.CryptoRandRead(encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-int(extraLen):]) util.CryptoRandRead(encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-extraLen:])
} }
} else { } else {
ciphertext := payloadCipher.Seal(nil, header[:12], f.Payload, nil) ciphertext := payloadCipher.Seal(nil, header[:12], f.Payload, nil)
@ -88,7 +88,7 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
// stream header length + minimum data size (i.e. nonce size of salsa20) // stream header length + minimum data size (i.e. nonce size of salsa20)
minInputLen := HEADER_LEN + 8 const minInputLen = HEADER_LEN + 8
deobfs := func(in []byte) (*Frame, error) { deobfs := func(in []byte) (*Frame, error) {
if len(in) < minInputLen { if len(in) < minInputLen {
return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen)
@ -106,8 +106,8 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
extraLen := header[13] extraLen := header[13]
usefulPayloadLen := len(pldWithOverHead) - int(extraLen) usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
if usefulPayloadLen < 0 { if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
return nil, errors.New("extra length is greater than total pldWithOverHead length") return nil, errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
} }
var outputPayload []byte var outputPayload []byte