-- |
-- Module      : Crypto.Cipher.ChaChaPoly1305
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : stable
-- Portability : good
--
-- A simple AEAD scheme using ChaCha20 and Poly1305. See
-- <https://tools.ietf.org/html/rfc7539 RFC 7539>.
--
-- The State is not modified in place, so each function changing the State,
-- returns a new State.
--
-- Authenticated Data need to be added before any call to 'encrypt' or 'decrypt',
-- and once all the data has been added, then 'finalizeAAD' need to be called.
--
-- Once 'finalizeAAD' has been called, no further 'appendAAD' call should be make.
--
-- >import Data.ByteString.Char8 as B
-- >import Data.ByteArray
-- >import Crypto.Error
-- >import Crypto.Cipher.ChaChaPoly1305 as C
-- >
-- >encrypt
-- >    :: ByteString -- nonce (12 random bytes)
-- >    -> ByteString -- symmetric key
-- >    -> ByteString -- optional associated data (won't be encrypted)
-- >    -> ByteString -- input plaintext to be encrypted
-- >    -> CryptoFailable ByteString -- ciphertext with a 128-bit tag attached
-- >encrypt nonce key header plaintext = do
-- >    st1 <- C.nonce12 nonce >>= C.initialize key
-- >    let
-- >        st2 = C.finalizeAAD $ C.appendAAD header st1
-- >        (out, st3) = C.encrypt plaintext st2
-- >        auth = C.finalize st3
-- >    return $ out `B.append` Data.ByteArray.convert auth
--
module Crypto.Cipher.ChaChaPoly1305
    ( State
    , Nonce
    , nonce12
    , nonce8
    , incrementNonce
    , initialize
    , appendAAD
    , finalizeAAD
    , encrypt
    , decrypt
    , finalize
    ) where

import           Control.Monad             (when)
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import           Crypto.Internal.Imports
import           Crypto.Error
import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Crypto.MAC.Poly1305  as Poly1305
import           Data.Memory.Endian
import qualified Data.ByteArray.Pack as P
import           Foreign.Ptr
import           Foreign.Storable

-- | A ChaChaPoly1305 State.
--
-- The state is immutable, and only new state can be created
data State = State !ChaCha.State
                   !Poly1305.State
                   !Word64 -- AAD length
                   !Word64 -- ciphertext length

-- | Valid Nonce for ChaChaPoly1305.
--
-- It can be created with 'nonce8' or 'nonce12'
data Nonce = Nonce8 Bytes | Nonce12 Bytes

instance ByteArrayAccess Nonce where
  length :: Nonce -> Int
length (Nonce8  Bytes
n) = forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n
  length (Nonce12 Bytes
n) = forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n

  withByteArray :: forall p a. Nonce -> (Ptr p -> IO a) -> IO a
withByteArray (Nonce8  Bytes
n) = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n
  withByteArray (Nonce12 Bytes
n) = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n

-- Based on the following pseudo code:
--
-- chacha20_aead_encrypt(aad, key, iv, constant, plaintext):
--     nonce = constant | iv
--     otk = poly1305_key_gen(key, nonce)
--     ciphertext = chacha20_encrypt(key, 1, nonce, plaintext)
--     mac_data = aad | pad16(aad)
--     mac_data |= ciphertext | pad16(ciphertext)
--     mac_data |= num_to_4_le_bytes(aad.length)
--     mac_data |= num_to_4_le_bytes(ciphertext.length)
--     tag = poly1305_mac(mac_data, otk)
--     return (ciphertext, tag)

pad16 :: Word64 -> Bytes
pad16 :: Word64 -> Bytes
pad16 Word64
n
    | Int
modLen forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. ByteArray a => a
B.empty
    | Bool
otherwise   = forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate (Int
16 forall a. Num a => a -> a -> a
- Int
modLen) Word8
0
  where
    modLen :: Int
modLen = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
n forall a. Integral a => a -> a -> a
`mod` Word64
16)

-- | Nonce smart constructor 12 bytes IV, nonce constructor
nonce12 :: ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce12 :: forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce12 iv
iv
    | forall ba. ByteArrayAccess ba => ba -> Int
B.length iv
iv forall a. Eq a => a -> a -> Bool
/= Int
12 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
    | Bool
otherwise         = forall a. a -> CryptoFailable a
CryptoPassed forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce12 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert forall a b. (a -> b) -> a -> b
$ iv
iv

-- | 8 bytes IV, nonce constructor
nonce8 :: ByteArrayAccess ba
       => ba -- ^ 4 bytes constant
       -> ba -- ^ 8 bytes IV
       -> CryptoFailable Nonce
nonce8 :: forall ba. ByteArrayAccess ba => ba -> ba -> CryptoFailable Nonce
nonce8 ba
constant ba
iv
    | forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
constant forall a. Eq a => a -> a -> Bool
/= Int
4 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
    | forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
iv       forall a. Eq a => a -> a -> Bool
/= Int
8 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
    | Bool
otherwise              = forall a. a -> CryptoFailable a
CryptoPassed forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat forall a b. (a -> b) -> a -> b
$ [ba
constant, ba
iv]

