{-# LANGUAGE CPP #-}
module Servant.Auth.Server.Internal.Types where
import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Time
import Data.Monoid          (Monoid (..))
import Data.Semigroup       (Semigroup (..))
import Data.Time            (getCurrentTime)
import GHC.Generics         (Generic)
import Network.Wai          (Request)
import qualified Control.Monad.Fail as Fail
data AuthResult val
  = BadPassword
  | NoSuchUser
  
  | Authenticated val
  
  
  
  
  | Indefinite
  deriving (Eq, Show, Read, Generic, Ord, Functor, Traversable, Foldable)
instance Semigroup (AuthResult val) where
  Indefinite <> y = y
  x          <> _ = x
instance Monoid (AuthResult val) where
  mempty = Indefinite
  mappend = (<>)
instance Applicative AuthResult where
  pure = return
  (<*>) = ap
instance Monad AuthResult where
  return = Authenticated
  Authenticated v >>= f = f v
  BadPassword  >>= _ = BadPassword
  NoSuchUser   >>= _ = NoSuchUser
  Indefinite   >>= _ = Indefinite
instance Alternative AuthResult where
  empty = mzero
  (<|>) = mplus
instance MonadPlus AuthResult where
  mzero = mempty
  mplus = (<>)
newtype AuthCheck val = AuthCheck
  { runAuthCheck :: Request -> IO (AuthResult val) }
  deriving (Generic, Functor)
instance Semigroup (AuthCheck val) where
  AuthCheck f <> AuthCheck g = AuthCheck $ \x -> do
    fx <- f x
    case fx of
      Indefinite -> g x
      r -> pure r
instance Monoid (AuthCheck val) where
  mempty = AuthCheck $ const $ return mempty
  mappend = (<>)
instance Applicative AuthCheck where
  pure = return
  (<*>) = ap
instance Monad AuthCheck where
  return = AuthCheck . return . return . return
  AuthCheck ac >>= f = AuthCheck $ \req -> do
    aresult <- ac req
    case aresult of
      Authenticated usr -> runAuthCheck (f usr) req
      BadPassword       -> return BadPassword
      NoSuchUser        -> return NoSuchUser
      Indefinite        -> return Indefinite
#if !MIN_VERSION_base(4,13,0)
  fail = Fail.fail
#endif
instance Fail.MonadFail AuthCheck where
  fail _ = AuthCheck . const $ return Indefinite
instance MonadReader Request AuthCheck where
  ask = AuthCheck $ \x -> return (Authenticated x)
  local f (AuthCheck check) = AuthCheck $ \req -> check (f req)
instance MonadIO AuthCheck where
  liftIO action = AuthCheck $ const $ Authenticated <$> action
instance MonadTime AuthCheck where
  currentTime = liftIO getCurrentTime
instance Alternative AuthCheck where
  empty = mzero
  (<|>) = mplus
instance MonadPlus AuthCheck where
  mzero = mempty
  mplus = (<>)