-- |
-- Module      : Crypto.KDF.Scrypt
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Scrypt key derivation function as defined in Colin Percival's paper
-- "Stronger Key Derivation via Sequential Memory-Hard Functions"
-- <http://www.tarsnap.com/scrypt/scrypt.pdf>.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Crypto.KDF.Scrypt
    ( Parameters(..)
    , generate
    ) where

import           Data.Word
import           Foreign.Marshal.Alloc
import           Foreign.Ptr (Ptr, plusPtr)
import           Control.Monad (forM_)

import           Crypto.Hash (SHA256(..))
import qualified Crypto.KDF.PBKDF2 as PBKDF2
import           Crypto.Internal.Compat (popCount, unsafeDoIO)
import           Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B

-- | Parameters for Scrypt
data Parameters = Parameters
    { Parameters -> Word64
n            :: Word64 -- ^ Cpu/Memory cost ratio. must be a power of 2 greater than 1. also known as N.
    , Parameters -> Int
r            :: Int    -- ^ Must satisfy r * p < 2^30
    , Parameters -> Int
p            :: Int    -- ^ Must satisfy r * p < 2^30
    , Parameters -> Int
outputLength :: Int    -- ^ the number of bytes to generate out of Scrypt
    }

foreign import ccall "crypton_scrypt_smix"
    ccrypton_scrypt_smix :: Ptr Word8 -> Word32 -> Word64 -> Ptr Word8 -> Ptr Word8 -> IO ()

-- | Generate the scrypt key derivation data
generate :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray output)
         => Parameters
         -> password
         -> salt
         -> output
generate :: forall password salt output.
(ByteArrayAccess password, ByteArrayAccess salt,
 ByteArray output) =>
Parameters -> password -> salt -> output
generate Parameters
params password
password salt
salt
    | Parameters -> Int
r Parameters
params forall a. Num a => a -> a -> a
* Parameters -> Int
p Parameters
params forall a. Ord a => a -> a -> Bool
>= Int
0x40000000 =
        forall a. HasCallStack => [Char] -> a
error [Char]
"Scrypt: invalid parameters: r and p constraint"
    | forall a. Bits a => a -> Int
popCount (Parameters -> Word64
n Parameters
params) forall a. Eq a => a -> a -> Bool
/= Int
1 =
        forall a. HasCallStack => [Char] -> a
error [Char]
"Scrypt: invalid parameters: n not a power of 2"
    | Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
        let b :: Bytes
b = forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
PBKDF2.generate PRF password
prf (Int -> Int -> Parameters
PBKDF2.Parameters Int
1 Int
intLen) password
password salt
salt :: B.Bytes
        Bytes
newSalt <- forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy Bytes
b forall a b. (a -> b) -> a -> b
$ \Ptr Any
bPtr ->
            forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Int
128forall a. Num a => a -> a -> a
*(forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Word64
n Parameters
params)forall a. Num a => a -> a -> a
*(Parameters -> Int
r Parameters
params)) Int
8 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
v ->
            forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Int
256forall a. Num a => a -> a -> a
*Parameters -> Int
r Parameters
params forall a. Num a => a -> a -> a
+ Int
64) Int
8 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
xy -> do
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..(Parameters -> Int
p Parameters
paramsforall a. Num a => a -> a -> a
-Int
1)] forall a b. (a -> b) -> a -> b
$ \Int
i ->
                    Ptr Word8 -> Word32 -> Word64 -> Ptr Word8 -> Ptr Word8 -> IO ()
ccrypton_scrypt_smix (Ptr Any
bPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i forall a. Num a => a -> a -> a
* Int
128 forall a. Num a => a -> a -> a
* (Parameters -> Int
r Parameters
params)))
                                            (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Int
r Parameters
params) (Parameters -> Word64
n Parameters
params) Ptr Word8
v Ptr Word8
xy

        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
PBKDF2.generate PRF password
prf (Int -> Int -> Parameters
PBKDF2.Parameters Int
1 (Parameters -> Int
outputLength Parameters
params)) password
password (Bytes
newSalt :: B.Bytes)
  where prf :: PRF password
prf    = forall a password.
(HashAlgorithm a, ByteArrayAccess password) =>
a -> PRF password
PBKDF2.prfHMAC SHA256
SHA256
        intLen :: Int
intLen = Parameters -> Int
p Parameters
params forall a. Num a => a -> a -> a
* Int
128 forall a. Num a => a -> a -> a
* Parameters -> Int
r Parameters
params
{-# NOINLINE generate #-}