-- | Increment a nonce
incrementNonce :: Nonce -> Nonce
incrementNonce :: Nonce -> Nonce
incrementNonce (Nonce8  Bytes
n) = Bytes -> Nonce
Nonce8  forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
4
incrementNonce (Nonce12 Bytes
n) = Bytes -> Nonce
Nonce12 forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
0

incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' Bytes
b Int
offset = forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze Bytes
b forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s ->
    Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
offset)
    where
      loop :: Ptr Word8 -> Ptr Word8 -> IO ()
      loop :: Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s Ptr Word8
p
          | Ptr Word8
s forall a. Eq a => a -> a -> Bool
== (Ptr Word8
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
b forall a. Num a => a -> a -> a
- Int
offset forall a. Num a => a -> a -> a
- Int
1)) = forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
s forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(+) Word8
1
          | Bool
otherwise = do
              Word8
r <- forall a. Num a => a -> a -> a
(+) Word8
1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
p
              forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p Word8
r
              forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
r forall a. Eq a => a -> a -> Bool
== Word8
0) forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)

-- | Initialize a new ChaChaPoly1305 State
--
-- The key length need to be 256 bits, and the nonce
-- procured using either `nonce8` or `nonce12`
initialize :: ByteArrayAccess key
           => key -> Nonce -> CryptoFailable State
initialize :: forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
initialize key
key (Nonce8  Bytes
nonce) = forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
initialize key
key (Nonce12 Bytes
nonce) = forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce

initialize' :: ByteArrayAccess key
            => key -> Bytes -> CryptoFailable State
initialize' :: forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
    | forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key forall a. Eq a => a -> a -> Bool
/= Int
32 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_KeySizeInvalid
    | Bool
otherwise          = forall a. a -> CryptoFailable a
CryptoPassed forall a b. (a -> b) -> a -> b
$ State -> State -> Word64 -> Word64 -> State
State State
encState State
polyState Word64
0 Word64
0
  where
    rootState :: State
rootState           = forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
ChaCha.initialize Int
20 key
key Bytes
nonce
    (ScrubbedBytes
polyKey, State
encState) = forall ba. ByteArray ba => State -> Int -> (ba, State)
ChaCha.generate State
rootState Int
64
    polyState :: State
polyState           = forall a. CryptoFailable a -> a
throwCryptoError forall a b. (a -> b) -> a -> b
$ forall key. ByteArrayAccess key => key -> CryptoFailable State
Poly1305.initialize (forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
32 ScrubbedBytes
polyKey :: ScrubbedBytes)

-- | Append Authenticated Data to the State and return
-- the new modified State.
--
-- Once no further call to this function need to be make,
-- the user should call 'finalizeAAD'
appendAAD :: ByteArrayAccess ba => ba -> State -> State
appendAAD :: forall ba. ByteArrayAccess ba => ba -> State -> State
appendAAD ba
ba (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
newLength Word64
plainLength
  where
    newMacState :: State
newMacState = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
ba
    newLength :: Word64
newLength   = Word64
aadLength forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba)

-- | Finalize the Authenticated Data and return the finalized State
finalizeAAD :: State -> State
finalizeAAD :: State -> State
finalizeAAD (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
aadLength Word64
plainLength
  where
    newMacState :: State
newMacState = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState forall a b. (a -> b) -> a -> b
$ Word64 -> Bytes
pad16 Word64
aadLength

-- | Encrypt a piece of data and returns the encrypted Data and the
-- updated State.
encrypt :: ByteArray ba => ba -> State -> (ba, State)
encrypt :: forall ba. ByteArray ba => ba -> State -> (ba, State)
encrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    (ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
  where
    (ba
output, State
newEncState) = forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
    newMacState :: State
newMacState           = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
output
    newPlainLength :: Word64
newPlainLength        = Word64
plainLength forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)

-- | Decrypt a piece of data and returns the decrypted Data and the
-- updated State.
decrypt :: ByteArray ba => ba -> State -> (ba, State)
decrypt :: forall ba. ByteArray ba => ba -> State -> (ba, State)
decrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    (ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
  where
    (ba
output, State
newEncState) = forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
    newMacState :: State
newMacState           = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
input
    newPlainLength :: Word64
newPlainLength        = Word64
plainLength forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)

-- | Generate an authentication tag from the State.
finalize :: State -> Poly1305.Auth
finalize :: State -> Auth
finalize (State State
_ State
macState Word64
aadLength Word64
plainLength) =
    State -> Auth
Poly1305.finalize forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => State -> [ba] -> State
Poly1305.updates State
macState
        [ Word64 -> Bytes
pad16 Word64
plainLength
        , forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a. HasCallStack => [Char] -> a
error [Char]
"finalize: internal error") forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ forall byteArray a.
ByteArray byteArray =>
Int -> Packer a -> Either [Char] byteArray
P.fill Int
16 (forall storable. Storable storable => storable -> Packer ()
P.putStorable (forall a. ByteSwap a => a -> LE a
toLE Word64
aadLength) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall storable. Storable storable => storable -> Packer ()
P.putStorable (forall a. ByteSwap a => a -> LE a
toLE Word64
plainLength))
        ]