{-# LANGUAGE OverloadedStrings #-}
{-|
    Module: Web.OIDC.Client.CodeFlow
    Maintainer: krdlab@gmail.com
    Stability: experimental
-}
module Web.OIDC.Client.CodeFlow
    (
      getAuthenticationRequestUrl
    , getValidTokens
    , prepareAuthenticationRequestUrl
    , requestTokens

    -- * For testing
    , validateClaims
    , getCurrentIntDate
    ) where

import           Control.Monad                      (unless, when)
import           Control.Monad.Catch                (MonadCatch, MonadThrow,
                                                     catch, throwM)
import           Control.Monad.IO.Class             (MonadIO, liftIO)
import           Data.Aeson                         (FromJSON, eitherDecode)
import qualified Data.ByteString.Char8              as B
import           Data.List                          (nub)
import           Data.Maybe                         (isNothing)
import           Data.Monoid                        ((<>))
import           Data.Text                          (Text, pack, unpack)
import           Data.Text.Encoding                 (decodeUtf8With)
import           Data.Text.Encoding.Error           (lenientDecode)
import           Data.Time.Clock.POSIX              (getPOSIXTime)
import qualified Jose.Jwt                           as Jwt
import           Network.HTTP.Client                (Manager, Request (..),
                                                     getUri, httpLbs,
                                                     responseBody,
                                                     setQueryString,
                                                     urlEncodedBody)
import           Network.URI                        (URI)

import           Prelude                            hiding (exp)

import qualified Web.OIDC.Client.Discovery.Provider as P
import           Web.OIDC.Client.Internal           (parseUrl)
import qualified Web.OIDC.Client.Internal           as I
import           Web.OIDC.Client.Settings           (OIDC (..))
import           Web.OIDC.Client.Tokens             (IdTokenClaims (..), validateIdToken,
                                                     Tokens (..))
import           Web.OIDC.Client.Types              (Code, Nonce,
                                                     OpenIdException (..),
                                                     Parameters, Scope,
                                                     SessionStore (..), State,
                                                     openId)

-- | Make URL for Authorization Request after generating state and nonce from 'SessionStore'.
prepareAuthenticationRequestUrl
    :: (MonadThrow m, MonadCatch m)
    => SessionStore m
    -> OIDC
    -> Scope            -- ^ used to specify what are privileges requested for tokens. (use `ScopeValue`)
    -> Parameters       -- ^ Optional parameters
    -> m URI
prepareAuthenticationRequestUrl :: forall (m :: * -> *).
(MonadThrow m, MonadCatch m) =>
SessionStore m -> OIDC -> Scope -> Parameters -> m URI
prepareAuthenticationRequestUrl SessionStore m
store OIDC
oidc Scope
scope Parameters
params = do
    ByteString
state <- SessionStore m -> m ByteString
forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
    ByteString
nonce' <- SessionStore m -> m ByteString
forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
    SessionStore m -> ByteString -> ByteString -> m ()
forall (m :: * -> *).
SessionStore m -> ByteString -> ByteString -> m ()
sessionStoreSave SessionStore m
store ByteString
state ByteString
nonce'
    OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
forall (m :: * -> *).
(MonadThrow m, MonadCatch m) =>
OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
state) (Parameters -> m URI) -> Parameters -> m URI
forall a b. (a -> b) -> a -> b
$ Parameters
params Parameters -> Parameters -> Parameters
forall a. [a] -> [a] -> [a]
++ [(ByteString
"nonce", ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
nonce')]

-- | Get and validate access token and with code and state stored in the 'SessionStore'.
--   Then deletes session info by 'sessionStoreDelete'.
getValidTokens
    :: (MonadThrow m, MonadCatch m, MonadIO m, FromJSON a)
    => SessionStore m
    -> OIDC
    -> Manager
    -> State
    -> Code
    -> m (Tokens a)
getValidTokens :: forall (m :: * -> *) a.
(MonadThrow m, MonadCatch m, MonadIO m, FromJSON a) =>
SessionStore m
-> OIDC -> Manager -> ByteString -> ByteString -> m (Tokens a)
getValidTokens SessionStore m
store OIDC
oidc Manager
mgr ByteString
stateFromIdP ByteString
code = do
    Maybe ByteString
savedNonce <- SessionStore m -> ByteString -> m (Maybe ByteString)
forall (m :: * -> *).
SessionStore m -> ByteString -> m (Maybe ByteString)
sessionStoreGet SessionStore m
store ByteString
stateFromIdP
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe ByteString -> Bool
forall a. Maybe a -> Bool
isNothing Maybe ByteString
savedNonce) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ OpenIdException -> m ()
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM OpenIdException
UnknownState
    Tokens a
result <- IO (Tokens a) -> m (Tokens a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Tokens a) -> m (Tokens a)) -> IO (Tokens a) -> m (Tokens a)
forall a b. (a -> b) -> a -> b
$ OIDC -> Maybe ByteString -> ByteString -> Manager -> IO (Tokens a)
forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> ByteString -> Manager -> IO (Tokens a)
requestTokens OIDC
oidc Maybe ByteString
savedNonce ByteString
code Manager
mgr
    SessionStore m -> m ()
