-- |
-- Module      : Crypto.Cipher.AESGCMSIV
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Implementation of AES-GCM-SIV, an AEAD scheme with nonce misuse resistance
-- defined in <https://tools.ietf.org/html/rfc8452 RFC 8452>.
--
-- To achieve the nonce misuse-resistance property, encryption requires two
-- passes on the plaintext, hence no streaming API is provided.  This AEAD
-- operates on complete inputs held in memory.  For simplicity, the
-- implementation of decryption uses a similar pattern, with performance
-- penalty compared to an implementation which is able to merge both passes.
--
-- The specification allows inputs up to 2^36 bytes but this implementation
-- requires AAD and plaintext/ciphertext to be both smaller than 2^32 bytes.
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.Cipher.AESGCMSIV
    ( Nonce
    , nonce
    , generateNonce
    , encrypt
    , decrypt
    ) where

import Data.Bits
import Data.Word

import Foreign.C.Types
import Foreign.C.String
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (peekElemOff, poke, pokeElemOff)

import           Data.ByteArray
import qualified Data.ByteArray as B
import           Data.Memory.Endian (toLE)
import           Data.Memory.PtrMethods (memXor)

import Crypto.Cipher.AES.Primitive
import Crypto.Cipher.Types
import Crypto.Error
import Crypto.Internal.Compat (unsafeDoIO)
import Crypto.Random


-- 12-byte nonces

-- | Nonce value for AES-GCM-SIV, always 12 bytes.
newtype Nonce = Nonce Bytes deriving (Int -> Nonce -> ShowS
[Nonce] -> ShowS
Nonce -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Nonce] -> ShowS
$cshowList :: [Nonce] -> ShowS
show :: Nonce -> String
$cshow :: Nonce -> String
showsPrec :: Int -> Nonce -> ShowS
$cshowsPrec :: Int -> Nonce -> ShowS
Show, Nonce -> Nonce -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Nonce -> Nonce -> Bool
$c/= :: Nonce -> Nonce -> Bool
== :: Nonce -> Nonce -> Bool
$c== :: Nonce -> Nonce -> Bool
Eq, Nonce -> Int
forall p. Nonce -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. Nonce -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: forall p. Nonce -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall p. Nonce -> Ptr p -> IO ()
withByteArray :: forall p a. Nonce -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall p a. Nonce -> (Ptr p -> IO a) -> IO a
length :: Nonce -> Int
$clength :: Nonce -> Int
ByteArrayAccess)

-- | Nonce smart constructor.  Accepts only 12-byte inputs.
nonce :: ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce :: forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce iv
iv
    | forall ba. ByteArrayAccess ba => ba -> Int
B.length iv
iv forall a. Eq a => a -> a -> Bool
== Int
12 = forall a. a -> CryptoFailable a
CryptoPassed (Bytes -> Nonce
Nonce forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert iv
iv)
    | Bool
otherwise         = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid

-- | Generate a random nonce for use with AES-GCM-SIV.
generateNonce :: MonadRandom m => m Nonce
generateNonce :: forall (m :: * -> *). MonadRandom m => m Nonce
generateNonce = Bytes -> Nonce
Nonce forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
12


-- POLYVAL (mutable context)

newtype Polyval = Polyval Bytes

polyvalInit :: ScrubbedBytes -> IO Polyval
polyvalInit :: ScrubbedBytes -> IO Polyval
polyvalInit ScrubbedBytes
h = Bytes -> Polyval
Polyval forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Bytes
doInit
  where doInit :: IO Bytes
doInit = forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
272 forall a b. (a -> b) -> a -> b
$ \Ptr Polyval
pctx -> forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
h forall a b. (a -> b) -> a -> b
$ \Ptr CChar
ph ->
            Ptr Polyval -> Ptr CChar -> IO ()
c_aes_polyval_init Ptr Polyval
pctx Ptr CChar
ph

