module Crypto.Cipher.ChaChaPoly1305
( State
, Nonce
, nonce12
, nonce8
, incrementNonce
, initialize
, appendAAD
, finalizeAAD
, encrypt
, decrypt
, finalize
) where
import Control.Monad (when)
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Imports
import Crypto.Error
import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Crypto.MAC.Poly1305 as Poly1305
import Data.Memory.Endian
import qualified Data.ByteArray.Pack as P
import Foreign.Ptr
import Foreign.Storable
data State = State !ChaCha.State
!Poly1305.State
!Word64
!Word64
data Nonce = Nonce8 Bytes | Nonce12 Bytes
instance ByteArrayAccess Nonce where
length :: Nonce -> Int
length (Nonce8 Bytes
n) = forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n
length (Nonce12 Bytes
n) = forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n
withByteArray :: forall p a. Nonce -> (Ptr p -> IO a) -> IO a
withByteArray (Nonce8 Bytes
n) = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n
withByteArray (Nonce12 Bytes
n) = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n
pad16 :: Word64 -> Bytes
pad16 :: Word64 -> Bytes
pad16 Word64
n
| Int
modLen forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. ByteArray a => a
B.empty
| Bool
otherwise = forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate (Int
16 forall a. Num a => a -> a -> a
- Int
modLen) Word8
0
where
modLen :: Int
modLen = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
n forall a. Integral a => a -> a -> a
`mod` Word64
16)
nonce12 :: ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce12 :: forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce12 iv
iv
| forall ba. ByteArrayAccess ba => ba -> Int
B.length iv
iv forall a. Eq a => a -> a -> Bool
/= Int
12 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
| Bool
otherwise = forall a. a -> CryptoFailable a
CryptoPassed forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce12 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert forall a b. (a -> b) -> a -> b
$ iv
iv
nonce8 :: ByteArrayAccess ba
=> ba
-> ba
-> CryptoFailable Nonce
nonce8 :: forall ba. ByteArrayAccess ba => ba -> ba -> CryptoFailable Nonce
nonce8 ba
constant ba
iv
| forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
constant forall a. Eq a => a -> a -> Bool
/= Int
4 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
| forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
iv forall a. Eq a => a -> a -> Bool
/= Int
8 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
| Bool
otherwise = forall a. a -> CryptoFailable a
CryptoPassed forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat forall a b. (a -> b) -> a -> b
$ [ba
constant, ba
iv]
incrementNonce :: Nonce -> Nonce
incrementNonce :: Nonce -> Nonce
incrementNonce (Nonce8 Bytes
n) = Bytes -> Nonce
Nonce8 forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
4
incrementNonce (Nonce12 Bytes
n) = Bytes -> Nonce
Nonce12 forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
0
incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' Bytes
b Int
offset = forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze Bytes
b forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s ->
Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
offset)
where
loop :: Ptr Word8 -> Ptr Word8 -> IO ()
loop :: Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s Ptr Word8
p
| Ptr Word8
s forall a. Eq a => a -> a -> Bool
== (Ptr Word8
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
b forall a. Num a => a -> a -> a
- Int
offset forall a. Num a => a -> a -> a
- Int
1)) = forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
s forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(+) Word8
1
| Bool
otherwise = do
Word8
r <- forall a. Num a => a -> a -> a
(+) Word8
1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
p
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p Word8
r
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
r forall a. Eq a => a -> a -> Bool
== Word8
0) forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)
initialize :: ByteArrayAccess key
=> key -> Nonce -> CryptoFailable State
initialize :: forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
initialize key
key (Nonce8 Bytes
nonce) = forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
initialize key
key (Nonce12 Bytes
nonce) = forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
initialize' :: ByteArrayAccess key
=> key -> Bytes -> CryptoFailable State
initialize' :: forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
| forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key forall a. Eq a => a -> a -> Bool
/= Int
32 = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_KeySizeInvalid
| Bool
otherwise = forall a. a -> CryptoFailable a
CryptoPassed forall a b. (a -> b) -> a -> b
$ State -> State -> Word64 -> Word64 -> State
State State
encState State
polyState Word64
0 Word64
0
where
rootState :: State
rootState = forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
ChaCha.initialize Int
20 key
key Bytes
nonce
(ScrubbedBytes
polyKey, State
encState) = forall ba. ByteArray ba => State -> Int -> (ba, State)
ChaCha.generate State
rootState Int
64
polyState :: State
polyState = forall a. CryptoFailable a -> a
throwCryptoError forall a b. (a -> b) -> a -> b
$ forall key. ByteArrayAccess key => key -> CryptoFailable State
Poly1305.initialize (forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
32 ScrubbedBytes
polyKey :: ScrubbedBytes)
appendAAD :: ByteArrayAccess ba => ba -> State -> State
appendAAD :: forall ba. ByteArrayAccess ba => ba -> State -> State
appendAAD ba
ba (State State
encState State
macState Word64
aadLength Word64
plainLength) =
State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
newLength Word64
plainLength
where
newMacState :: State
newMacState = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
ba
newLength :: Word64
newLength = Word64
aadLength forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba)
finalizeAAD :: State -> State
finalizeAAD :: State -> State
finalizeAAD (State State
encState State
macState Word64
aadLength Word64
plainLength) =
State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
aadLength Word64
plainLength
where
newMacState :: State
newMacState = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState forall a b. (a -> b) -> a -> b
$ Word64 -> Bytes
pad16 Word64
aadLength
encrypt :: ByteArray ba => ba -> State -> (ba, State)
encrypt :: forall ba. ByteArray ba => ba -> State -> (ba, State)
encrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
(ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
where
(ba
output, State
newEncState) = forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
newMacState :: State
newMacState = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
output
newPlainLength :: Word64
newPlainLength = Word64
plainLength forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)
decrypt :: ByteArray ba => ba -> State -> (ba, State)
decrypt :: forall ba. ByteArray ba => ba -> State -> (ba, State)
decrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
(ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
where
(ba
output, State
newEncState) = forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
newMacState :: State
newMacState = forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
input
newPlainLength :: Word64
newPlainLength = Word64
plainLength forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)
finalize :: State -> Poly1305.Auth
finalize :: State -> Auth
finalize (State State
_ State
macState Word64
aadLength Word64
plainLength) =
State -> Auth
Poly1305.finalize forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => State -> [ba] -> State
Poly1305.updates State
macState
[ Word64 -> Bytes
pad16 Word64
plainLength
, forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a. HasCallStack => [Char] -> a
error [Char]
"finalize: internal error") forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ forall byteArray a.
ByteArray byteArray =>
Int -> Packer a -> Either [Char] byteArray
P.fill Int
16 (forall storable. Storable storable => storable -> Packer ()
P.putStorable (forall a. ByteSwap a => a -> LE a
toLE Word64
aadLength) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall storable. Storable storable => storable -> Packer ()
P.putStorable (forall a. ByteSwap a => a -> LE a
toLE Word64
plainLength))
]