{-# LANGUAGE CPP #-}

module Network.OIDC.WellKnown where

import Control.Monad.Except
#if MIN_VERSION_base(4,18,0)
import Control.Monad.IO.Class
#endif
import Data.Aeson
import Data.Aeson.Types
import Data.Bifunctor
import Data.ByteString.Lazy (ByteString)
import Data.Text.Lazy (Text)
import Data.Text.Lazy qualified as TL
import Data.Text.Lazy.Encoding qualified as TL
import Network.HTTP.Simple
import Network.HTTP.Types.Status
import URI.ByteString
import URI.ByteString.Aeson ()

-- | Slim OpenID Configuration
-- TODO: could add more fields to be complete.
--
-- See spec <https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata>
data OpenIDConfiguration = OpenIDConfiguration
  { OpenIDConfiguration -> URI
issuer :: URI
  , OpenIDConfiguration -> URI
authorizationEndpoint :: URI
  , OpenIDConfiguration -> URI
tokenEndpoint :: URI
  , OpenIDConfiguration -> URI
userinfoEndpoint :: URI
  , OpenIDConfiguration -> URI
jwksUri :: URI
  , OpenIDConfiguration -> URI
deviceAuthorizationEndpoint :: URI
  }
  deriving (Int -> OpenIDConfiguration -> ShowS
[OpenIDConfiguration] -> ShowS
OpenIDConfiguration -> String
(Int -> OpenIDConfiguration -> ShowS)
-> (OpenIDConfiguration -> String)
-> ([OpenIDConfiguration] -> ShowS)
-> Show OpenIDConfiguration
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OpenIDConfiguration -> ShowS
showsPrec :: Int -> OpenIDConfiguration -> ShowS
$cshow :: OpenIDConfiguration -> String
show :: OpenIDConfiguration -> String
$cshowList :: [OpenIDConfiguration] -> ShowS
showList :: [OpenIDConfiguration] -> ShowS
Show, OpenIDConfiguration -> OpenIDConfiguration -> Bool
(OpenIDConfiguration -> OpenIDConfiguration -> Bool)
-> (OpenIDConfiguration -> OpenIDConfiguration -> Bool)
-> Eq OpenIDConfiguration
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
== :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
$c/= :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
/= :: OpenIDConfiguration -> OpenIDConfiguration -> Bool
Eq)

instance FromJSON OpenIDConfiguration where
  parseJSON :: Value -> Parser OpenIDConfiguration
  parseJSON :: Value -> Parser OpenIDConfiguration
parseJSON = String
-> (Object -> Parser OpenIDConfiguration)
-> Value
-> Parser OpenIDConfiguration
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"parseJSON OpenIDConfiguration" ((Object -> Parser OpenIDConfiguration)
 -> Value -> Parser OpenIDConfiguration)
-> (Object -> Parser OpenIDConfiguration)
-> Value
-> Parser OpenIDConfiguration
forall a b. (a -> b) -> a -> b
$ \Object
t -> do
    URI
issuer <- Object
t Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"issuer"
    URI
authorizationEndpoint <- Object
t Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"authorization_endpoint"
    URI
tokenEndpoint <- Object
t Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"token_endpoint"
    URI
userinfoEndpoint <- Object
t Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"userinfo_endpoint"
    URI
jwksUri <- Object
t Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"jwks_uri"
    URI
deviceAuthorizationEndpoint <- Object
t Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"device_authorization_endpoint"
    OpenIDConfiguration -> Parser OpenIDConfiguration
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure OpenIDConfiguration {URI
issuer :: URI
authorizationEndpoint :: URI
tokenEndpoint :: URI
userinfoEndpoint :: URI
jwksUri :: URI
deviceAuthorizationEndpoint :: URI
issuer :: URI
authorizationEndpoint :: URI
tokenEndpoint :: URI
userinfoEndpoint :: URI
jwksUri :: URI
deviceAuthorizationEndpoint :: URI
..}

wellknownUrl :: TL.Text
wellknownUrl :: Text
wellknownUrl = Text
"/.well-known/openid-configuration"

fetchWellKnown ::
  MonadIO m =>
  -- | Domain
  TL.Text ->
  ExceptT Text m OpenIDConfiguration
fetchWellKnown :: forall (m :: * -> *).
MonadIO m =>
Text -> ExceptT Text m OpenIDConfiguration
fetchWellKnown Text
domain = m (Either Text OpenIDConfiguration)
-> ExceptT Text m OpenIDConfiguration
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either Text OpenIDConfiguration)
 -> ExceptT Text m OpenIDConfiguration)
-> m (Either Text OpenIDConfiguration)
-> ExceptT Text m OpenIDConfiguration
forall a b. (a -> b) -> a -> b
$ do
  let uri :: Text
uri = Text
"https://" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
domain Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
wellknownUrl
  Request
req <- IO Request -> m Request
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> m Request) -> IO Request -> m Request
forall a b. (a -> b) -> a -> b
$ String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest (Text -> String
TL.unpack Text
uri)
  Response ByteString
resp <- Request -> m (Response ByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs Request
req
  Either Text OpenIDConfiguration
-> m (Either Text OpenIDConfiguration)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse Response ByteString
resp)

handleWellKnownResponse :: Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse :: Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse Response ByteString
resp = do
  let rawBody :: ByteString
rawBody = Response ByteString -> ByteString
forall a. Response a -> a
getResponseBody Response ByteString
resp
  let rStatus :: Status
rStatus = Response ByteString -> Status
forall a. Response a -> Status
getResponseStatus Response ByteString
resp
  if Status
rStatus Status -> Status -> Bool
forall a. Eq a => a -> a -> Bool
== Status
status200
    then (String -> Text)
-> Either String OpenIDConfiguration
-> Either Text OpenIDConfiguration
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Text
"handleWellKnownResponse decode response failed: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>) (Text -> Text) -> (String -> Text) -> String -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
TL.pack) (ByteString -> Either String OpenIDConfiguration
forall a. FromJSON a => ByteString -> Either String a
eitherDecode ByteString
rawBody)
    else Text -> Either Text OpenIDConfiguration
forall a b. a -> Either a b
Left (Text -> Either Text OpenIDConfiguration)
-> Text -> Either Text OpenIDConfiguration
forall a b. (a -> b) -> a -> b
$ Text
"handleWellKnownResponse failed: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
TL.pack (Status -> String
forall a. Show a => a -> String
show Status
rStatus) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ByteString -> Text
TL.decodeUtf8 ByteString
rawBody