{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Middleware.Auth.OIDC
(
OpenIDConnect
, discover
, oidcClientId
, oidcClientSecret
, oidcProviderInfo
, oidcManager
, oidcScopes
, oidcAllowedSkew
, getAccessToken
, getIdToken
) where
import Control.Applicative ((<|>))
import qualified Crypto.JOSE as JOSE
import qualified Crypto.JWT as JWT
import Control.Monad.Except (runExceptT)
import Data.Aeson (FromJSON(parseJSON),
withObject, (.:), (.!=))
import qualified Data.ByteString.Char8 as S8
import Data.Function ((&))
import qualified Data.Time.Clock as Clock
import Data.Traversable (for)
import qualified Data.Text as T
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TLE
import qualified Data.Vault.Lazy as Vault
import Foreign.C.Types (CTime (..))
import qualified Lens.Micro as Lens
import qualified Lens.Micro.Extras as Lens.Extras
import Network.HTTP.Simple (httpJSON,
getResponseBody,
parseRequestThrow)
import Network.Wai.Middleware.Auth.OAuth2 (parseAbsoluteURI,
getAccessToken)
import qualified Network.OAuth.OAuth2 as OA2
import Network.HTTP.Client (Manager)
import Network.HTTP.Client.TLS (getGlobalManager)
import Network.Wai (Request, vault)
import Network.Wai.Auth.Internal (Metadata(..),
decodeToken, encodeToken,
oauth2Login,
refreshTokens)
import Network.Wai.Middleware.Auth.Provider
import System.IO.Unsafe (unsafePerformIO)
import System.PosixCompat.Time (epochTime)
import qualified Text.Hamlet
import qualified URI.ByteString as U
data OpenIDConnect
= OpenIDConnect
{ oidcMetadata :: Metadata
, oidcJwkSet :: JOSE.JWKSet
, oidcClientId :: T.Text
, oidcClientSecret :: T.Text
, oidcProviderInfo :: ProviderInfo
, oidcManager :: Maybe Manager
, oidcScopes :: [T.Text]
, oidcAllowedSkew :: Clock.NominalDiffTime
}
instance FromJSON OpenIDConnect where
parseJSON =
withObject "OpenIDConnect Object" $ \obj -> do
metadata <- obj .: "metadata"
jwkSet <- obj .: "jwk_set"
clientId <- obj .: "client_id"
clientSecret <- obj .: "client_secret"
providerInfo <- obj .: "provider_info" .!= defProviderInfo
scopes <- obj .: "scopes" .!= ["openid"]
allowedSkew <- obj .: "allowed_skew" .!= 0
pure OpenIDConnect {
oidcMetadata = metadata,
oidcJwkSet = jwkSet,
oidcClientId = clientId,
oidcClientSecret = clientSecret,
oidcProviderInfo = providerInfo,
oidcManager = Nothing,
oidcScopes = scopes,
oidcAllowedSkew = allowedSkew
}
instance AuthProvider OpenIDConnect where
getProviderName _ = "oidc"
getProviderInfo = oidcProviderInfo
handleLogin oidc@OpenIDConnect {.. } req suffix renderUrl onSuccess onFailure = do
oauth2 <- mkOauth2 oidc (Just renderUrl)
manager <- maybe getGlobalManager pure oidcManager
oauth2Login
oauth2
manager
(Just oidcScopes)
(getProviderName oidc)
req
suffix
onSuccess
onFailure
refreshLoginState oidc req user =
let loginState = authLoginState user
in case decodeToken loginState of
Left _ -> pure Nothing
Right tokens -> do
vRes <- validateIdToken' oidc tokens
case vRes of
Nothing -> do
oauth2 <- mkOauth2 oidc Nothing
manager <- maybe getGlobalManager pure (oidcManager oidc)
rRes <- refreshTokens tokens manager oauth2
case rRes of
Nothing -> pure Nothing
Just newTokens -> do
v2Res <- validateIdToken' oidc newTokens
case v2Res of
Nothing -> pure Nothing
Just claims -> do
CTime now <- epochTime
let newUser =
user {
authLoginState = encodeToken newTokens,
authLoginTime = fromIntegral now
}
pure (Just (storeClaims claims req, newUser))
Just claims ->
pure (Just (storeClaims claims req, user))
discover :: T.Text -> IO OpenIDConnect
discover urlText = do
base <- parseAbsoluteURI urlText
let uri = base { U.uriPath = "/.well-known/openid-configuration" }
metadata <- fetchMetadata uri
jwkset <- fetchJWKSet (jwksUri metadata)
pure OpenIDConnect
{ oidcClientId = ""
, oidcClientSecret = ""
, oidcMetadata = metadata
, oidcJwkSet = jwkset
, oidcProviderInfo = defProviderInfo
, oidcManager = Nothing
, oidcScopes = ["openid"]
, oidcAllowedSkew = 0
}
defProviderInfo :: ProviderInfo
defProviderInfo = ProviderInfo "OpenID Connect Provider" "" ""
fetchMetadata :: U.URI -> IO Metadata
fetchMetadata metadataEndpoint = do
req <- parseRequestThrow (S8.unpack $ U.serializeURIRef' metadataEndpoint)
getResponseBody <$> httpJSON req
fetchJWKSet :: T.Text -> IO JOSE.JWKSet
fetchJWKSet jwkSetEndpoint = do
req <- parseRequestThrow (T.unpack jwkSetEndpoint)
getResponseBody <$> httpJSON req
mkOauth2 :: OpenIDConnect -> Maybe (Text.Hamlet.Render ProviderUrl) -> IO OA2.OAuth2
mkOauth2 OpenIDConnect {..} renderUrl = do
callbackURI <- for renderUrl $ \render -> parseAbsoluteURI $ render (ProviderUrl ["complete"]) []
pure OA2.OAuth2
{ oauthClientId = oidcClientId
, oauthClientSecret = Just oidcClientSecret
, oauthOAuthorizeEndpoint = authorizationEndpoint oidcMetadata
, oauthAccessTokenEndpoint = tokenEndpoint oidcMetadata
, oauthCallback = callbackURI
}
validateIdToken :: OpenIDConnect -> OA2.IdToken -> IO (Either JWT.JWTError JWT.ClaimsSet)
validateIdToken oidc (OA2.IdToken idToken) = runExceptT $ do
signedJwt <- JOSE.decodeCompact (TLE.encodeUtf8 $ TL.fromStrict idToken)
JWT.verifyClaims (validationSettings oidc) (oidcJwkSet oidc) signedJwt
validateIdToken' :: OpenIDConnect -> OA2.OAuth2Token -> IO (Maybe JWT.ClaimsSet)
validateIdToken' oidc tokens =
case OA2.idToken tokens of
Nothing -> pure Nothing
Just idToken ->
either (const Nothing) Just <$> validateIdToken oidc idToken
validationSettings :: OpenIDConnect -> JWT.JWTValidationSettings
validationSettings oidc =
validateAudience oidc
& JWT.defaultJWTValidationSettings
& Lens.set JWT.jwtValidationSettingsCheckIssuedAt True
& Lens.set JWT.jwtValidationSettingsIssuerPredicate (validateIssuer oidc)
& Lens.set JWT.jwtValidationSettingsAllowedSkew (oidcAllowedSkew oidc)
validateAudience :: OpenIDConnect -> JWT.StringOrURI -> Bool
validateAudience oidc audClaim =
audienceFromJWT == Just correctClientId
where
correctClientId = oidcClientId oidc
audienceFromJWT = fromStringOrURI audClaim
validateIssuer :: OpenIDConnect -> JWT.StringOrURI -> Bool
validateIssuer oidc issClaim =
issuerFromJWT == Just correctIssuer
where
correctIssuer = issuer (oidcMetadata oidc)
issuerFromJWT = fromStringOrURI issClaim
fromStringOrURI :: JWT.StringOrURI -> Maybe T.Text
fromStringOrURI stringOrURI =
Lens.Extras.preview JWT.string stringOrURI
<|> fmap (T.pack . show) (Lens.Extras.preview JWT.uri stringOrURI)
storeClaims :: JWT.ClaimsSet -> Request -> Request
storeClaims claims req =
req { vault = Vault.insert idTokenKey claims (vault req) }
getIdToken :: Request -> Maybe JWT.ClaimsSet
getIdToken req = Vault.lookup idTokenKey (vault req)
idTokenKey :: Vault.Key JWT.ClaimsSet
idTokenKey = unsafePerformIO Vault.newKey
{-# NOINLINE idTokenKey #-}