module OpenID.Connect.Client.TokenResponse
( decodeIdentityToken
, verifyIdentityTokenClaims
) where
import Control.Lens ((^.), (.~), (^?), (#))
import Control.Monad.Except
import Control.Monad.Reader
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JWT as JWT
import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Function ((&))
import Data.Functor.Identity
import qualified Data.HashMap.Strict as Map
import Data.Maybe (isJust)
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime)
import OpenID.Connect.Authentication (ClientID)
import OpenID.Connect.Client.Provider
import OpenID.Connect.TokenResponse
decodeIdentityToken
:: TokenResponse Text
-> Either JOSE.Error (TokenResponse SignedJWT)
decodeIdentityToken token
= JOSE.decodeCompact (LChar8.fromStrict (Text.encodeUtf8 (idToken token)))
& runExceptT
& runIdentity
& fmap (<$ token)
verifyIdentityTokenClaims
:: Discovery
-> ClientID
-> UTCTime
-> JWKSet
-> Text
-> TokenResponse SignedJWT
-> Either JWTError (TokenResponse ClaimsSet)
verifyIdentityTokenClaims disco clientId now keys nonce token =
let JWKSet jwks = keys
in foldr (\k -> either (const (verifyWithKey k)) Right)
(Left (JWT.JWSError JOSE.NoUsableKeys)) jwks
where
verifyWithKey :: JWK -> Either JWTError (TokenResponse ClaimsSet)
verifyWithKey key =
let settings = JWT.defaultJWTValidationSettings (const True)
& allowedSkew .~ 120
& issuerPredicate .~ verifyIssuer
& checkIssuedAt .~ True
in JWT.verifyClaimsAt settings key now (idToken token)
& runExceptT
& runIdentity >>= additionalValidation clientId nonce
& fmap (<$ token)
verifyIssuer :: JWT.StringOrURI -> Bool
verifyIssuer = (== (JWT.uri # getURI (issuer disco)))
type Validate a = ExceptT JWTError (ReaderT ClaimsSet Identity) a
orFailWith :: (ClaimsSet -> Bool) -> JWTError -> Validate ()
orFailWith f e = do
claims <- ask
if f claims then pure () else throwError e
claimEq :: Text -> Aeson.Value -> ClaimsSet -> Bool
claimEq key val claims =
case Map.lookup key (claims ^. JWT.unregisteredClaims) of
Nothing -> False
Just val' -> val == val'
additionalValidation :: ClientID -> Text -> ClaimsSet -> Either JWTError ClaimsSet
additionalValidation clientId nonce = go
where
go :: ClaimsSet -> Either JWTError ClaimsSet
go claims = checks
& runExceptT
& flip runReaderT claims
& runIdentity
& (claims <$)
checks :: Validate ()
checks = do
verifyNonce `orFailWith` JWT.JWTClaimsSetDecodeError "invalid nonce"
verifyIat `orFailWith` JWT.JWTIssuedAtFuture
verifySub `orFailWith` JWT.JWTClaimsSetDecodeError "missing subject"
(\c -> verifyAudience c ||
verifyAzp c) `orFailWith` JWT.JWTNotInAudience
verifyNonce :: ClaimsSet -> Bool
verifyNonce = claimEq "nonce" (Aeson.String nonce)
verifyAudience :: ClaimsSet -> Bool
verifyAudience claims =
case claims ^. claimAud of
Just (JWT.Audience [aud]) ->
Just aud == clientId ^? JWT.stringOrUri
_ -> False
verifyAzp :: ClaimsSet -> Bool
verifyAzp = claimEq "azp" (Aeson.String clientId)
verifySub :: ClaimsSet -> Bool
verifySub = isJust . (^. claimSub)
verifyIat :: ClaimsSet -> Bool
verifyIat = isJust . (^. claimIat)