Use 64bit frame Seq to prevent nonce reuse

This commit is contained in:
Andy Wang 2019-08-27 15:06:28 +01:00
parent 2006e5971a
commit 4fb1f55e2d
5 changed files with 30 additions and 75 deletions

View File

@ -2,7 +2,7 @@ package multiplex
type Frame struct { type Frame struct {
StreamID uint32 StreamID uint32
Seq uint32 Seq uint64
Closing uint8 Closing uint8
Payload []byte Payload []byte
} }

View File

@ -8,17 +8,17 @@ import (
"errors" "errors"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/salsa20" "golang.org/x/crypto/salsa20"
prand "math/rand"
) )
type Obfser func(*Frame, []byte) (int, error) type Obfser func(*Frame, []byte) (int, error)
type Deobfser func([]byte) (*Frame, error) type Deobfser func([]byte) (*Frame, error)
var u32 = binary.BigEndian.Uint32 var u32 = binary.BigEndian.Uint32
var u64 = binary.BigEndian.Uint64
var putU32 = binary.BigEndian.PutUint32 var putU32 = binary.BigEndian.PutUint32
var putU64 = binary.BigEndian.PutUint64
const HEADER_LEN = 12 const HEADER_LEN = 14
const ( const (
E_METHOD_PLAIN = iota E_METHOD_PLAIN = iota
@ -52,24 +52,10 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
header := useful[5 : 5+HEADER_LEN] header := useful[5 : 5+HEADER_LEN]
encryptedPayloadWithExtra := useful[5+HEADER_LEN:] encryptedPayloadWithExtra := useful[5+HEADER_LEN:]
// TODO: Once Seq wraps around, the chance of a nonce reuse will be 1/65536 which is unacceptably low
// prohibit Seq wrap around? simple solution : 2^32 messages per stream may be too little
//
// use uint64 Seq? Vastly reduces the complexity of frameSorter : concern with 64 bit number performance on
// embedded systems (frameSorter already has a non-trivial performance impact on RPi2B, can only be worse on
// mipsle). HOWEVER since frameSorter already deals with uint64, prehaps changing it totally wouldn't matter much?
//
// regular rekey? Improves security in general : when to rekey? Not easy to synchronise, also will add a decent
// amount of complexity
//
// LEANING TOWARDS uint64 Seq. Adds extra 2 bytes of overhead but shouldn't really matter that much
// header: [StreamID 4 bytes][Seq 4 bytes][Closing 1 byte][extraLen 1 bytes][random 2 bytes]
putU32(header[0:4], f.StreamID) putU32(header[0:4], f.StreamID)
putU32(header[4:8], f.Seq) putU64(header[4:12], f.Seq)
header[8] = f.Closing header[12] = f.Closing
header[9] = extraLen header[13] = extraLen
prand.Read(header[10:12])
if payloadCipher == nil { if payloadCipher == nil {
copy(encryptedPayloadWithExtra, f.Payload) copy(encryptedPayloadWithExtra, f.Payload)
@ -77,7 +63,7 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
rand.Read(encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-int(extraLen):]) rand.Read(encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-int(extraLen):])
} }
} else { } else {
ciphertext := payloadCipher.Seal(nil, header, f.Payload, nil) ciphertext := payloadCipher.Seal(nil, header[:12], f.Payload, nil)
copy(encryptedPayloadWithExtra, ciphertext) copy(encryptedPayloadWithExtra, ciphertext)
} }
@ -98,22 +84,22 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
deobfs := func(in []byte) (*Frame, error) { deobfs := func(in []byte) (*Frame, error) {
if len(in) < 5+HEADER_LEN+8 { if len(in) < 5+HEADER_LEN+8 {
return nil, errors.New("Input cannot be shorter than 25 bytes") return nil, errors.New("Input cannot be shorter than 27 bytes")
} }
peeled := make([]byte, len(in)-5) peeled := make([]byte, len(in)-5)
copy(peeled, in[5:]) copy(peeled, in[5:])
header := peeled[:12] header := peeled[:HEADER_LEN]
pldWithOverHead := peeled[12:] // payload + potential overhead pldWithOverHead := peeled[HEADER_LEN:] // payload + potential overhead
nonce := peeled[len(peeled)-8:] nonce := peeled[len(peeled)-8:]
salsa20.XORKeyStream(header, header, nonce, &salsaKey) salsa20.XORKeyStream(header, header, nonce, &salsaKey)
streamID := u32(header[0:4]) streamID := u32(header[0:4])
seq := u32(header[4:8]) seq := u64(header[4:12])
closing := header[8] closing := header[12]
extraLen := header[9] extraLen := header[13]
usefulPayloadLen := len(pldWithOverHead) - int(extraLen) usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
if usefulPayloadLen < 0 { if usefulPayloadLen < 0 {
@ -129,7 +115,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
outputPayload = pldWithOverHead[:usefulPayloadLen] outputPayload = pldWithOverHead[:usefulPayloadLen]
} }
} else { } else {
_, err := payloadCipher.Open(pldWithOverHead[:0], header, pldWithOverHead, nil) _, err := payloadCipher.Open(pldWithOverHead[:0], header[:12], pldWithOverHead, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -30,7 +30,7 @@ type Stream struct {
recvBuf recvBuffer recvBuf recvBuffer
// atomic // atomic
nextSendSeq uint32 nextSendSeq uint64
writingM sync.RWMutex writingM sync.RWMutex
@ -115,7 +115,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
f := &Frame{ f := &Frame{
StreamID: s.id, StreamID: s.id,
Seq: atomic.AddUint32(&s.nextSendSeq, 1) - 1, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: 0, Closing: 0,
Payload: in, Payload: in,
} }
@ -163,7 +163,7 @@ func (s *Stream) Close() error {
prand.Read(pad) prand.Read(pad)
f := &Frame{ f := &Frame{
StreamID: s.id, StreamID: s.id,
Seq: atomic.AddUint32(&s.nextSendSeq, 1) - 1, Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: 1, Closing: 1,
Payload: pad, Payload: pad,
} }

View File

@ -13,17 +13,14 @@ package multiplex
import ( import (
"container/heap" "container/heap"
"errors" "errors"
"fmt"
"sync" "sync"
) )
type frameNode struct { type sorterHeap []*Frame
trueSeq uint64
frame Frame
}
type sorterHeap []*frameNode
func (sh sorterHeap) Less(i, j int) bool { func (sh sorterHeap) Less(i, j int) bool {
return sh[i].trueSeq < sh[j].trueSeq return sh[i].Seq < sh[j].Seq
} }
func (sh sorterHeap) Len() int { func (sh sorterHeap) Len() int {
return len(sh) return len(sh)
@ -33,7 +30,7 @@ func (sh sorterHeap) Swap(i, j int) {
} }
func (sh *sorterHeap) Push(x interface{}) { func (sh *sorterHeap) Push(x interface{}) {
*sh = append(*sh, x.(*frameNode)) *sh = append(*sh, x.(*Frame))
} }
func (sh *sorterHeap) Pop() interface{} { func (sh *sorterHeap) Pop() interface{} {
@ -47,17 +44,16 @@ func (sh *sorterHeap) Pop() interface{} {
type streamBuffer struct { type streamBuffer struct {
recvM sync.Mutex recvM sync.Mutex
nextRecvSeq uint32 nextRecvSeq uint64
rev int rev int
sh sorterHeap sh sorterHeap
wrapMode bool
buf *bufferedPipe buf *bufferedPipe
} }
func NewStreamBuffer() *streamBuffer { func NewStreamBuffer() *streamBuffer {
sb := &streamBuffer{ sb := &streamBuffer{
sh: []*frameNode{}, sh: []*Frame{},
rev: 0, rev: 0,
buf: NewBufferedPipe(), buf: NewBufferedPipe(),
} }
@ -80,41 +76,18 @@ func (sb *streamBuffer) Write(f Frame) error {
} else { } else {
sb.buf.Write(f.Payload) sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1 sb.nextRecvSeq += 1
if sb.nextRecvSeq == 0 { // getting wrapped
sb.rev += 1
sb.wrapMode = false
}
} }
return nil return nil
} }
node := &frameNode{
trueSeq: 0,
frame: f,
}
if f.Seq < sb.nextRecvSeq { if f.Seq < sb.nextRecvSeq {
// For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255 return fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
// e.g. we are on rev=0 (wrap has not happened yet)
// and we get the order of recv as 253 254 0 1
// after 254, nextN should be 255, but 0 is received and 0 < 255
// now 0 should have a trueSeq of 256
if !sb.wrapMode {
// wrapMode is true when the latest seq is wrapped but nextN is not
sb.wrapMode = true
}
node.trueSeq = uint64(1<<32)*uint64(sb.rev+1) + uint64(f.Seq) + 1
// +1 because wrapped 0 should have trueSeq of 256 instead of 255
// when this bit was run on 1, the trueSeq of 1 would become 256
} else {
node.trueSeq = uint64(1<<32)*uint64(sb.rev) + uint64(f.Seq)
// when this bit was run on 255, the trueSeq of 255 would be 255
} }
heap.Push(&sb.sh, node) heap.Push(&sb.sh, &f)
// Keep popping from the heap until empty or to the point that the wanted seq was not received // Keep popping from the heap until empty or to the point that the wanted seq was not received
for len(sb.sh) > 0 && sb.sh[0].frame.Seq == sb.nextRecvSeq { for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
f = heap.Pop(&sb.sh).(*frameNode).frame f = *heap.Pop(&sb.sh).(*Frame)
if f.Closing == 1 { if f.Closing == 1 {
// empty data indicates closing signal // empty data indicates closing signal
sb.buf.Close() sb.buf.Close()
@ -122,10 +95,6 @@ func (sb *streamBuffer) Write(f Frame) error {
} else { } else {
sb.buf.Write(f.Payload) sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1 sb.nextRecvSeq += 1
if sb.nextRecvSeq == 0 { // getting wrapped
sb.rev += 1
sb.wrapMode = false
}
} }
} }
return nil return nil

View File

@ -17,12 +17,12 @@ func TestRecvNewFrame(t *testing.T) {
test := func(set []uint64, ct *testing.T) { test := func(set []uint64, ct *testing.T) {
sb := NewStreamBuffer() sb := NewStreamBuffer()
sb.nextRecvSeq = uint32(set[0]) sb.nextRecvSeq = set[0]
for _, n := range set { for _, n := range set {
bu64 := make([]byte, 8) bu64 := make([]byte, 8)
binary.BigEndian.PutUint64(bu64, n) binary.BigEndian.PutUint64(bu64, n)
frame := Frame{ frame := Frame{
Seq: uint32(n), Seq: n,
Payload: bu64, Payload: bu64,
} }
sb.Write(frame) sb.Write(frame)