{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.Dispatch
( FetchCreds
, dispatchAuthRequest
) where
import Control.Exception.Safe (throwString, tryIO)
import Control.Monad (unless, (<=<))
import Data.Monoid ((<>))
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2
import System.Random (newStdGen, randomRs)
import URI.ByteString.Extension
import Yesod.Auth
import Yesod.Auth.OAuth2.ErrorResponse (onErrorResponse)
import Yesod.Core
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)
dispatchAuthRequest
:: Text
-> OAuth2
-> FetchCreds m
-> Text
-> [Text]
-> AuthHandler m TypedContent
dispatchAuthRequest name oauth2 _ "GET" ["forward"] = dispatchForward name oauth2
dispatchAuthRequest name oauth2 getCreds "GET" ["callback"] = dispatchCallback name oauth2 getCreds
dispatchAuthRequest _ _ _ _ _ = notFound
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent
dispatchForward name oauth2 = do
csrf <- setSessionCSRF $ tokenSessionKey name
oauth2' <- withCallbackAndState name oauth2 csrf
redirect $ toText $ authorizationUrl oauth2'
dispatchCallback :: Text -> OAuth2 -> FetchCreds m -> AuthHandler m TypedContent
dispatchCallback name oauth2 getCreds = do
csrf <- verifySessionCSRF $ tokenSessionKey name
onErrorResponse errInvalidOAuth
code <- requireGetParam "code"
manager <- authHttpManager
oauth2' <- withCallbackAndState name oauth2 csrf
token <- denyLeft $ fetchAccessToken manager oauth2' $ ExchangeToken code
creds <- denyLeft $ tryIO $ getCreds manager token
setCredsRedirect creds
where
denyLeft :: (MonadHandler m, MonadLogger m, Show e) => IO (Either e a) -> m a
denyLeft = either errInvalidOAuth pure <=< liftIO
errInvalidOAuth :: (MonadHandler m, MonadLogger m, Show e) => e -> m a
errInvalidOAuth err = do
$(logError) $ T.pack $ "OAuth2 error: " <> show err
permissionDenied "Invalid OAuth2 authentication attempt"
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
withCallbackAndState name oauth2 csrf = do
let url = PluginR name ["callback"]
render <- getParentUrlRender
let callbackText = render url
callback <- maybe
(liftIO
$ throwString
$ "Invalid callback URI: "
<> T.unpack callbackText
<> ". Not using an absolute Approot?"
) pure $ fromText callbackText
pure oauth2
{ oauthCallback = Just callback
, oauthOAuthorizeEndpoint = oauthOAuthorizeEndpoint oauth2
`withQuery` [("state", encodeUtf8 csrf)]
}
getParentUrlRender :: MonadHandler m => m (Route (SubHandlerSite m) -> Text)
getParentUrlRender = (.)
<$> getUrlRender
<*> getRouteToParent
setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF sessionKey = do
csrfToken <- liftIO randomToken
csrfToken <$ setSession sessionKey csrfToken
where
randomToken = T.pack . take 30 . randomRs ('a', 'z') <$> newStdGen
verifySessionCSRF :: MonadHandler m => Text -> m Text
verifySessionCSRF sessionKey = do
token <- requireGetParam "state"
sessionToken <- lookupSession sessionKey
deleteSession sessionKey
unless (sessionToken == Just token)
$ permissionDenied "Invalid OAuth2 state token"
return token
requireGetParam :: MonadHandler m => Text -> m Text
requireGetParam key = do
m <- lookupGetParam key
maybe errInvalidArgs return m
where
errInvalidArgs = invalidArgs ["The '" <> key <> "' parameter is required"]
tokenSessionKey :: Text -> Text
tokenSessionKey name = "_yesod_oauth2_" <> name