{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}

{-|
    Module: Web.OIDC.Client.Tokens
    Maintainer: krdlab@gmail.com
    Stability: experimental
-}
module Web.OIDC.Client.Tokens
    ( Tokens(..)
    , IdTokenClaims(..)
    , validateIdToken
    )
where

import           Control.Applicative                ((<|>))
import           Control.Exception                  (throwIO)
import           Control.Monad.IO.Class             (MonadIO, liftIO)
import           Data.Aeson                         (FromJSON (parseJSON),
                                                     FromJSON, Value (Object),
                                                     eitherDecode, withObject,
                                                     (.:), (.:?))
import           Data.ByteString                    (ByteString)
import qualified Data.ByteString.Lazy.Char8         as BL
import           Data.Either                        (partitionEithers)
import           Data.Monoid                        ((<>))
import           Data.Text                          (Text, pack)
import           Data.Text.Encoding                 (encodeUtf8)
import           GHC.Generics                       (Generic)
import           Jose.Jwt                           (IntDate, Jwt, JwtContent (Jwe, Jws, Unsecured))
import qualified Jose.Jwt                           as Jwt
import           Prelude                            hiding (exp)
import qualified Web.OIDC.Client.Discovery.Provider as P
import           Web.OIDC.Client.Settings           (OIDC (..))
import           Web.OIDC.Client.Types              (OpenIdException (..))

data Tokens a = Tokens
    { forall a. Tokens a -> Text
accessToken  :: Text
    , forall a. Tokens a -> Text
tokenType    :: Text
    , forall a. Tokens a -> IdTokenClaims a
idToken      :: IdTokenClaims a
    , forall a. Tokens a -> Jwt
idTokenJwt   :: Jwt
    , forall a. Tokens a -> Maybe Integer
expiresIn    :: Maybe Integer
    , forall a. Tokens a -> Maybe Text
refreshToken :: Maybe Text
    }
  deriving (Int -> Tokens a -> ShowS
forall a. Show a => Int -> Tokens a -> ShowS
forall a. Show a => [Tokens a] -> ShowS
forall a. Show a => Tokens a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tokens a] -> ShowS
$cshowList :: forall a. Show a => [Tokens a] -> ShowS
show :: Tokens a -> String
$cshow :: forall a. Show a => Tokens a -> String
showsPrec :: Int -> Tokens a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Tokens a -> ShowS
Show, Tokens a -> Tokens a -> Bool
forall a. Eq a => Tokens a -> Tokens a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tokens a -> Tokens a -> Bool
$c/= :: forall a. Eq a => Tokens a -> Tokens a -> Bool
== :: Tokens a -> Tokens a -> Bool
$c== :: forall a. Eq a => Tokens a -> Tokens a -> Bool
Eq)

-- | Claims required for an <https://openid.net/specs/openid-connect-core-1_0.html#IDToken ID Token>,
--   plus recommended claims (nonce) and other custom claims.
data IdTokenClaims a = IdTokenClaims
    { forall a. IdTokenClaims a -> Text
iss         :: !Text
    , forall a. IdTokenClaims a -> Text
sub         :: !Text
    , forall a. IdTokenClaims a -> [Text]
aud         :: ![Text]
    , forall a. IdTokenClaims a -> IntDate
exp         :: !IntDate
    , forall a. IdTokenClaims a -> IntDate
iat         :: !IntDate
    , forall a. IdTokenClaims a -> Maybe ByteString
nonce       :: !(Maybe ByteString)
    , forall a. IdTokenClaims a -> a
otherClaims :: !a
    }
  deriving (Int -> IdTokenClaims a -> ShowS
forall a. Show a => Int -> IdTokenClaims a -> ShowS
forall a. Show a => [IdTokenClaims a] -> ShowS
forall a. Show a => IdTokenClaims a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IdTokenClaims a] -> ShowS
$cshowList :: forall a. Show a => [IdTokenClaims a] -> ShowS
show :: IdTokenClaims a -> String
$cshow :: forall a. Show a => IdTokenClaims a -> String
showsPrec :: Int -> IdTokenClaims a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> IdTokenClaims a -> ShowS
Show, IdTokenClaims a -> IdTokenClaims a -> Bool
forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c/= :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
== :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c== :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
Eq, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
$cto :: forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
$cfrom :: forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
Generic)

