{-|
Module      : Z.Crypto.KDF
Description : Key Derivation Functions
Copyright   : Dong Han, 2021
              AnJie Dong, 2021
License     : BSD
Maintainer  : winterland1989@gmail.com
Stability   : experimental
Portability : non-portable

KDF(Key Derivation Function) and PBKDF(Password Based Key Derivation Function).

-}
module Z.Crypto.KDF (
  -- * KDF
    KDFType(..)
  , BlockCipherType (..)
  , HashType(..)
  , MACType(..)
  , kdf
  , kdf'
  -- * PBKDF
  , PBKDFType(..)
  , pbkdf
  , pbkdfTimed
  -- * Internal helps
  , kdfTypeToCBytes
  , pbkdfTypeToParam
  ) where

import           Z.Botan.Exception
import           Z.Botan.FFI
import           Z.Crypto.Cipher   (BlockCipherType (..))
import           Z.Crypto.Hash     (HashType (..), hashTypeToCBytes)
import           Z.Crypto.MAC      (MACType (..), macTypeToCBytes)
import           Z.Data.CBytes     (CBytes, withCBytes, withCBytesUnsafe)
import qualified Z.Data.CBytes     as CB
import qualified Z.Data.Vector     as V
import           Z.Foreign

-----------------------------
-- Key Derivation Function --
-----------------------------

-- | Key derivation functions are used to turn some amount of shared secret material into uniform random keys
-- suitable for use with symmetric algorithms. An example of an input which is useful for a KDF is a shared
-- secret created using Diffie-Hellman key agreement.
data KDFType
    = HKDF MACType
    | HKDF_Extract MACType
    | HKDF_Expand MACType
    -- ^ Defined in RFC 5869, HKDF uses HMAC to process inputs.
    -- Also available are variants HKDF-Extract and HKDF-Expand.
    -- HKDF is the combined Extract+Expand operation.
    -- Use the combined HKDF unless you need compatibility with some other system.
    | KDF2 HashType
    -- ^ KDF2 comes from IEEE 1363. It uses a hash function.
    | KDF1_18033 HashType
    -- ^ KDF1 from ISO 18033-2. Very similar to (but incompatible with) KDF2.
    | KDF1 HashType
    -- ^ KDF1 from IEEE 1363. It can only produce an output at most the length of the hash function used.
    | TLS_PRF
    -- ^ A KDF from ANSI X9.42. Sometimes used for Diffie-Hellman.
    | TLS_12_PRF MACType
    | SP800_108_Counter MACType
    -- ^ KDFs from NIST SP 800-108. Variants include “SP800-108-Counter”, “SP800-108-Feedback” and “SP800-108-Pipeline”.
    | SP800_108_Feedback MACType
    | SP800_108_Pipeline MACType
    | SP800_56AHash HashType
    -- ^ NIST SP 800-56A KDF using hash function
    | SP800_56AMAC MACType
    -- ^ NIST SP 800-56A KDF using HMAC
    | SP800_56C MACType
    -- ^ NIST SP 800-56C KDF using HMAC

