{-|

Copyright:

  This file is part of the package openid-connect.  It is subject to
  the license terms in the LICENSE file found in the top-level
  directory of this distribution and at:

    https://code.devalot.com/open/openid-connect

  No part of this package, including this file, may be copied,
  modified, propagated, or distributed except according to the terms
  contained in the LICENSE file.

License: BSD-2-Clause

-}
module OpenID.Connect.Client.TokenResponse
  ( decodeIdentityToken
  , verifyIdentityTokenClaims
  ) where

--------------------------------------------------------------------------------
-- Imports:
import Control.Lens ((^.), (.~), (^?), (#))
import Control.Monad.Except
import Control.Monad.Reader
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JWT as JWT
import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Function ((&))
import Data.Functor.Identity
import qualified Data.HashMap.Strict as Map
import Data.Maybe (isJust)
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime)
import OpenID.Connect.Authentication (ClientID)
import OpenID.Connect.Client.Provider
import OpenID.Connect.TokenResponse

--------------------------------------------------------------------------------
-- | Decode the compacted identity token into a 'SignedJWT'.
decodeIdentityToken
  :: TokenResponse Text
  -> Either JOSE.Error (TokenResponse SignedJWT)
decodeIdentityToken :: TokenResponse Text -> Either Error (TokenResponse SignedJWT)
decodeIdentityToken TokenResponse Text
token
  = ByteString -> ExceptT Error Identity SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
JOSE.decodeCompact (ByteString -> ByteString
LChar8.fromStrict (Text -> ByteString
Text.encodeUtf8 (TokenResponse Text -> Text
forall a. TokenResponse a -> a
idToken TokenResponse Text
token)))
  ExceptT Error Identity SignedJWT
-> (ExceptT Error Identity SignedJWT
    -> Identity (Either Error SignedJWT))
-> Identity (Either Error SignedJWT)
forall a b. a -> (a -> b) -> b
& ExceptT Error Identity SignedJWT
-> Identity (Either Error SignedJWT)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
  Identity (Either Error SignedJWT)
-> (Identity (Either Error SignedJWT) -> Either Error SignedJWT)
-> Either Error SignedJWT
forall a b. a -> (a -> b) -> b
& Identity (Either Error SignedJWT) -> Either Error SignedJWT
forall a. Identity a -> a
runIdentity
  Either Error SignedJWT
-> (Either Error SignedJWT
    -> Either Error (TokenResponse SignedJWT))
-> Either Error (TokenResponse SignedJWT)
forall a b. a -> (a -> b) -> b
& (SignedJWT -> TokenResponse SignedJWT)
-> Either Error SignedJWT -> Either Error (TokenResponse SignedJWT)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SignedJWT -> TokenResponse Text -> TokenResponse SignedJWT
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ TokenResponse Text
token)

--------------------------------------------------------------------------------
-- | Identity token verification and claim validation.
verifyIdentityTokenClaims
  :: Discovery                -- ^ Provider discovery document.
  -> ClientID                 -- ^ Intended audience.
  -> UTCTime                  -- ^ Current time.
  -> JWKSet                   -- ^ Available keys to try.
  -> Text                     -- ^ Nonce.
  -> TokenResponse SignedJWT  -- ^ Signed identity token.
  -> Either JWTError (TokenResponse ClaimsSet)
verifyIdentityTokenClaims :: Discovery
-> Text
-> UTCTime
-> JWKSet
-> Text
-> TokenResponse SignedJWT
-> Either JWTError (TokenResponse ClaimsSet)
verifyIdentityTokenClaims Discovery
disco Text
clientId UTCTime
now JWKSet
keys Text
nonce TokenResponse SignedJWT
token =
    let JWKSet [JWK]
jwks = JWKSet
keys
    in (JWK
 -> Either JWTError (TokenResponse ClaimsSet)
 -> Either JWTError (TokenResponse ClaimsSet))
-> Either JWTError (TokenResponse ClaimsSet)
-> [JWK]
-> Either JWTError (TokenResponse ClaimsSet)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\JWK
k -> (JWTError -> Either JWTError (TokenResponse ClaimsSet))
-> (TokenResponse ClaimsSet
    -> Either JWTError (TokenResponse ClaimsSet))