instance FromJSON a => FromJSON (IdTokenClaims a) where
    parseJSON :: Value -> Parser (IdTokenClaims a)
parseJSON = forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"IdTokenClaims" forall a b. (a -> b) -> a -> b
$ \Object
o ->
        forall a.
Text
-> Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a
IdTokenClaims
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"iss"
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"sub"
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"aud" forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ((forall a. a -> [a] -> [a]
:[]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"aud")))
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"exp"
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"iat"
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> ByteString
encodeUtf8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
.:? Key
"nonce")
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. FromJSON a => Value -> Parser a
parseJSON (Object -> Value
Object Object
o)

validateIdToken :: (MonadIO m, FromJSON a) => OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken :: forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt' = do
    let jwks :: [Jwk]
jwks  = Provider -> [Jwk]
P.jwkSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> Provider
oidcProvider forall a b. (a -> b) -> a -> b
$ OIDC
oidc
        token :: ByteString
token = Jwt -> ByteString
Jwt.unJwt Jwt
jwt'
        algs :: [JwsAlgJson]
algs  = Configuration -> [JwsAlgJson]
P.idTokenSigningAlgValuesSupported
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. Provider -> Configuration
P.configuration
              forall a b. (a -> b) -> a -> b
$ OIDC -> Provider
oidcProvider OIDC
oidc
    Either JwtError JwtContent
decoded <-
        forall {b}. [Either JwtError b] -> Either JwtError b
selectDecodedResult
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
                    (forall {m :: * -> *}.
MonadIO m =>
[Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token)
                    [JwsAlgJson]
algs
    case Either JwtError JwtContent
decoded of
        Right (Unsecured ByteString
payload) -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ ByteString -> OpenIdException
UnsecuredJwt ByteString
payload
        Right (Jws (JwsHeader
_header, ByteString
payload)) -> forall {a} {m :: * -> *}.
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
        Right (Jwe (JweHeader
_header, ByteString
payload)) -> forall {a} {m :: * -> *}.
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
        Left JwtError
err -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ JwtError -> OpenIdException
JwtException JwtError
err
  where
    tryDecode :: [Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token = \case
        P.JwsAlgJson  JwsAlg
alg -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadRandom m =>
[Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
Jwt.decode [Jwk]
jwks (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ JwsAlg -> JwtEncoding
Jwt.JwsEncoding JwsAlg
alg) ByteString
token
        P.Unsupported Text
alg -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.BadAlgorithm (Text
"Unsupported algorithm: " forall a. Semigroup a => a -> a -> a
<> Text
alg)

    selectDecodedResult :: [Either JwtError b] -> Either JwtError b
selectDecodedResult [Either JwtError b]
xs = case forall a b. [Either a b] -> ([a], [b])
partitionEithers [Either JwtError b]
xs of
        ([JwtError]
_, b
k : [b]
_) -> forall a b. b -> Either a b
Right b
k
        (JwtError
e : [JwtError]
_, [b]
_) -> forall a b. a -> Either a b
Left JwtError
e
        ([], [])   -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.KeyError Text
"No Keys available for decoding"

    parsePayload :: ByteString -> m a
parsePayload ByteString
payload = case forall a. FromJSON a => ByteString -> Either String a
eitherDecode forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict ByteString
payload of
        Right a
x   -> forall (m :: * -> *) a. Monad m => a -> m a
return a
x
        Left  String
err -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> OpenIdException
JsonException forall a b. (a -> b) -> a -> b
$ String -> Text
pack String
err