{-# LANGUAGE BangPatterns #-}
module Crypto.KDF.HKDF
( PRK
, extract
, extractSkip
, expand
) where
import Data.Word
import Crypto.Hash
import Crypto.MAC.HMAC
import Crypto.Internal.ByteArray (ScrubbedBytes, ByteArray, ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B
data PRK a = PRK (HMAC a) | PRK_NoExpand ScrubbedBytes
deriving (PRK a -> PRK a -> Bool
forall a. PRK a -> PRK a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PRK a -> PRK a -> Bool
$c/= :: forall a. PRK a -> PRK a -> Bool
== :: PRK a -> PRK a -> Bool
$c== :: forall a. PRK a -> PRK a -> Bool
Eq)
instance ByteArrayAccess (PRK a) where
length :: PRK a -> Int
length (PRK HMAC a
hm) = forall ba. ByteArrayAccess ba => ba -> Int
B.length HMAC a
hm
length (PRK_NoExpand ScrubbedBytes
sb) = forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
sb
withByteArray :: forall p a. PRK a -> (Ptr p -> IO a) -> IO a
withByteArray (PRK HMAC a
hm) = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray HMAC a
hm
withByteArray (PRK_NoExpand ScrubbedBytes
sb) = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sb
extract :: (HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm)
=> salt
-> ikm
-> PRK a
salt
salt ikm
ikm = forall a. HMAC a -> PRK a
PRK forall a b. (a -> b) -> a -> b
$ forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac salt
salt ikm
ikm
extractSkip :: ByteArrayAccess ikm
=> ikm
-> PRK a
ikm
ikm = forall a. ScrubbedBytes -> PRK a
PRK_NoExpand forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert ikm
ikm
expand :: (HashAlgorithm a, ByteArrayAccess info, ByteArray out)
=> PRK a
-> info
-> Int
-> out
expand :: forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
expand PRK a
prkAt info
infoAt Int
outputLength =
let hF :: ScrubbedBytes -> HMAC a
hF = forall a b.
(HashAlgorithm a, ByteArrayAccess b) =>
PRK a -> b -> HMAC a
hFGet PRK a
prkAt
in forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat forall a b. (a -> b) -> a -> b
$ forall a.
HashAlgorithm a =>
(ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
loop ScrubbedBytes -> HMAC a
hF forall a. ByteArray a => a
B.empty Int
outputLength Word8
1
where
hFGet :: (HashAlgorithm a, ByteArrayAccess b) => PRK a -> (b -> HMAC a)
hFGet :: forall a b.
(HashAlgorithm a, ByteArrayAccess b) =>
PRK a -> b -> HMAC a
hFGet PRK a
prk = case PRK a
prk of
PRK HMAC a
hmacKey -> forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac HMAC a
hmacKey
PRK_NoExpand ScrubbedBytes
ikm -> forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ScrubbedBytes
ikm
info :: ScrubbedBytes
info :: ScrubbedBytes
info = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert info
infoAt
loop :: HashAlgorithm a
=> (ScrubbedBytes -> HMAC a)
-> ScrubbedBytes
-> Int
-> Word8
-> [ScrubbedBytes]
loop :: forall a.
HashAlgorithm a =>
(ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
loop ScrubbedBytes -> HMAC a
hF ScrubbedBytes
tim1 Int
n Word8
i
| Int
n forall a. Ord a => a -> a -> Bool
<= Int
0 = []
| Bool
otherwise =
let input :: ScrubbedBytes
input = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat [ScrubbedBytes
tim1,ScrubbedBytes
info,forall a. ByteArray a => Word8 -> a
B.singleton Word8
i] :: ScrubbedBytes
ti :: ScrubbedBytes
ti = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> HMAC a
hF ScrubbedBytes
input
hashLen :: Int
hashLen = forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
ti
r :: Int
r = Int
n forall a. Num a => a -> a -> a
- Int
hashLen
in (if Int
n forall a. Ord a => a -> a -> Bool
>= Int
hashLen then ScrubbedBytes
ti else forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
n ScrubbedBytes
ti)
forall a. a -> [a] -> [a]
: forall a.
HashAlgorithm a =>
(ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
loop ScrubbedBytes -> HMAC a
hF ScrubbedBytes
ti Int
r (Word8
iforall a. Num a => a -> a -> a
+Word8
1)