{-# OPTIONS_HADDOCK hide, not-home #-}
{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections     #-}
module Network.Wai.Auth.Internal
  ( OAuth2TokenBinary(..)
  , Metadata(..)
  , encodeToken
  , decodeToken
  , oauth2Login
  , refreshTokens
  ) where

import qualified Data.Aeson                           as Aeson
import           Data.Binary                          (Binary(get, put), encode,
                                                      decodeOrFail)
import qualified Data.ByteString                      as S
import qualified Data.ByteString.Char8                as S8 (pack)
import qualified Data.ByteString.Lazy                 as SL
import qualified Data.Text                            as T
import           Data.Text.Encoding                   (encodeUtf8,
                                                       decodeUtf8With)
import           Data.Text.Encoding.Error             (lenientDecode)
import           GHC.Generics                         (Generic)
import           Network.HTTP.Client                  (Manager)
import           Network.HTTP.Types                   (Status, status303,
                                                       status403, status404,
                                                       status501)
import qualified Network.OAuth.OAuth2                 as OA2
import           Network.Wai                          (Request, Response,
                                                       queryString, responseLBS)
import           Network.Wai.Middleware.Auth.Provider
import qualified URI.ByteString                       as U
import           URI.ByteString                       (URI)

decodeToken :: S.ByteString -> Either String OA2.OAuth2Token
decodeToken :: ByteString -> Either String OAuth2Token
decodeToken ByteString
bs =
  case ByteString
-> Either
     (ByteString, ByteOffset, String)
     (ByteString, ByteOffset, OAuth2TokenBinary)
forall a.
Binary a =>
ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
decodeOrFail (ByteString
 -> Either
      (ByteString, ByteOffset, String)
      (ByteString, ByteOffset, OAuth2TokenBinary))
-> ByteString
-> Either
     (ByteString, ByteOffset, String)
     (ByteString, ByteOffset, OAuth2TokenBinary)
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
SL.fromStrict ByteString
bs of
    Right (ByteString
_, ByteOffset
_, OAuth2TokenBinary
token) -> OAuth2Token -> Either String OAuth2Token
forall a b. b -> Either a b
Right (OAuth2Token -> Either String OAuth2Token)
-> OAuth2Token -> Either String OAuth2Token
forall a b. (a -> b) -> a -> b
$ OAuth2TokenBinary -> OAuth2Token
unOAuth2TokenBinary OAuth2TokenBinary
token
    Left (ByteString
_, ByteOffset
_, String
err) -> String -> Either String OAuth2Token
forall a b. a -> Either a b
Left String
err

encodeToken :: OA2.OAuth2Token -> S.ByteString
encodeToken :: OAuth2Token -> ByteString
encodeToken = ByteString -> ByteString
SL.toStrict (ByteString -> ByteString)
-> (OAuth2Token -> ByteString) -> OAuth2Token -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OAuth2TokenBinary -> ByteString
forall a. Binary a => a -> ByteString
encode (OAuth2TokenBinary -> ByteString)
-> (OAuth2Token -> OAuth2TokenBinary) -> OAuth2Token -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OAuth2Token -> OAuth2TokenBinary
OAuth2TokenBinary

newtype OAuth2TokenBinary =
  OAuth2TokenBinary { OAuth2TokenBinary -> OAuth2Token
unOAuth2TokenBinary :: OA2.OAuth2Token }
  deriving (Int -> OAuth2TokenBinary -> ShowS
[OAuth2TokenBinary] -> ShowS
OAuth2TokenBinary -> String
(Int -> OAuth2TokenBinary -> ShowS)
-> (OAuth2TokenBinary -> String)
-> ([OAuth2TokenBinary] -> ShowS)
-> Show OAuth2TokenBinary
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OAuth2TokenBinary] -> ShowS
$cshowList :: [OAuth2TokenBinary] -> ShowS
show :: OAuth2TokenBinary -> String
$cshow :: OAuth2TokenBinary -> String
showsPrec :: Int -> OAuth2TokenBinary -> ShowS
$cshowsPrec :: Int -> OAuth2TokenBinary -> ShowS
Show)

instance Binary OAuth2TokenBinary where
  put :: OAuth2TokenBinary -> Put
put (OAuth2TokenBinary OAuth2Token
token) = do
    Text -> Put
forall t. Binary t => t -> Put
put (Text -> Put) -> Text -> Put
forall a b. (a -> b) -> a -> b
$ AccessToken -> Text
OA2.atoken (AccessToken -> Text) -> AccessToken -> Text
forall a b. (a -> b) -> a -> b
$ OAuth2Token -> AccessToken
OA2.accessToken OAuth2Token
token
    Maybe Text -> Put
forall t. Binary t => t -> Put
put (Maybe Text -> Put) -> Maybe Text -> Put
forall a b. (a -> b) -> a -> b
$ RefreshToken -> Text
OA2.rtoken (RefreshToken -> Text) -> Maybe RefreshToken -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OAuth2Token -> Maybe RefreshToken
OA2.refreshToken OAuth2Token
token
    Maybe Int -> Put
forall t. Binary t => t -> Put
put (Maybe Int -> Put) -> Maybe Int -> Put
forall a b. (a -> b) -> a -> b
$ OAuth2Token -> Maybe Int
OA2.expiresIn OAuth2Token
token
    Maybe Text -> Put
forall t. Binary t => t -> Put
put (Maybe Text -> Put) -> Maybe Text -> Put
forall a b. (a -> b) -> a -> b
$ OAuth2Token -> Maybe Text
OA2.tokenType OAuth2Token
token
    Maybe Text -> Put
forall t. Binary t => t -> Put
put (Maybe Text -> Put) -> Maybe Text -> Put
forall a b. (a -> b) -> a -> b
$ IdToken -> Text
OA2.idtoken (IdToken -> Text) -> Maybe IdToken -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OAuth2Token -> Maybe IdToken
OA2.idToken OAuth2Token
token
  get :: Get OAuth2TokenBinary
get = do
    AccessToken
accessToken <- Text -> AccessToken
OA2.AccessToken (Text -> AccessToken) -> Get Text -> Get AccessToken
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Text
forall t. Binary t => Get t
get
    Maybe RefreshToken
refreshToken <- (Text -> RefreshToken) -> Maybe Text -> Maybe RefreshToken
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> RefreshToken
OA2.RefreshToken (Maybe Text -> Maybe RefreshToken)
-> Get (Maybe Text) -> Get (Maybe RefreshToken)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get (Maybe Text)
forall t. Binary t => Get t
get
    Maybe Int
expiresIn <- Get (Maybe Int)
forall t. Binary t => Get t
get
    Maybe Text
tokenType <- Get (Maybe Text)
forall t. Binary t => Get t
get
    Maybe IdToken
idToken <- (Text -> IdToken) -> Maybe Text -> Maybe IdToken
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> IdToken
OA2.IdToken (Maybe Text -> Maybe IdToken)
-> Get (Maybe Text) -> Get (Maybe IdToken)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get (Maybe Text)
forall t. Binary t => Get t
get
    OAuth2TokenBinary -> Get OAuth2TokenBinary
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OAuth2TokenBinary -> Get OAuth2TokenBinary)
-> OAuth2TokenBinary -> Get OAuth2TokenBinary
forall a b. (a -> b) -> a -> b
$ OAuth2Token -> OAuth2TokenBinary
OAuth2TokenBinary (OAuth2Token -> OAuth2TokenBinary)
-> OAuth2Token -> OAuth2TokenBinary
forall a b. (a -> b) -> a -> b
$
      AccessToken
