{-# LANGUAGE CPP #-}
module OpenID.Connect.Client.TokenResponse
( decodeIdentityToken
, verifyIdentityTokenClaims
) where
import Control.Lens ((^.), (.~), (^?), (#))
import Control.Monad.Except
import Control.Monad.Reader
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JWT as JWT
import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Function ((&))
import Data.Functor.Identity
import Data.Maybe (isJust)
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime)
import OpenID.Connect.Authentication (ClientID)
import OpenID.Connect.Client.Provider
import OpenID.Connect.TokenResponse
#if MIN_VERSION_aeson(2, 0, 0)
import qualified Data.Map.Strict as Map
#else
import qualified Data.HashMap.Strict as Map
#endif
decodeIdentityToken
:: TokenResponse Text
-> Either JOSE.Error (TokenResponse SignedJWT)
decodeIdentityToken :: TokenResponse Text -> Either Error (TokenResponse SignedJWT)
decodeIdentityToken TokenResponse Text
token
= forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
JOSE.decodeCompact (ByteString -> ByteString
LChar8.fromStrict (Text -> ByteString
Text.encodeUtf8 (forall a. TokenResponse a -> a
idToken TokenResponse Text
token)))
forall a b. a -> (a -> b) -> b
& forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
forall a b. a -> (a -> b) -> b
& forall a. Identity a -> a
runIdentity
forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ TokenResponse Text
token)
verifyIdentityTokenClaims
:: Discovery
-> ClientID
-> UTCTime
-> JWKSet
-> Text
-> TokenResponse SignedJWT
-> Either JWTError (TokenResponse ClaimsSet)
verifyIdentityTokenClaims :: Discovery
-> Text
-> UTCTime
-> JWKSet
-> Text
-> TokenResponse SignedJWT
-> Either JWTError (TokenResponse ClaimsSet)
verifyIdentityTokenClaims Discovery
disco Text
clientId UTCTime
now JWKSet
keys Text
nonce TokenResponse SignedJWT
token =
let JWKSet [JWK]
jwks = JWKSet
keys
in forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\JWK
k -> forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const (JWK -> Either JWTError (TokenResponse ClaimsSet)
verifyWithKey JWK
k)) forall a b. b -> Either a b
Right)
(forall a b. a -> Either a b
Left (Error -> JWTError
JWT.JWSError Error
JOSE.NoUsableKeys)) [JWK]
jwks
where
verifyWithKey :: JWK -> Either JWTError (TokenResponse ClaimsSet)
verifyWithKey :: JWK -> Either JWTError (TokenResponse ClaimsSet)
verifyWithKey JWK
key =
let settings :: JWTValidationSettings
settings = (StringOrURI -> Bool) -> JWTValidationSettings
JWT.defaultJWTValidationSettings (forall a b. a -> b -> a
const Bool
True)
forall a b. a -> (a -> b) -> b
& forall s. HasAllowedSkew s => Lens' s NominalDiffTime
allowedSkew forall s t a b. ASetter s t a b -> b -> s -> t
.~ NominalDiffTime
120
forall a b. a -> (a -> b) -> b
& forall s. HasIssuerPredicate s => Lens' s (StringOrURI -> Bool)
issuerPredicate forall s t a b. ASetter s t a b -> b -> s -> t
.~ StringOrURI -> Bool
verifyIssuer
forall a b. a -> (a -> b) -> b
& forall s. HasCheckIssuedAt s => Lens' s Bool
checkIssuedAt forall s t a b. ASetter s t a b -> b -> s -> t
.~ Bool
True
in forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
HasCheckIssuedAt a, HasValidationSettings a, AsError e,
AsJWTError e, MonadError e m,
VerificationKeyStore
(ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
JWT.verifyClaimsAt JWTValidationSettings
settings JWK
key UTCTime
now (forall a. TokenResponse a -> a
idToken TokenResponse SignedJWT
token)
forall a b. a -> (a -> b) -> b
& forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
forall a b. a -> (a -> b) -> b
& forall a. Identity a -> a
runIdentity forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Text -> Text -> ClaimsSet -> Either JWTError ClaimsSet
additionalValidation Text
clientId Text
nonce
forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ TokenResponse SignedJWT
token)
verifyIssuer :: JWT.StringOrURI -> Bool
verifyIssuer :: StringOrURI -> Bool
verifyIssuer = (forall a. Eq a => a -> a -> Bool
== (Prism' StringOrURI URI
JWT.uri forall t b. AReview t b -> b -> t
# URI -> URI
getURI (Discovery -> URI
issuer Discovery
disco)))
type Validate a = ExceptT JWTError (ReaderT ClaimsSet Identity) a
orFailWith :: (ClaimsSet -> Bool) -> JWTError -> Validate ()
orFailWith :: (ClaimsSet -> Bool) -> JWTError -> Validate ()
orFailWith ClaimsSet -> Bool
f JWTError
e = do
ClaimsSet
claims <- forall r (m :: * -> *). MonadReader r m => m r
ask
if ClaimsSet -> Bool
f ClaimsSet
claims then forall (f :: * -> *) a. Applicative f => a -> f a
pure () else forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError JWTError
e
claimEq :: Text -> Aeson.Value -> ClaimsSet -> Bool
claimEq :: Text -> Value -> ClaimsSet -> Bool
claimEq Text
key Value
val ClaimsSet
claims =
case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
key (ClaimsSet
claims forall s a. s -> Getting a s a -> a
^. Lens' ClaimsSet (Map Text Value)
JWT.unregisteredClaims) of
Maybe Value
Nothing -> Bool
False
Just Value
val' -> Value
val forall a. Eq a => a -> a -> Bool
== Value
val'
additionalValidation :: ClientID -> Text -> ClaimsSet -> Either JWTError ClaimsSet
additionalValidation :: Text -> Text -> ClaimsSet -> Either JWTError ClaimsSet
additionalValidation Text
clientId Text
nonce = ClaimsSet -> Either JWTError ClaimsSet
go
where
go :: ClaimsSet -> Either JWTError ClaimsSet
go :: ClaimsSet -> Either JWTError ClaimsSet
go ClaimsSet
claims = Validate ()
checks
forall a b. a -> (a -> b) -> b
& forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
forall a b. a -> (a -> b) -> b
& forall a b c. (a -> b -> c) -> b -> a -> c
flip forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ClaimsSet
claims
forall a b. a -> (a -> b) -> b
& forall a. Identity a -> a
runIdentity
forall a b. a -> (a -> b) -> b
& (ClaimsSet
claims forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$)
checks :: Validate ()
checks :: Validate ()
checks = do
ClaimsSet -> Bool
verifyNonce (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` String -> JWTError
JWT.JWTClaimsSetDecodeError String
"invalid nonce"
ClaimsSet -> Bool
verifyIat (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` JWTError
JWT.JWTIssuedAtFuture
ClaimsSet -> Bool
verifySub (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` String -> JWTError
JWT.JWTClaimsSetDecodeError String
"missing subject"
(\ClaimsSet
c -> ClaimsSet -> Bool
verifyAudience ClaimsSet
c Bool -> Bool -> Bool
||
ClaimsSet -> Bool
verifyAzp ClaimsSet
c) (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` JWTError
JWT.JWTNotInAudience
verifyNonce :: ClaimsSet -> Bool
verifyNonce :: ClaimsSet -> Bool
verifyNonce = Text -> Value -> ClaimsSet -> Bool
claimEq Text
"nonce" (Text -> Value
Aeson.String Text
nonce)
verifyAudience :: ClaimsSet -> Bool
verifyAudience :: ClaimsSet -> Bool
verifyAudience ClaimsSet
claims =
case ClaimsSet
claims forall s a. s -> Getting a s a -> a
^. forall a. HasClaimsSet a => Lens' a (Maybe Audience)
claimAud of
Just (JWT.Audience [StringOrURI
aud]) ->
forall a. a -> Maybe a
Just StringOrURI
aud forall a. Eq a => a -> a -> Bool
== Text
clientId forall s a. s -> Getting (First a) s a -> Maybe a
^? forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
Maybe Audience
_ -> Bool
False
verifyAzp :: ClaimsSet -> Bool
verifyAzp :: ClaimsSet -> Bool
verifyAzp = Text -> Value -> ClaimsSet -> Bool
claimEq Text
"azp" (Text -> Value
Aeson.String Text
clientId)
verifySub :: ClaimsSet -> Bool
verifySub :: ClaimsSet -> Bool
verifySub = forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s a. s -> Getting a s a -> a
^. forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
claimSub)
verifyIat :: ClaimsSet -> Bool
verifyIat :: ClaimsSet -> Bool
verifyIat = forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s a. s -> Getting a s a -> a
^. forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
claimIat)