{-# LANGUAGE OverloadedStrings, FlexibleContexts #-}
{-# OPTIONS_HADDOCK prune #-}

-- | High-level JWT encoding and decoding.
--
-- See the Jose.Jws and Jose.Jwe modules for specific JWS and JWE examples.
--
-- Example usage with a key stored as a JWK:
--
-- >>> import Jose.Jwe
-- >>> import Jose.Jwa
-- >>> import Jose.Jwk
-- >>> import Data.ByteString
-- >>> import Data.Aeson (decodeStrict)
-- >>> let jsonJwk = "{\"kty\":\"RSA\", \"kid\":\"mykey\", \"n\":\"ofgWCuLjybRlzo0tZWJjNiuSfb4p4fAkd_wWJcyQoTbji9k0l8W26mPddxHmfHQp-Vaw-4qPCJrcS2mJPMEzP1Pt0Bm4d4QlL-yRT-SFd2lZS-pCgNMsD1W_YpRPEwOWvG6b32690r2jZ47soMZo9wGzjb_7OMg0LOL-bSf63kpaSHSXndS5z5rexMdbBYUsLA9e-KXBdQOS-UTo7WTBEMa2R2CapHg665xsmtdVMTBQY4uDZlxvb3qCo5ZwKh9kG4LT6_I5IhlJH7aGhyxXFvUK-DWNmoudF8NAco9_h9iaGNj8q2ethFkMLs91kzk2PAcDTW9gb54h4FRWyuXpoQ\", \"e\":\"AQAB\", \"d\":\"Eq5xpGnNCivDflJsRQBXHx1hdR1k6Ulwe2JZD50LpXyWPEAeP88vLNO97IjlA7_GQ5sLKMgvfTeXZx9SE-7YwVol2NXOoAJe46sui395IW_GO-pWJ1O0BkTGoVEn2bKVRUCgu-GjBVaYLU6f3l9kJfFNS3E0QbVdxzubSu3Mkqzjkn439X0M_V51gfpRLI9JYanrC4D4qAdGcopV_0ZHHzQlBjudU2QvXt4ehNYTCBr6XCLQUShb1juUO1ZdiYoFaFQT5Tw8bGUl_x_jTj3ccPDVZFD9pIuhLhBOneufuBiB4cS98l2SR_RQyGWSeWjnczT0QU91p1DhOVRuOopznQ\"}" :: ByteString
-- >>> let Just jwk = decodeStrict jsonJwk :: Maybe Jwk
-- >>> Right (Jwt jwtEncoded) <- encode [jwk] (JwsEncoding RS256) (Claims "public claims")
-- >>> Right jwtDecoded <- Jose.Jwt.decode [jwk] (Just (JwsEncoding RS256)) jwtEncoded
-- >>> jwtDecoded
-- Jws (JwsHeader {jwsAlg = RS256, jwsTyp = Nothing, jwsCty = Nothing, jwsKid = Just (KeyId "mykey")},"public claims")

module Jose.Jwt
    ( module Jose.Types
    , encode
    , decode
    , decodeClaims
    )
where

import Control.Monad (msum, when, unless)
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..))
import Crypto.Random (MonadRandom)
import Data.Aeson (decodeStrict',FromJSON)
import Data.ByteString (ByteString)
import Data.Maybe (isNothing)
import qualified Data.ByteString.Char8 as BC

import qualified Jose.Internal.Base64 as B64
import qualified Jose.Internal.Parser as P
import Jose.Types
import Jose.Jwk
import Jose.Jwa

import qualified Jose.Jws as Jws
import qualified Jose.Jwe as Jwe


-- | Use the supplied JWKs to create a JWT.
-- The list of keys will be searched to locate one which is
-- consistent with the chosen encoding algorithms.
--
encode :: MonadRandom m
    => [Jwk]                     -- ^ The key or keys. At least one must be consistent with the chosen algorithm
    -> JwtEncoding               -- ^ The encoding algorithm(s) used to encode the payload
    -> Payload                   -- ^ The payload (claims)
    -> m (Either JwtError Jwt)   -- ^ The encoded JWT, if successful
encode :: [Jwk] -> JwtEncoding -> Payload -> m (Either JwtError Jwt)
encode [Jwk]
jwks JwtEncoding
encoding Payload
msg = ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwt -> m (Either JwtError Jwt))
-> ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ case JwtEncoding
encoding of
    JwsEncoding JwsAlg
None -> case Payload
msg of
        Claims ByteString
p -> Jwt -> ExceptT JwtError m Jwt
forall (m :: * -> *) a. Monad m => a -> m a
return (Jwt -> ExceptT JwtError m Jwt) -> Jwt -> ExceptT JwtError m Jwt
forall a b. (a -> b) -> a -> b
$ ByteString -> Jwt
Jwt (ByteString -> Jwt) -> ByteString -> Jwt
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BC.intercalate ByteString
"." [ByteString
unsecuredHdr, ByteString -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode ByteString
p]
        Nested Jwt
