{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}
module Network.Wai.Middleware.Auth.OAuth2
  ( OAuth2(..)
  , oAuth2Parser
  , URIParseException(..)
  , parseAbsoluteURI
  , getAccessToken
  ) where

import           Control.Monad.Catch
import           Data.Aeson.TH                        (defaultOptions,
                                                       deriveJSON,
                                                       fieldLabelModifier)
import           Data.Functor                         ((<&>))
import           Data.Int                             (Int64)
import           Data.Proxy                           (Proxy (..))
import qualified Data.Text                            as T
import           Data.Text.Encoding                   (encodeUtf8)
import           Foreign.C.Types                      (CTime (..))
import           Network.HTTP.Client.TLS              (getGlobalManager)
import qualified Network.OAuth.OAuth2                 as OA2
import           Network.Wai                          (Request)
import           Network.Wai.Auth.Internal            (decodeToken, encodeToken,
                                                       oauth2Login,
                                                       refreshTokens)
import           Network.Wai.Auth.Tools               (toLowerUnderscore)
import qualified Network.Wai.Middleware.Auth          as MA
import           Network.Wai.Middleware.Auth.Provider
import           System.PosixCompat.Time              (epochTime)
import qualified URI.ByteString                       as U

-- | General OAuth2 authentication `Provider`.
data OAuth2 = OAuth2
  { oa2ClientId            :: T.Text
  , oa2ClientSecret        :: T.Text
  , oa2AuthorizeEndpoint   :: T.Text
  , oa2AccessTokenEndpoint :: T.Text
  , oa2Scope               :: Maybe [T.Text]
  , oa2ProviderInfo        :: ProviderInfo
  }

-- | Used for validating proper url structure. Can be thrown by
-- `parseAbsoluteURI` and consequently by `handleLogin` for `OAuth2` `Provider`
-- instance.
--
-- @since 0.1.2.0
data URIParseException = URIParseException U.URIParseError deriving Show

instance Exception URIParseException

-- | Parse absolute URI and throw `URIParseException` in case it is malformed
--
-- @since 0.1.2.0
parseAbsoluteURI :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI urlTxt = do
  case U.parseURI U.strictURIParserOptions (encodeUtf8 urlTxt) of
    Left err  -> throwM $ URIParseException err
    Right url -> return url

getClientId :: T.Text -> T.Text
getClientId = id

getClientSecret :: T.Text -> T.Text
getClientSecret = id

-- | Aeson parser for `OAuth2` provider.
--
-- @since 0.1.0
oAuth2Parser :: ProviderParser
oAuth2Parser = mkProviderParser (Proxy :: Proxy OAuth2)


instance AuthProvider OAuth2 where
  getProviderName _ = "oauth2"
  getProviderInfo = oa2ProviderInfo
  handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do
    authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint
    accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint
    callbackURI <- parseAbsoluteURI $ renderUrl (ProviderUrl ["complete"]) []
    let oauth2 =
          OA2.OAuth2
          { oauthClientId = getClientId oa2ClientId
          , oauthClientSecret = Just $ getClientSecret oa2ClientSecret
          , oauthOAuthorizeEndpoint = authEndpointURI
          , oauthAccessTokenEndpoint = accessTokenEndpointURI
          , oauthCallback = Just callbackURI
          }
    man <- getGlobalManager
    oauth2Login
      oauth2
      man
      oa2Scope
      (getProviderName oa2)
      req
      suffix
      onSuccess
      onFailure
  refreshLoginState OAuth2 {..} req user = do
    authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint
    accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint
    let loginState = authLoginState user
    case decodeToken loginState of
      Left _ -> pure Nothing
      Right tokens -> do
        CTime now <- epochTime
        if tokenExpired user now tokens then do
          let oauth2 =
                OA2.OAuth2
                { oauthClientId = getClientId oa2ClientId
                , oauthClientSecret = Just (getClientSecret oa2ClientSecret)
                , oauthOAuthorizeEndpoint = authEndpointURI
                , oauthAccessTokenEndpoint = accessTokenEndpointURI
                -- Setting callback endpoint to `Nothing` below is a lie.
                -- We do have a callback endpoint but in this context
                -- don't have access to the function that can render it.
                -- We get away with this because the callback endpoint is
                -- not needed for obtaining a refresh token, the only
                -- way we use the config here constructed.
                , oauthCallback = Nothing
                }
          man <- getGlobalManager
          rRes <- refreshTokens tokens man oauth2
          pure (rRes <&> \newTokens -> (req, user {
                 authLoginState = encodeToken newTokens,
                 authLoginTime = fromIntegral now
               }))
        else
          pure (Just (req, user))

tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool
tokenExpired user now tokens =
  case OA2.expiresIn tokens of
    Nothing -> False
    Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now

$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)

-- | Get the @AccessToken@ for the current user.
--
-- If called on a @Request@ behind the middleware, should always return a
-- @Just@ value.
--
-- @since 0.2.0.0
getAccessToken :: Request -> Maybe OA2.OAuth2Token
getAccessToken req = do
  user <- MA.getAuthUser req
  either (const Nothing) Just $ decodeToken (authLoginState user)