forall (m :: * -> *). SessionStore m -> m ()
sessionStoreDelete SessionStore m
store
    Tokens a -> m (Tokens a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Tokens a
result

-- | Make URL for Authorization Request.
{-# WARNING getAuthenticationRequestUrl "This function doesn't manage state and nonce. Use prepareAuthenticationRequestUrl only unless your IdP doesn't support state and/or nonce." #-}
getAuthenticationRequestUrl
    :: (MonadThrow m, MonadCatch m)
    => OIDC
    -> Scope            -- ^ used to specify what are privileges requested for tokens. (use `ScopeValue`)
    -> Maybe State      -- ^ used for CSRF mitigation. (recommended parameter)
    -> Parameters       -- ^ Optional parameters
    -> m URI
getAuthenticationRequestUrl :: forall (m :: * -> *).
(MonadThrow m, MonadCatch m) =>
OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope Maybe ByteString
state Parameters
params = do
    Request
req <- Text -> m Request
forall (m :: * -> *). MonadThrow m => Text -> m Request
parseUrl Text
endpoint m Request -> (HttpException -> m Request) -> m Request
forall e a. (HasCallStack, Exception e) => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
`catch` HttpException -> m Request
forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
I.rethrow
    URI -> m URI
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (URI -> m URI) -> URI -> m URI
forall a b. (a -> b) -> a -> b
$ Request -> URI
getUri (Request -> URI) -> Request -> URI
forall a b. (a -> b) -> a -> b
$ Parameters -> Request -> Request
setQueryString Parameters
query Request
req
  where
    endpoint :: Text
endpoint  = OIDC -> Text
oidcAuthorizationServerUrl OIDC
oidc
    query :: Parameters
query     = Parameters
requireds Parameters -> Parameters -> Parameters
forall a. [a] -> [a] -> [a]
++ Parameters
state' Parameters -> Parameters -> Parameters
forall a. [a] -> [a] -> [a]
++ Parameters
params
    requireds :: Parameters
requireds =
        [ (ByteString
"response_type", ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"code")
        , (ByteString
"client_id",     ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcClientId OIDC
oidc)
        , (ByteString
"redirect_uri",  ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcRedirectUri OIDC
oidc)
        , (ByteString
"scope",         ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ String -> ByteString
B.pack (String -> ByteString) -> (Scope -> String) -> Scope -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unwords ([String] -> String) -> (Scope -> [String]) -> Scope -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> [String]) -> (Scope -> [String]) -> Scope -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> String) -> Scope -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Text -> String
unpack (Scope -> ByteString) -> Scope -> ByteString
forall a b. (a -> b) -> a -> b
$ Text
openIdText -> Scope -> Scope
forall a. a -> [a] -> [a]
:Scope
scope)
        ]
    state' :: Parameters
state' =
        case Maybe ByteString
state of
            Just ByteString
_  -> [(ByteString
"state", Maybe ByteString
state)]
            Maybe ByteString