polyvalUpdate :: ByteArrayAccess ba => Polyval -> ba -> IO ()
polyvalUpdate :: forall ba. ByteArrayAccess ba => Polyval -> ba -> IO ()
polyvalUpdate (Polyval Bytes
ctx) ba
bs = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
ctx forall a b. (a -> b) -> a -> b
$ \Ptr Polyval
pctx ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
bs forall a b. (a -> b) -> a -> b
$ \Ptr CChar
pbs -> Ptr Polyval -> Ptr CChar -> CUInt -> IO ()
c_aes_polyval_update Ptr Polyval
pctx Ptr CChar
pbs CUInt
sz
  where sz :: CUInt
sz = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs)

polyvalFinalize :: Polyval -> IO ScrubbedBytes
polyvalFinalize :: Polyval -> IO ScrubbedBytes
polyvalFinalize (Polyval Bytes
ctx) = forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
16 forall a b. (a -> b) -> a -> b
$ \Ptr CChar
dst ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
ctx forall a b. (a -> b) -> a -> b
$ \Ptr Polyval
pctx -> Ptr Polyval -> Ptr CChar -> IO ()
c_aes_polyval_finalize Ptr Polyval
pctx Ptr CChar
dst

foreign import ccall unsafe "crypton_aes.h crypton_aes_polyval_init"
    c_aes_polyval_init :: Ptr Polyval -> CString -> IO ()

foreign import ccall "crypton_aes.h crypton_aes_polyval_update"
    c_aes_polyval_update :: Ptr Polyval -> CString -> CUInt -> IO ()

foreign import ccall unsafe "crypton_aes.h crypton_aes_polyval_finalize"
    c_aes_polyval_finalize :: Ptr Polyval -> CString -> IO ()


-- Key Generation

le32iv :: Word32 -> Nonce -> Bytes
le32iv :: Word32 -> Nonce -> Bytes
le32iv Word32
n (Nonce Bytes
iv) = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
16 forall a b. (a -> b) -> a -> b
$ \Ptr (LE Word32)
ptr -> do
    forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (LE Word32)
ptr (forall a. ByteSwap a => a -> LE a
toLE Word32
n)
    forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
copyByteArrayToPtr Bytes
iv (Ptr (LE Word32)
ptr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
4)

deriveKeys :: BlockCipher128 aes => aes -> Nonce -> (ScrubbedBytes, AES)
deriveKeys :: forall aes.
BlockCipher128 aes =>
aes -> Nonce -> (ScrubbedBytes, AES)
deriveKeys aes
aes Nonce
iv =
    case forall cipher. Cipher cipher => cipher -> KeySizeSpecifier
cipherKeySize aes
aes of
        KeySizeFixed Int
sz | Int
sz forall a. Integral a => a -> a -> a
`mod` Int
8 forall a. Eq a => a -> a -> Bool
== Int
0 ->
            let mak :: ScrubbedBytes
mak = [Word32] -> ScrubbedBytes
buildKey [Word32
0 .. Word32
1]
                key :: ScrubbedBytes
key = [Word32] -> ScrubbedBytes
buildKey [Word32
2 .. forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
sz forall a. Integral a => a -> a -> a
`div` Int
8) forall a. Num a => a -> a -> a
+ Word32
1]
                mek :: AES
mek = forall a. CryptoFailable a -> a
throwCryptoError (forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ScrubbedBytes
key)
             in (ScrubbedBytes
mak, AES
mek)
        KeySizeSpecifier
_ -> forall a. HasCallStack => String -> a
error String
"AESGCMSIV: invalid cipher"
  where
    idx :: Word32 -> View Bytes
idx Word32
n = forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt aes
aes (Word32 -> Nonce -> Bytes
le32iv Word32
n Nonce
iv) forall bytes. ByteArrayAccess bytes => bytes -> Int -> View bytes
`takeView` Int
8
    buildKey :: [Word32] -> ScrubbedBytes
buildKey = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Word32 -> View Bytes
idx


-- Encryption and decryption

lengthInvalid :: ByteArrayAccess ba => ba -> Bool
lengthInvalid :: forall ba. ByteArrayAccess ba => ba -> Bool
lengthInvalid ba
bs
    | forall b. FiniteBits b => b -> Int
finiteBitSize Int
len forall a. Ord a => a -> a -> Bool
> Int
32 = Int
len forall a. Ord a => a -> a -> Bool
>= Int
1 forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32
    | Bool