kdfTypeToCBytes :: KDFType -> CBytes
kdfTypeToCBytes :: KDFType -> CBytes
kdfTypeToCBytes (HKDF MACType
mt        ) = [CBytes] -> CBytes
CB.concat [ CBytes
"HKDF(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (HKDF_Extract MACType
mt) = [CBytes] -> CBytes
CB.concat [ CBytes
"HKDF-Extract(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (HKDF_Expand MACType
mt ) = [CBytes] -> CBytes
CB.concat [ CBytes
"HKDF-Expand(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (KDF2 HashType
ht        ) = [CBytes] -> CBytes
CB.concat [ CBytes
"KDF2(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (KDF1_18033 HashType
ht  ) = [CBytes] -> CBytes
CB.concat [ CBytes
"KDF1-18033(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (KDF1 HashType
ht        ) = [CBytes] -> CBytes
CB.concat [ CBytes
"KDF1(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (KDFType
TLS_PRF        ) = CBytes
"TLS-PRF"
kdfTypeToCBytes (TLS_12_PRF MACType
mt  ) = [CBytes] -> CBytes
CB.concat [ CBytes
"TLS-12-PRF(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_108_Counter MACType
mt ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-108-Counter(" ,  MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_108_Feedback MACType
mt) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-108-Feedback(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_108_Pipeline MACType
mt) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-108-Pipeline(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_56AHash HashType
ht     ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-56A(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (SP800_56AMAC MACType
mt      ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-56A(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_56C MACType
mt         ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-56C(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]

-- | Derive a key using the given KDF algorithm.
kdf
  :: KDFType    -- ^ the name of the given PBKDF algorithm
  -> Int        -- ^ length of output key
  -> V.Bytes    -- ^ secret
  -> V.Bytes    -- ^ salt
  -> V.Bytes    -- ^ label
  -> IO V.Bytes
kdf :: KDFType -> Int -> Bytes -> Bytes -> Bytes -> IO Bytes
kdf KDFType
algo Int
siz Bytes
secret Bytes
salt Bytes
label =
    CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe (KDFType -> CBytes
kdfTypeToCBytes KDFType
algo) ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
algoBA ->
        Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
secret ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
secretBA Int
secretOff Int
secretLen ->
            Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
salt ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
saltBA Int
saltOff Int
saltLen ->
                Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
label ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
labelBA Int
labelOff Int
labelLen ->
                    (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (MBA# Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (MBA# Word8 -> IO b) -> IO (PrimVector a, b)
allocPrimVectorUnsafe Int
siz (\ MBA# Word8
buf -> do
                        -- some kdf needs xor output buffer, so we clear it first
                        MBA# Word8 -> Int -> IO ()
forall k (a :: k). MBA# Word8 -> Int -> IO ()
clearMBA MBA# Word8
buf Int
siz
                        IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                            BA# Word8
-> MBA# Word8
-> Int
-> BA# Word8
-> Int
-> Int
-> BA# Word8
-> Int
-> Int
-> BA# Word8
-> Int
-> Int
-> IO CInt
hs_botan_kdf BA# Word8
algoBA MBA# Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                                BA# Word8
secretBA Int
secretOff Int
secretLen
                                BA# Word8
saltBA Int
saltOff Int
saltLen
                                BA# Word8
labelBA Int
labelOff Int
labelLen)

-- | Derive a key using the given KDF algorithm, with default empty salt and label.
kdf'
  :: KDFType    -- ^ the name of the given PBKDF algorithm
  -> Int        -- ^ length of output key
  -> V.Bytes    -- ^ secret
  -> IO V.Bytes
kdf' :: KDFType -> Int -> Bytes -> IO Bytes
kdf' KDFType
algo Int
siz Bytes
secret = KDFType -> Int -> Bytes -> Bytes -> Bytes -> IO Bytes
kdf KDFType
algo Int
siz Bytes
secret Bytes
forall a. Monoid a => a
mempty Bytes
forall a. Monoid a => a
mempty

--------------------------------------------
-- Password-Based Key Derivation Function --
--------------------------------------------

-- | Often one needs to convert a human readable password into a cryptographic key. It is useful to slow down the
-- computation of these computations in order to reduce the speed of brute force search, thus they are parameterized
-- in some way which allows their required computation to be tuned.
data PBKDFType
    = PBKDF2 MACType Int   -- ^ iterations
    -- ^ PBKDF2 is the “standard” password derivation scheme,
    -- widely implemented in many different libraries.
    | Scrypt  Int Int Int   -- ^ N, r, p
    -- ^ Scrypt is a relatively newer design which is “memory hard”,
    -- in addition to requiring large amounts of CPU power it uses a large block of memory to compute the hash.
    -- This makes brute force attacks using ASICs substantially more expensive.
    | Argon2d Int Int Int   -- ^ iterations, memory, parallelism
    -- ^ Argon2 is the winner of the PHC (Password Hashing Competition) and provides a tunable memory hard PBKDF.
    | Argon2i Int Int Int   -- ^ iterations, memory, parallelism
    | Argon2id Int Int Int  -- ^ iterations, memory, parallelism
    | Bcrypt Int            -- ^ iterations
    | OpenPGP_S2K HashType Int -- ^ iterations
    -- ^ The OpenPGP algorithm is weak and strange, and should be avoided unless implementing OpenPGP.

pbkdfTypeToParam :: PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam :: PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam (PBKDF2 MACType
mt Int
i     ) = ([CBytes] -> CBytes
CB.concat [ CBytes
"PBKDF2(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"], Int
i, Int
0, Int
0)
pbkdfTypeToParam (Scrypt Int
n Int
r Int
p    ) = (CBytes
"Scrypt", Int
n, Int
r, Int
p)
pbkdfTypeToParam (Argon2d Int
i Int
m Int
p   ) = (CBytes
"Argon2d", Int
i, Int
m, Int
p)
pbkdfTypeToParam (Argon2i Int
i Int
m Int
p   ) = (CBytes
"Argon2i", Int
i, Int
m, Int
p)
pbkdfTypeToParam (Argon2id Int
i Int
m Int
p  ) = (CBytes
"Argon2id", Int
i, Int
m, Int
p)
pbkdfTypeToParam (Bcrypt Int
i        ) = (CBytes
"Bcrypt-PBKDF", Int
i, Int
0, Int
0)
pbkdfTypeToParam (OpenPGP_S2K HashType
ht Int
i) = ([CBytes] -> CBytes
CB.concat [ CBytes
"OpenPGP-S2K(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"], Int
i, Int
0, Int
0)

-- | Derive a key from a passphrase for a number of iterations using the given PBKDF algorithm and params.
pbkdf :: PBKDFType  -- ^ PBKDF algorithm type
      -> Int        -- ^ length of output key
      -> CBytes     -- ^ passphrase
      -> V.Bytes    -- ^ salt
      -> IO V.Bytes
pbkdf :: PBKDFType -> Int -> CBytes -> Bytes -> IO Bytes
pbkdf PBKDFType
typ Int
siz CBytes
pwd Bytes
salt = do
    CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
algo ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
algoBA ->
        CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
pwd ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
pwdBA ->
            Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
salt ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
saltBA Int
saltOff Int
saltLen -> do
                (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (MBA# Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (MBA# Word8 -> IO b) -> IO (PrimVector a, b)
allocPrimVectorUnsafe Int
siz (\ MBA# Word8
buf -> do
                    MBA# Word8 -> Int -> IO ()
forall k (a :: k). MBA# Word8 -> Int -> IO ()
clearMBA MBA# Word8
buf Int
siz
                    IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                        BA# Word8
-> Int
-> Int
-> Int
-> MBA# Word8
-> Int
-> BA# Word8
-> Int
-> BA# Word8
-> Int
-> Int
-> IO CInt
hs_botan_pwdhash BA# Word8
algoBA
                            Int
i1 Int
i2 Int
i3
                            MBA# Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                            BA# Word8
pwdBA (CBytes -> Int
CB.length CBytes
pwd)
                            BA# Word8
saltBA Int
saltOff Int
saltLen)
  where
    (CBytes
algo, Int
i1, Int
i2, Int
i3) = PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam PBKDFType
typ

-- | Derive a key from a passphrase using the given PBKDF algorithm, the iteration params are
-- ignored and PBKDF is run until given milliseconds have passed.
pbkdfTimed :: PBKDFType  -- ^ the name of the given PBKDF algorithm
           -> Int        -- ^ run until milliseconds have passwd
           -> Int        -- ^ length of output key
           -> CBytes     -- ^ passphrase
           -> V.Bytes    -- ^ salt
           -> IO V.Bytes
pbkdfTimed :: PBKDFType -> Int -> Int -> CBytes -> Bytes -> IO Bytes
pbkdfTimed PBKDFType
typ Int
msec Int
siz CBytes
pwd Bytes
s = do
    -- we want run it in new OS thread without stop GC from running
    -- if the expected time is too long(>0.1s)
    if Int
msec Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
100
    then CBytes -> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (Ptr Word8 -> IO a) -> IO a
withCBytes CBytes
algo ((Ptr Word8 -> IO Bytes) -> IO Bytes)
-> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
algo' ->
        CBytes -> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (Ptr Word8 -> IO a) -> IO a
withCBytes CBytes
pwd ((Ptr Word8 -> IO Bytes) -> IO Bytes)
-> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
pwd' ->
            Bytes -> (Ptr Word8 -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (Ptr a -> Int -> IO b) -> IO b
withPrimVectorSafe Bytes
s ((Ptr Word8 -> Int -> IO Bytes) -> IO Bytes)
-> (Ptr Word8 -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s' Int
sLen ->
                (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (Ptr a -> IO b) -> IO (PrimVector a, b)
allocPrimVectorSafe Int
siz (\ Ptr Word8
buf -> do
                    Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Int -> IO ()
clearPtr Ptr Word8
buf Int
siz
                    IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                        Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Int
-> IO CInt
hs_botan_pwdhash_timed_safe
                            Ptr Word8
algo' Int
msec Ptr Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                            Ptr Word8
pwd' (CBytes -> Int
CB.length CBytes
pwd) Ptr Word8
s' Int
0 Int
sLen)
    else CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
algo ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \BA# Word8
algo' ->
        CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
pwd ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
pwd' ->
            Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
s ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \BA# Word8
s' Int
sOff Int
sLen ->
                (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (MBA# Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (MBA# Word8 -> IO b) -> IO (PrimVector a, b)
allocPrimVectorUnsafe Int
siz (\ MBA# Word8
buf -> do
                    MBA# Word8 -> Int -> IO ()
forall k (a :: k). MBA# Word8 -> Int -> IO ()
clearMBA MBA# Word8
buf Int
siz
                    IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                        BA# Word8
-> Int
-> MBA# Word8
-> Int
-> BA# Word8
-> Int
-> BA# Word8
-> Int
-> Int
-> IO CInt
hs_botan_pwdhash_timed
                            BA# Word8
algo' Int
msec MBA# Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                            BA# Word8
pwd' (CBytes -> Int
CB.length CBytes
pwd) BA# Word8
s' Int
sOff Int
sLen)
  where
    (CBytes
algo, Int
_, Int
_, Int
_) = PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam PBKDFType
typ