{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
module Network.Wai.Middleware.Auth.OAuth2
( OAuth2(..)
, oAuth2Parser
, URIParseException(..)
, parseAbsoluteURI
, getAccessToken
) where
import Control.Monad.Catch
import Data.Aeson.TH (defaultOptions,
deriveJSON,
fieldLabelModifier)
import Data.Functor ((<&>))
import Data.Int (Int64)
import Data.Proxy (Proxy (..))
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Foreign.C.Types (CTime (..))
import Network.HTTP.Client.TLS (getGlobalManager)
import qualified Network.OAuth.OAuth2 as OA2
import Network.Wai (Request)
import Network.Wai.Auth.Internal (decodeToken, encodeToken,
oauth2Login,
refreshTokens)
import Network.Wai.Auth.Tools (toLowerUnderscore)
import qualified Network.Wai.Middleware.Auth as MA
import Network.Wai.Middleware.Auth.Provider
import System.PosixCompat.Time (epochTime)
import qualified URI.ByteString as U
data OAuth2 = OAuth2
{ oa2ClientId :: T.Text
, oa2ClientSecret :: T.Text
, oa2AuthorizeEndpoint :: T.Text
, oa2AccessTokenEndpoint :: T.Text
, oa2Scope :: Maybe [T.Text]
, oa2ProviderInfo :: ProviderInfo
}
data URIParseException = URIParseException U.URIParseError deriving Show
instance Exception URIParseException
parseAbsoluteURI :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI urlTxt = do
case U.parseURI U.strictURIParserOptions (encodeUtf8 urlTxt) of
Left err -> throwM $ URIParseException err
Right url -> return url
getClientId :: T.Text -> T.Text
getClientId = id
getClientSecret :: T.Text -> T.Text
getClientSecret = id
oAuth2Parser :: ProviderParser
oAuth2Parser = mkProviderParser (Proxy :: Proxy OAuth2)
instance AuthProvider OAuth2 where
getProviderName _ = "oauth2"
getProviderInfo = oa2ProviderInfo
handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do
authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint
accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint
callbackURI <- parseAbsoluteURI $ renderUrl (ProviderUrl ["complete"]) []
let oauth2 =
OA2.OAuth2
{ oauthClientId = getClientId oa2ClientId
, oauthClientSecret = Just $ getClientSecret oa2ClientSecret
, oauthOAuthorizeEndpoint = authEndpointURI
, oauthAccessTokenEndpoint = accessTokenEndpointURI
, oauthCallback = Just callbackURI
}
man <- getGlobalManager
oauth2Login
oauth2
man
oa2Scope
(getProviderName oa2)
req
suffix
onSuccess
onFailure
refreshLoginState OAuth2 {..} req user = do
authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint
accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint
let loginState = authLoginState user
case decodeToken loginState of
Left _ -> pure Nothing
Right tokens -> do
CTime now <- epochTime
if tokenExpired user now tokens then do
let oauth2 =
OA2.OAuth2
{ oauthClientId = getClientId oa2ClientId
, oauthClientSecret = Just (getClientSecret oa2ClientSecret)
, oauthOAuthorizeEndpoint = authEndpointURI
, oauthAccessTokenEndpoint = accessTokenEndpointURI
, oauthCallback = Nothing
}
man <- getGlobalManager
rRes <- refreshTokens tokens man oauth2
pure (rRes <&> \newTokens -> (req, user {
authLoginState = encodeToken newTokens,
authLoginTime = fromIntegral now
}))
else
pure (Just (req, user))
tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool
tokenExpired user now tokens =
case OA2.expiresIn tokens of
Nothing -> False
Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now
$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)
getAccessToken :: Request -> Maybe OA2.OAuth2Token
getAccessToken req = do
user <- MA.getAuthUser req
either (const Nothing) Just $ decodeToken (authLoginState user)