otherwise              = Bool
False
  where len :: Int
len = forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs

-- | AEAD encryption with the specified key and nonce.  The key must be given
-- as an initialized 'Crypto.Cipher.AES.AES128' or 'Crypto.Cipher.AES.AES256'
-- cipher.
--
-- Lengths of additional data and plaintext must be less than 2^32 bytes,
-- otherwise an exception is thrown.
encrypt :: (BlockCipher128 aes, ByteArrayAccess aad, ByteArray ba)
        => aes -> Nonce -> aad -> ba -> (AuthTag, ba)
encrypt :: forall aes aad ba.
(BlockCipher128 aes, ByteArrayAccess aad, ByteArray ba) =>
aes -> Nonce -> aad -> ba -> (AuthTag, ba)
encrypt aes
aes Nonce
iv aad
aad ba
plaintext
    | forall ba. ByteArrayAccess ba => ba -> Bool
lengthInvalid aad
aad = forall a. HasCallStack => String -> a
error String
"AESGCMSIV: aad is too large"
    | forall ba. ByteArrayAccess ba => ba -> Bool
lengthInvalid ba
plaintext = forall a. HasCallStack => String -> a
error String
"AESGCMSIV: plaintext is too large"
    | Bool
otherwise = (Bytes -> AuthTag
AuthTag Bytes
tag, ba
ciphertext)
  where
    (ScrubbedBytes
mak, AES
mek) = forall aes.
BlockCipher128 aes =>
aes -> Nonce -> (ScrubbedBytes, AES)
deriveKeys aes
aes Nonce
iv
    ss :: ScrubbedBytes
ss = forall aad ba.
(ByteArrayAccess aad, ByteArrayAccess ba) =>
ScrubbedBytes -> aad -> ba -> ScrubbedBytes
getSs ScrubbedBytes
mak aad
aad ba
plaintext
    tag :: Bytes
tag = forall aes.
BlockCipher128 aes =>
aes -> ScrubbedBytes -> Nonce -> Bytes
buildTag AES
mek ScrubbedBytes
ss Nonce
iv
    ciphertext :: ba
ciphertext = forall ba. ByteArray ba => AES -> IV AES -> ba -> ba
combineC32 AES
mek (Bytes -> IV AES
transformTag Bytes
tag) ba
plaintext

-- | AEAD decryption with the specified key and nonce.  The key must be given
-- as an initialized 'Crypto.Cipher.AES.AES128' or 'Crypto.Cipher.AES.AES256'
-- cipher.
--
-- Lengths of additional data and ciphertext must be less than 2^32 bytes,
-- otherwise an exception is thrown.
decrypt :: (BlockCipher128 aes, ByteArrayAccess aad, ByteArray ba)
        => aes -> Nonce -> aad -> ba -> AuthTag -> Maybe ba
decrypt :: forall aes aad ba.
(BlockCipher128 aes, ByteArrayAccess aad, ByteArray ba) =>
aes -> Nonce -> aad -> ba -> AuthTag -> Maybe ba
decrypt aes
aes Nonce
iv aad
aad ba
ciphertext (AuthTag Bytes
tag)
    | forall ba. ByteArrayAccess ba => ba -> Bool
lengthInvalid aad
aad = forall a. HasCallStack => String -> a
error String
"AESGCMSIV: aad is too large"
    | forall ba. ByteArrayAccess ba => ba -> Bool
lengthInvalid ba
ciphertext = forall a. HasCallStack => String -> a
error String
"AESGCMSIV: ciphertext is too large"
    | Bytes
