{-# LANGUAGE OverloadedStrings #-}
module Jose.Jwe
( jwkEncode
, jwkDecode
, rsaEncode
, rsaDecode
)
where
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import Crypto.Cipher.Types (AuthTag(..))
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder, private_pub)
import Crypto.Random (MonadRandom)
import Data.ByteArray (ByteArray, ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import Jose.Jwa
import Jose.Jwk
import qualified Jose.Internal.Parser as P
jwkEncode :: MonadRandom m
=> JweAlg
-> Enc
-> Jwk
-> Payload
-> m (Either JwtError Jwt)
jwkEncode :: JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
jwkEncode JweAlg
a Enc
e Jwk
jwk Payload
payload = ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwt -> m (Either JwtError Jwt))
-> ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
RsaPublicJwk PublicKey
kPub Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> JweHeader
hdr Maybe KeyId
kid) (PublicKey -> ScrubbedBytes -> ExceptT JwtError m ByteString
forall (m :: * -> *) msg a.
(MonadRandom m, ByteArray msg, ByteArray a) =>
PublicKey -> msg -> ExceptT JwtError m a
doRsa PublicKey
kPub) ByteString
bytes
RsaPrivateJwk PrivateKey
kPr Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> JweHeader
hdr Maybe KeyId
kid) (PublicKey -> ScrubbedBytes -> ExceptT JwtError m ByteString
forall (m :: * -> *) msg a.
(MonadRandom m, ByteArray msg, ByteArray a) =>
PublicKey -> msg -> ExceptT JwtError m a
doRsa (PrivateKey -> PublicKey
private_pub PrivateKey
kPr)) ByteString
bytes
SymmetricJwk ByteString
kek Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> JweHeader
hdr Maybe KeyId
kid) (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString)
-> (ScrubbedBytes -> m (Either JwtError ByteString))
-> ScrubbedBytes
-> ExceptT JwtError m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either JwtError ByteString -> m (Either JwtError ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError ByteString -> m (Either JwtError ByteString))
-> (ScrubbedBytes -> Either JwtError ByteString)
-> ScrubbedBytes
-> m (Either JwtError ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JweAlg
-> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ByteString
forall ba.
ByteArray ba =>
JweAlg -> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ba
keyWrap JweAlg
a (ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kek)) ByteString
bytes
Jwk
_ -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m Jwt)
-> JwtError -> ExceptT JwtError m Jwt
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"JWK cannot encode a JWE"
where
doRsa :: PublicKey -> msg -> ExceptT JwtError m a
doRsa PublicKey
kPub = m (Either JwtError a) -> ExceptT JwtError m a
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError a) -> ExceptT JwtError m a)
-> (msg -> m (Either JwtError a)) -> msg -> ExceptT JwtError m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PublicKey -> JweAlg -> msg -> m (Either JwtError a)
forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a
hdr :: Maybe KeyId -> JweHeader
hdr Maybe KeyId
kid = JweHeader
defJweHdr {jweAlg :: JweAlg
jweAlg = JweAlg
a, jweEnc :: Enc
jweEnc = Enc
e, jweKid :: Maybe KeyId
jweKid = Maybe KeyId
kid, jweCty :: Maybe Text
jweCty = Maybe Text
contentType}
(Maybe Text
contentType, ByteString
bytes) = case Payload
payload of
Claims ByteString
c -> (Maybe Text
forall a. Maybe a
Nothing, ByteString
c)
Nested (Jwt ByteString
b) -> (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"JWT", ByteString
b)
jwkDecode :: MonadRandom m
=> Jwk
-> ByteString
-> m (Either JwtError JwtContent)
jwkDecode :: Jwk -> ByteString -> m (Either JwtError JwtContent)
jwkDecode Jwk
jwk ByteString
jwt = ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent))
-> ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
RsaPrivateJwk PrivateKey
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> do
Blinder
blinder <- m Blinder -> ExceptT JwtError m Blinder
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Blinder -> ExceptT JwtError m Blinder)
-> m Blinder -> ExceptT JwtError m Blinder
forall a b. (a -> b) -> a -> b
$ Integer -> m Blinder
forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n (PublicKey -> Integer) -> PublicKey -> Integer
forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
kPr)
Jwe
e <- (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (Maybe Blinder
-> PrivateKey
-> JweAlg
-> ByteString
-> Either JwtError ScrubbedBytes
forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (Blinder -> Maybe Blinder
forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
kPr) ByteString
jwt
JwtContent -> ExceptT JwtError m JwtContent
forall (m :: * -> *) a. Monad m => a -> m a
return (Jwe -> JwtContent
Jwe Jwe
e)
SymmetricJwk ByteString
kb Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> (Jwe -> JwtContent)
-> ExceptT JwtError m Jwe -> ExceptT JwtError m JwtContent
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Jwe -> JwtContent
Jwe ((JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (ScrubbedBytes
-> JweAlg -> ByteString -> Either JwtError ScrubbedBytes
forall ba.
ByteArray ba =>
ScrubbedBytes -> JweAlg -> ba -> Either JwtError ScrubbedBytes
keyUnwrap (ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kb)) ByteString
jwt)
Jwk
_ -> JwtError -> ExceptT JwtError m JwtContent
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m JwtContent)
-> JwtError -> ExceptT JwtError m JwtContent
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"JWK cannot decode a JWE"
doDecode :: MonadRandom m
=> (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString
-> ExceptT JwtError m Jwe
doDecode :: (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek ByteString
jwt = do
DecodableJwt
encodedJwt <- m (Either JwtError DecodableJwt) -> ExceptT JwtError m DecodableJwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (Either JwtError DecodableJwt -> m (Either JwtError DecodableJwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))
case DecodableJwt
encodedJwt of
P.DecodableJwe JweHeader
hdr (P.EncryptedCEK ByteString
ek) IV
iv (P.Payload ByteString
payload) Tag
tag (P.AAD ByteString
aad) -> do
let alg :: JweAlg
alg = JweHeader -> JweAlg
jweAlg JweHeader
hdr
enc :: Enc
enc = JweHeader -> Enc
jweEnc JweHeader
hdr
(ScrubbedBytes
dummyCek, ScrubbedBytes
_) <- m (ScrubbedBytes, ScrubbedBytes)
-> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (ScrubbedBytes, ScrubbedBytes)
-> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes))
-> m (ScrubbedBytes, ScrubbedBytes)
-> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes)
forall a b. (a -> b) -> a -> b
$ Enc -> m (ScrubbedBytes, ScrubbedBytes)
forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
enc
let decryptedCek :: ScrubbedBytes
decryptedCek = (JwtError -> ScrubbedBytes)
-> (ScrubbedBytes -> ScrubbedBytes)
-> Either JwtError ScrubbedBytes
-> ScrubbedBytes
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ScrubbedBytes -> JwtError -> ScrubbedBytes
forall a b. a -> b -> a
const ScrubbedBytes
dummyCek) ScrubbedBytes -> ScrubbedBytes
forall a. a -> a
id (Either JwtError ScrubbedBytes -> ScrubbedBytes)
-> Either JwtError ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek JweAlg
alg ByteString
ek
cek :: ScrubbedBytes
cek = if ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
decryptedCek Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
dummyCek
then ScrubbedBytes
decryptedCek
else ScrubbedBytes
dummyCek
ByteString
claims <- ExceptT JwtError m ByteString
-> (ByteString -> ExceptT JwtError m ByteString)
-> Maybe ByteString
-> ExceptT JwtError m ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (JwtError -> ExceptT JwtError m ByteString
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadCrypto) ByteString -> ExceptT JwtError m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> ExceptT JwtError m ByteString)
-> Maybe ByteString -> ExceptT JwtError m ByteString
forall a b. (a -> b) -> a -> b
$ Enc
-> ScrubbedBytes
-> IV
-> ByteString
-> Tag
-> ByteString
-> Maybe ByteString
forall ba.
ByteArray ba =>
Enc -> ScrubbedBytes -> IV -> ba -> Tag -> ba -> Maybe ba
decryptPayload Enc
enc ScrubbedBytes
cek IV
iv ByteString
aad Tag
tag ByteString
payload
Jwe -> ExceptT JwtError m Jwe
forall (m :: * -> *) a. Monad m => a -> m a
return (JweHeader
hdr, ByteString
claims)
DecodableJwt
_ -> JwtError -> ExceptT JwtError m Jwe
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadHeader Text
"Content is not a JWE")
doEncode :: (MonadRandom m, ByteArray ba)
=> JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode :: JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode JweHeader
h ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ba
claims = do
(ScrubbedBytes
cmk, ScrubbedBytes
iv) <- m (ScrubbedBytes, ScrubbedBytes)
-> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Enc -> m (ScrubbedBytes, ScrubbedBytes)
forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
e)
let Just (AuthTag Bytes
sig, ba
ct) = Enc
-> ScrubbedBytes
-> ScrubbedBytes
-> ba
-> ba
-> Maybe (AuthTag, ba)
forall ba iv.
(ByteArray ba, ByteArray iv) =>
Enc -> ScrubbedBytes -> iv -> ba -> ba -> Maybe (AuthTag, ba)
encryptPayload Enc
e ScrubbedBytes
cmk ScrubbedBytes
iv ba
aad ba
claims
ByteString
jweKey <- ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ScrubbedBytes
cmk
let jwe :: ByteString
jwe = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"." ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode [ByteString
hdr, ByteString
jweKey, ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ScrubbedBytes
iv, ba -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
ct, Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
sig]
Jwt -> ExceptT JwtError m Jwt
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Jwt
Jwt ByteString
jwe)
where
e :: Enc
e = JweHeader -> Enc
jweEnc JweHeader
h
hdr :: ByteString
hdr = JweHeader -> ByteString
forall a. ToJSON a => a -> ByteString
encodeHeader JweHeader
h
aad :: ba
aad = ByteString -> ba
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode ByteString
hdr
rsaEncode :: MonadRandom m
=> JweAlg
-> Enc
-> PublicKey
-> ByteString
-> m (Either JwtError Jwt)
rsaEncode :: JweAlg -> Enc -> PublicKey -> ByteString -> m (Either JwtError Jwt)
rsaEncode JweAlg
a Enc
e PublicKey
kPub ByteString
claims = ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwt -> m (Either JwtError Jwt))
-> ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (JweHeader
defJweHdr {jweAlg :: JweAlg
jweAlg = JweAlg
a, jweEnc :: Enc
jweEnc = Enc
e}) (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString)
-> (ScrubbedBytes -> m (Either JwtError ByteString))
-> ScrubbedBytes
-> ExceptT JwtError m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PublicKey
-> JweAlg -> ScrubbedBytes -> m (Either JwtError ByteString)
forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a) ByteString
claims
rsaDecode :: MonadRandom m
=> PrivateKey
-> ByteString
-> m (Either JwtError Jwe)
rsaDecode :: PrivateKey -> ByteString -> m (Either JwtError Jwe)
rsaDecode PrivateKey
pk ByteString
jwt = ExceptT JwtError m Jwe -> m (Either JwtError Jwe)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwe -> m (Either JwtError Jwe))
-> ExceptT JwtError m Jwe -> m (Either JwtError Jwe)
forall a b. (a -> b) -> a -> b
$ do
Blinder
blinder <- m Blinder -> ExceptT JwtError m Blinder
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Blinder -> ExceptT JwtError m Blinder)
-> m Blinder -> ExceptT JwtError m Blinder
forall a b. (a -> b) -> a -> b
$ Integer -> m Blinder
forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n (PublicKey -> Integer) -> PublicKey -> Integer
forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
pk)
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (Maybe Blinder
-> PrivateKey
-> JweAlg
-> ByteString
-> Either JwtError ScrubbedBytes
forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (Blinder -> Maybe Blinder
forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
pk) ByteString
jwt