{-# LANGUAGE OverloadedStrings #-}
module Jose.Jwe
( jwkEncode
, jwkDecode
, rsaEncode
, rsaDecode
)
where
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import Crypto.Cipher.Types (AuthTag(..))
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder, private_pub)
import Crypto.Random (MonadRandom)
import qualified Data.Aeson as A
import Data.ByteArray (ByteArray, ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Maybe (isNothing)
import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import Jose.Jwa
import Jose.Jwk
import qualified Jose.Internal.Parser as P
jwkEncode :: MonadRandom m
=> JweAlg
-> Enc
-> Jwk
-> Payload
-> m (Either JwtError Jwt)
jwkEncode :: forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
jwkEncode JweAlg
a Enc
e Jwk
jwk Payload
payload = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
RsaPublicJwk PublicKey
kPub Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (forall {m :: * -> *} {a} {a}.
(MonadRandom m, ByteArray a, ByteArray a) =>
PublicKey -> a -> ExceptT JwtError m a
doRsa PublicKey
kPub) ByteString
bytes
RsaPrivateJwk PrivateKey
kPr Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (forall {m :: * -> *} {a} {a}.
(MonadRandom m, ByteArray a, ByteArray a) =>
PublicKey -> a -> ExceptT JwtError m a
doRsa (PrivateKey -> PublicKey
private_pub PrivateKey
kPr)) ByteString
bytes
SymmetricJwk ByteString
kek Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ba.
ByteArray ba =>
JweAlg -> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ba
keyWrap JweAlg
a (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kek)) ByteString
bytes
Jwk
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"JWK cannot encode a JWE"
where
doRsa :: PublicKey -> a -> ExceptT JwtError m a
doRsa PublicKey
kPub = forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a
hdr :: Maybe KeyId -> B.ByteString
hdr :: Maybe KeyId -> ByteString
hdr Maybe KeyId
kid = ByteString -> ByteString
BL.toStrict forall a b. (a -> b) -> a -> b
$
[ByteString] -> ByteString
BL.concat
[ ByteString
"{\"alg\":"
, forall a. ToJSON a => a -> ByteString
A.encode JweAlg
a
, ByteString
",\"enc\":"
, forall a. ToJSON a => a -> ByteString
A.encode Enc
e
, forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" (\ByteString
c -> [ByteString] -> ByteString
BL.concat [ByteString
",\"cty\":\"", ByteString
c, ByteString
"\"" ]) Maybe ByteString
contentType
, if forall a. Maybe a -> Bool
isNothing Maybe KeyId
kid then ByteString
"" else [ByteString] -> ByteString
BL.concat [ByteString
",\"kid\":", forall a. ToJSON a => a -> ByteString
A.encode Maybe KeyId
kid ]
, ByteString
"}"
]
(Maybe ByteString
contentType, ByteString
bytes) = case Payload
payload of
Claims ByteString
c -> (forall a. Maybe a
Nothing, ByteString
c)
Nested (Jwt ByteString
b) -> (forall a. a -> Maybe a
Just ByteString
"JWT", ByteString
b)
jwkDecode :: MonadRandom m
=> Jwk
-> ByteString
-> m (Either JwtError JwtContent)
jwkDecode :: forall (m :: * -> *).
MonadRandom m =>
Jwk -> ByteString -> m (Either JwtError JwtContent)
jwkDecode Jwk
jwk ByteString
jwt = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
RsaPrivateJwk PrivateKey
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> do
Blinder
blinder <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
kPr)
Jwe
e <- forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
kPr) ByteString
jwt
forall (m :: * -> *) a. Monad m => a -> m a
return (Jwe -> JwtContent
Jwe Jwe
e)
SymmetricJwk ByteString
kb Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Jwe -> JwtContent
Jwe (forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (forall ba.
ByteArray ba =>
ScrubbedBytes -> JweAlg -> ba -> Either JwtError ScrubbedBytes
keyUnwrap (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kb)) ByteString
jwt)
UnsupportedJwk Object
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"Unsupported JWK cannot be used to decode JWE")
Jwk
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"This JWK cannot decode a JWE"
doDecode :: MonadRandom m
=> (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString
-> ExceptT JwtError m Jwe
doDecode :: forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek ByteString
jwt = do
DecodableJwt
encodedJwt <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))
case DecodableJwt
encodedJwt of
P.DecodableJwe JweHeader
hdr (P.EncryptedCEK ByteString
ek) IV
iv (P.Payload ByteString
payload) Tag
tag (P.AAD ByteString
aad) -> do
let alg :: JweAlg
alg = JweHeader -> JweAlg
jweAlg JweHeader
hdr
enc :: Enc
enc = JweHeader -> Enc
jweEnc JweHeader
hdr
(ScrubbedBytes
dummyCek, ScrubbedBytes
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
enc
let decryptedCek :: ScrubbedBytes
decryptedCek = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const ScrubbedBytes
dummyCek) forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek JweAlg
alg ByteString
ek
cek :: ScrubbedBytes
cek = if forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
decryptedCek forall a. Eq a => a -> a -> Bool
== forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
dummyCek
then ScrubbedBytes
decryptedCek
else ScrubbedBytes
dummyCek
ByteString
claims <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadCrypto) forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall ba.
ByteArray ba =>
Enc -> ScrubbedBytes -> IV -> ba -> Tag -> ba -> Maybe ba
decryptPayload Enc
enc ScrubbedBytes
cek IV
iv ByteString
aad Tag
tag ByteString
payload
forall (m :: * -> *) a. Monad m => a -> m a
return (JweHeader
hdr, ByteString
claims)
DecodableJwt
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadHeader Text
"Content is not a JWE")
doEncode :: (MonadRandom m, ByteArray ba)
=> ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode :: forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode ByteString
hdr Enc
e ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ba
claims = do
(ScrubbedBytes
cmk, ScrubbedBytes
iv) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
e)
let Just (AuthTag Bytes
sig, ba
ct) = forall ba iv.
(ByteArray ba, ByteArray iv) =>
Enc -> ScrubbedBytes -> iv -> ba -> ba -> Maybe (AuthTag, ba)
encryptPayload Enc
e ScrubbedBytes
cmk ScrubbedBytes
iv ba
aad ba
claims
ByteString
jweKey <- ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ScrubbedBytes
cmk
let jwe :: ByteString
jwe = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"." forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode [ByteString
hdr, ByteString
jweKey, forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ScrubbedBytes
iv, forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
ct, forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
sig]
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Jwt
Jwt ByteString
jwe)
where
aad :: ba
aad = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode ByteString
hdr
rsaEncode :: MonadRandom m
=> JweAlg
-> Enc
-> PublicKey
-> ByteString
-> m (Either JwtError Jwt)
rsaEncode :: forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> PublicKey -> ByteString -> m (Either JwtError Jwt)
rsaEncode JweAlg
a Enc
e PublicKey
kPub ByteString
claims = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode ByteString
hdr Enc
e (forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a) ByteString
claims
where
hdr :: ByteString
hdr = ByteString -> ByteString
BL.toStrict forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat [ByteString
"{\"alg\":", forall a. ToJSON a => a -> ByteString
A.encode JweAlg
a, ByteString
",", ByteString
"\"enc\":", forall a. ToJSON a => a -> ByteString
A.encode Enc
e, ByteString
"}"]
rsaDecode :: MonadRandom m
=> PrivateKey
-> ByteString
-> m (Either JwtError Jwe)
rsaDecode :: forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> ByteString -> m (Either JwtError Jwe)
rsaDecode PrivateKey
pk ByteString
jwt = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
Blinder
blinder <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
pk)
forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
pk) ByteString
jwt