-> Maybe RefreshToken
-> Maybe Int
-> Maybe Text
-> Maybe IdToken
-> OAuth2Token
OA2.OAuth2Token AccessToken
accessToken Maybe RefreshToken
refreshToken Maybe Int
expiresIn Maybe Text
tokenType Maybe IdToken
idToken

oauth2Login
  :: OA2.OAuth2
  -> Manager
  -> Maybe [T.Text]
  -> T.Text
  -> Request 
  -> [T.Text]
  -> (AuthLoginState -> IO Response)
  -> (Status -> S.ByteString -> IO Response)
  -> IO Response
oauth2Login :: OAuth2
-> Manager
-> Maybe [Text]
-> Text
-> Request
-> [Text]
-> (ByteString -> IO Response)
-> (Status -> ByteString -> IO Response)
-> IO Response
oauth2Login OAuth2
oauth2 Manager
man Maybe [Text]
oa2Scope Text
providerName Request
req [Text]
suffix ByteString -> IO Response
onSuccess Status -> ByteString -> IO Response
onFailure = 
  case [Text]
suffix of
    [] -> do
      -- https://tools.ietf.org/html/rfc6749#section-3.3
      let scope :: Maybe ByteString
scope = (Text -> ByteString
encodeUtf8 (Text -> ByteString) -> ([Text] -> Text) -> [Text] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Text] -> Text
T.intercalate Text
" ") ([Text] -> ByteString) -> Maybe [Text] -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe [Text]
oa2Scope
      let redirectUrl :: ByteString
