module Snap.Snaplet.CustomAuth.OAuth2.Internal
  ( oauth2Init
  , saveAction
  , redirectToProvider
  ) where
import Control.Error.Util hiding (err)
import Control.Lens
import Control.Monad.Except
import Control.Monad.Trans.Except
import Control.Monad.Trans.Maybe
import Control.Monad.State
import Data.Aeson
import qualified Data.Binary
import Data.Binary (Binary)
import Data.Binary.Orphans ()
import qualified Data.ByteString.Base64
import Data.ByteString.Lazy (toStrict, fromStrict)
import Data.Char (chr)
import qualified Data.Configurator as C
import Data.HashMap.Lazy (HashMap)
import qualified Data.HashMap.Lazy as M
import Data.Maybe (isJust, isNothing, catMaybes)
import Data.Monoid
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1, decodeUtf8', encodeUtf8)
import Data.Time.Clock (UTCTime, getCurrentTime, diffUTCTime)
import Network.HTTP.Client (Manager)
import Network.OAuth.OAuth2
import Prelude hiding (lookup)
import Snap hiding (path)
import Snap.Snaplet.Session
import System.Random
import URI.ByteString
import Snap.Snaplet.CustomAuth.AuthManager
import Snap.Snaplet.CustomAuth.Types hiding (name)
import Snap.Snaplet.CustomAuth.User (setUser, currentUser, recoverSession)
import Snap.Snaplet.CustomAuth.Util (getStateName, getParamText, setFailure)
oauth2Init
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Initializer b (AuthManager u e b) (HashMap Text Provider)
oauth2Init s = do
  cfg <- getSnapletUserConfig
  root <- getSnapletRootURL
  hostname <- liftIO $ C.require cfg "hostname"
  scheme <- liftIO $ C.lookupDefault "http" cfg "protocol"
  names <- liftIO $ C.lookupDefault [] cfg "oauth2.providers"
  
  let makeProvider name = let
        name' = "oauth2." <> name
        lk = MaybeT . C.lookup cfg . (name' <>)
        lku n = lk n >>=
          MaybeT . return . hush . parseURI strictURIParserOptions . encodeUtf8
        callback = URI (Scheme scheme)
                   (Just $ Authority Nothing (Host hostname) Nothing)
                   ("/" <> root <> "/oauth2callback/" <> (encodeUtf8 name))
                   mempty Nothing
        in Provider
           <$> (MaybeT $ return $ pure $ T.toLower $ name)
           <*> (MaybeT $ return $ pure $ Nothing)
           <*> lk ".scope"
           <*> lku ".endpoint.identity"
           <*> lk ".identityField"
           <*> (OAuth2
                <$> lk ".clientId"
                <*> lk ".clientSecret"
                <*> lku ".endpoint.auth"
                <*> lku ".endpoint.access"
                <*> (pure $ Just callback))
  addRoutes $ mapped._2 %~ (bracket s) $
    [ ("oauth2createaccount", oauth2CreateAccount s)
    , ("oauth2callback/:provider", oauth2Callback s)
    , ("oauth2login/:provider", redirectLogin)
    ]
  liftIO $ M.fromList . map (\x -> (providerName x, x)) . catMaybes <$>
    (mapM (runMaybeT . makeProvider) names)
redirectLogin
  :: Handler b (AuthManager u e b) ()
redirectLogin = do
  provs <- gets providers
  provider <- (flip M.lookup provs =<<) <$> getParamText "provider"
  maybe pass toProvider provider
  where
    toProvider p = do
      success <- redirectToProvider $ providerName p
      if success then return () else pass
getRedirUrl
  :: Provider
  -> Text
  -> URI
getRedirUrl p token =
  appendQueryParams [("state", encodeUtf8 token)
                    ,("scope", encodeUtf8 $ scope p)] $ authorizationUrl $ oauth p
redirectToProvider
  :: Text
  -> Handler b (AuthManager u e b) Bool
redirectToProvider pName = do
  maybe (return False) redirectToProvider' =<< M.lookup pName <$> gets providers
redirectToProvider'
  :: Provider
  -> Handler b (AuthManager u e b) Bool
redirectToProvider' provider = do
  
  store <- gets stateStore'
  stamp <- liftIO $ (T.pack . show) <$> getCurrentTime
  name <- getStateName
  let randomChar i
        | i < 10 = chr (i+48)
        | i < 36 = chr (i+55)
        | otherwise = chr (i+61)
      randomText n = T.pack <$> replicateM n (randomChar <$> randomRIO (0,61))
  token <- liftIO $ randomText 20
  withTop' store $ do
    setInSession name token
    setInSession (name <> "_stamp") stamp
    commitSession
  let redirUrl = serializeURIRef' $ getRedirUrl provider token
  redirect' redirUrl 303
getUserInfo
  :: OAuth2Settings u i e b
  -> Provider
  -> AccessToken
  -> Handler b (AuthManager u e b) (Maybe Text)
getUserInfo s provider token = do
  let endpoint = identityEndpoint provider
  let mgr = httpManager s
  liftIO $ runMaybeT $ do
    dat <- MaybeT $ hush <$> authGetJSON' mgr token endpoint
    MaybeT . return $ lookupProviderInfo dat
  where
    authGetJSON' :: Manager -> AccessToken -> URI
                 -> IO (OAuth2Result (HashMap Text Value) (HashMap Text Value))
    authGetJSON' = authGetJSON
    lookup' a b = maybeText =<< M.lookup a b
    maybeText (String x) = Just x
    maybeText _ = Nothing
    lookupProviderInfo dat = lookup' (identityField provider) dat
oauth2Callback
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Handler b (AuthManager u e b) ()
oauth2Callback s = do
  provs <- gets providers
  maybe pass (oauth2Callback' s) =<<
    ((flip M.lookup provs =<<) <$> getParamText "provider")
oauth2Callback'
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Provider
  -> Handler b (AuthManager u e b) ()
oauth2Callback' s provider = do
  name <- getStateName
  let ss = stateStore s
      mgr = httpManager s
  res <- runExceptT $ do
    let param = oauth provider
    expiredStamp <- lift $ withTop' ss $
      maybe (return True) (liftIO . isExpiredStamp) =<<
      fmap (read . T.unpack) <$> getFromSession (name <> "_stamp")
    when expiredStamp $ throwE ExpiredState
    hostState <- maybe (throwE StateNotStored) return =<<
      (lift $ withTop' ss $ getFromSession name)
    providerState <- maybe (throwE StateNotReceived) return =<<
      (lift $ getParamText "state")
    when (hostState /= providerState) $ throwE BadState
    _ <- runMaybeT $ do
      err <- MaybeT $ lift $ getParam "error"
      lift $ throwE $ ProviderError $ hush $ decodeUtf8' err
    
    (maybe (throwE IdExtractionFailed) return =<<) $ runMaybeT $ do
      code <- MaybeT $ (fmap ExchangeToken) <$> (lift $ getParamText "code")
      
      token <- either (const $ lift $ throwE AccessTokenFetchError) return =<< liftIO
        (fetchAccessToken mgr param code)
      
      
      MaybeT $ lift $ getUserInfo s provider (accessToken token)
  either (setFailure ((oauth2Failure s) SCallback) (Just $ providerName provider) .
          Right . Create . OAuth2Failure)
    (oauth2Success s provider) res
oauth2Success
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Provider
  -> Text
  -> Handler b (AuthManager u e b) ()
oauth2Success s provider token = do
  key <- getActionKey $ providerName provider
  store <- gets stateStore'
  name <- getStateName
  act <- withTop' store $ runMaybeT $ do
    act <- MaybeT $ getFromSession key
    lift $ deleteFromSession key >> commitSession
    return act
  withTop' store $ do
    setInSession (name <> "_provider") (providerName provider)
    setInSession (name <> "_token") token
    commitSession
  
  
  maybe (doOauth2Login s provider token) (doResume s provider token) act
doOauth2Login
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Provider
  -> Text
  -> Handler b (AuthManager u e b) ()
doOauth2Login s provider token = do
  
  recoverSession
  currentUser >>=
    maybe proceed (const $ setFailure ((oauth2Failure s) SLogin)
                   (Just $ providerName provider) $
                   Right $ Create $ OAuth2Failure AlreadyLoggedIn)
  where
    proceed = do
      res <- runExceptT $ do
        usr <- ExceptT $ (oauth2Login s) (providerName provider) token
        maybe (return ()) (lift . setUser) usr
        return usr
      either (setFailure ((oauth2Failure s) SLogin)
              (Just $ providerName provider) . Left)
        (const $ oauth2LoginDone s) res
isExpiredStamp
  :: UTCTime
  -> IO Bool
isExpiredStamp stamp = do
  current <- getCurrentTime
  let diff = diffUTCTime current stamp
  return $ diff < 0 || diff > 300
prepareOAuth2Create'
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Provider
  -> Text
  -> Handler b (AuthManager u e b) (Either (Either e CreateFailure) i)
prepareOAuth2Create' s provider token =
  (prepareOAuth2Create s) (providerName provider) token >>=
  either checkDuplicate (return . Right)
  where
    checkDuplicate e = do
      isE <- isDuplicateError e
      return $ Left $ if isE then Right $ OAuth2Failure IdentityInUse else Left e
doResume
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Provider
  -> Text
  -> Text
  -> Handler b (AuthManager u e b) ()
doResume s provider token d = do
  recoverSession
  user <- currentUser
  userId <- runMaybeT $ lift . getUserId =<< (MaybeT $ return user)
  res <- runExceptT $ do
    d' <- ExceptT . return $ maybe (Left $ Right ActionDecodeError) Right $
      ((fmap $ \(_, _, x) -> x) . hush . Data.Binary.decodeOrFail . fromStrict) =<<
      (hush $ Data.ByteString.Base64.decode $ encodeUtf8 d)
    when (requireUser d' && isNothing user) $ throwE (Right AttachNotLoggedIn)
    u <- ExceptT $ return . either (Left . Left) Right =<<
      (oauth2Check s) (providerName provider) token
    
    when (userId /= actionUser d') $
     throwE (Right ActionUserMismatch)
    case requireUser d' of
      
      True -> when (maybe True ((/= userId) . Just) u) $
        throwE (Right ActionUserMismatch)
      
      False -> when (isJust u) $ throwE (Right AlreadyAttached)
    expired <- liftIO $ isExpiredStamp (actionStamp d')
    when expired $ throwE (Right ActionTimeout)
    return $ savedAction d'
  either (setFailure ((oauth2Failure s) SAction)
          (Just $ providerName provider) . fmap Action)
    ((resumeAction s) (providerName provider) token) res
oauth2CreateAccount
  :: IAuthBackend u i e b
  => OAuth2Settings u i e b
  -> Handler b (AuthManager u e b) ()
oauth2CreateAccount s = do
  store <- gets stateStore'
  provs <- gets providers
  usrName <- ((hush . decodeUtf8') =<<) <$>
    (getParam =<< ("_new" <>) <$> gets userField)
  name <- getStateName
  provider <- (flip M.lookup provs =<<) <$>
    (withTop' store $ getFromSession (name <> "_provider"))
  user <- runExceptT $ do
    
    u <- lift $ recoverSession >> currentUser
    when (isJust u) $ throwE (Right $ OAuth2Failure AlreadyUser)
    
    userName <- hoistEither $ note (Right MissingName) usrName
    
    res <- maybe (throwE $ Right $ OAuth2Failure NoStoredToken) return =<<
           (lift $ withTop' store $ runMaybeT $ do
               provider' <- MaybeT $ return provider
               token <- MaybeT $ getFromSession (name <> "_token")
               return (provider', token))
    ExceptT $ fmap (,userName) <$> prepareOAuth2Create' s (fst res) (snd res)
  res <- runExceptT $ do
    (i, userName) <- hoistEither user
    usr <- ExceptT $ create userName i
    lift $ setUser usr
    return usr
  case (user, res) of
    (Right (i,_), Left _) -> cancelPrepare i
    _ -> return ()
  either (setFailure ((oauth2Failure s) SCreate) (providerName <$> provider) . fmap Create)
    (oauth2AccountCreated s) res
getActionKey
  :: Text
  -> Handler b (AuthManager u e b) Text
getActionKey p = do
  path <- maybe "auth" id . hush . decodeUtf8' <$> getSnapletRootURL
  name <- maybe "auth" id <$> getSnapletName
  return $ "__" <> name <> "_" <> path <> "_action_" <> p
saveAction
  :: (IAuthBackend u i e b, Binary a)
  => Bool
  -> Text
  -> a
  -> Handler b (AuthManager u e b) ()
saveAction require provider a = do
  provs <- gets providers
  guard $ provider `elem` (M.keys provs)
  let d = Data.Binary.encode a
  key <- getActionKey provider
  store <- gets $ stateStore'
  stamp <- liftIO $ getCurrentTime
  i <- runMaybeT $ lift . getUserId =<< MaybeT currentUser
  let payload = SavedAction {
          actionProvider = provider
        , actionStamp = stamp
        , actionUser = i
        , requireUser = require
        , savedAction = toStrict d
        }
  let d' = decodeLatin1 $ Data.ByteString.Base64.encode $
        toStrict . Data.Binary.encode $ payload
  withTop' store $ do
    setInSession key d'
    commitSession