tag forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` forall aes.
BlockCipher128 aes =>
aes -> ScrubbedBytes -> Nonce -> Bytes
buildTag AES
mek ScrubbedBytes
ss Nonce
iv = forall a. a -> Maybe a
Just ba
plaintext
    | Bool
otherwise = forall a. Maybe a
Nothing
  where
    (ScrubbedBytes
mak, AES
mek) = forall aes.
BlockCipher128 aes =>
aes -> Nonce -> (ScrubbedBytes, AES)
deriveKeys aes
aes Nonce
iv
    ss :: ScrubbedBytes
ss = forall aad ba.
(ByteArrayAccess aad, ByteArrayAccess ba) =>
ScrubbedBytes -> aad -> ba -> ScrubbedBytes
getSs ScrubbedBytes
mak aad
aad ba
plaintext
    plaintext :: ba
plaintext = forall ba. ByteArray ba => AES -> IV AES -> ba -> ba
combineC32 AES
mek (Bytes -> IV AES
transformTag Bytes
tag) ba
ciphertext

-- Calculate S_s = POLYVAL(mak, X_1, X_2, ...).
getSs :: (ByteArrayAccess aad, ByteArrayAccess ba)
      => ScrubbedBytes -> aad -> ba -> ScrubbedBytes
getSs :: forall aad ba.
(ByteArrayAccess aad, ByteArrayAccess ba) =>
ScrubbedBytes -> aad -> ba -> ScrubbedBytes
getSs ScrubbedBytes
mak aad
aad ba
plaintext = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
    Polyval
ctx <- ScrubbedBytes -> IO Polyval
polyvalInit ScrubbedBytes
mak
    forall ba. ByteArrayAccess ba => Polyval -> ba -> IO ()
polyvalUpdate Polyval
ctx aad
aad
    forall ba. ByteArrayAccess ba => Polyval -> ba -> IO ()
polyvalUpdate Polyval
ctx ba
plaintext
    forall ba. ByteArrayAccess ba => Polyval -> ba -> IO ()
polyvalUpdate Polyval
ctx (Bytes
lb :: Bytes)  -- the "length block"
    Polyval -> IO ScrubbedBytes
polyvalFinalize Polyval
ctx
  where
    lb :: Bytes
lb = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
16 forall a b. (a -> b) -> a -> b
$ \Ptr (LE Word64)
ptr -> do
            forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr (LE Word64)
ptr Int
0 (forall {a}. Integral a => a -> LE Word64
toLE64 forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length aad
aad)
            forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr (LE Word64)
ptr Int
1 (forall {a}. Integral a => a -> LE Word64
toLE64 forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
plaintext)
    toLE64 :: a -> LE Word64
toLE64 a
x = forall a. ByteSwap a => a -> LE a
toLE (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x forall a. Num a => a -> a -> a
* Word64
8 :: Word64)

-- XOR the first 12 bytes of S_s with the nonce and clear the most significant
-- bit of the last byte.
tagInput :: ScrubbedBytes -> Nonce -> Bytes
tagInput :: ScrubbedBytes -> Nonce -> Bytes
tagInput ScrubbedBytes
ss (Nonce Bytes
iv) =
    forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze ScrubbedBytes
ss forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
iv forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ivPtr -> do
        Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor Ptr Word8
ptr Ptr Word8
ptr Ptr Word8
ivPtr Int
12
        Word8
b <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr Word8
ptr Int
15
        forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr Word8
ptr Int
15 (Word8
b forall a. Bits a => a -> a -> a
.&. (Word8
0x7f :: Word8))

-- Encrypt the result with AES using the message-encryption key to produce the
-- tag.
buildTag :: BlockCipher128 aes => aes -> ScrubbedBytes -> Nonce -> Bytes
buildTag :: forall aes.
BlockCipher128 aes =>
aes -> ScrubbedBytes -> Nonce -> Bytes
buildTag aes
mek ScrubbedBytes
ss Nonce
iv = forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt aes
mek (ScrubbedBytes -> Nonce -> Bytes
tagInput ScrubbedBytes
ss Nonce
iv)

-- The initial counter block is the tag with the most significant bit of the
-- last byte set to one.
transformTag :: Bytes -> IV AES
transformTag :: Bytes -> IV AES
transformTag Bytes
tag = forall {c}. BlockCipher c => Bytes -> IV c
toIV forall a b. (a -> b) -> a -> b
$ forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze Bytes
tag forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
    forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr Word8
ptr Int
15 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr Word8
ptr Int
15 forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Bits a => a -> a -> a
.|. (Word8
0x80 :: Word8))
  where toIV :: Bytes -> IV c
toIV Bytes
bs = let Just IV c
iv = forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV (Bytes
bs :: Bytes) in IV c
iv