{-# LANGUAGE CPP #-}

module Network.OIDC.WellKnown where

import Control.Monad.Except
#if MIN_VERSION_base(4,18,0)
import Control.Monad.IO.Class
#endif
import Data.Aeson
import Data.Aeson.Types
import Data.Bifunctor
import Data.ByteString.Lazy (ByteString)
import Data.Text.Lazy (Text)
import Data.Text.Lazy qualified as TL
import Data.Text.Lazy.Encoding qualified as TL
import Network.HTTP.Simple
import Network.HTTP.Types.Status
import URI.ByteString
import URI.ByteString.Aeson ()

-- | Slim OpenID Configuration
-- TODO: could add more fields to be complete.
--
-- See spec <https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata>
data OpenIDConfiguration = OpenIDConfiguration
  { OpenIDConfiguration -> URI
issuer :: URI
  , OpenIDConfiguration -> URI
authorizationEndpoint :: URI
  , OpenIDConfiguration -> URI
tokenEndpoint :: URI
  , OpenIDConfiguration -> URI
userinfoEndpoint :: URI
  , OpenIDConfiguration -> URI
jwksUri :: URI
  , OpenIDConfiguration -> URI
deviceAuthorizationEndpoint :: URI
  }
  deriving (Int -> OpenIDConfiguration -> ShowS
[OpenIDConfiguration] -> ShowS
OpenIDConfiguration -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OpenIDConfiguration] -> ShowS
$cshowList :: [OpenIDConfiguration] -> ShowS
show :: OpenIDConfiguration -> String
$cshow :: OpenIDConfiguration -> String
showsPrec :: Int -> OpenIDConfiguration -> ShowS
$cshowsPrec :: Int -> OpenIDConfiguration -> ShowS
Show, OpenIDConfiguration -> OpenIDConfiguration -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
$c/= :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
== :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
$c== :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
Eq)

instance FromJSON OpenIDConfiguration where
  parseJSON :: Value -> Parser OpenIDConfiguration
  parseJSON :: Value -> Parser OpenIDConfiguration
parseJSON = forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"parseJSON OpenIDConfiguration" forall a b. (a -> b) -> a -> b
$ \Object
t -> do
    URI
issuer <- Object
t forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"issuer"
    URI
authorizationEndpoint <- Object
t forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"authorization_endpoint"
    URI
tokenEndpoint <- Object
t forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"token_endpoint"
    URI
userinfoEndpoint <- Object
t forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"userinfo_endpoint"
    URI
jwksUri <- Object
t forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"jwks_uri"
    URI
deviceAuthorizationEndpoint <- Object
t forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"device_authorization_endpoint"
    forall (f :: * -> *) a. Applicative f => a -> f a
pure OpenIDConfiguration {URI
deviceAuthorizationEndpoint :: URI
jwksUri :: URI
userinfoEndpoint :: URI
tokenEndpoint :: URI
authorizationEndpoint :: URI
issuer :: URI
deviceAuthorizationEndpoint :: URI
jwksUri :: URI
userinfoEndpoint :: URI
tokenEndpoint :: URI
authorizationEndpoint :: URI
issuer :: URI
..}

wellknownUrl :: TL.Text
wellknownUrl :: Text
wellknownUrl = Text
"/.well-known/openid-configuration"

fetchWellKnown ::
  MonadIO m =>
  -- | Domain
  TL.Text ->
  ExceptT Text m OpenIDConfiguration
fetchWellKnown :: forall (m :: * -> *).
MonadIO m =>
Text -> ExceptT Text m OpenIDConfiguration
fetchWellKnown Text
domain = forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$ do
  let uri :: Text
uri = Text
"https://" forall a. Semigroup a => a -> a -> a
<> Text
domain forall a. Semigroup a => a -> a -> a
<> Text
wellknownUrl
  Request
req <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest (Text -> String
TL.unpack Text
uri)
  Response ByteString
resp <- forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs Request
req
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse Response ByteString
resp)

handleWellKnownResponse :: Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse :: Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse Response ByteString
resp = do
  let rawBody :: ByteString
rawBody = forall a. Response a -> a
getResponseBody Response ByteString
resp
  let rStatus :: Status
rStatus = forall a. Response a -> Status
getResponseStatus Response ByteString
resp
  if Status
rStatus forall a. Eq a => a -> a -> Bool
== Status
status200
    then forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Text
"handleWellKnownResponse decode response failed: " forall a. Semigroup a => a -> a -> a
<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
TL.pack) (forall a. FromJSON a => ByteString -> Either String a
eitherDecode ByteString
rawBody)
    else forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Text
"handleWellKnownResponse failed: " forall a. Semigroup a => a -> a -> a
<> String -> Text
TL.pack (forall a. Show a => a -> String
show Status
rStatus) forall a. Semigroup a => a -> a -> a
<> ByteString -> Text
TL.decodeUtf8 ByteString
rawBody