{-# LANGUAGE OverloadedStrings #-}
module Web.OIDC.Client.CodeFlow
(
getAuthenticationRequestUrl
, getValidTokens
, prepareAuthenticationRequestUrl
, requestTokens
, validateClaims
, getCurrentIntDate
) where
import Control.Monad (unless, when)
import Control.Monad.Catch (MonadCatch, MonadThrow,
catch, throwM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Aeson (FromJSON, eitherDecode)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Either (partitionEithers)
import Data.List (nub)
import Data.Maybe (isNothing)
import Data.Monoid ((<>))
import Data.Text (Text, pack, unpack)
import Data.Text.Encoding (decodeUtf8)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Jose.Jwt (Jwt, JwtContent (Jwe, Jws, Unsecured))
import qualified Jose.Jwt as Jwt
import Network.HTTP.Client (Manager, Request (..),
getUri, httpLbs,
responseBody,
setQueryString,
urlEncodedBody)
import Network.URI (URI)
import Prelude hiding (exp)
import qualified Web.OIDC.Client.Discovery.Provider as P
import Web.OIDC.Client.Internal (parseUrl)
import qualified Web.OIDC.Client.Internal as I
import Web.OIDC.Client.Settings (OIDC (..))
import Web.OIDC.Client.Tokens (IdTokenClaims (..),
Tokens (..))
import Web.OIDC.Client.Types (Code, Nonce,
OpenIdException (..),
Parameters, Scope,
SessionStore (..), State,
openId)
prepareAuthenticationRequestUrl
:: (MonadThrow m, MonadCatch m)
=> SessionStore m
-> OIDC
-> Scope
-> Parameters
-> m URI
prepareAuthenticationRequestUrl store oidc scope params = do
state <- sessionStoreGenerate store
nonce' <- sessionStoreGenerate store
sessionStoreSave store state nonce'
getAuthenticationRequestUrl oidc scope (Just state) $ params ++ [("nonce", Just nonce')]
getValidTokens
:: (MonadThrow m, MonadCatch m, MonadIO m, FromJSON a)
=> SessionStore m
-> OIDC
-> Manager
-> State
-> Code
-> m (Tokens a)
getValidTokens store oidc mgr stateFromIdP code = do
(state, savedNonce) <- sessionStoreGet store
if state == Just stateFromIdP
then do
when (isNothing savedNonce) $ throwM $ ValidationException "Nonce is not saved!"
result <- liftIO $ requestTokens oidc savedNonce code mgr
sessionStoreDelete store
return result
else throwM $ ValidationException $ "Incosistent state: " <> decodeUtf8 stateFromIdP
{-# WARNING getAuthenticationRequestUrl "This function doesn't manage state and nonce. Use prepareAuthenticationRequestUrl only unless your IdP doesn't support state and/or nonce." #-}
getAuthenticationRequestUrl
:: (MonadThrow m, MonadCatch m)
=> OIDC
-> Scope
-> Maybe State
-> Parameters
-> m URI
getAuthenticationRequestUrl oidc scope state params = do
req <- parseUrl endpoint `catch` I.rethrow
return $ getUri $ setQueryString query req
where
endpoint = oidcAuthorizationServerUrl oidc
query = requireds ++ state' ++ params
requireds =
[ ("response_type", Just "code")
, ("client_id", Just $ oidcClientId oidc)
, ("redirect_uri", Just $ oidcRedirectUri oidc)
, ("scope", Just $ B.pack . unwords . nub . map unpack $ openId:scope)
]
state' =
case state of
Just _ -> [("state", state)]
Nothing -> []
{-# WARNING requestTokens "This function doesn't manage state and nonce. Use getValidTokens only unless your IdP doesn't support state and/or nonce." #-}
requestTokens :: FromJSON a => OIDC -> Maybe Nonce -> Code -> Manager -> IO (Tokens a)
requestTokens oidc savedNonce code manager = do
json <- getTokensJson `catch` I.rethrow
case eitherDecode json of
Right ts -> validate oidc savedNonce ts
Left err -> throwM . JsonException $ pack err
where
getTokensJson = do
req <- parseUrl endpoint
let req' = urlEncodedBody body $ req { method = "POST" }
res <- httpLbs req' manager
return $ responseBody res
endpoint = oidcTokenEndpoint oidc
cid = oidcClientId oidc
sec = oidcClientSecret oidc
redirect = oidcRedirectUri oidc
body =
[ ("grant_type", "authorization_code")
, ("code", code)
, ("client_id", cid)
, ("client_secret", sec)
, ("redirect_uri", redirect)
]
validate :: FromJSON a => OIDC -> Maybe Nonce -> I.TokensResponse -> IO (Tokens a)
validate oidc savedNonce tres = do
let jwt' = I.idToken tres
claims' <- validateIdToken oidc jwt'
now <- getCurrentIntDate
validateClaims
(P.issuer . P.configuration . oidcProvider $ oidc)
(decodeUtf8 . oidcClientId $ oidc)
now
savedNonce
claims'
return Tokens {
accessToken = I.accessToken tres
, tokenType = I.tokenType tres
, idToken = claims'
, expiresIn = I.expiresIn tres
, refreshToken = I.refreshToken tres
}
validateIdToken :: FromJSON a => OIDC -> Jwt -> IO (IdTokenClaims a)
validateIdToken oidc jwt' = do
let jwks = P.jwkSet . oidcProvider $ oidc
token = Jwt.unJwt jwt'
alg = fmap (Jwt.JwsEncoding . P.getJwsAlg)
. P.idTokenSigningAlgValuesSupported
. P.configuration
$ oidcProvider oidc
decoded <-
(\x -> case partitionEithers x of
(_ , k : _) -> Right k
(e : _, _ ) -> Left e
([] , [] ) -> Left $ Jwt.KeyError "No Keys available for decoding"
)
<$> traverse (\alg' -> Jwt.decode jwks (Just alg') token) alg
case decoded of
Right (Unsecured payload) -> throwM $ UnsecuredJwt payload
Right (Jws (_header, payload)) -> parsePayload payload
Right (Jwe (_header, payload)) -> parsePayload payload
Left err -> throwM $ JwtExceptoin err
where
parsePayload payload =
case eitherDecode $ BL.fromStrict payload of
Right x -> return x
Left err -> throwM . JsonException $ pack err
validateClaims :: Text -> Text -> Jwt.IntDate -> Maybe Nonce -> IdTokenClaims a -> IO ()
validateClaims issuer' clientId' now savedNonce claims' = do
let iss' = iss claims'
unless (iss' == issuer')
$ throwM $ ValidationException $ "issuer from token \"" <> iss' <> "\" is different than expected issuer \"" <> issuer' <> "\""
let aud' = aud claims'
unless (clientId' `elem` aud')
$ throwM $ ValidationException $ "our client \"" <> clientId' <> "\" isn't contained in the token's audience " <> (pack . show) aud'
unless (now < exp claims')
$ throwM $ ValidationException "received token has expired"
unless (nonce claims' == savedNonce)
$ throwM $ ValidationException "Inconsistent nonce"
getCurrentIntDate :: IO Jwt.IntDate
getCurrentIntDate = Jwt.IntDate <$> getPOSIXTime