{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- SPDX-FileCopyrightText: 2020 Serokell
--
-- SPDX-License-Identifier: MPL-2.0

-- | Key derivation/generation internals.
module Crypto.Key.Internal
  ( Params (..)
  , DerivationSlip
  , derive
  , rederive

  , DerivationSlipData (..)
  , derivationSlipEncode
  , derivationSlipDecode
  ) where

import Control.Monad (when)
import Data.ByteArray (ByteArrayAccess)
import Data.ByteArray.Sized (ByteArrayN, sizedByteArray, unSizedByteArray)
import Data.ByteString (ByteString)
import Data.Serialize (Serialize (put, get), decode, encode)
import Data.Word (Word8)
import GHC.TypeLits (type (<=))

import qualified Libsodium as Na

import Crypto.Nonce (generate)
import Crypto.Pwhash.Internal (Algorithm (Argon2id_1_3), Params (..), Salt, pwhash)


-- | Opaque bytes that contain the salt and pwhash params.
type DerivationSlip = ByteString

-- | Data contained in a derivation slip.
--
-- This data type is used only internally within this module for
-- convenience. It is exported only for testing purposes.
--
-- Currently only one KDF is supported, so it is assumed implicitly,
-- however the actual binary encoding contains an identifier of the KDF
-- used (for forward-compatibility).
data DerivationSlipData = DerivationSlipData
  { DerivationSlipData -> Params
params :: !Params
  , DerivationSlipData -> Salt ByteString
salt :: !(Salt ByteString)
  }
  deriving (DerivationSlipData -> DerivationSlipData -> Bool
(DerivationSlipData -> DerivationSlipData -> Bool)
-> (DerivationSlipData -> DerivationSlipData -> Bool)
-> Eq DerivationSlipData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DerivationSlipData -> DerivationSlipData -> Bool
$c/= :: DerivationSlipData -> DerivationSlipData -> Bool
== :: DerivationSlipData -> DerivationSlipData -> Bool
$c== :: DerivationSlipData -> DerivationSlipData -> Bool
Eq, Int -> DerivationSlipData -> ShowS
[DerivationSlipData] -> ShowS
DerivationSlipData -> String
(Int -> DerivationSlipData -> ShowS)
-> (DerivationSlipData -> String)
-> ([DerivationSlipData] -> ShowS)
-> Show DerivationSlipData
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DerivationSlipData] -> ShowS
$cshowList :: [DerivationSlipData] -> ShowS
show :: DerivationSlipData -> String
$cshow :: DerivationSlipData -> String
showsPrec :: Int -> DerivationSlipData -> ShowS
$cshowsPrec :: Int -> DerivationSlipData -> ShowS
Show)

instance Serialize DerivationSlipData where
  put :: Putter DerivationSlipData
put (DerivationSlipData Params{Word64
opsLimit :: Params -> Word64
opsLimit :: Word64
opsLimit, Word64
memLimit :: Params -> Word64
memLimit :: Word64
memLimit} Salt ByteString
salt) = do
    Putter Word8
forall t. Serialize t => Putter t
put (Word8
1 :: Word8)  -- algorithm marker for forward-compatibility
    Putter Word64
forall t. Serialize t => Putter t
put Word64
opsLimit Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Putter Word64
forall t. Serialize t => Putter t
put Word64
memLimit
    Putter ByteString
forall t. Serialize t => Putter t
put (Salt ByteString -> ByteString
forall (n :: Nat) ba. SizedByteArray n ba -> ba
unSizedByteArray Salt ByteString
salt)
  get :: Get DerivationSlipData
get = do
    Word8
tag <- Serialize Word8 => Get Word8
forall t. Serialize t => Get t
get @Word8
    Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
tag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
1) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Wrong algorithm parameters encoding tag"
    Params
params <- Word64 -> Word64 -> Params
Params (Word64 -> Word64 -> Params)
-> Get Word64 -> Get (Word64 -> Params)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word64
forall t. Serialize t => Get t
get Get (Word64 -> Params) -> Get Word64 -> Get Params
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Word64
forall t. Serialize t => Get t
get
    Maybe (Salt ByteString)
msalt <- ByteString -> Maybe (Salt ByteString)
forall (n :: Nat) ba.
(KnownNat n, ByteArrayAccess ba) =>
ba -> Maybe (SizedByteArray n ba)
sizedByteArray (ByteString -> Maybe (Salt ByteString))
-> Get ByteString -> Get (Maybe (Salt ByteString))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Serialize ByteString => Get ByteString
forall t. Serialize t => Get t
get @ByteString
    case Maybe (Salt ByteString)
msalt of
      Maybe (Salt ByteString)
Nothing -> String -> Get DerivationSlipData
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unexpected salt size"
      Just Salt ByteString
salt -> DerivationSlipData -> Get DerivationSlipData
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DerivationSlipData -> Get DerivationSlipData)
-> DerivationSlipData -> Get DerivationSlipData
forall a b. (a -> b) -> a -> b
$ Params -> Salt ByteString -> DerivationSlipData
DerivationSlipData Params
params Salt ByteString
salt


-- | Encode derivation slip data into bytes.
derivationSlipEncode :: DerivationSlipData -> DerivationSlip
derivationSlipEncode :: DerivationSlipData -> ByteString
derivationSlipEncode = DerivationSlipData -> ByteString
forall a. Serialize a => a -> ByteString
encode

-- | Decode derivation slip data from bytes.
derivationSlipDecode :: DerivationSlip -> Maybe DerivationSlipData
derivationSlipDecode :: ByteString -> Maybe DerivationSlipData
derivationSlipDecode ByteString
bytes = case ByteString -> Either String DerivationSlipData
forall a. Serialize a => ByteString -> Either String a
decode ByteString
bytes of
  Right DerivationSlipData
slip -> DerivationSlipData -> Maybe DerivationSlipData
forall a. a -> Maybe a
Just DerivationSlipData
slip
  Left String
_ -> Maybe DerivationSlipData
forall a. Maybe a
Nothing


-- | Derive a key for the first time.
derive
  ::  ( ByteArrayAccess passwd
      , ByteArrayN n key
      , Na.CRYPTO_PWHASH_BYTES_MIN <= n, n <= Na.CRYPTO_PWHASH_BYTES_MAX
      )
  => Params
  -> passwd
  -> IO (Maybe (key, DerivationSlip))
derive :: Params -> passwd -> IO (Maybe (key, ByteString))
derive Params
params passwd
passwd = do
  Salt ByteString
salt <- IO (Salt ByteString)
forall (n :: Nat). KnownNat n => IO (SizedByteArray n ByteString)
generate
  Maybe key
mkey <- Algorithm -> Params -> passwd -> Salt ByteString -> IO (Maybe key)
forall passwd salt (n :: Nat) hash.
(ByteArrayAccess passwd, ByteArrayAccess salt, ByteArrayN n hash,
 CRYPTO_PWHASH_BYTES_MIN <= n, n <= CRYPTO_PWHASH_BYTES_MAX) =>
Algorithm -> Params -> passwd -> Salt salt -> IO (Maybe hash)
pwhash Algorithm
Argon2id_1_3 Params
params passwd
passwd Salt ByteString
salt
  let slip :: DerivationSlipData
slip = Params -> Salt ByteString -> DerivationSlipData
DerivationSlipData Params
params Salt ByteString
salt
  Maybe (key, ByteString) -> IO (Maybe (key, ByteString))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (key, ByteString) -> IO (Maybe (key, ByteString)))
-> Maybe (key, ByteString) -> IO (Maybe (key, ByteString))
forall a b. (a -> b) -> a -> b
$ (key -> (key, ByteString)) -> Maybe key -> Maybe (key, ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (, DerivationSlipData -> ByteString
derivationSlipEncode DerivationSlipData
slip) Maybe key
mkey

-- | Derive the same key form the same password again.
rederive
  ::  ( ByteArrayAccess passwd
      , ByteArrayN n key
      , Na.CRYPTO_PWHASH_BYTES_MIN <= n, n <= Na.CRYPTO_PWHASH_BYTES_MAX
      )
  => DerivationSlip
  -> passwd
  -> IO (Maybe key)
rederive :: ByteString -> passwd -> IO (Maybe key)
rederive ByteString
slip passwd
passwd =
  case ByteString -> Maybe DerivationSlipData
derivationSlipDecode ByteString
slip of
    Maybe DerivationSlipData
Nothing -> Maybe key -> IO (Maybe key)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe key
forall a. Maybe a
Nothing
    Just (DerivationSlipData{Params
params :: Params
params :: DerivationSlipData -> Params
params, Salt ByteString
salt :: Salt ByteString
salt :: DerivationSlipData -> Salt ByteString
salt}) ->
      Algorithm -> Params -> passwd -> Salt ByteString -> IO (Maybe key)
forall passwd salt (n :: Nat) hash.
(ByteArrayAccess passwd, ByteArrayAccess salt, ByteArrayN n hash,
 CRYPTO_PWHASH_BYTES_MIN <= n, n <= CRYPTO_PWHASH_BYTES_MAX) =>
Algorithm -> Params -> passwd -> Salt salt -> IO (Maybe hash)
pwhash Algorithm
Argon2id_1_3 Params
params passwd
passwd Salt ByteString
salt