-> Either JWTError (TokenResponse ClaimsSet)
-> Either JWTError (TokenResponse ClaimsSet)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either JWTError (TokenResponse ClaimsSet)
-> JWTError -> Either JWTError (TokenResponse ClaimsSet)
forall a b. a -> b -> a
const (JWK -> Either JWTError (TokenResponse ClaimsSet)
verifyWithKey JWK
k)) TokenResponse ClaimsSet
-> Either JWTError (TokenResponse ClaimsSet)
forall a b. b -> Either a b
Right)
             (JWTError -> Either JWTError (TokenResponse ClaimsSet)
forall a b. a -> Either a b
Left (Error -> JWTError
JWT.JWSError Error
JOSE.NoUsableKeys)) [JWK]
jwks
  where
    verifyWithKey :: JWK -> Either JWTError (TokenResponse ClaimsSet)
    verifyWithKey :: JWK -> Either JWTError (TokenResponse ClaimsSet)
verifyWithKey JWK
key =
      let settings :: JWTValidationSettings
settings = (StringOrURI -> Bool) -> JWTValidationSettings
JWT.defaultJWTValidationSettings (Bool -> StringOrURI -> Bool
forall a b. a -> b -> a
const Bool
True)
                   JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& (NominalDiffTime -> Identity NominalDiffTime)
-> JWTValidationSettings -> Identity JWTValidationSettings
forall s. HasAllowedSkew s => Lens' s NominalDiffTime
allowedSkew     ((NominalDiffTime -> Identity NominalDiffTime)
 -> JWTValidationSettings -> Identity JWTValidationSettings)
-> NominalDiffTime
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
.~ NominalDiffTime
120
                   JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ((StringOrURI -> Bool) -> Identity (StringOrURI -> Bool))
-> JWTValidationSettings -> Identity JWTValidationSettings
forall s. HasIssuerPredicate s => Lens' s (StringOrURI -> Bool)
issuerPredicate (((StringOrURI -> Bool) -> Identity (StringOrURI -> Bool))
 -> JWTValidationSettings -> Identity JWTValidationSettings)
-> (StringOrURI -> Bool)
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
.~ StringOrURI -> Bool
verifyIssuer
                   JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& (Bool -> Identity Bool)
-> JWTValidationSettings -> Identity JWTValidationSettings
forall s. HasCheckIssuedAt s => Lens' s Bool
checkIssuedAt   ((Bool -> Identity Bool)
 -> JWTValidationSettings -> Identity JWTValidationSettings)
-> Bool -> JWTValidationSettings -> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Bool
True
      in JWTValidationSettings
-> JWK
-> UTCTime
-> SignedJWT
-> ExceptT JWTError Identity ClaimsSet
forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
JWT.verifyClaimsAt JWTValidationSettings
settings JWK
key UTCTime
now (TokenResponse SignedJWT -> SignedJWT
forall a. TokenResponse a -> a
idToken TokenResponse SignedJWT
token)
         ExceptT JWTError Identity ClaimsSet
-> (ExceptT JWTError Identity ClaimsSet
    -> Identity (Either JWTError ClaimsSet))
-> Identity (Either JWTError ClaimsSet)
forall a b. a -> (a -> b) -> b
& ExceptT JWTError Identity ClaimsSet
-> Identity (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
         Identity (Either JWTError ClaimsSet)
-> (Identity (Either JWTError ClaimsSet)
    -> Either JWTError ClaimsSet)
-> Either JWTError ClaimsSet
forall a b. a -> (a -> b) -> b
& Identity (Either JWTError ClaimsSet) -> Either JWTError ClaimsSet
forall a. Identity a -> a
runIdentity Either JWTError ClaimsSet
-> (ClaimsSet -> Either JWTError ClaimsSet)
-> Either JWTError ClaimsSet
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Text -> Text -> ClaimsSet -> Either JWTError ClaimsSet
additionalValidation Text
clientId Text
nonce
         Either JWTError ClaimsSet
-> (Either JWTError ClaimsSet
    -> Either JWTError (TokenResponse ClaimsSet))
-> Either JWTError (TokenResponse ClaimsSet)
forall a b. a -> (a -> b) -> b
& (ClaimsSet -> TokenResponse ClaimsSet)
-> Either JWTError ClaimsSet
-> Either JWTError (TokenResponse ClaimsSet)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClaimsSet -> TokenResponse SignedJWT -> TokenResponse ClaimsSet
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ TokenResponse SignedJWT
token)

    verifyIssuer :: JWT.StringOrURI -> Bool
    verifyIssuer :: StringOrURI -> Bool
