module Crypto.HKDF
( hkdfExtract
, hkdfExpand
, hkdf
) where
import Crypto.Hash (HashAlgorithm)
import Crypto.MAC (HMAC, hmacAlg)
import Data.Byteable (toBytes)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (concat, empty, length, take)
import qualified Data.ByteString.Char8 as C8 (singleton)
import Data.Char (chr)
hkdfExtract :: (HashAlgorithm a) => a
-> ByteString
-> ByteString
-> HMAC a
hkdfExtract = hmacAlg
hkdfExpand :: (HashAlgorithm a) => a
-> ByteString
-> ByteString
-> Int
-> Maybe ByteString
hkdfExpand alg prk info l
| l <= 255 * chunkSize = Just $ BS.take l $ BS.concat $ take (l `div` chunkSize + 2) hkdfChunks
| otherwise = Nothing
where hkdfChunks = map fst $ iterate (hkdfSingle alg prk info) (BS.empty, 1)
chunkSize = BS.length $ hkdfChunks !! 1
type HKDFIteration = (ByteString, Int)
hkdfSingle :: (HashAlgorithm a) => a
-> ByteString
-> ByteString
-> HKDFIteration
-> HKDFIteration
hkdfSingle alg prk info (prev, n) = (toBytes $ hmacAlg alg prk $ BS.concat [prev, info, C8.singleton $ chr n], n + 1)
hkdf :: (HashAlgorithm a) => a
-> ByteString
-> ByteString
-> ByteString
-> Int
-> Maybe ByteString
hkdf alg salt ikm = hkdfExpand alg (toBytes $ hkdfExtract alg salt ikm)