{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -- | -- Module : Network.Reddit.Auth -- Copyright : (c) 2021 Rory Tyler Hayford -- License : BSD-3-Clause -- Maintainer : rory.hayford@protonmail.com -- Stability : experimental -- Portability : GHC -- -- Authentication via OAuth for the Reddit API module Network.Reddit.Auth ( loadAuthConfig , getAccessToken , getAccessTokenWith , getAuthURL , redditURL , oauthURL , refreshAccessToken ) where import Conduit ( (.|) , decodeUtf8LenientC , runConduit , sinkLazy , withSourceFile ) import Control.Monad.Catch ( MonadThrow(throwM) ) import Control.Monad.Reader ( asks ) import Data.Aeson ( decode, eitherDecode ) import Data.ByteString ( ByteString ) import qualified Data.ByteString.Lazy as LB import qualified Data.CaseInsensitive as CI import Data.Conduit.Binary ( sinkLbs ) import Data.Foldable ( asum ) import Data.Function ( on ) import Data.Generics.Product ( HasField(field) ) import Data.Ini.Config ( IniParser ) import qualified Data.Ini.Config as Ini import qualified Data.Text as T import Data.Text ( Text ) import qualified Data.Text.Encoding as T import qualified Data.Text.Lazy as LT import Lens.Micro import Network.HTTP.Client.Conduit ( Request , RequestBody(RequestBodyLBS) ) import qualified Network.HTTP.Client.Conduit as H import Network.HTTP.Simple ( withResponse ) import Network.Reddit.Types import Network.Reddit.Utils import UnliftIO ( MonadUnliftIO ) import UnliftIO.Directory import Web.FormUrlEncoded ( toForm, urlEncodeAsFormStable ) import Web.HttpApiData ( ToHttpApiData(toQueryParam) ) import Web.Internal.FormUrlEncoded ( Form ) -- | Load the auth file, looking in the following locations, in order: -- -- * $PWD\/auth.ini -- * XDG_CONFIG_HOME\/heddit\/auth.ini -- -- __Note__: Only 'ScriptApp's and 'ApplicationOnly' apps are supported loadAuthConfig :: (MonadUnliftIO m, MonadThrow m) => ClientSite -> m AuthConfig loadAuthConfig cs = do cwDir <- getCurrentDirectory cfgDir <- getXdgDirectory XdgConfig "heddit" findFile [ cfgDir, cwDir ] "auth.ini" >>= \case Nothing -> throwM . OtherError $ mconcat [ "No auth.ini file found in the current directory" , " or $XDG_CONFIG_HOME/heddit, please create one" ] Just fp -> parseAuthIni cs fp parseAuthIni :: forall m. (MonadUnliftIO m, MonadThrow m) => ClientSite -> FilePath -> m AuthConfig parseAuthIni cs fp = withSourceFile @_ @m fp $ \b -> either (throwM . userError) pure . flip Ini.parseIniFile (authConfigP cs) . LT.toStrict =<< runConduit (b .| decodeUtf8LenientC .| sinkLazy) authConfigP :: Text -> IniParser AuthConfig authConfigP sec = asum [ scriptP, appOnlyP ] where appOnlyP = Ini.section sec $ AuthConfig <$> Ini.field "id" <*> (ApplicationOnly <$> Ini.field "secret") <*> Ini.fieldOf "agent" uaP scriptP = Ini.section sec $ AuthConfig <$> Ini.field "id" <*> (ScriptApp <$> Ini.field "secret" <*> (PasswordFlow <$> Ini.field "username" <*> Ini.field "password")) <*> Ini.fieldOf "agent" uaP uaP :: Text -> Either [Char] UserAgent uaP t = case T.splitOn "," t of [ platform, appID, version, author ] -> Right UserAgent { .. } _ -> Left $ mconcat [ "User agent must be of the form" , " ',,,'" ] -- | Get the URL required to authorize your application, for 'WebApp's and -- 'InstalledApp's getAuthURL :: Foldable t => URL -- ^ A redirect URI, which must exactly match the one -- registered with Reddit when creating your application -> TokenDuration -> t Scope -- ^ The OAuth scopes to request authorization for -> ClientID -> Text -- ^ Text that is embedded in the callback URI when the -- client completes the request. It must be composed -- of printable ASCII characters and should be unique -- for the client -> URL getAuthURL redirectURI duration scopes clientID state = T.decodeUtf8 $ "https://" <> mconcat pieces where pieces = [ H.host, H.path, H.queryString ] <*> [ request ] query = LB.toStrict . urlEncodeAsFormStable $ mkTextForm [ ("client_id", clientID) , ("duration", toQueryParam duration) , ("redirect_uri", redirectURI) , ("response_type", "code") , ("state", state) , ("scope", joinParams scopes) ] request = H.defaultRequest { H.host = redditURL , H.path = joinPathSegments [ "api", "v1", "authorize" ] , H.queryString = "?" <> query } -- | Generate an 'AccessToken' from an 'AuthConfig'. This serves to create an -- initial token for all 'AppType's, and can also be used to refresh tokens for -- 'ScriptApp's and 'ApplicationOnly' apps getAccessToken :: (MonadUnliftIO m, MonadThrow m) => (AppType -> Form) -> AuthConfig -> m AccessToken getAccessToken f ac@AuthConfig { .. } = makeTokenRequest . setUAHeader ac =<< request appType where request = \case sa@(ScriptApp clientSecret _) -> applyAuth clientID clientSecret <$> mkReq sa ro@(ApplicationOnly clientSecret) -> applyAuth clientID clientSecret <$> mkReq ro wa@(WebApp clientSecret _) -> applyAuth clientID clientSecret <$> mkReq wa ia@InstalledApp {} -> H.applyBasicAuth (T.encodeUtf8 clientID) mempty <$> mkReq ia applyAuth = H.applyBasicAuth `on` T.encodeUtf8 mkReq = routeToRequest . mkAuthRoute . f getAccessTokenWith :: (MonadUnliftIO m, MonadThrow m) => Token -> AuthConfig -> m AccessToken getAccessTokenWith rt AuthConfig { .. } = case appType of ScriptApp {} -> cfgError ApplicationOnly {} -> cfgError WebApp clientSecret _ -> makeTokenRequest . applyAuth clientID clientSecret =<< mkReq InstalledApp {} -> makeTokenRequest . applyAuth clientID mempty =<< mkReq where mkReq = routeToRequest . mkAuthRoute $ mkTextForm [ ("grant_type", "refresh_token") , ("refresh_token", rt) ] applyAuth cid secret = (H.applyBasicAuth `on` T.encodeUtf8) cid secret cfgError = throwM $ ConfigurationError "getAccessTokenWith: unsupported application type" makeTokenRequest :: forall m. (MonadUnliftIO m, MonadThrow m) => Request -> m AccessToken makeTokenRequest req = withResponse @_ @m req $ \resp -> do bodyBS <- runConduit $ (resp & H.responseBody) .| sinkLbs case eitherDecode bodyBS of Right token -> pure token Left err -> case decode @APIException bodyBS of Just e -> throwM e Nothing -> throwM . flip JSONParseError bodyBS $ "getAccessToken: Failed to parse JSON - " <> T.pack err -- | Generate the correct API 'APIAction' for an 'AppType' mkAuthRoute :: Form -> APIAction a mkAuthRoute form = defaultAPIAction { method = POST , pathSegments = [ "api", "v1", "access_token" ] , requestData = WithForm form } -- | Convert an API 'APIAction' to a 'Request' routeToRequest :: MonadThrow m => APIAction a -> m Request routeToRequest APIAction { .. } = case requestData of WithForm fd -> case method of p | p `elem` [ POST, PUT ] -> pure $ mkRequest { H.requestBody = RequestBodyLBS $ urlEncodeAsFormStable fd } _ -> invalidRequest NoData -> pure mkRequest _ -> invalidRequest where mkRequest = H.defaultRequest { H.host = "www.reddit.com" , H.secure = True , H.port = 443 , H.method = bshow method , H.path = joinPathSegments pathSegments } invalidRequest = throwM $ InvalidRequest "Invalid request body" setUAHeader :: AuthConfig -> Request -> Request setUAHeader AuthConfig { .. } req = req { H.requestHeaders = newHeader : headers } where newHeader = (CI.mk "user-agent", ua) ua = writeUA userAgent headers = req & H.requestHeaders -- | Refresh the access token refreshAccessToken :: MonadReddit m => m AccessToken refreshAccessToken = do ac@AuthConfig { .. } <- asks (^. field @"authConfig") case appType of ScriptApp {} -> getAccessToken toForm ac ApplicationOnly {} -> getAccessToken toForm ac WebApp {} -> tryRefresh ac InstalledApp {} -> tryRefresh ac where tryRefresh ac = asks (^. field @"tokenManager") >>= \case Just TokenManager { .. } -> do token <- flip getAccessTokenWith ac =<< loadToken putToken $ token ^. field @"refreshToken" pure token Nothing -> lookupRefreshToken >>= \case Nothing -> cfgError "refreshAccessToken: No refresh token available" Just rt -> getAccessTokenWith rt ac lookupRefreshToken = readClientState $ field @"accessToken" . field @"refreshToken" cfgError = throwM . ConfigurationError -- | The endpoint for non-OAuth actions redditURL :: ByteString redditURL = "www.reddit.com" -- | The endpoint for OAuth actions oauthURL :: ByteString oauthURL = "oauth.reddit.com"