verifyIssuer = (StringOrURI -> StringOrURI -> Bool
forall a. Eq a => a -> a -> Bool
== (Tagged URI (Identity URI)
-> Tagged StringOrURI (Identity StringOrURI)
Prism' StringOrURI URI
JWT.uri (Tagged URI (Identity URI)
 -> Tagged StringOrURI (Identity StringOrURI))
-> URI -> StringOrURI
forall t b. AReview t b -> b -> t
# URI -> URI
getURI (Discovery -> URI
issuer Discovery
disco)))

-- FIXME: validate the at_hash
-- FIXME: rp-id_token-bad-sig-hs256 (Request an ID token and verify
-- its signature using the 'client_secret' as MAC key.)

--------------------------------------------------------------------------------
type Validate a = ExceptT JWTError (ReaderT ClaimsSet Identity) a

--------------------------------------------------------------------------------
orFailWith :: (ClaimsSet -> Bool) -> JWTError -> Validate ()
orFailWith :: (ClaimsSet -> Bool) -> JWTError -> Validate ()
orFailWith ClaimsSet -> Bool
f JWTError
e = do
  ClaimsSet
claims <- ExceptT JWTError (ReaderT ClaimsSet Identity) ClaimsSet
forall r (m :: * -> *). MonadReader r m => m r
ask
  if ClaimsSet -> Bool
f ClaimsSet
claims then () -> Validate ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure () else JWTError -> Validate ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError JWTError
e

--------------------------------------------------------------------------------
claimEq :: Text -> Aeson.Value -> ClaimsSet -> Bool
claimEq :: Text -> Value -> ClaimsSet -> Bool
claimEq Text
key Value
val ClaimsSet
claims =
  case Text -> HashMap Text Value -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
Map.lookup Text
key (ClaimsSet
claims ClaimsSet
-> Getting (HashMap Text Value) ClaimsSet (HashMap Text Value)
-> HashMap Text Value
forall s a. s -> Getting a s a -> a
^. Getting (HashMap Text Value) ClaimsSet (HashMap Text Value)
Lens' ClaimsSet (HashMap Text Value)
JWT.unregisteredClaims) of
    Maybe Value
Nothing   -> Bool
False
    Just Value
val' -> Value
val Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
== Value
val'

--------------------------------------------------------------------------------
additionalValidation :: ClientID -> Text -> ClaimsSet -> Either JWTError ClaimsSet
additionalValidation :: Text -> Text -> ClaimsSet -> Either JWTError ClaimsSet
additionalValidation Text
clientId Text
nonce = ClaimsSet -> Either JWTError ClaimsSet
go
  where
    go :: ClaimsSet -> Either JWTError ClaimsSet
    go :: ClaimsSet -> Either JWTError ClaimsSet
go ClaimsSet
claims = Validate ()
checks
              Validate ()
