-- |
-- Module      : Crypto.Cipher.Salsa
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : stable
-- Portability : good
--
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.Cipher.Salsa
    ( initialize
    , combine
    , generate
    , State(..)
    ) where

import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import           Crypto.Internal.Compat
import           Crypto.Internal.Imports
import           Foreign.Ptr
import           Foreign.C.Types

-- | Salsa context
newtype State = State ScrubbedBytes
    deriving (State -> ()
(State -> ()) -> NFData State
forall a. (a -> ()) -> NFData a
$crnf :: State -> ()
rnf :: State -> ()
NFData)

-- | Initialize a new Salsa context with the number of rounds,
-- the key and the nonce associated.
initialize :: (ByteArrayAccess key, ByteArrayAccess nonce)
           => Int    -- ^ number of rounds (8,12,20)
           -> key    -- ^ the key (128 or 256 bits)
           -> nonce  -- ^ the nonce (64 or 96 bits)
           -> State  -- ^ the initial Salsa state
initialize :: forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
initialize Int
nbRounds key
key nonce
nonce
    | Int
kLen Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
16,Int
32]          = [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"Salsa: key length should be 128 or 256 bits"
    | Int
nonceLen Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8,Int
12]       = [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"Salsa: nonce length should be 64 or 96 bits"
    | Int
nbRounds Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8,Int
12,Int
20]    = [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"Salsa: rounds should be 8, 12 or 20"
    | Bool
otherwise = IO State -> State
forall a. IO a -> a
unsafeDoIO (IO State -> State) -> IO State -> State
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
stPtr <- Int -> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
132 ((Ptr State -> IO ()) -> IO ScrubbedBytes)
-> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr State
stPtr ->
            nonce -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. nonce -> (Ptr p -> IO a) -> IO a
B.withByteArray nonce
nonce ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
noncePtr  ->
            key -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. key -> (Ptr p -> IO a) -> IO a
B.withByteArray key
key   ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr ->
                Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccrypton_salsa_init Ptr State
stPtr Int
nbRounds Int
kLen Ptr Word8
keyPtr Int
nonceLen Ptr Word8
noncePtr
        State -> IO State
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> IO State) -> State -> IO State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
stPtr
  where kLen :: Int
kLen     = key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key
        nonceLen :: Int
nonceLen = nonce -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length nonce
nonce

-- | Combine the salsa output and an arbitrary message with a xor,
-- and return the combined output and the new state.
combine :: ByteArray ba
        => State      -- ^ the current Salsa state
        -> ba         -- ^ the source to xor with the generator
        -> (ba, State)
combine :: forall ba. ByteArray ba => State -> ba -> (ba, State)
combine prevSt :: State
prevSt@(State ScrubbedBytes
prevStMem) ba
src
    | ba -> Bool
forall a. ByteArrayAccess a => a -> Bool
B.null ba
src = (ba
forall a. ByteArray a => a
B.empty, State
prevSt)
    | Bool
otherwise  = IO (ba, State) -> (ba, State)
forall a. IO a -> a
unsafeDoIO (IO (ba, State) -> (ba, State)) -> IO (ba, State) -> (ba, State)
forall a b. (a -> b) -> a -> b
$ do
        (ba
out, ScrubbedBytes
st) <- ScrubbedBytes -> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall bs1 bs2 p a.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
B.copyRet ScrubbedBytes
prevStMem ((Ptr State -> IO ba) -> IO (ba, ScrubbedBytes))
-> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall a b. (a -> b) -> a -> b
$ \Ptr State
ctx ->
            Int -> (Ptr Word8 -> IO ()) -> IO ba
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src) ((Ptr Word8 -> IO ()) -> IO ba) -> (Ptr Word8 -> IO ()) -> IO ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
            ba -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src    ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
                Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
ccrypton_salsa_combine Ptr Word8
dstPtr Ptr State
ctx Ptr Word8
srcPtr (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CUInt) -> Int -> CUInt
forall a b. (a -> b) -> a -> b
$ ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src)
        (ba, State) -> IO (ba, State)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
out, ScrubbedBytes -> State
State ScrubbedBytes
st)

-- | Generate a number of bytes from the Salsa output directly
generate :: ByteArray ba
         => State -- ^ the current Salsa state
         -> Int   -- ^ the length of data to generate
         -> (ba, State)
generate :: forall ba. ByteArray ba => State -> Int -> (ba, State)
generate prevSt :: State
prevSt@(State ScrubbedBytes
prevStMem) Int
len
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0  = (ba
forall a. ByteArray a => a
B.empty, State
prevSt)
    | Bool
otherwise = IO (ba, State) -> (ba, State)
forall a. IO a -> a
unsafeDoIO (IO (ba, State) -> (ba, State)) -> IO (ba, State) -> (ba, State)
forall a b. (a -> b) -> a -> b
$ do
        (ba
out, ScrubbedBytes
st) <- ScrubbedBytes -> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall bs1 bs2 p a.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
B.copyRet ScrubbedBytes
prevStMem ((Ptr State -> IO ba) -> IO (ba, ScrubbedBytes))
-> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall a b. (a -> b) -> a -> b
$ \Ptr State
ctx ->
            Int -> (Ptr Word8 -> IO ()) -> IO ba
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
len ((Ptr Word8 -> IO ()) -> IO ba) -> (Ptr Word8 -> IO ()) -> IO ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
                Ptr Word8 -> Ptr State -> CUInt -> IO ()
ccrypton_salsa_generate Ptr Word8
dstPtr Ptr State
ctx (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
        (ba, State) -> IO (ba, State)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
out, ScrubbedBytes -> State
State ScrubbedBytes
st)

foreign import ccall "crypton_salsa_init"
    ccrypton_salsa_init :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()

foreign import ccall "crypton_salsa_combine"
    ccrypton_salsa_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()

foreign import ccall "crypton_salsa_generate"
    ccrypton_salsa_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO ()