diff --git a/internal/common/crypto.go b/internal/common/crypto.go index f5585c9..71b3f3f 100644 --- a/internal/common/crypto.go +++ b/internal/common/crypto.go @@ -4,6 +4,7 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "errors" "io" "time" @@ -19,6 +20,11 @@ func AESGCMEncrypt(nonce []byte, key []byte, plaintext []byte) ([]byte, error) { if err != nil { return nil, err } + if len(nonce) != aesgcm.NonceSize() { + // check here so it doesn't panic + return nil, errors.New("incorrect nonce size") + } + return aesgcm.Seal(nil, nonce, plaintext, nil), nil } @@ -31,6 +37,10 @@ func AESGCMDecrypt(nonce []byte, key []byte, ciphertext []byte) ([]byte, error) if err != nil { return nil, err } + if len(nonce) != aesgcm.NonceSize() { + // check here so it doesn't panic + return nil, errors.New("incorrect nonce size") + } plain, err := aesgcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, err @@ -51,12 +61,12 @@ func RandRead(randSource io.Reader, buf []byte) { 100 * time.Millisecond, 300 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second, 3 * time.Second, 5 * time.Second} for i := 0; i < 10; i++ { - log.Errorf("Failed to get cryptographic random bytes: %v. Retrying...", err) - _, err = rand.Read(buf) + log.Errorf("Failed to get random bytes: %v. Retrying...", err) + _, err = randSource.Read(buf) if err == nil { return } - time.Sleep(time.Millisecond * waitDur[i]) + time.Sleep(waitDur[i]) } - log.Fatal("Cannot get cryptographic random bytes after 10 retries") + log.Fatal("Cannot get random bytes after 10 retries") }