_ -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadClaims
    JwsEncoding JwsAlg
a    -> case (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JwsAlg -> Jwk -> Bool
canEncodeJws JwsAlg
a) [Jwk]
jwks of
        []    -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWS algorithm")
        (Jwk
k:[Jwk]
_) -> m (Either JwtError Jwt) -> ExceptT JwtError m Jwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError Jwt) -> ExceptT JwtError m Jwt)
-> (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt
-> ExceptT JwtError m Jwt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> ExceptT JwtError m Jwt)
-> ExceptT JwtError m (Either JwtError Jwt)
-> ExceptT JwtError m Jwt
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Either JwtError Jwt) -> ExceptT JwtError m (Either JwtError Jwt)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (JwsAlg -> Jwk -> Payload -> m (Either JwtError Jwt)
forall (m :: * -> *).
MonadRandom m =>
JwsAlg -> Jwk -> Payload -> m (Either JwtError Jwt)
Jws.jwkEncode JwsAlg
a Jwk
k Payload
msg)
    JweEncoding JweAlg
a Enc
e -> case (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JweAlg -> Jwk -> Bool
canEncodeJwe JweAlg
a) [Jwk]
jwks of
        []    -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWE algorithm")
        (Jwk
k:[Jwk]
_) -> m (Either JwtError Jwt) -> ExceptT JwtError m Jwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError Jwt) -> ExceptT JwtError m Jwt)
-> (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt
-> ExceptT JwtError m Jwt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> ExceptT JwtError m Jwt)
-> ExceptT JwtError m (Either JwtError Jwt)
-> ExceptT JwtError m Jwt
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Either JwtError Jwt) -> ExceptT JwtError m (Either JwtError Jwt)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
Jwe.jwkEncode JweAlg
a Enc
e Jwk
k Payload
msg)
  where
    unsecuredHdr :: ByteString
unsecuredHdr = ByteString -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode (String -> ByteString
BC.pack String
"{\"alg\":\"none\"}")


-- | Uses the supplied keys to decode a JWT.
-- Locates a matching key by header @kid@ value where possible
-- or by suitable key type for the encoding algorithm.
--
-- The algorithm(s) used can optionally be supplied for validation
-- by setting the @JwtEncoding@ parameter, in which case an error will
-- be returned if they don't match. If you expect the tokens to use
-- a particular algorithm, then you should set this parameter.
--
-- For unsecured tokens (with algorithm "none"), the expected algorithm
-- must be set to @Just (JwsEncoding None)@ or an error will be returned.
decode :: MonadRandom m
    => [Jwk]                           -- ^ The keys to use for decoding
    -> Maybe JwtEncoding               -- ^ The expected encoding information
    -> ByteString                      -- ^ The encoded JWT
    -> m (Either JwtError JwtContent)  -- ^ The decoded JWT payload, if successful
decode :: [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
decode [Jwk]
keySet Maybe JwtEncoding
encoding ByteString
jwt = ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent))
-> ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ do
    DecodableJwt
decodableJwt <- m (Either JwtError DecodableJwt) -> ExceptT JwtError m DecodableJwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (Either JwtError DecodableJwt -> m (Either JwtError DecodableJwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))

    [Maybe JwtContent]
decodings <- case (DecodableJwt
decodableJwt, Maybe JwtEncoding
encoding) of
        (P.Unsecured ByteString
p, Just (JwsEncoding JwsAlg
None)) -> [Maybe JwtContent] -> ExceptT JwtError m [Maybe JwtContent]
forall (m :: * -> *) a. Monad m => a -> m a
return [JwtContent -> Maybe JwtContent
forall a. a -> Maybe a
Just (ByteString -> JwtContent
Unsecured ByteString
p)]
        (P.Unsecured ByteString
_, Maybe JwtEncoding
_) -> JwtError -> ExceptT JwtError m [Maybe JwtContent]
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"JWT is unsecured but expected 'alg' was not 'none'")
        (P.DecodableJws JwsHeader
hdr Payload
_ Sig
_ SigTarget
_, Maybe JwtEncoding
e) -> do
            Bool -> ExceptT JwtError m () -> ExceptT JwtError m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe JwtEncoding -> Bool
forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e Maybe JwtEncoding -> Maybe JwtEncoding -> Bool
forall a. Eq a => a -> a -> Bool
== JwtEncoding -> Maybe JwtEncoding
forall a. a -> Maybe a
Just (JwsAlg -> JwtEncoding
JwsEncoding (JwsHeader -> JwsAlg
jwsAlg JwsHeader
hdr))) (ExceptT JwtError m () -> ExceptT JwtError m ())
-> ExceptT JwtError m () -> ExceptT JwtError m ()
forall a b. (a -> b) -> a -> b
$
                JwtError -> ExceptT JwtError m ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected 'alg' doesn't match JWS header")
            [Jwk]
ks <- [Jwk] -> ExceptT JwtError m [Jwk]
forall (m :: * -> *) a. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys ([Jwk] -> ExceptT JwtError m [Jwk])
-> [Jwk] -> ExceptT JwtError m [Jwk]
forall a b. (a -> b) -> a -> b
$ (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JwsHeader -> Jwk -> Bool
canDecodeJws JwsHeader
hdr) [Jwk]
keySet
            (Jwk -> ExceptT JwtError m (Maybe JwtContent))
-> [Jwk] -> ExceptT JwtError m [Maybe JwtContent]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Jwk -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws [Jwk]
ks
        (P.DecodableJwe JweHeader
hdr EncryptedCEK
_ IV
_ Payload
_ Tag
_ AAD
_, Maybe JwtEncoding
e) -> do
            Bool -> ExceptT JwtError m () -> ExceptT JwtError m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe JwtEncoding -> Bool
forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e Maybe JwtEncoding -> Maybe JwtEncoding -> Bool
forall a. Eq a => a -> a -> Bool
== JwtEncoding -> Maybe JwtEncoding
forall a. a -> Maybe a
Just (JweAlg -> Enc -> JwtEncoding
JweEncoding (JweHeader -> JweAlg
jweAlg JweHeader
hdr) (JweHeader -> Enc
jweEnc JweHeader
hdr))) (ExceptT JwtError m () -> ExceptT JwtError m ())
-> ExceptT JwtError m () -> ExceptT JwtError m ()
forall a b. (a -> b) -> a -> b
$
                JwtError -> ExceptT JwtError m ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected encoding doesn't match JWE header")
            [Jwk]
ks <- [Jwk] -> ExceptT JwtError m [Jwk]
forall (m :: * -> *) a. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys ([Jwk] -> ExceptT JwtError m [Jwk])
-> [Jwk] -> ExceptT JwtError m [Jwk]
forall a b. (a -> b) -> a -> b
$ (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JweHeader -> Jwk -> Bool
canDecodeJwe JweHeader
hdr) [Jwk]
keySet
            (Jwk -> ExceptT JwtError m (Maybe JwtContent))
-> [Jwk] -> ExceptT JwtError m [Maybe JwtContent]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Jwk -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe [Jwk]
ks
    case [Maybe JwtContent] -> Maybe JwtContent
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [Maybe JwtContent]
decodings of
        Maybe JwtContent
Nothing  -> JwtError -> ExceptT JwtError m JwtContent
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m JwtContent)
-> JwtError -> ExceptT JwtError m JwtContent
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"None of the keys was able to decode the JWT"
        Just JwtContent
jwtContent -> JwtContent -> ExceptT JwtError m JwtContent
forall (m :: * -> *) a. Monad m => a -> m a
return JwtContent
jwtContent
  where
    decodeWithJws :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
    decodeWithJws :: Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws Jwk
k = (JwtError -> ExceptT JwtError m (Maybe JwtContent))
-> (Jws -> ExceptT JwtError m (Maybe JwtContent))
-> Either JwtError Jws
-> ExceptT JwtError m (Maybe JwtContent)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ExceptT JwtError m (Maybe JwtContent)
-> JwtError -> ExceptT JwtError m (Maybe JwtContent)
forall a b. a -> b -> a
const (ExceptT JwtError m (Maybe JwtContent)
 -> JwtError -> ExceptT JwtError m (Maybe JwtContent))
-> ExceptT JwtError m (Maybe JwtContent)
-> JwtError
-> ExceptT JwtError m (Maybe JwtContent)
forall a b. (a -> b) -> a -> b
$ Maybe JwtContent -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe JwtContent
forall a. Maybe a
Nothing) (Maybe JwtContent -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe JwtContent -> ExceptT JwtError m (Maybe JwtContent))
-> (Jws -> Maybe JwtContent)
-> Jws
-> ExceptT JwtError m (Maybe JwtContent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JwtContent -> Maybe JwtContent
forall a. a -> Maybe a
Just (JwtContent -> Maybe JwtContent)
-> (Jws -> JwtContent) -> Jws -> Maybe JwtContent
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Jws -> JwtContent
Jws) (Either JwtError Jws -> ExceptT JwtError m (Maybe JwtContent))
-> Either JwtError Jws -> ExceptT JwtError m (Maybe JwtContent)
forall a b. (a -> b) -> a -> b
$ case Jwk
k of
        Ed25519PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
        Ed25519PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
        Ed448PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
        Ed448PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
        RsaPublicJwk  PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode PublicKey