redirectUrl =
            URIRef Absolute -> ByteString
forall a. URIRef a -> ByteString
getRedirectURI (URIRef Absolute -> ByteString) -> URIRef Absolute -> ByteString
forall a b. (a -> b) -> a -> b
$
            URIRef Absolute -> [(ByteString, ByteString)] -> URIRef Absolute
appendQueryParams
              (OAuth2 -> URIRef Absolute
OA2.authorizationUrl OAuth2
oauth2)
              ([(ByteString, ByteString)]
-> (ByteString -> [(ByteString, ByteString)])
-> Maybe ByteString
-> [(ByteString, ByteString)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (((ByteString, ByteString)
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. a -> [a] -> [a]
: []) ((ByteString, ByteString) -> [(ByteString, ByteString)])
-> (ByteString -> (ByteString, ByteString))
-> ByteString
-> [(ByteString, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
"scope", )) Maybe ByteString
scope)
      Response -> IO Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> IO Response) -> Response -> IO Response
forall a b. (a -> b) -> a -> b
$
        Status -> ResponseHeaders -> ByteString -> Response
responseLBS
          Status
status303
          [(HeaderName
"Location", ByteString
redirectUrl)]
          ByteString
"Redirect to OAuth2 Authentication server"
    [Text
"complete"] ->
      let params :: Query
params = Request -> Query
queryString Request
req
      in case ByteString -> Query -> Maybe (Maybe ByteString)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"code" Query
params of
            Just (Just ByteString
code) -> do
              OAuth2Result Errors OAuth2Token
eRes <- Manager
-> OAuth2 -> ExchangeToken -> IO (OAuth2Result Errors OAuth2Token)
OA2.fetchAccessToken Manager
man OAuth2
oauth2 (ExchangeToken -> IO (OAuth2Result Errors OAuth2Token))
-> ExchangeToken -> IO (OAuth2Result Errors OAuth2Token)
forall a b. (a -> b) -> a -> b
$ ByteString -> ExchangeToken
getExchangeToken ByteString
code
              case OAuth2Result Errors OAuth2Token
eRes of
                Left OAuth2Error Errors
err    -> Status -> ByteString -> IO Response
onFailure Status
status501 (ByteString -> IO Response) -> ByteString -> IO Response
forall a b. (a -> b) -> a -> b
$ String -> ByteString
S8.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ OAuth2Error Errors -> String
forall a. Show a => a -> String
show OAuth2Error Errors
err
                Right OAuth2Token
token -> ByteString -> IO Response
onSuccess (ByteString -> IO Response) -> ByteString -> IO Response
forall a b. (a -> b) -> a -> b
$ OAuth2Token -> ByteString
encodeToken OAuth2Token
token
            Maybe (Maybe ByteString)
_ ->
              case ByteString -> Query -> Maybe (Maybe ByteString)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"error" Query
params of
                (Just (Just ByteString
"access_denied")) ->
                  Status -> ByteString -> IO Response
onFailure
                    Status
status403
                    ByteString
"User rejected access to the application."
                (Just (Just ByteString
error_code)) ->
                  Status -> ByteString -> IO Response
onFailure Status
status501 (ByteString -> IO Response) -> ByteString -> IO Response
forall a b. (a -> b) -> a -> b
$ ByteString
"Received an error: " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
error_code
                (Just Maybe ByteString
Nothing) ->
                  Status -> ByteString -> IO Response
onFailure Status
status501 (ByteString -> IO Response) -> ByteString -> IO Response
forall a b. (a -> b) -> a -> b
$
                  ByteString
"Unknown error connecting to " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>
                  Text -> ByteString
encodeUtf8 Text
providerName
                Maybe (Maybe ByteString)
Nothing ->
                  Status -> ByteString -> IO Response
onFailure
                    Status
status404
                    ByteString
"Page not found. Please continue with login."
    [Text]
_ -> Status -> ByteString -> IO Response
onFailure Status
status404 ByteString
"Page not found. Please continue with login."

refreshTokens :: OA2.OAuth2Token -> Manager -> OA2.OAuth2 -> IO (Maybe OA2.OAuth2Token)
refreshTokens :: OAuth2Token -> Manager -> OAuth2 -> IO (Maybe OAuth2Token)
refreshTokens OAuth2Token
tokens Manager
manager OAuth2
oauth2 = 
  case OAuth2Token -> Maybe RefreshToken
OA2.refreshToken OAuth2Token
tokens of
    Maybe RefreshToken
Nothing -> Maybe OAuth2Token -> IO (Maybe OAuth2Token)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe OAuth2Token
forall a. Maybe a
Nothing
    Just RefreshToken
refreshToken -> do
      OAuth2Result Errors OAuth2Token
res <- Manager
-> OAuth2 -> RefreshToken -> IO (OAuth2Result Errors OAuth2Token)
OA2.refreshAccessToken Manager
manager OAuth2
oauth2 RefreshToken
refreshToken
      case OAuth2Result Errors OAuth2Token
res of
        Left OAuth2Error Errors
_ -> Maybe OAuth2Token -> IO (Maybe OAuth2Token)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe OAuth2Token
forall a. Maybe a
Nothing
        Right OAuth2Token
newTokens -> Maybe OAuth2Token -> IO (Maybe OAuth2Token)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OAuth2Token -> Maybe OAuth2Token
forall a. a -> Maybe a
Just OAuth2Token
newTokens)