-> (Validate () -> ReaderT ClaimsSet Identity (Either JWTError ()))
-> ReaderT ClaimsSet Identity (Either JWTError ())
forall a b. a -> (a -> b) -> b
& Validate () -> ReaderT ClaimsSet Identity (Either JWTError ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
              ReaderT ClaimsSet Identity (Either JWTError ())
-> (ReaderT ClaimsSet Identity (Either JWTError ())
    -> Identity (Either JWTError ()))
-> Identity (Either JWTError ())
forall a b. a -> (a -> b) -> b
& (ReaderT ClaimsSet Identity (Either JWTError ())
 -> ClaimsSet -> Identity (Either JWTError ()))
-> ClaimsSet
-> ReaderT ClaimsSet Identity (Either JWTError ())
-> Identity (Either JWTError ())
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT ClaimsSet Identity (Either JWTError ())
-> ClaimsSet -> Identity (Either JWTError ())
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ClaimsSet
claims
              Identity (Either JWTError ())
-> (Identity (Either JWTError ()) -> Either JWTError ())
-> Either JWTError ()
forall a b. a -> (a -> b) -> b
& Identity (Either JWTError ()) -> Either JWTError ()
forall a. Identity a -> a
runIdentity
              Either JWTError ()
-> (Either JWTError () -> Either JWTError ClaimsSet)
-> Either JWTError ClaimsSet
forall a b. a -> (a -> b) -> b
& (ClaimsSet
claims ClaimsSet -> Either JWTError () -> Either JWTError ClaimsSet
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$)

    checks :: Validate ()
    checks :: Validate ()
checks = do
      ClaimsSet -> Bool
verifyNonce (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` String -> JWTError
JWT.JWTClaimsSetDecodeError String
"invalid nonce"
      ClaimsSet -> Bool
verifyIat   (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` JWTError
JWT.JWTIssuedAtFuture
      ClaimsSet -> Bool
verifySub   (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` String -> JWTError
JWT.JWTClaimsSetDecodeError String
"missing subject"
      (\ClaimsSet
c -> ClaimsSet -> Bool
verifyAudience ClaimsSet
c Bool -> Bool -> Bool
||
             ClaimsSet -> Bool
verifyAzp ClaimsSet
c) (ClaimsSet -> Bool) -> JWTError -> Validate ()
`orFailWith` JWTError
JWT.JWTNotInAudience

    verifyNonce :: ClaimsSet -> Bool
    verifyNonce :: ClaimsSet -> Bool
verifyNonce = Text -> Value -> ClaimsSet -> Bool
claimEq Text
"nonce" (Text -> Value
Aeson.String Text
nonce)

    verifyAudience :: ClaimsSet -> Bool
    verifyAudience :: ClaimsSet -> Bool
verifyAudience ClaimsSet
claims =
      case ClaimsSet
claims ClaimsSet
-> Getting (Maybe Audience) ClaimsSet (Maybe Audience)
-> Maybe Audience
forall s a. s -> Getting a s a -> a
^. Getting (Maybe Audience) ClaimsSet (Maybe Audience)
Lens' ClaimsSet (Maybe Audience)
claimAud of
        Just (JWT.Audience [StringOrURI
aud]) ->
          StringOrURI -> Maybe StringOrURI
forall a. a -> Maybe a
Just StringOrURI
aud Maybe StringOrURI -> Maybe StringOrURI -> Bool
forall a. Eq a => a -> a -> Bool
== Text
clientId Text
-> Getting (First StringOrURI) Text StringOrURI
-> Maybe StringOrURI
forall s a. s -> Getting (First a) s a -> Maybe a
^? Getting (First StringOrURI) Text StringOrURI
forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
        Maybe Audience
_ -> Bool
False

    verifyAzp :: ClaimsSet -> Bool
    verifyAzp :: ClaimsSet -> Bool
verifyAzp = Text -> Value -> ClaimsSet -> Bool
claimEq Text
"azp" (Text -> Value
Aeson.String Text
clientId)

    verifySub :: ClaimsSet -> Bool
    verifySub :: ClaimsSet -> Bool
verifySub = Maybe StringOrURI -> Bool
forall a. Maybe a -> Bool
isJust (Maybe StringOrURI -> Bool)
-> (ClaimsSet -> Maybe StringOrURI) -> ClaimsSet -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClaimsSet
-> Getting (Maybe StringOrURI) ClaimsSet (Maybe StringOrURI)
-> Maybe StringOrURI
forall s a. s -> Getting a s a -> a
^. Getting (Maybe StringOrURI) ClaimsSet (Maybe StringOrURI)
Lens' ClaimsSet (Maybe StringOrURI)
claimSub)

    -- JOSE verifies the iat claim if it exists but does not reject if
    -- the iat is missing.  OpenID Connect requires a rejection when
    -- iat is missing.
    verifyIat :: ClaimsSet -> Bool
    verifyIat :: ClaimsSet -> Bool
verifyIat = Maybe NumericDate -> Bool
forall a. Maybe a -> Bool
isJust (Maybe NumericDate -> Bool)
-> (ClaimsSet -> Maybe NumericDate) -> ClaimsSet -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClaimsSet
-> Getting (Maybe NumericDate) ClaimsSet (Maybe NumericDate)
-> Maybe NumericDate
forall s a. s -> Getting a s a -> a
^. Getting (Maybe NumericDate) ClaimsSet (Maybe NumericDate)
Lens' ClaimsSet (Maybe NumericDate)
claimIat)