{-# LANGUAGE DuplicateRecordFields #-}

module Network.OIDC.WellKnown where

import Control.Monad.Except
import Data.Aeson
import Data.Bifunctor
import Data.ByteString.Lazy (ByteString)
import Data.ByteString.Lazy qualified as BSL
import Data.Text.Lazy (Text)
import Data.Text.Lazy qualified as TL
import Data.Text.Lazy.Encoding qualified as TL
import GHC.Generics
import Network.HTTP.Simple
import Network.HTTP.Types.Status
import URI.ByteString

-- | Slim OpenID Configuration
-- TODO: could add more fields to be complete.
data OpenIDConfiguration = OpenIDConfiguration
  { OpenIDConfiguration -> Text
issuer :: Text,
    OpenIDConfiguration -> Text
authorizationEndpoint :: Text,
    OpenIDConfiguration -> Text
tokenEndpoint :: Text,
    OpenIDConfiguration -> Text
userinfoEndpoint :: Text,
    OpenIDConfiguration -> Text
jwksUri :: Text
  }
  deriving (forall x. Rep OpenIDConfiguration x -> OpenIDConfiguration
forall x. OpenIDConfiguration -> Rep OpenIDConfiguration x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep OpenIDConfiguration x -> OpenIDConfiguration
$cfrom :: forall x. OpenIDConfiguration -> Rep OpenIDConfiguration x
Generic)

data OpenIDConfigurationUris = OpenIDConfigurationUris
  { OpenIDConfigurationUris -> URI
authorizationUri :: URI,
    OpenIDConfigurationUris -> URI
tokenUri :: URI,
    OpenIDConfigurationUris -> URI
userinfoUri :: URI,
    OpenIDConfigurationUris -> URI
jwksUri :: URI
  }
  deriving (forall x. Rep OpenIDConfigurationUris x -> OpenIDConfigurationUris
forall x. OpenIDConfigurationUris -> Rep OpenIDConfigurationUris x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep OpenIDConfigurationUris x -> OpenIDConfigurationUris
$cfrom :: forall x. OpenIDConfigurationUris -> Rep OpenIDConfigurationUris x
Generic)

instance FromJSON OpenIDConfiguration where
  parseJSON :: Value -> Parser OpenIDConfiguration
parseJSON = forall a.
(Generic a, GFromJSON Zero (Rep a)) =>
Options -> Value -> Parser a
genericParseJSON Options
defaultOptions {fieldLabelModifier :: String -> String
fieldLabelModifier = Char -> String -> String
camelTo2 Char
'_'}

wellknownUrl :: TL.Text
wellknownUrl :: Text
wellknownUrl = Text
"/.well-known/openid-configuration"

fetchWellKnown ::
  MonadIO m =>
  -- | Domain
  TL.Text ->
  ExceptT Text m OpenIDConfiguration
fetchWellKnown :: forall (m :: * -> *).
MonadIO m =>
Text -> ExceptT Text m OpenIDConfiguration
fetchWellKnown Text
domain = forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$ do
  let uri :: Text
uri = Text
"https://" forall a. Semigroup a => a -> a -> a
<> Text
domain forall a. Semigroup a => a -> a -> a
<> Text
wellknownUrl
  Request
req <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest (Text -> String
TL.unpack Text
uri)
  Response ByteString
resp <- forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs Request
req
  forall (m :: * -> *) a. Monad m => a -> m a
return (Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse Response ByteString
resp)

fetchWellKnownUris :: MonadIO m => TL.Text -> ExceptT Text m OpenIDConfigurationUris
fetchWellKnownUris :: forall (m :: * -> *).
MonadIO m =>
Text -> ExceptT Text m OpenIDConfigurationUris
fetchWellKnownUris Text
domain = do
  OpenIDConfiguration {Text
jwksUri :: Text
userinfoEndpoint :: Text
tokenEndpoint :: Text
authorizationEndpoint :: Text
issuer :: Text
$sel:jwksUri:OpenIDConfiguration :: OpenIDConfiguration -> Text
$sel:userinfoEndpoint:OpenIDConfiguration :: OpenIDConfiguration -> Text
$sel:tokenEndpoint:OpenIDConfiguration :: OpenIDConfiguration -> Text
$sel:authorizationEndpoint:OpenIDConfiguration :: OpenIDConfiguration -> Text
$sel:issuer:OpenIDConfiguration :: OpenIDConfiguration -> Text
..} <- forall (m :: * -> *).
MonadIO m =>
Text -> ExceptT Text m OpenIDConfiguration
fetchWellKnown Text
domain
  forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT (String -> Text
TL.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show) forall a b. (a -> b) -> a -> b
$ do
    URI
ae <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (URIParserOptions -> ByteString -> Either URIParseError URI
parseURI URIParserOptions
strictURIParserOptions forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BSL.toStrict forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TL.encodeUtf8 Text
authorizationEndpoint)
    URI
te <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (URIParserOptions -> ByteString -> Either URIParseError URI
parseURI URIParserOptions
strictURIParserOptions forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BSL.toStrict forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TL.encodeUtf8 Text
tokenEndpoint)
    URI
ue <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (URIParserOptions -> ByteString -> Either URIParseError URI
parseURI URIParserOptions
strictURIParserOptions forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BSL.toStrict forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TL.encodeUtf8 Text
userinfoEndpoint)
    URI
jwks <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (URIParserOptions -> ByteString -> Either URIParseError URI
parseURI URIParserOptions
strictURIParserOptions forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BSL.toStrict forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TL.encodeUtf8 Text
jwksUri)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure
      OpenIDConfigurationUris
        { $sel:authorizationUri:OpenIDConfigurationUris :: URI
authorizationUri = URI
ae,
          $sel:tokenUri:OpenIDConfigurationUris :: URI
tokenUri = URI
te,
          $sel:userinfoUri:OpenIDConfigurationUris :: URI
userinfoUri = URI
ue,
          $sel:jwksUri:OpenIDConfigurationUris :: URI
jwksUri = URI
jwks
        }

handleWellKnownResponse :: Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse :: Response ByteString -> Either Text OpenIDConfiguration
handleWellKnownResponse Response ByteString
resp = do
  let rawBody :: ByteString
rawBody = forall a. Response a -> a
getResponseBody Response ByteString
resp
  let rStatus :: Status
rStatus = forall a. Response a -> Status
getResponseStatus Response ByteString
resp
  if Status
rStatus forall a. Eq a => a -> a -> Bool
== Status
status200
    then forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Text
"handleWellKnownResponse decode response failed: " forall a. Semigroup a => a -> a -> a
<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
TL.pack) (forall a. FromJSON a => ByteString -> Either String a
eitherDecode ByteString
rawBody)
    else forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Text
"handleWellKnownResponse failed: " forall a. Semigroup a => a -> a -> a
<> String -> Text
TL.pack (forall a. Show a => a -> String
show Status
rStatus) forall a. Semigroup a => a -> a -> a
<> ByteString -> Text
TL.decodeUtf8 ByteString
rawBody