module Crypto.Cipher.Salsa
( initialize
, combine
, generate
, State
) where
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Foreign.Ptr
import Foreign.C.Types
newtype State = State ScrubbedBytes
deriving (NFData)
initialize :: (ByteArrayAccess key, ByteArrayAccess nonce)
=> Int
-> key
-> nonce
-> State
initialize nbRounds key nonce
| not (kLen `elem` [16,32]) = error "Salsa: key length should be 128 or 256 bits"
| not (nonceLen `elem` [8,12]) = error "Salsa: nonce length should be 64 or 96 bits"
| not (nbRounds `elem` [8,12,20]) = error "Salsa: rounds should be 8, 12 or 20"
| otherwise = unsafeDoIO $ do
stPtr <- B.alloc 132 $ \stPtr ->
B.withByteArray nonce $ \noncePtr ->
B.withByteArray key $ \keyPtr ->
ccryptonite_salsa_init stPtr (fromIntegral nbRounds) kLen keyPtr nonceLen noncePtr
return $ State stPtr
where kLen = B.length key
nonceLen = B.length nonce
combine :: ByteArray ba
=> State
-> ba
-> (ba, State)
combine prevSt@(State prevStMem) src
| B.null src = (B.empty, prevSt)
| otherwise = unsafeDoIO $ do
(out, st) <- B.copyRet prevStMem $ \ctx ->
B.alloc (B.length src) $ \dstPtr ->
B.withByteArray src $ \srcPtr -> do
ccryptonite_salsa_combine dstPtr ctx srcPtr (fromIntegral $ B.length src)
return (out, State st)
generate :: ByteArray ba
=> State
-> Int
-> (ba, State)
generate prevSt@(State prevStMem) len
| len <= 0 = (B.empty, prevSt)
| otherwise = unsafeDoIO $ do
(out, st) <- B.copyRet prevStMem $ \ctx ->
B.alloc len $ \dstPtr ->
ccryptonite_salsa_generate dstPtr ctx (fromIntegral len)
return (out, State st)
foreign import ccall "cryptonite_salsa_init"
ccryptonite_salsa_init :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
foreign import ccall "cryptonite_salsa_combine"
ccryptonite_salsa_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
foreign import ccall "cryptonite_salsa_generate"
ccryptonite_salsa_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO ()