{-# LANGUAGE CPP #-}
module Servant.Auth.Server.Internal.Types where
import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Time
import Data.Monoid (Monoid (..))
import Data.Semigroup (Semigroup (..))
import Data.Time (getCurrentTime)
import GHC.Generics (Generic)
import Network.Wai (Request)
import qualified Control.Monad.Fail as Fail
data AuthResult val
= BadPassword
| NoSuchUser
| Authenticated val
| Indefinite
deriving (Eq, Show, Read, Generic, Ord, Functor, Traversable, Foldable)
instance Semigroup (AuthResult val) where
Indefinite <> y = y
x <> _ = x
instance Monoid (AuthResult val) where
mempty = Indefinite
mappend = (<>)
instance Applicative AuthResult where
pure = return
(<*>) = ap
instance Monad AuthResult where
return = Authenticated
Authenticated v >>= f = f v
BadPassword >>= _ = BadPassword
NoSuchUser >>= _ = NoSuchUser
Indefinite >>= _ = Indefinite
instance Alternative AuthResult where
empty = mzero
(<|>) = mplus
instance MonadPlus AuthResult where
mzero = mempty
mplus = (<>)
newtype AuthCheck val = AuthCheck
{ runAuthCheck :: Request -> IO (AuthResult val) }
deriving (Generic, Functor)
instance Semigroup (AuthCheck val) where
AuthCheck f <> AuthCheck g = AuthCheck $ \x -> do
fx <- f x
case fx of
Indefinite -> g x
r -> pure r
instance Monoid (AuthCheck val) where
mempty = AuthCheck $ const $ return mempty
mappend = (<>)
instance Applicative AuthCheck where
pure = return
(<*>) = ap
instance Monad AuthCheck where
return = AuthCheck . return . return . return
AuthCheck ac >>= f = AuthCheck $ \req -> do
aresult <- ac req
case aresult of
Authenticated usr -> runAuthCheck (f usr) req
BadPassword -> return BadPassword
NoSuchUser -> return NoSuchUser
Indefinite -> return Indefinite
#if !MIN_VERSION_base(4,13,0)
fail = Fail.fail
#endif
instance Fail.MonadFail AuthCheck where
fail _ = AuthCheck . const $ return Indefinite
instance MonadReader Request AuthCheck where
ask = AuthCheck $ \x -> return (Authenticated x)
local f (AuthCheck check) = AuthCheck $ \req -> check (f req)
instance MonadIO AuthCheck where
liftIO action = AuthCheck $ const $ Authenticated <$> action
instance MonadTime AuthCheck where
currentTime = liftIO getCurrentTime
instance Alternative AuthCheck where
empty = mzero
(<|>) = mplus
instance MonadPlus AuthCheck where
mzero = mempty
mplus = (<>)