kPub ByteString
jwt
        RsaPrivateJwk PrivateKey
kPr  Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode (PrivateKey -> PublicKey
private_pub PrivateKey
kPr) ByteString
jwt
        EcPublicJwk   PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode PublicKey
kPub ByteString
jwt
        EcPrivateJwk  KeyPair
kPr  Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode (KeyPair -> PublicKey
ECDSA.toPublicKey KeyPair
kPr) ByteString
jwt
        SymmetricJwk  ByteString
kb   Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> ByteString -> ByteString -> Either JwtError Jws
Jws.hmacDecode ByteString
kb ByteString
jwt

    decodeWithJwe :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
    decodeWithJwe :: Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe Jwk
k = (Either JwtError JwtContent -> Maybe JwtContent)
-> ExceptT JwtError m (Either JwtError JwtContent)
-> ExceptT JwtError m (Maybe JwtContent)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((JwtError -> Maybe JwtContent)
-> (JwtContent -> Maybe JwtContent)
-> Either JwtError JwtContent
-> Maybe JwtContent
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe JwtContent -> JwtError -> Maybe JwtContent
forall a b. a -> b -> a
const Maybe JwtContent
forall a. Maybe a
Nothing) JwtContent -> Maybe JwtContent
forall a. a -> Maybe a
Just) (m (Either JwtError JwtContent)
-> ExceptT JwtError m (Either JwtError JwtContent)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Jwk -> ByteString -> m (Either JwtError JwtContent)
forall (m :: * -> *).
MonadRandom m =>
Jwk -> ByteString -> m (Either JwtError JwtContent)
Jwe.jwkDecode Jwk
k ByteString
jwt))

    checkKeys :: [a] -> ExceptT JwtError m [a]
checkKeys [] = JwtError -> ExceptT JwtError m [a]
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m [a])
-> JwtError -> ExceptT JwtError m [a]
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"No suitable key was found to decode the JWT"
    checkKeys [a]
ks = [a] -> ExceptT JwtError m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a]
ks


-- | Convenience function to return the claims contained in a JWS.
-- This is needed in situations such as client assertion authentication,
-- <https://tools.ietf.org/html/rfc7523>, where the contents of the JWT,
-- such as the @sub@ claim, may be required in order to work out
-- which key should be used to verify the token.
--
-- Obviously this should not be used by itself to decode a token since
-- no integrity checking is done and the contents may be forged.
decodeClaims :: (FromJSON a)
    => ByteString
    -> Either JwtError (JwtHeader, a)
decodeClaims :: ByteString -> Either JwtError (JwtHeader, a)
decodeClaims ByteString
jwt = do
    let components :: [ByteString]
components = Char -> ByteString -> [ByteString]
BC.split Char
'.' ByteString
jwt
    Bool -> Either JwtError () -> Either JwtError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ByteString] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
components Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
3) (Either JwtError () -> Either JwtError ())
-> Either JwtError () -> Either JwtError ()
forall a b. (a -> b) -> a -> b
$ JwtError -> Either JwtError ()
forall a b. a -> Either a b
Left (JwtError -> Either JwtError ()) -> JwtError -> Either JwtError ()
forall a b. (a -> b) -> a -> b
$ Int -> JwtError
BadDots Int
2
    JwtHeader
hdr    <- ByteString -> Either JwtError ByteString
forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode ([ByteString] -> ByteString
forall a. [a] -> a
head [ByteString]
components) Either JwtError ByteString
-> (ByteString -> Either JwtError JwtHeader)
-> Either JwtError JwtHeader
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either JwtError JwtHeader
parseHeader
    a
claims <- ByteString -> Either JwtError ByteString
forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode (([ByteString] -> ByteString
forall a. [a] -> a
head ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. [a] -> [a]
tail) [ByteString]
components) Either JwtError ByteString
-> (ByteString -> Either JwtError a) -> Either JwtError a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either JwtError a
forall b. FromJSON b => ByteString -> Either JwtError b
parseClaims
    (JwtHeader, a) -> Either JwtError (JwtHeader, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, a
claims)
  where
    parseClaims :: ByteString -> Either JwtError b
parseClaims ByteString
bs = Either JwtError b
-> (b -> Either JwtError b) -> Maybe b -> Either JwtError b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (JwtError -> Either JwtError b
forall a b. a -> Either a b
Left JwtError
BadClaims) b -> Either JwtError b
forall a b. b -> Either a b
Right (Maybe b -> Either JwtError b) -> Maybe b -> Either JwtError b
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe b
forall a. FromJSON a => ByteString -> Maybe a
decodeStrict' ByteString
bs