{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.Cipher.ChaCha
( initialize
, initializeX
, combine
, generate
, State
, initializeSimple
, generateSimple
, StateSimple
) where
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Foreign.Ptr
import Foreign.C.Types
newtype State = State ScrubbedBytes
deriving (State -> ()
forall a. (a -> ()) -> NFData a
rnf :: State -> ()
$crnf :: State -> ()
NFData)
newtype StateSimple = StateSimple ScrubbedBytes
deriving (StateSimple -> ()
forall a. (a -> ()) -> NFData a
rnf :: StateSimple -> ()
$crnf :: StateSimple -> ()
NFData)
initialize :: (ByteArrayAccess key, ByteArrayAccess nonce)
=> Int
-> key
-> nonce
-> State
initialize :: forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
initialize Int
nbRounds key
key nonce
nonce
| Int
kLen forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
16,Int
32] = forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha: key length should be 128 or 256 bits"
| Int
nonceLen forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8,Int
12] = forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha: nonce length should be 64 or 96 bits"
| Int
nbRounds forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8,Int
12,Int
20] = forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha: rounds should be 8, 12 or 20"
| Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
ScrubbedBytes
stPtr <- forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
132 forall a b. (a -> b) -> a -> b
$ \Ptr State
stPtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray nonce
nonce forall a b. (a -> b) -> a -> b
$ \Ptr Word8
noncePtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray key
key forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr ->
Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccrypton_chacha_init Ptr State
stPtr Int
nbRounds Int
kLen Ptr Word8
keyPtr Int
nonceLen Ptr Word8
noncePtr
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
stPtr
where kLen :: Int
kLen = forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key
nonceLen :: Int
nonceLen = forall ba. ByteArrayAccess ba => ba -> Int
B.length nonce
nonce
initializeX :: (ByteArrayAccess key, ByteArrayAccess nonce)
=> Int
-> key
-> nonce
-> State
initializeX :: forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
initializeX Int
nbRounds key
key nonce
nonce
| Int
kLen forall a. Eq a => a -> a -> Bool
/= Int
32 = forall a. HasCallStack => [Char] -> a
error [Char]
"XChaCha: key length should be 256 bits"
| Int
nonceLen forall a. Eq a => a -> a -> Bool
/= Int
24 = forall a. HasCallStack => [Char] -> a
error [Char]
"XChaCha: nonce length should be 192 bits"
| Int
nbRounds forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8,Int
12,Int
20] = forall a. HasCallStack => [Char] -> a
error [Char]
"XChaCha: rounds should be 8, 12 or 20"
| Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
ScrubbedBytes
stPtr <- forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
132 forall a b. (a -> b) -> a -> b
$ \Ptr State
stPtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray nonce
nonce forall a b. (a -> b) -> a -> b
$ \Ptr Word8
noncePtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray key
key forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr ->
Ptr State -> Int -> Ptr Word8 -> Ptr Word8 -> IO ()
ccrypton_xchacha_init Ptr State
stPtr Int
nbRounds Ptr Word8
keyPtr Ptr Word8
noncePtr
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
stPtr
where kLen :: Int
kLen = forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key
nonceLen :: Int
nonceLen = forall ba. ByteArrayAccess ba => ba -> Int
B.length nonce
nonce
initializeSimple :: ByteArrayAccess seed
=> seed
-> StateSimple
initializeSimple :: forall seed. ByteArrayAccess seed => seed -> StateSimple
initializeSimple seed
seed
| Int
sLen forall a. Ord a => a -> a -> Bool
< Int
40 = forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha Random: seed length should be 40 bytes"
| Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
ScrubbedBytes
stPtr <- forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
64 forall a b. (a -> b) -> a -> b
$ \Ptr StateSimple
stPtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray seed
seed forall a b. (a -> b) -> a -> b
$ \Ptr Word8
seedPtr ->
Ptr StateSimple -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccrypton_chacha_init_core Ptr StateSimple
stPtr Int
32 Ptr Word8
seedPtr Int
8 (Ptr Word8
seedPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
32)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> StateSimple
StateSimple ScrubbedBytes
stPtr
where
sLen :: Int
sLen = forall ba. ByteArrayAccess ba => ba -> Int
B.length seed
seed
combine :: ByteArray ba
=> State
-> ba
-> (ba, State)
combine :: forall ba. ByteArray ba => State -> ba -> (ba, State)
combine prevSt :: State
prevSt@(State ScrubbedBytes
prevStMem) ba
src
| forall a. ByteArrayAccess a => a -> Bool
B.null ba
src = (forall a. ByteArray a => a
B.empty, State
prevSt)
| Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
(ba
out, ScrubbedBytes
st) <- forall bs1 bs2 p a.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
B.copyRet ScrubbedBytes
prevStMem forall a b. (a -> b) -> a -> b
$ \Ptr State
ctx ->
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc (forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src) forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr ->
Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
ccrypton_chacha_combine Ptr Word8
dstPtr Ptr State
ctx Ptr Word8
srcPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src)
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
out, ScrubbedBytes -> State
State ScrubbedBytes
st)
generate :: ByteArray ba
=> State
-> Int
-> (ba, State)
generate :: forall ba. ByteArray ba => State -> Int -> (ba, State)
generate prevSt :: State
prevSt@(State ScrubbedBytes
prevStMem) Int
len
| Int
len forall a. Ord a => a -> a -> Bool
<= Int
0 = (forall a. ByteArray a => a
B.empty, State
prevSt)
| Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
(ba
out, ScrubbedBytes
st) <- forall bs1 bs2 p a.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
B.copyRet ScrubbedBytes
prevStMem forall a b. (a -> b) -> a -> b
$ \Ptr State
ctx ->
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
len forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
Ptr Word8 -> Ptr State -> CUInt -> IO ()
ccrypton_chacha_generate Ptr Word8
dstPtr Ptr State
ctx (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
out, ScrubbedBytes -> State
State ScrubbedBytes
st)
generateSimple :: ByteArray ba
=> StateSimple
-> Int
-> (ba, StateSimple)
generateSimple :: forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
generateSimple (StateSimple ScrubbedBytes
prevSt) Int
nbBytes = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
ScrubbedBytes
newSt <- forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
prevSt (\Ptr Any
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ())
ba
output <- forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
nbBytes forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
newSt forall a b. (a -> b) -> a -> b
$ \Ptr StateSimple
stPtr ->
Int -> Ptr Word8 -> Ptr StateSimple -> CUInt -> IO ()
ccrypton_chacha_random Int
8 Ptr Word8
dstPtr Ptr StateSimple
stPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nbBytes)
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
output, ScrubbedBytes -> StateSimple
StateSimple ScrubbedBytes
newSt)
foreign import ccall "crypton_chacha_init_core"
ccrypton_chacha_init_core :: Ptr StateSimple -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
foreign import ccall "crypton_chacha_init"
ccrypton_chacha_init :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
foreign import ccall "crypton_xchacha_init"
ccrypton_xchacha_init :: Ptr State -> Int -> Ptr Word8 -> Ptr Word8 -> IO ()
foreign import ccall "crypton_chacha_combine"
ccrypton_chacha_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
foreign import ccall "crypton_chacha_generate"
ccrypton_chacha_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO ()
foreign import ccall "crypton_chacha_random"
ccrypton_chacha_random :: Int -> Ptr Word8 -> Ptr StateSimple -> CUInt -> IO ()