{-# LANGUAGE OverloadedStrings, FlexibleContexts #-}
{-# OPTIONS_HADDOCK prune #-}
module Jose.Jwt
( module Jose.Types
, encode
, decode
, decodeClaims
)
where
import Control.Monad (msum, when, unless)
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..))
import Crypto.Random (MonadRandom)
import Data.Aeson (decodeStrict',FromJSON)
import Data.ByteString (ByteString)
import Data.Maybe (isNothing)
import qualified Data.ByteString.Char8 as BC
import qualified Jose.Internal.Base64 as B64
import qualified Jose.Internal.Parser as P
import Jose.Types
import Jose.Jwk
import Jose.Jwa
import qualified Jose.Jws as Jws
import qualified Jose.Jwe as Jwe
encode :: MonadRandom m
=> [Jwk]
-> JwtEncoding
-> Payload
-> m (Either JwtError Jwt)
encode :: [Jwk] -> JwtEncoding -> Payload -> m (Either JwtError Jwt)
encode [Jwk]
jwks JwtEncoding
encoding Payload
msg = ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwt -> m (Either JwtError Jwt))
-> ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ case JwtEncoding
encoding of
JwsEncoding JwsAlg
None -> case Payload
msg of
Claims ByteString
p -> Jwt -> ExceptT JwtError m Jwt
forall (m :: * -> *) a. Monad m => a -> m a
return (Jwt -> ExceptT JwtError m Jwt) -> Jwt -> ExceptT JwtError m Jwt
forall a b. (a -> b) -> a -> b
$ ByteString -> Jwt
Jwt (ByteString -> Jwt) -> ByteString -> Jwt
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BC.intercalate ByteString
"." [ByteString
unsecuredHdr, ByteString -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode ByteString
p]
Nested Jwt
_ -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadClaims
JwsEncoding JwsAlg
a -> case (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JwsAlg -> Jwk -> Bool
canEncodeJws JwsAlg
a) [Jwk]
jwks of
[] -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWS algorithm")
(Jwk
k:[Jwk]
_) -> m (Either JwtError Jwt) -> ExceptT JwtError m Jwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError Jwt) -> ExceptT JwtError m Jwt)
-> (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt
-> ExceptT JwtError m Jwt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> ExceptT JwtError m Jwt)
-> ExceptT JwtError m (Either JwtError Jwt)
-> ExceptT JwtError m Jwt
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Either JwtError Jwt) -> ExceptT JwtError m (Either JwtError Jwt)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (JwsAlg -> Jwk -> Payload -> m (Either JwtError Jwt)
forall (m :: * -> *).
MonadRandom m =>
JwsAlg -> Jwk -> Payload -> m (Either JwtError Jwt)
Jws.jwkEncode JwsAlg
a Jwk
k Payload
msg)
JweEncoding JweAlg
a Enc
e -> case (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JweAlg -> Jwk -> Bool
canEncodeJwe JweAlg
a) [Jwk]
jwks of
[] -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWE algorithm")
(Jwk
k:[Jwk]
_) -> m (Either JwtError Jwt) -> ExceptT JwtError m Jwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError Jwt) -> ExceptT JwtError m Jwt)
-> (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt
-> ExceptT JwtError m Jwt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> ExceptT JwtError m Jwt)
-> ExceptT JwtError m (Either JwtError Jwt)
-> ExceptT JwtError m Jwt
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Either JwtError Jwt) -> ExceptT JwtError m (Either JwtError Jwt)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
Jwe.jwkEncode JweAlg
a Enc
e Jwk
k Payload
msg)
where
unsecuredHdr :: ByteString
unsecuredHdr = ByteString -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode (String -> ByteString
BC.pack String
"{\"alg\":\"none\"}")
decode :: MonadRandom m
=> [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
decode :: [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
decode [Jwk]
keySet Maybe JwtEncoding
encoding ByteString
jwt = ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent))
-> ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ do
DecodableJwt
decodableJwt <- m (Either JwtError DecodableJwt) -> ExceptT JwtError m DecodableJwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (Either JwtError DecodableJwt -> m (Either JwtError DecodableJwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))
[Maybe JwtContent]
decodings <- case (DecodableJwt
decodableJwt, Maybe JwtEncoding
encoding) of
(P.Unsecured ByteString
p, Just (JwsEncoding JwsAlg
None)) -> [Maybe JwtContent] -> ExceptT JwtError m [Maybe JwtContent]
forall (m :: * -> *) a. Monad m => a -> m a
return [JwtContent -> Maybe JwtContent
forall a. a -> Maybe a
Just (ByteString -> JwtContent
Unsecured ByteString
p)]
(P.Unsecured ByteString
_, Maybe JwtEncoding
_) -> JwtError -> ExceptT JwtError m [Maybe JwtContent]
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"JWT is unsecured but expected 'alg' was not 'none'")
(P.DecodableJws JwsHeader
hdr Payload
_ Sig
_ SigTarget
_, Maybe JwtEncoding
e) -> do
Bool -> ExceptT JwtError m () -> ExceptT JwtError m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe JwtEncoding -> Bool
forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e Maybe JwtEncoding -> Maybe JwtEncoding -> Bool
forall a. Eq a => a -> a -> Bool
== JwtEncoding -> Maybe JwtEncoding
forall a. a -> Maybe a
Just (JwsAlg -> JwtEncoding
JwsEncoding (JwsHeader -> JwsAlg
jwsAlg JwsHeader
hdr))) (ExceptT JwtError m () -> ExceptT JwtError m ())
-> ExceptT JwtError m () -> ExceptT JwtError m ()
forall a b. (a -> b) -> a -> b
$
JwtError -> ExceptT JwtError m ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected 'alg' doesn't match JWS header")
[Jwk]
ks <- [Jwk] -> ExceptT JwtError m [Jwk]
forall (m :: * -> *) a. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys ([Jwk] -> ExceptT JwtError m [Jwk])
-> [Jwk] -> ExceptT JwtError m [Jwk]
forall a b. (a -> b) -> a -> b
$ (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JwsHeader -> Jwk -> Bool
canDecodeJws JwsHeader
hdr) [Jwk]
keySet
(Jwk -> ExceptT JwtError m (Maybe JwtContent))
-> [Jwk] -> ExceptT JwtError m [Maybe JwtContent]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Jwk -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws [Jwk]
ks
(P.DecodableJwe JweHeader
hdr EncryptedCEK
_ IV
_ Payload
_ Tag
_ AAD
_, Maybe JwtEncoding
e) -> do
Bool -> ExceptT JwtError m () -> ExceptT JwtError m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe JwtEncoding -> Bool
forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e Maybe JwtEncoding -> Maybe JwtEncoding -> Bool
forall a. Eq a => a -> a -> Bool
== JwtEncoding -> Maybe JwtEncoding
forall a. a -> Maybe a
Just (JweAlg -> Enc -> JwtEncoding
JweEncoding (JweHeader -> JweAlg
jweAlg JweHeader
hdr) (JweHeader -> Enc
jweEnc JweHeader
hdr))) (ExceptT JwtError m () -> ExceptT JwtError m ())
-> ExceptT JwtError m () -> ExceptT JwtError m ()
forall a b. (a -> b) -> a -> b
$
JwtError -> ExceptT JwtError m ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected encoding doesn't match JWE header")
[Jwk]
ks <- [Jwk] -> ExceptT JwtError m [Jwk]
forall (m :: * -> *) a. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys ([Jwk] -> ExceptT JwtError m [Jwk])
-> [Jwk] -> ExceptT JwtError m [Jwk]
forall a b. (a -> b) -> a -> b
$ (Jwk -> Bool) -> [Jwk] -> [Jwk]
forall a. (a -> Bool) -> [a] -> [a]
filter (JweHeader -> Jwk -> Bool
canDecodeJwe JweHeader
hdr) [Jwk]
keySet
(Jwk -> ExceptT JwtError m (Maybe JwtContent))
-> [Jwk] -> ExceptT JwtError m [Maybe JwtContent]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Jwk -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe [Jwk]
ks
case [Maybe JwtContent] -> Maybe JwtContent
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [Maybe JwtContent]
decodings of
Maybe JwtContent
Nothing -> JwtError -> ExceptT JwtError m JwtContent
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m JwtContent)
-> JwtError -> ExceptT JwtError m JwtContent
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"None of the keys was able to decode the JWT"
Just JwtContent
jwtContent -> JwtContent -> ExceptT JwtError m JwtContent
forall (m :: * -> *) a. Monad m => a -> m a
return JwtContent
jwtContent
where
decodeWithJws :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws :: Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws Jwk
k = (JwtError -> ExceptT JwtError m (Maybe JwtContent))
-> (Jws -> ExceptT JwtError m (Maybe JwtContent))
-> Either JwtError Jws
-> ExceptT JwtError m (Maybe JwtContent)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ExceptT JwtError m (Maybe JwtContent)
-> JwtError -> ExceptT JwtError m (Maybe JwtContent)
forall a b. a -> b -> a
const (ExceptT JwtError m (Maybe JwtContent)
-> JwtError -> ExceptT JwtError m (Maybe JwtContent))
-> ExceptT JwtError m (Maybe JwtContent)
-> JwtError
-> ExceptT JwtError m (Maybe JwtContent)
forall a b. (a -> b) -> a -> b
$ Maybe JwtContent -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe JwtContent
forall a. Maybe a
Nothing) (Maybe JwtContent -> ExceptT JwtError m (Maybe JwtContent)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe JwtContent -> ExceptT JwtError m (Maybe JwtContent))
-> (Jws -> Maybe JwtContent)
-> Jws
-> ExceptT JwtError m (Maybe JwtContent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JwtContent -> Maybe JwtContent
forall a. a -> Maybe a
Just (JwtContent -> Maybe JwtContent)
-> (Jws -> JwtContent) -> Jws -> Maybe JwtContent
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Jws -> JwtContent
Jws) (Either JwtError Jws -> ExceptT JwtError m (Maybe JwtContent))
-> Either JwtError Jws -> ExceptT JwtError m (Maybe JwtContent)
forall a b. (a -> b) -> a -> b
$ case Jwk
k of
Ed25519PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
Ed25519PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
Ed448PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
Ed448PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
RsaPublicJwk PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode PublicKey
kPub ByteString
jwt
RsaPrivateJwk PrivateKey
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode (PrivateKey -> PublicKey
private_pub PrivateKey
kPr) ByteString
jwt
EcPublicJwk PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode PublicKey
kPub ByteString
jwt
EcPrivateJwk KeyPair
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode (KeyPair -> PublicKey
ECDSA.toPublicKey KeyPair
kPr) ByteString
jwt
SymmetricJwk ByteString
kb Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> ByteString -> ByteString -> Either JwtError Jws
Jws.hmacDecode ByteString
kb ByteString
jwt
UnsupportedJwk Object
_ -> JwtError -> Either JwtError Jws
forall a b. a -> Either a b
Left (Text -> JwtError
KeyError Text
"Unsupported JWKs cannot be used")
decodeWithJwe :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe :: Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe Jwk
k = (Either JwtError JwtContent -> Maybe JwtContent)
-> ExceptT JwtError m (Either JwtError JwtContent)
-> ExceptT JwtError m (Maybe JwtContent)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((JwtError -> Maybe JwtContent)
-> (JwtContent -> Maybe JwtContent)
-> Either JwtError JwtContent
-> Maybe JwtContent
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe JwtContent -> JwtError -> Maybe JwtContent
forall a b. a -> b -> a
const Maybe JwtContent
forall a. Maybe a
Nothing) JwtContent -> Maybe JwtContent
forall a. a -> Maybe a
Just) (m (Either JwtError JwtContent)
-> ExceptT JwtError m (Either JwtError JwtContent)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Jwk -> ByteString -> m (Either JwtError JwtContent)
forall (m :: * -> *).
MonadRandom m =>
Jwk -> ByteString -> m (Either JwtError JwtContent)
Jwe.jwkDecode Jwk
k ByteString
jwt))
checkKeys :: [a] -> ExceptT JwtError m [a]
checkKeys [] = JwtError -> ExceptT JwtError m [a]
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m [a])
-> JwtError -> ExceptT JwtError m [a]
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"No suitable key was found to decode the JWT"
checkKeys [a]
ks = [a] -> ExceptT JwtError m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a]
ks
decodeClaims :: (FromJSON a)
=> ByteString
-> Either JwtError (JwtHeader, a)
decodeClaims :: ByteString -> Either JwtError (JwtHeader, a)
decodeClaims ByteString
jwt = do
let components :: [ByteString]
components = Char -> ByteString -> [ByteString]
BC.split Char
'.' ByteString
jwt
Bool -> Either JwtError () -> Either JwtError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ByteString] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
components Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
3) (Either JwtError () -> Either JwtError ())
-> Either JwtError () -> Either JwtError ()
forall a b. (a -> b) -> a -> b
$ JwtError -> Either JwtError ()
forall a b. a -> Either a b
Left (JwtError -> Either JwtError ()) -> JwtError -> Either JwtError ()
forall a b. (a -> b) -> a -> b
$ Int -> JwtError
BadDots Int
2
JwtHeader
hdr <- ByteString -> Either JwtError ByteString
forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode ([ByteString] -> ByteString
forall a. [a] -> a
head [ByteString]
components) Either JwtError ByteString
-> (ByteString -> Either JwtError JwtHeader)
-> Either JwtError JwtHeader
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either JwtError JwtHeader
parseHeader
a
claims <- ByteString -> Either JwtError ByteString
forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode (([ByteString] -> ByteString
forall a. [a] -> a
head ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. [a] -> [a]
tail) [ByteString]
components) Either JwtError ByteString
-> (ByteString -> Either JwtError a) -> Either JwtError a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either JwtError a
forall b. FromJSON b => ByteString -> Either JwtError b
parseClaims
(JwtHeader, a) -> Either JwtError (JwtHeader, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, a
claims)
where
parseClaims :: ByteString -> Either JwtError b
parseClaims ByteString
bs = Either JwtError b
-> (b -> Either JwtError b) -> Maybe b -> Either JwtError b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (JwtError -> Either JwtError b
forall a b. a -> Either a b
Left JwtError
BadClaims) b -> Either JwtError b
forall a b. b -> Either a b
Right (Maybe b -> Either JwtError b) -> Maybe b -> Either JwtError b
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe b
forall a. FromJSON a => ByteString -> Maybe a
decodeStrict' ByteString
bs