{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Crypto.Store.PKCS5.PBES1
( PBEParameter(..)
, Key
, pkcs5
, pkcs12
, pkcs12rc2
, pkcs12stream
, pkcs12mac
, rc4Combine
) where
import Basement.Block (Block)
import Basement.Compat.IsList
import Basement.Endianness
import qualified Basement.String as S
import Crypto.Cipher.Types
import qualified Crypto.Cipher.RC4 as RC4
import qualified Crypto.Hash as Hash
import Data.ASN1.Types
import Data.Bits
import Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as B
import Data.ByteString (ByteString)
import Data.Maybe (fromMaybe)
import Data.Memory.PtrMethods
import Data.Word
import Foreign.Ptr (plusPtr)
import Foreign.Storable
import Crypto.Store.ASN1.Parse
import Crypto.Store.ASN1.Generate
import Crypto.Store.CMS.Algorithms
import Crypto.Store.CMS.Util
import Crypto.Store.Error
type Key = B.ScrubbedBytes
data PBEParameter = PBEParameter
{ pbeSalt :: Salt
, pbeIterationCount :: Int
}
deriving (Show,Eq)
instance ASN1Elem e => ProduceASN1Object e PBEParameter where
asn1s PBEParameter{..} =
let salt = gOctetString pbeSalt
iters = gIntVal (toInteger pbeIterationCount)
in asn1Container Sequence (salt . iters)
instance Monoid e => ParseASN1Object e PBEParameter where
parse = onNextContainer Sequence $ do
OctetString salt <- getNext
IntVal iters <- getNext
return PBEParameter { pbeSalt = salt
, pbeIterationCount = fromInteger iters }
cbcWith :: (BlockCipher cipher, ByteArrayAccess iv)
=> ContentEncryptionCipher cipher -> iv -> ContentEncryptionParams
cbcWith cipher iv = ParamsCBC cipher getIV
where
getIV = fromMaybe (error "PKCS5: bad initialization vector") (makeIV iv)
rc2cbcWith :: ByteArrayAccess iv => Int -> iv -> ContentEncryptionParams
rc2cbcWith len iv = ParamsCBCRC2 len getIV
where
getIV = fromMaybe (error "PKCS5: bad RC2 initialization vector") (makeIV iv)
rc4Combine :: (ByteArrayAccess key, ByteArray ba) => key -> ba -> Either StoreError ba
rc4Combine key = Right . snd . RC4.combine (RC4.initialize key)
toUCS2 :: (ByteArrayAccess butf8, ByteArray bucs2) => butf8 -> Maybe bucs2
toUCS2 pwdUTF8
| B.null r = Just pwdUCS2
| otherwise = Nothing
where
(p, _, r) = S.fromBytes S.UTF8 $ B.snoc (B.convert pwdUTF8) 0
pwdBlock = fromList $ map ucs2 $ toList p :: Block (BE Word16)
pwdUCS2 = B.convert pwdBlock
ucs2 :: Char -> BE Word16
ucs2 = toBE . toEnum . fromEnum
pkcs5 :: (Hash.HashAlgorithm hash, BlockCipher cipher, ByteArrayAccess password)
=> (StoreError -> result)
-> (Key -> ContentEncryptionParams -> ByteString -> result)
-> DigestProxy hash
-> ContentEncryptionCipher cipher
-> PBEParameter
-> ByteString
-> password
-> result
pkcs5 failure encdec hashAlg cec pbeParam bs pwd
| proxyBlockSize cec /= 8 = failure (InvalidParameter "Invalid cipher block size")
| otherwise =
case pbkdf1 hashAlg pwd pbeParam 16 of
Left err -> failure err
Right dk ->
let (key, iv) = B.splitAt 8 (dk :: Key)
in encdec key (cbcWith cec iv) bs
pbkdf1 :: (Hash.HashAlgorithm hash, ByteArrayAccess password, ByteArray out)
=> DigestProxy hash
-> password
-> PBEParameter
-> Int
-> Either StoreError out
pbkdf1 hashAlg pwd PBEParameter{..} dkLen
| dkLen > B.length t1 = Left (InvalidParameter "Derived key too long")
| otherwise = Right (B.convert $ B.takeView tc dkLen)
where
a = hashFromProxy hashAlg
t1 = Hash.hashFinalize (Hash.hashUpdate (Hash.hashUpdate (Hash.hashInitWith a) pwd) pbeSalt)
tc = iterate (Hash.hashWith a) t1 !! pred pbeIterationCount
pkcs12 :: (Hash.HashAlgorithm hash, BlockCipher cipher, ByteArrayAccess password)
=> (StoreError -> result)
-> (Key -> ContentEncryptionParams -> ByteString -> result)
-> DigestProxy hash
-> ContentEncryptionCipher cipher
-> PBEParameter
-> ByteString
-> password
-> result
pkcs12 failure encdec hashAlg cec pbeParam bs pwdUTF8 =
case toUCS2 pwdUTF8 of
Nothing -> failure passwordNotUTF8
Just pwdUCS2 ->
let ivLen = proxyBlockSize cec
iv = pkcs12Derive hashAlg pbeParam 2 pwdUCS2 ivLen :: B.Bytes
eScheme = cbcWith cec iv
keyLen = getMaximumKeySize eScheme
key = pkcs12Derive hashAlg pbeParam 1 pwdUCS2 keyLen :: Key
in encdec key eScheme bs
pkcs12rc2 :: (Hash.HashAlgorithm hash, ByteArrayAccess password)
=> (StoreError -> result)
-> (Key -> ContentEncryptionParams -> ByteString -> result)
-> DigestProxy hash
-> Int
-> PBEParameter
-> ByteString
-> password
-> result
pkcs12rc2 failure encdec hashAlg len pbeParam bs pwdUTF8 =
case toUCS2 pwdUTF8 of
Nothing -> failure passwordNotUTF8
Just pwdUCS2 ->
let ivLen = 8
iv = pkcs12Derive hashAlg pbeParam 2 pwdUCS2 ivLen :: B.Bytes
eScheme = rc2cbcWith len iv
keyLen = getMaximumKeySize eScheme
key = pkcs12Derive hashAlg pbeParam 1 pwdUCS2 keyLen :: Key
in encdec key eScheme bs
pkcs12stream :: (Hash.HashAlgorithm hash, ByteArrayAccess password)
=> (StoreError -> result)
-> (Key -> ByteString -> result)
-> DigestProxy hash
-> Int
-> PBEParameter
-> ByteString
-> password
-> result
pkcs12stream failure encdec hashAlg keyLen pbeParam bs pwdUTF8 =
case toUCS2 pwdUTF8 of
Nothing -> failure passwordNotUTF8
Just pwdUCS2 ->
let key = pkcs12Derive hashAlg pbeParam 1 pwdUCS2 keyLen :: Key
in encdec key bs
pkcs12mac :: (Hash.HashAlgorithm hash, ByteArrayAccess password)
=> (StoreError -> result)
-> (Key -> MACAlgorithm -> ByteString -> result)
-> DigestProxy hash
-> PBEParameter
-> ByteString
-> password
-> result
pkcs12mac failure macFn hashAlg pbeParam bs pwdUTF8 =
case toUCS2 pwdUTF8 of
Nothing -> failure passwordNotUTF8
Just pwdUCS2 ->
let macAlg = HMAC hashAlg
keyLen = getMaximumKeySize macAlg
key = pkcs12Derive hashAlg pbeParam 3 pwdUCS2 keyLen :: Key
in macFn key macAlg bs
passwordNotUTF8 :: StoreError
passwordNotUTF8 = InvalidPassword "Provided password is not valid UTF-8"
pkcs12Derive :: (Hash.HashAlgorithm hash, ByteArray bout)
=> DigestProxy hash
-> PBEParameter
-> Word8
-> ByteString
-> Int
-> bout
pkcs12Derive hashAlg PBEParameter{..} idByte pwdUCS2 n =
B.take n $ B.concat $ take c $ loop t (s `B.append` p)
where
a = hashFromProxy hashAlg
v = getV (DigestAlgorithm hashAlg)
u = Hash.hashDigestSize a
c = (n + u - 1) `div` u
d = B.replicate v idByte :: B.Bytes
t = Hash.hashUpdate (Hash.hashInitWith a) d
p = pwdUCS2 `extendedToMult` v
s = pbeSalt `extendedToMult` v
loop :: Hash.HashAlgorithm hash
=> Hash.Context hash -> ByteString -> [Hash.Digest hash]
loop x i = let z = Hash.hashFinalize (Hash.hashUpdate x i)
ai = iterate Hash.hash z !! pred pbeIterationCount
b = ai `extendedTo` v
j = B.concat $ map (add1 b) (chunks v i)
in ai : loop x j
getV :: DigestAlgorithm -> Int
getV (DigestAlgorithm MD2) = 64
getV (DigestAlgorithm MD4) = 64
getV (DigestAlgorithm MD5) = 64
getV (DigestAlgorithm SHA1) = 64
getV (DigestAlgorithm SHA224) = 64
getV (DigestAlgorithm SHA256) = 64
getV (DigestAlgorithm SHA384) = 128
getV (DigestAlgorithm SHA512) = 128
hashFromProxy :: proxy a -> a
hashFromProxy _ = undefined
chunks :: ByteArray ba => Int -> ba -> [ba]
chunks n bs
| len > n = let (c, cs) = B.splitAt n bs in c : chunks n cs
| len > 0 = [bs]
| otherwise = []
where
len = B.length bs
extendedTo :: (ByteArrayAccess bin, ByteArray bout) => bin -> Int -> bout
bs `extendedTo` n =
B.allocAndFreeze n $ \pout ->
B.withByteArray bs $ \pin -> do
mapM_ (\off -> memCopy (pout `plusPtr` off) pin len)
(enumFromThenTo 0 len (n - 1))
memCopy (pout `plusPtr` (n - r)) pin r
where
len = B.length bs
r = n `mod` len
{-# NOINLINE extendedTo #-}
extendedToMult :: ByteArray ba => ba -> Int -> ba
bs `extendedToMult` n
| len > n = bs `B.append` B.take (n - len `mod` n) bs
| len == n = bs
| len > 0 = bs `extendedTo` n
| otherwise = B.empty
where
len = B.length bs
add1 :: ByteString -> ByteString -> ByteString
add1 a b =
B.allocAndFreeze alen $ \pc ->
B.withByteArray a $ \pa ->
B.withByteArray b $ \pb ->
loop3 pa pb pc alen blen 1
where
alen = B.length a
blen = B.length b
loop3 !pa !pb !pc !ma !mb !c
| ma == 0 = return ()
| mb == 0 = loop2 pa pc ma c
| otherwise = do
let na = pred ma
nb = pred mb
ba <- peekElemOff pa na
bb <- peekElemOff pb nb
let (cc, bc) = carryAdd3 c ba bb
pokeElemOff pc na bc
loop3 pa pb pc na nb cc
loop2 !pa !pc !ma !c
| ma == 0 = return ()
| otherwise = do
let na = pred ma
ba <- peekElemOff pa na
let (cc, bc) = carryAdd2 c ba
pokeElemOff pc na bc
loop2 pa pc na cc
split16 :: Word16 -> (Word8, Word8)
split16 x = (fromIntegral (shiftR x 8), fromIntegral x)
carryAdd2 :: Word8 -> Word8 -> (Word8, Word8)
carryAdd2 a b = split16 (fromIntegral a + fromIntegral b)
carryAdd3 :: Word8 -> Word8 -> Word8 -> (Word8, Word8)
carryAdd3 a b c = split16 (fromIntegral a + fromIntegral b + fromIntegral c)