Nothing -> []

-- TODO: error response

-- | Request and validate tokens.
--
-- This function requests ID Token and Access Token to a OP's token endpoint, and validates the received ID Token.
-- Returned `Tokens` value is a valid.
--
-- If a HTTP error has occurred or a tokens validation has failed, this function throws `OpenIdException`.
{-# WARNING requestTokens "This function doesn't manage state and nonce. Use getValidTokens only unless your IdP doesn't support state and/or nonce." #-}
requestTokens :: FromJSON a => OIDC -> Maybe Nonce -> Code -> Manager -> IO (Tokens a)
requestTokens :: forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> ByteString -> Manager -> IO (Tokens a)
requestTokens OIDC
oidc Maybe ByteString
savedNonce ByteString
code Manager
manager = do
    ByteString
json <- IO ByteString
getTokensJson IO ByteString -> (HttpException -> IO ByteString) -> IO ByteString
forall e a.
(HasCallStack, Exception e) =>
IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
`catch` HttpException -> IO ByteString
forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
I.rethrow
    case ByteString -> Either String TokensResponse
forall a. FromJSON a => ByteString -> Either String a
eitherDecode ByteString
json of
        Right TokensResponse
ts -> OIDC -> Maybe ByteString -> TokensResponse -> IO (Tokens a)
forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> TokensResponse -> IO (Tokens a)
validate OIDC
oidc Maybe ByteString
savedNonce TokensResponse
ts
        Left String
err -> OpenIdException -> IO (Tokens a)
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (OpenIdException -> IO (Tokens a))
-> (Text -> OpenIdException) -> Text -> IO (Tokens a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> OpenIdException
JsonException (Text -> IO (Tokens a)) -> Text -> IO (Tokens a)
forall a b. (a -> b) -> a -> b
$ String -> Text
pack String
err
  where
    getTokensJson :: IO ByteString
getTokensJson = do
        Request
req <- Text -> IO Request
forall (m :: * -> *). MonadThrow m => Text -> m Request
parseUrl Text
endpoint
        let req' :: Request
req' = [(ByteString, ByteString)] -> Request -> Request
urlEncodedBody [(ByteString, ByteString)]
body (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ Request
req { method = "POST" }
        Response ByteString
res <- Request -> Manager -> IO (Response ByteString)
httpLbs Request
req' Manager
manager
        ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Response ByteString -> ByteString
forall body. Response body -> body
responseBody Response ByteString
res
    endpoint :: Text
endpoint = OIDC -> Text
oidcTokenEndpoint OIDC
oidc
    cid :: ByteString
cid      = OIDC -> ByteString
oidcClientId OIDC
oidc
    sec :: ByteString
sec      = OIDC -> ByteString
oidcClientSecret OIDC
oidc
    redirect :: ByteString
redirect = OIDC -> ByteString
oidcRedirectUri OIDC
oidc
    body :: [(ByteString, ByteString)]
body     =
        [ (ByteString
"grant_type",    ByteString
"authorization_code")
        , (ByteString
"code",          ByteString
code)
        , (ByteString
"client_id",     ByteString
cid)
        , (ByteString
"client_secret", ByteString
sec)
        , (ByteString
"redirect_uri",  ByteString
redirect)
        ]

validate :: FromJSON a => OIDC -> Maybe Nonce -> I.TokensResponse -> IO (Tokens a)
validate :: forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> TokensResponse -> IO (Tokens a)
validate OIDC
oidc Maybe ByteString
savedNonce TokensResponse
tres = do
    let jwt' :: Jwt
jwt' = TokensResponse -> Jwt
I.idToken TokensResponse
tres
    IdTokenClaims a
claims' <- OIDC -> Jwt -> IO (IdTokenClaims a)
forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt'
    IntDate
now <- IO IntDate
getCurrentIntDate
    Text
-> Text -> IntDate -> Maybe ByteString -> IdTokenClaims a -> IO ()
forall a.
Text
-> Text -> IntDate -> Maybe ByteString -> IdTokenClaims a -> IO ()
validateClaims
        (Configuration -> Text
P.issuer (Configuration -> Text) -> (OIDC -> Configuration) -> OIDC -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Provider -> Configuration
P.configuration (Provider -> Configuration)
-> (OIDC -> Provider) -> OIDC -> Configuration
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> Provider
oidcProvider (OIDC -> Text) -> OIDC -> Text
forall a b. (a -> b) -> a -> b
$ OIDC
oidc)
        (OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode (ByteString -> Text) -> (OIDC -> ByteString) -> OIDC -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> ByteString
oidcClientId (OIDC -> Text) -> OIDC -> Text
forall a b. (a -> b) -> a -> b
$ OIDC
oidc)
        IntDate
now
        Maybe ByteString
savedNonce
        IdTokenClaims a
claims'
    Tokens a -> IO (Tokens a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tokens {
          accessToken :: Text
accessToken  = TokensResponse -> Text
I.accessToken TokensResponse
tres
        , tokenType :: Text
tokenType    = TokensResponse -> Text
I.tokenType TokensResponse
tres
        , idToken :: IdTokenClaims a
idToken      = IdTokenClaims a
claims'
        , idTokenJwt :: Jwt
idTokenJwt   = Jwt
jwt'
        , expiresIn :: Maybe Integer
expiresIn    = TokensResponse -> Maybe Integer
I.expiresIn TokensResponse
tres
        , refreshToken :: Maybe Text
refreshToken = TokensResponse -> Maybe Text
I.refreshToken TokensResponse
tres
        }

validateClaims :: Text -> Text -> Jwt.IntDate -> Maybe Nonce -> IdTokenClaims a -> IO ()
validateClaims :: forall a.
Text
-> Text -> IntDate -> Maybe ByteString -> IdTokenClaims a -> IO ()
validateClaims Text
issuer' Text
clientId' IntDate
now Maybe ByteString
savedNonce IdTokenClaims a
claims' = do
    let iss' :: Text
iss' = IdTokenClaims a -> Text
forall a. IdTokenClaims a -> Text
iss IdTokenClaims a
claims'
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Text
iss' Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
issuer')
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ OpenIdException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (OpenIdException -> IO ()) -> OpenIdException -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException (Text -> OpenIdException) -> Text -> OpenIdException
forall a b. (a -> b) -> a -> b
$ Text
"issuer from token \"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
iss' Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\" is different than expected issuer \"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
issuer' Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\""

    let aud' :: Scope
aud' = IdTokenClaims a -> Scope
forall a. IdTokenClaims a -> Scope
aud IdTokenClaims a
claims'
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Text
clientId' Text -> Scope -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Scope
aud')
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ OpenIdException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (OpenIdException -> IO ()) -> OpenIdException -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException (Text -> OpenIdException) -> Text -> OpenIdException
forall a b. (a -> b) -> a -> b
$ Text
"our client \"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
clientId' Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\" isn't contained in the token's audience " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (String -> Text
pack (String -> Text) -> (Scope -> String) -> Scope -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> String
forall a. Show a => a -> String
show) Scope
aud'

    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntDate
now IntDate -> IntDate -> Bool
forall a. Ord a => a -> a -> Bool
< IdTokenClaims a -> IntDate
forall a. IdTokenClaims a -> IntDate
exp IdTokenClaims a
claims')
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ OpenIdException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (OpenIdException -> IO ()) -> OpenIdException -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException Text
"received token has expired"

    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IdTokenClaims a -> Maybe ByteString
forall a. IdTokenClaims a -> Maybe ByteString
nonce IdTokenClaims a
claims' Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe ByteString
savedNonce)
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ OpenIdException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (OpenIdException -> IO ()) -> OpenIdException -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException Text
"Inconsistent nonce"

getCurrentIntDate :: IO Jwt.IntDate
getCurrentIntDate :: IO IntDate
getCurrentIntDate = POSIXTime -> IntDate
Jwt.IntDate (POSIXTime -> IntDate) -> IO POSIXTime -> IO IntDate
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime