{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.Dispatch
( FetchToken
, fetchAccessToken
, fetchAccessToken2
, FetchCreds
, dispatchAuthRequest
) where
import Control.Monad.Except
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2
import Network.OAuth.OAuth2.TokenRequest (Errors)
import URI.ByteString.Extension
import UnliftIO.Exception
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.DispatchError
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse)
type FetchToken
= Manager -> OAuth2 -> ExchangeToken -> IO (OAuth2Result Errors OAuth2Token)
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)
dispatchAuthRequest
:: Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> Text
-> [Text]
-> AuthHandler m TypedContent
dispatchAuthRequest :: Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> Text
-> [Text]
-> AuthHandler m TypedContent
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
_ FetchCreds m
_ Text
"GET" [Text
"forward"] =
ExceptT DispatchError m TypedContent -> m TypedContent
forall site (m :: * -> *).
MonadAuthHandler site m =>
ExceptT DispatchError m TypedContent -> m TypedContent
handleDispatchError (ExceptT DispatchError m TypedContent -> m TypedContent)
-> ExceptT DispatchError m TypedContent -> m TypedContent
forall a b. (a -> b) -> a -> b
$ Text -> OAuth2 -> ExceptT DispatchError m TypedContent
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> m TypedContent
dispatchForward Text
name OAuth2
oauth2
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds Text
"GET" [Text
"callback"] =
ExceptT DispatchError m TypedContent -> m TypedContent
forall site (m :: * -> *).
MonadAuthHandler site m =>
ExceptT DispatchError m TypedContent -> m TypedContent
handleDispatchError (ExceptT DispatchError m TypedContent -> m TypedContent)
-> ExceptT DispatchError m TypedContent -> m TypedContent
forall a b. (a -> b) -> a -> b
$ Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> ExceptT DispatchError m TypedContent
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> FetchToken -> FetchCreds site -> m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds
dispatchAuthRequest Text
_ OAuth2
_ FetchToken
_ FetchCreds m
_ Text
_ [Text]
_ = m TypedContent
forall (m :: * -> *) a. MonadHandler m => m a
notFound
dispatchForward
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> m TypedContent
dispatchForward :: Text -> OAuth2 -> m TypedContent
dispatchForward Text
name OAuth2
oauth2 = do
Text
csrf <- Text -> m Text
forall (m :: * -> *). MonadHandler m => Text -> m Text
setSessionCSRF (Text -> m Text) -> Text -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
OAuth2
oauth2' <- Text -> OAuth2 -> Text -> m OAuth2
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
Text -> m TypedContent
forall (m :: * -> *) url a.
(MonadHandler m, RedirectUrl (HandlerSite m) url) =>
url -> m a
redirect (Text -> m TypedContent) -> Text -> m TypedContent
forall a b. (a -> b) -> a -> b
$ URI -> Text
toText (URI -> Text) -> URI -> Text
forall a b. (a -> b) -> a -> b
$ OAuth2 -> URI
authorizationUrl OAuth2
oauth2'
dispatchCallback
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> FetchToken
-> FetchCreds site
-> m TypedContent
dispatchCallback :: Text -> OAuth2 -> FetchToken -> FetchCreds site -> m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds site
getCreds = do
(ErrorResponse -> m Any) -> m ()
forall (m :: * -> *) a.
MonadHandler m =>
(ErrorResponse -> m a) -> m ()
onErrorResponse ((ErrorResponse -> m Any) -> m ())
-> (ErrorResponse -> m Any) -> m ()
forall a b. (a -> b) -> a -> b
$ DispatchError -> m Any
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m Any)
-> (ErrorResponse -> DispatchError) -> ErrorResponse -> m Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorResponse -> DispatchError
OAuth2HandshakeError
Text
csrf <- Text -> m Text
forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
verifySessionCSRF (Text -> m Text) -> Text -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
Text
code <- Text -> m Text
forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
requireGetParam Text
"code"
Manager
manager <- m Manager
forall master (m :: * -> *).
(YesodAuth master, MonadHandler m, HandlerSite m ~ master) =>
m Manager
authHttpManager
OAuth2
oauth2' <- Text -> OAuth2 -> Text -> m OAuth2
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
OAuth2Token
token <- (OAuth2Error Errors -> m OAuth2Token)
-> (OAuth2Token -> m OAuth2Token)
-> Either (OAuth2Error Errors) OAuth2Token
-> m OAuth2Token
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (DispatchError -> m OAuth2Token
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m OAuth2Token)
-> (OAuth2Error Errors -> DispatchError)
-> OAuth2Error Errors
-> m OAuth2Token
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OAuth2Error Errors -> DispatchError
OAuth2ResultError) OAuth2Token -> m OAuth2Token
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Either (OAuth2Error Errors) OAuth2Token -> m OAuth2Token)
-> m (Either (OAuth2Error Errors) OAuth2Token) -> m OAuth2Token
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Either (OAuth2Error Errors) OAuth2Token)
-> m (Either (OAuth2Error Errors) OAuth2Token)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (FetchToken
getToken Manager
manager OAuth2
oauth2' (ExchangeToken -> IO (Either (OAuth2Error Errors) OAuth2Token))
-> ExchangeToken -> IO (Either (OAuth2Error Errors) OAuth2Token)
forall a b. (a -> b) -> a -> b
$ Text -> ExchangeToken
ExchangeToken Text
code)
Creds site
creds <-
IO (Creds site) -> m (Creds site)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (FetchCreds site
getCreds Manager
manager OAuth2Token
token)
m (Creds site) -> (IOException -> m (Creds site)) -> m (Creds site)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (DispatchError -> m (Creds site)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m (Creds site))
-> (IOException -> DispatchError) -> IOException -> m (Creds site)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> DispatchError
FetchCredsIOException)
m (Creds site)
-> (YesodOAuth2Exception -> m (Creds site)) -> m (Creds site)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (DispatchError -> m (Creds site)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m (Creds site))
-> (YesodOAuth2Exception -> DispatchError)
-> YesodOAuth2Exception
-> m (Creds site)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. YesodOAuth2Exception -> DispatchError
FetchCredsYesodOAuth2Exception)
Creds (HandlerSite m) -> m TypedContent
forall (m :: * -> *).
(MonadHandler m, YesodAuth (HandlerSite m)) =>
Creds (HandlerSite m) -> m TypedContent
setCredsRedirect Creds site
Creds (HandlerSite m)
creds
withCallbackAndState
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> Text
-> m OAuth2
withCallbackAndState :: Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf = do
Text
uri <- ((Route Auth -> Text) -> Route Auth -> Text
forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Route Auth
PluginR Text
name [Text
"callback"]) ((Route Auth -> Text) -> Text) -> m (Route Auth -> Text) -> m Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Route Auth -> Text)
forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Text)
getParentUrlRender
URI
callback <- m URI -> (URI -> m URI) -> Maybe URI -> m URI
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (DispatchError -> m URI
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m URI) -> DispatchError -> m URI
forall a b. (a -> b) -> a -> b
$ Text -> DispatchError
InvalidCallbackUri Text
uri) URI -> m URI
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe URI -> m URI) -> Maybe URI -> m URI
forall a b. (a -> b) -> a -> b
$ Text -> Maybe URI
fromText Text
uri
OAuth2 -> m OAuth2
forall (f :: * -> *) a. Applicative f => a -> f a
pure OAuth2
oauth2
{ oauthCallback :: Maybe URI
oauthCallback = URI -> Maybe URI
forall a. a -> Maybe a
Just URI
callback
, oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint =
OAuth2 -> URI
oauthOAuthorizeEndpoint OAuth2
oauth2
URI -> [(ByteString, ByteString)] -> URI
forall a. URIRef a -> [(ByteString, ByteString)] -> URIRef a
`withQuery` [(ByteString
"state", Text -> ByteString
encodeUtf8 Text
csrf)]
}
getParentUrlRender :: MonadHandler m => m (Route (SubHandlerSite m) -> Text)
getParentUrlRender :: m (Route (SubHandlerSite m) -> Text)
getParentUrlRender = (Route (HandlerSite m) -> Text)
-> (Route (SubHandlerSite m) -> Route (HandlerSite m))
-> Route (SubHandlerSite m)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) ((Route (HandlerSite m) -> Text)
-> (Route (SubHandlerSite m) -> Route (HandlerSite m))
-> Route (SubHandlerSite m)
-> Text)
-> m (Route (HandlerSite m) -> Text)
-> m ((Route (SubHandlerSite m) -> Route (HandlerSite m))
-> Route (SubHandlerSite m) -> Text)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Route (HandlerSite m) -> Text)
forall (m :: * -> *).
MonadHandler m =>
m (Route (HandlerSite m) -> Text)
getUrlRender m ((Route (SubHandlerSite m) -> Route (HandlerSite m))
-> Route (SubHandlerSite m) -> Text)
-> m (Route (SubHandlerSite m) -> Route (HandlerSite m))
-> m (Route (SubHandlerSite m) -> Text)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m (Route (SubHandlerSite m) -> Route (HandlerSite m))
forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Route (HandlerSite m))
getRouteToParent
setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF :: Text -> m Text
setSessionCSRF Text
sessionKey = do
Text
csrfToken <- IO Text -> m Text
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Text
randomToken
Text
csrfToken Text -> m () -> m Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Text -> Text -> m ()
forall (m :: * -> *). MonadHandler m => Text -> Text -> m ()
setSession Text
sessionKey Text
csrfToken
where randomToken :: IO Text
randomToken = (Char -> Bool) -> Text -> Text
T.filter (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'+') (Text -> Text) -> IO Text -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO Text
forall (m :: * -> *). MonadRandom m => Int -> m Text
randomText Int
64
verifySessionCSRF
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
verifySessionCSRF :: Text -> m Text
verifySessionCSRF Text
sessionKey = do
Text
token <- Text -> m Text
forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
requireGetParam Text
"state"
Maybe Text
sessionToken <- Text -> m (Maybe Text)
forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupSession Text
sessionKey
Text -> m ()
forall (m :: * -> *). MonadHandler m => Text -> m ()
deleteSession Text
sessionKey
Text
token Text -> m () -> m Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
(Maybe Text
sessionToken Maybe Text -> Maybe Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> Maybe Text
forall a. a -> Maybe a
Just Text
token)
(DispatchError -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m ()) -> DispatchError -> m ()
forall a b. (a -> b) -> a -> b
$ Maybe Text -> Text -> DispatchError
InvalidStateToken Maybe Text
sessionToken Text
token)
requireGetParam
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
requireGetParam :: Text -> m Text
requireGetParam Text
key =
m Text -> (Text -> m Text) -> Maybe Text -> m Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (DispatchError -> m Text
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m Text) -> DispatchError -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> DispatchError
MissingParameter Text
key) Text -> m Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> m Text) -> m (Maybe Text) -> m Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> m (Maybe Text)
forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupGetParam Text
key
tokenSessionKey :: Text -> Text
tokenSessionKey :: Text -> Text
tokenSessionKey Text
name = Text
"_yesod_oauth2_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name