getExchangeToken :: S.ByteString -> OA2.ExchangeToken
getExchangeToken :: ByteString -> ExchangeToken
getExchangeToken = Text -> ExchangeToken
OA2.ExchangeToken (Text -> ExchangeToken)
-> (ByteString -> Text) -> ByteString -> ExchangeToken
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode

appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI
appendQueryParams :: URIRef Absolute -> [(ByteString, ByteString)] -> URIRef Absolute
appendQueryParams URIRef Absolute
uri [(ByteString, ByteString)]
params =
  [(ByteString, ByteString)] -> URIRef Absolute -> URIRef Absolute
forall a. [(ByteString, ByteString)] -> URIRef a -> URIRef a
OA2.appendQueryParams [(ByteString, ByteString)]
params URIRef Absolute
uri

getRedirectURI :: U.URIRef a -> S.ByteString
getRedirectURI :: URIRef a -> ByteString
getRedirectURI = URIRef a -> ByteString
forall a. URIRef a -> ByteString
U.serializeURIRef'

data Metadata
  = Metadata
      { Metadata -> Text
issuer :: T.Text
      , Metadata -> URIRef Absolute
authorizationEndpoint :: U.URI
      , Metadata -> URIRef Absolute
tokenEndpoint :: U.URI
      , Metadata -> Maybe Text
userinfoEndpoint :: Maybe T.Text
      , Metadata -> Maybe Text
revocationEndpoint :: Maybe T.Text
      , Metadata -> Text
jwksUri :: T.Text
      , Metadata -> [Text]
responseTypesSupported :: [T.Text]
      , Metadata -> [Text]
subjectTypesSupported :: [T.Text]
      , Metadata -> [Text]
idTokenSigningAlgValuesSupported :: [T.Text]
      , Metadata -> Maybe [Text]
scopesSupported :: Maybe [T.Text]
      , Metadata -> Maybe [Text]
tokenEndpointAuthMethodsSupported :: Maybe [T.Text]
      , Metadata -> Maybe [Text]
claimsSupported :: Maybe [T.Text]
      }
  deriving ((forall x. Metadata -> Rep Metadata x)
-> (forall x. Rep Metadata x -> Metadata) -> Generic Metadata
forall x. Rep Metadata x -> Metadata
forall x. Metadata -> Rep Metadata x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Metadata x -> Metadata
$cfrom :: forall x. Metadata -> Rep Metadata x
Generic)

instance Aeson.FromJSON Metadata where
  parseJSON :: Value -> Parser Metadata
parseJSON = Options -> Value -> Parser Metadata
forall a.
(Generic a, GFromJSON Zero (Rep a)) =>
Options -> Value -> Parser a
Aeson.genericParseJSON Options
metadataAesonOptions

instance Aeson.ToJSON Metadata where

  toJSON :: Metadata -> Value
toJSON = Options -> Metadata -> Value
forall a.
(Generic a, GToJSON' Value Zero (Rep a)) =>
Options -> a -> Value
Aeson.genericToJSON Options
metadataAesonOptions

  toEncoding :: Metadata -> Encoding
toEncoding = Options -> Metadata -> Encoding
forall a.
(Generic a, GToJSON' Encoding Zero (Rep a)) =>
Options -> a -> Encoding
Aeson.genericToEncoding Options
metadataAesonOptions

metadataAesonOptions :: Aeson.Options
metadataAesonOptions :: Options
metadataAesonOptions =
  Options
Aeson.defaultOptions {fieldLabelModifier :: ShowS
Aeson.fieldLabelModifier = Char -> ShowS
Aeson.camelTo2 Char
'_'}