module Jose.Internal.Crypto
( hmacSign
, hmacVerify
, rsaSign
, rsaVerify
, rsaEncrypt
, rsaDecrypt
, ecVerify
, encryptPayload
, decryptPayload
, generateCmkAndIV
, pad
, unpad
)
where
import Control.Applicative
import Crypto.Cipher.Types (AuthTag(..))
import Control.Monad.Error
import Crypto.Number.Serialize (os2ip)
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.RSA as RSA
import qualified Crypto.PubKey.RSA.PKCS15 as PKCS15
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import Crypto.Random (CPRG, cprgGenerate)
import qualified Crypto.Cipher.AES as AES
import Crypto.PubKey.HashDescr
import Crypto.MAC.HMAC (hmac)
import Data.Byteable (constEqBytes)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.Serialize as Serialize
import qualified Data.Text as T
import Data.Word (Word64, Word8)
import Jose.Jwa
import Jose.Types (JwtError(..))
hmacSign :: JwsAlg
-> ByteString
-> ByteString
-> Either JwtError ByteString
hmacSign a k m = do
hash <- maybe (Left $ BadAlgorithm $ T.pack $ "Not an HMAC algorithm: " ++ show a) return $ lookup a hmacHashes
return $ hmac (hashFunction hash) 64 k m
hmacVerify :: JwsAlg
-> ByteString
-> ByteString
-> ByteString
-> Bool
hmacVerify a key msg sig = either (const False) (`constEqBytes` sig) $ hmacSign a key msg
rsaSign :: Maybe RSA.Blinder
-> JwsAlg
-> RSA.PrivateKey
-> ByteString
-> Either JwtError ByteString
rsaSign blinder a key msg = do
hash <- lookupRSAHash a
either (const $ Left BadCrypto) Right $ PKCS15.sign blinder hash key msg
where
rsaVerify :: JwsAlg
-> RSA.PublicKey
-> ByteString
-> ByteString
-> Bool
rsaVerify a key msg sig = case lookupRSAHash a of
Right hash -> PKCS15.verify hash key msg sig
_ -> False
ecVerify :: JwsAlg
-> ECDSA.PublicKey
-> ByteString
-> ByteString
-> Bool
ecVerify a key msg sig = case lookupECHash a of
Just hash -> let (r, s) = B.splitAt (B.length sig `div` 2) sig
in ECDSA.verify hash key (ECDSA.Signature (os2ip r) (os2ip s)) msg
Nothing -> False
hmacHashes :: [(JwsAlg, HashDescr)]
hmacHashes = [(HS256, hashDescrSHA256), (HS384, hashDescrSHA384), (HS512, hashDescrSHA512)]
lookupECHash :: JwsAlg -> Maybe HashFunction
lookupECHash alg = hashFunction <$> case alg of
ES256 -> Just hashDescrSHA256
ES384 -> Just hashDescrSHA384
ES512 -> Just hashDescrSHA512
_ -> Nothing
lookupRSAHash :: JwsAlg -> Either JwtError HashDescr
lookupRSAHash alg = case alg of
RS256 -> Right hashDescrSHA256
RS384 -> Right hashDescrSHA384
RS512 -> Right hashDescrSHA512
_ -> Left . BadAlgorithm . T.pack $ "Not an RSA algorithm: " ++ show alg
generateCmkAndIV :: CPRG g
=> g
-> Enc
-> ((B.ByteString, B.ByteString), g)
generateCmkAndIV g e = ((cmk, iv), g'')
where
(cmk, g') = cprgGenerate (keySize e) g
(iv, g'') = cprgGenerate (ivSize e) g'
keySize :: Enc -> Int
keySize A128GCM = 16
keySize A256GCM = 32
keySize A128CBC_HS256 = 32
keySize A256CBC_HS512 = 64
ivSize :: Enc -> Int
ivSize A128GCM = 12
ivSize A256GCM = 12
ivSize _ = 16
rsaEncrypt :: CPRG g
=> g
-> JweAlg
-> RSA.PublicKey
-> B.ByteString
-> (B.ByteString, g)
rsaEncrypt gen a pubKey content = (ct, g')
where
encrypt = case a of
RSA1_5 -> PKCS15.encrypt gen
RSA_OAEP -> OAEP.encrypt gen oaepParams
(Right ct, g') = encrypt pubKey content
rsaDecrypt :: Maybe RSA.Blinder
-> JweAlg
-> RSA.PrivateKey
-> B.ByteString
-> Either JwtError B.ByteString
rsaDecrypt blinder a rsaKey jweKey = either (const $ throwError BadCrypto) return $ decrypt rsaKey jweKey
where
decrypt = case a of
RSA1_5 -> PKCS15.decrypt blinder
RSA_OAEP -> OAEP.decrypt blinder oaepParams
oaepParams :: OAEP.OAEPParams
oaepParams = OAEP.defaultOAEPParams (hashFunction hashDescrSHA1)
decryptPayload :: MonadError JwtError m
=> Enc
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> m ByteString
decryptPayload e cek iv aad sig ct = do
(plaintext, tag) <- case e of
A128GCM -> decryptedGCM
A256GCM -> decryptedGCM
A128CBC_HS256 -> decryptedCBC 16 hashDescrSHA256
A256CBC_HS512 -> decryptedCBC 32 hashDescrSHA512
if tag == AuthTag sig
then return plaintext
else throwError BadSignature
where
decryptedGCM = return $ AES.decryptGCM (AES.initAES cek) iv aad ct
decryptedCBC l h = do
unless (B.length ct `mod` 16 == 0) $ throwError BadCrypto
let (macKey, encKey) = B.splitAt (B.length cek `div` 2) cek
let al = fromIntegral (B.length aad) * 8 :: Word64
plaintext <- unpad $ AES.decryptCBC (AES.initAES encKey) iv ct
let mac = authTag l h macKey $ B.concat [aad, iv, ct, Serialize.encode al]
return (plaintext, mac)
encryptPayload :: Enc
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> (ByteString, AuthTag)
encryptPayload e cek iv aad msg = case e of
A128GCM -> aesgcm
A256GCM -> aesgcm
A128CBC_HS256 -> (aescbc, sig 16 hashDescrSHA256)
A256CBC_HS512 -> (aescbc, sig 32 hashDescrSHA512)
where
aesgcm = AES.encryptGCM (AES.initAES cek) iv aad msg
(macKey, encKey) = B.splitAt (B.length cek `div` 2) cek
aescbc = AES.encryptCBC (AES.initAES encKey) iv (pad msg)
al = fromIntegral (B.length aad) * 8 :: Word64
sig l h = authTag l h macKey $ B.concat [aad, iv, aescbc, Serialize.encode al]
authTag :: Int -> HashDescr -> ByteString -> ByteString -> AuthTag
authTag l h k m = AuthTag $ B.take l $ hmac (hashFunction h) 64 k m
unpad :: MonadError JwtError m => ByteString -> m ByteString
unpad bs
| padLen > 16 || padLen /= B.length padding = throwError BadCrypto
| B.any (/= padByte) padding = throwError BadCrypto
| otherwise = return pt
where
len = B.length bs
padByte = B.last bs
padLen = fromIntegral padByte
(pt, padding) = B.splitAt (len padLen) bs
pad :: ByteString -> ByteString
pad bs = B.append bs padding
where
lastBlockSize = B.length bs `mod` 16
padByte = fromIntegral $ 16 lastBlockSize :: Word8
padding = B.replicate (fromIntegral padByte) padByte