module Happstack.Server.Internal.Monads where
import Control.Applicative (Applicative, pure, (<*>), Alternative(empty,(<|>)))
import Control.Concurrent (newMVar)
import Control.Monad ( MonadPlus(mzero, mplus), ap, liftM, msum
)
import Control.Monad.Base ( MonadBase, liftBase )
import Control.Monad.Catch ( MonadCatch(..), MonadThrow(..) )
import Control.Monad.Error ( ErrorT(ErrorT), runErrorT
, Error, MonadError, throwError
, catchError, mapErrorT
)
import Control.Monad.Reader ( ReaderT(ReaderT), runReaderT
, MonadReader, ask, local, mapReaderT
)
import qualified Control.Monad.RWS.Lazy as Lazy ( RWST, mapRWST )
import qualified Control.Monad.RWS.Strict as Strict ( RWST, mapRWST )
import Control.Monad.Trans.Except ( ExceptT, mapExceptT )
import Control.Monad.State.Class ( MonadState, get, put )
import qualified Control.Monad.State.Lazy as Lazy ( StateT, mapStateT )
import qualified Control.Monad.State.Strict as Strict ( StateT, mapStateT )
import Control.Monad.Trans ( MonadTrans, lift
, MonadIO, liftIO
)
import Control.Monad.Trans.Control ( MonadTransControl(..)
, MonadBaseControl(..)
, ComposeSt, defaultLiftBaseWith, defaultRestoreM
)
import Control.Monad.Writer.Class ( MonadWriter, tell, pass, listens )
import qualified Control.Monad.Writer.Lazy as Lazy ( WriterT(WriterT), runWriterT, mapWriterT )
import qualified Control.Monad.Writer.Strict as Strict ( WriterT, mapWriterT )
import qualified Control.Monad.Writer.Class as Writer ( listen )
import Control.Monad.Trans.Maybe (MaybeT(MaybeT), runMaybeT)
import qualified Data.ByteString.Char8 as P
import qualified Data.ByteString.Lazy.UTF8 as LU (fromString)
import Data.Char (ord)
import Data.List (inits, isPrefixOf, stripPrefix, tails)
import Data.Maybe (fromMaybe)
import Data.Monoid (Monoid(mempty, mappend), Dual(..), Endo(..))
import qualified Paths_happstack_server as Cabal
import qualified Data.Version as DV
import Debug.Trace (trace)
import Happstack.Server.Internal.Cookie (Cookie)
import Happstack.Server.Internal.RFC822Headers (parseContentType)
import Happstack.Server.Types
import Prelude (Bool(..), Either(..), Eq(..), Functor(..), IO(..), Monad(..), Char, Maybe(..), String, Show(..), ($), (.), (>), (++), (&&), (||), (=<<), const, concatMap, flip, id, otherwise, zip)
type Web a = WebT IO a
type ServerPart a = ServerPartT IO a
newtype ServerPartT m a = ServerPartT { unServerPartT :: ReaderT Request (WebT m) a }
deriving (Monad, MonadPlus, Functor)
instance MonadCatch m => MonadCatch (ServerPartT m) where
catch action handle = ServerPartT $ catch (unServerPartT action) (unServerPartT . handle)
instance MonadThrow m => MonadThrow (ServerPartT m) where
throwM = ServerPartT . throwM
instance MonadBase b m => MonadBase b (ServerPartT m) where
liftBase = lift . liftBase
instance (MonadIO m) => MonadIO (ServerPartT m) where
liftIO = ServerPartT . liftIO
#if MIN_VERSION_monad_control(1,0,0)
instance MonadTransControl ServerPartT where
type StT ServerPartT a = StT WebT (StT (ReaderT Request) a)
liftWith f = ServerPartT $ liftWith $ \runReader ->
liftWith $ \runWeb ->
f $ runWeb . runReader . unServerPartT
restoreT = ServerPartT . restoreT . restoreT
instance MonadBaseControl b m => MonadBaseControl b (ServerPartT m) where
type StM (ServerPartT m) a = ComposeSt ServerPartT m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
#else
instance MonadTransControl ServerPartT where
newtype StT ServerPartT a = StSP {unStSP :: StT WebT (StT (ReaderT Request) a)}
liftWith f = ServerPartT $ liftWith $ \runReader ->
liftWith $ \runWeb ->
f $ liftM StSP . runWeb . runReader . unServerPartT
restoreT = ServerPartT . restoreT . restoreT . liftM unStSP
instance MonadBaseControl b m => MonadBaseControl b (ServerPartT m) where
newtype StM (ServerPartT m) a = StMSP {unStMSP :: ComposeSt ServerPartT m a}
liftBaseWith = defaultLiftBaseWith StMSP
restoreM = defaultRestoreM unStMSP
#endif
runServerPartT :: ServerPartT m a -> Request -> WebT m a
runServerPartT = runReaderT . unServerPartT
withRequest :: (Request -> WebT m a) -> ServerPartT m a
withRequest = ServerPartT . ReaderT
anyRequest :: Monad m => WebT m a -> ServerPartT m a
anyRequest x = withRequest $ \_ -> x
mapServerPartT :: ( UnWebT m a -> UnWebT n b)
-> (ServerPartT m a -> ServerPartT n b)
mapServerPartT f ma = withRequest $ \rq -> mapWebT f (runServerPartT ma rq)
mapServerPartT' :: (Request -> UnWebT m a -> UnWebT n b)
-> ( ServerPartT m a -> ServerPartT n b)
mapServerPartT' f ma = withRequest $ \rq -> mapWebT (f rq) (runServerPartT ma rq)
instance MonadTrans (ServerPartT) where
lift m = withRequest (\_ -> lift m)
instance (Monad m, MonadPlus m) => Monoid (ServerPartT m a) where
mempty = mzero
mappend = mplus
instance (Monad m, Functor m) => Applicative (ServerPartT m) where
pure = return
(<*>) = ap
instance (Functor m, MonadPlus m) => Alternative (ServerPartT m) where
empty = mzero
(<|>) = mplus
instance (Monad m, MonadWriter w m) => MonadWriter w (ServerPartT m) where
tell = lift . tell
listen m = withRequest $ \rq -> Writer.listen (runServerPartT m rq) >>= return
pass m = withRequest $ \rq -> pass (runServerPartT m rq) >>= return
instance (Monad m, MonadError e m) => MonadError e (ServerPartT m) where
throwError e = lift $ throwError e
catchError action handler = withRequest $ \rq -> (runServerPartT action rq) `catchError` ((flip runServerPartT $ rq) . handler)
instance (Monad m, MonadReader r m) => MonadReader r (ServerPartT m) where
ask = lift ask
local fn m = withRequest $ \rq-> local fn (runServerPartT m rq)
instance (Monad m, MonadState s m) => MonadState s (ServerPartT m) where
get = lift get
put = lift . put
instance Monad m => FilterMonad Response (ServerPartT m) where
setFilter = anyRequest . setFilter
composeFilter = anyRequest . composeFilter
getFilter m = withRequest $ \rq -> getFilter (runServerPartT m rq)
instance Monad m => WebMonad Response (ServerPartT m) where
finishWith r = anyRequest $ finishWith r
class Monad m => ServerMonad m where
askRq :: m Request
localRq :: (Request -> Request) -> m a -> m a
instance (Monad m) => ServerMonad (ServerPartT m) where
askRq = ServerPartT $ ask
localRq f m = ServerPartT $ local f (unServerPartT m)
smAskRqEnv :: (ServerMonad m, MonadIO m) => m ([(String, Input)], Maybe [(String, Input)], [(String, Cookie)])
smAskRqEnv = do
rq <- askRq
mbi <- liftIO $ if ((rqMethod rq == POST) || (rqMethod rq == PUT)) && (isDecodable (ctype rq))
then readInputsBody rq
else return (Just [])
return (rqInputsQuery rq, mbi, rqCookies rq)
where
ctype :: Request -> Maybe ContentType
ctype req = parseContentType . P.unpack =<< getHeader "content-type" req
isDecodable :: Maybe ContentType -> Bool
isDecodable Nothing = True
isDecodable (Just (ContentType "application" "x-www-form-urlencoded" _)) = True
isDecodable (Just (ContentType "multipart" "form-data" _ps)) = True
isDecodable (Just _) = False
smLocalRqEnv :: (ServerMonad m, MonadIO m) => (([(String, Input)], Maybe [(String, Input)], [(String, Cookie)]) -> ([(String, Input)], Maybe [(String, Input)], [(String, Cookie)])) -> m b -> m b
smLocalRqEnv f m = do
rq <- askRq
b <- liftIO $ readInputsBody rq
let (q', b', c') = f (rqInputsQuery rq, b, rqCookies rq)
bv <- liftIO $ newMVar (fromMaybe [] b')
let rq' = rq { rqInputsQuery = q'
, rqInputsBody = bv
, rqCookies = c'
}
localRq (const rq') m
data SetAppend a = Set a | Append a
deriving (Eq, Show)
instance Monoid a => Monoid (SetAppend a) where
mempty = Append mempty
Set x `mappend` Append y = Set (x `mappend` y)
Append x `mappend` Append y = Append (x `mappend` y)
_ `mappend` Set y = Set y
extract :: SetAppend t -> t
extract (Set x) = x
extract (Append x) = x
instance Functor (SetAppend) where
fmap f (Set x) = Set $ f x
fmap f (Append x) = Append $ f x
type FilterFun a = SetAppend (Dual (Endo a))
unFilterFun :: FilterFun a -> (a -> a)
unFilterFun = appEndo . getDual . extract
filterFun :: (a -> a) -> FilterFun a
filterFun = Set . Dual . Endo
newtype FilterT a m b = FilterT { unFilterT :: Lazy.WriterT (FilterFun a) m b }
deriving (Functor, Applicative, Monad, MonadTrans)
instance MonadCatch m => MonadCatch (FilterT a m) where
catch action handle = FilterT $ catch (unFilterT action) (unFilterT . handle)
instance MonadThrow m => MonadThrow (FilterT a m) where
throwM = FilterT . throwM
instance MonadBase b m => MonadBase b (FilterT a m) where
liftBase = lift . liftBase
instance (MonadIO m) => MonadIO (FilterT a m) where
liftIO = FilterT . liftIO
#if MIN_VERSION_monad_control(1,0,0)
instance MonadTransControl (FilterT a) where
type StT (FilterT a) b = StT (Lazy.WriterT (FilterFun a)) b
liftWith f = FilterT $ liftWith $ \run -> f $ run . unFilterT
restoreT = FilterT . restoreT
instance MonadBaseControl b m => MonadBaseControl b (FilterT a m) where
type StM (FilterT a m) c = ComposeSt (FilterT a) m c
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
#else
instance MonadTransControl (FilterT a) where
newtype StT (FilterT a) b = StFilter {unStFilter :: StT (Lazy.WriterT (FilterFun a)) b}
liftWith f = FilterT $ liftWith $ \run -> f $ liftM StFilter . run . unFilterT
restoreT = FilterT . restoreT . liftM unStFilter
instance MonadBaseControl b m => MonadBaseControl b (FilterT a m) where
newtype StM (FilterT a m) c = StMFilter {unStMFilter :: ComposeSt (FilterT a) m c}
liftBaseWith = defaultLiftBaseWith StMFilter
restoreM = defaultRestoreM unStMFilter
#endif
class Monad m => FilterMonad a m | m->a where
setFilter :: (a->a) -> m ()
composeFilter :: (a->a) -> m ()
getFilter :: m b -> m (b, a->a)
ignoreFilters :: (FilterMonad a m) => m ()
ignoreFilters = setFilter id
instance (Monad m) => FilterMonad a (FilterT a m) where
setFilter = FilterT . tell . Set . Dual . Endo
composeFilter = FilterT . tell . Append . Dual . Endo
getFilter = FilterT . listens unFilterFun . unFilterT
newtype WebT m a = WebT { unWebT :: ErrorT Response (FilterT (Response) (MaybeT m)) a }
deriving (Functor)
instance MonadCatch m => MonadCatch (WebT m) where
catch action handle = WebT $ catch (unWebT action) (unWebT . handle)
instance MonadThrow m => MonadThrow (WebT m) where
throwM = WebT . throwM
instance MonadBase b m => MonadBase b (WebT m) where
liftBase = lift . liftBase
instance (MonadIO m) => MonadIO (WebT m) where
liftIO = WebT . liftIO
#if MIN_VERSION_monad_control(1,0,0)
instance MonadTransControl WebT where
type StT WebT a = StT MaybeT
(StT (FilterT Response)
(StT (ErrorT Response) a))
liftWith f = WebT $ liftWith $ \runError ->
liftWith $ \runFilter ->
liftWith $ \runMaybe ->
f $ runMaybe .
runFilter .
runError . unWebT
restoreT = WebT . restoreT . restoreT . restoreT
instance MonadBaseControl b m => MonadBaseControl b (WebT m) where
type StM (WebT m) a = ComposeSt WebT m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
#else
instance MonadTransControl WebT where
newtype StT WebT a = StWeb {unStWeb :: StT MaybeT
(StT (FilterT Response)
(StT (ErrorT Response) a))}
liftWith f = WebT $ liftWith $ \runError ->
liftWith $ \runFilter ->
liftWith $ \runMaybe ->
f $ liftM StWeb . runMaybe .
runFilter .
runError . unWebT
restoreT = WebT . restoreT . restoreT . restoreT . liftM unStWeb
instance MonadBaseControl b m => MonadBaseControl b (WebT m) where
newtype StM (WebT m) a = StMWeb {unStMWeb :: ComposeSt WebT m a}
liftBaseWith = defaultLiftBaseWith StMWeb
restoreM = defaultRestoreM unStMWeb
#endif
type UnWebT m a = m (Maybe (Either Response a, FilterFun Response))
instance Monad m => Monad (WebT m) where
m >>= f = WebT $ unWebT m >>= unWebT . f
return a = WebT $ return a
fail s = lift (fail s)
class Monad m => WebMonad a m | m->a where
finishWith :: a
-> m b
escape :: (WebMonad a m, FilterMonad a m) => m a -> m b
escape gen = ignoreFilters >> gen >>= finishWith
escape' :: (WebMonad a m, FilterMonad a m) => a -> m b
escape' a = ignoreFilters >> finishWith a
instance (Monad m) => WebMonad Response (WebT m) where
finishWith r = WebT $ throwError r
instance MonadTrans WebT where
lift = WebT . lift . lift . lift
instance (Monad m, MonadPlus m) => MonadPlus (WebT m) where
mzero = WebT $ lift $ lift $ mzero
mplus x y = WebT $ ErrorT $ FilterT $ (lower x) `mplus` (lower y)
where lower = (unFilterT . runErrorT . unWebT)
instance (Monad m) => FilterMonad Response (WebT m) where
setFilter f = WebT $ lift $ setFilter $ f
composeFilter f = WebT . lift . composeFilter $ f
getFilter m = WebT $ ErrorT $ liftM lft $ getFilter (runErrorT $ unWebT m)
where
lft (Left r, _) = Left r
lft (Right a, f) = Right (a, f)
instance (Monad m, MonadPlus m) => Monoid (WebT m a) where
mempty = mzero
mappend = mplus
ununWebT :: WebT m a -> UnWebT m a
ununWebT = runMaybeT . Lazy.runWriterT . unFilterT . runErrorT . unWebT
mkWebT :: UnWebT m a -> WebT m a
mkWebT = WebT . ErrorT . FilterT . Lazy.WriterT . MaybeT
mapWebT :: (UnWebT m a -> UnWebT n b)
-> ( WebT m a -> WebT n b)
mapWebT f ma = mkWebT $ f (ununWebT ma)
localContext :: Monad m => (WebT m a -> WebT m' a) -> ServerPartT m a -> ServerPartT m' a
localContext fn hs
= withRequest $ \rq -> fn (runServerPartT hs rq)
instance (Monad m, Functor m) => Applicative (WebT m) where
pure = return
(<*>) = ap
instance (Functor m, MonadPlus m) => Alternative (WebT m) where
empty = mzero
(<|>) = mplus
instance MonadReader r m => MonadReader r (WebT m) where
ask = lift ask
local fn m = mkWebT $ local fn (ununWebT m)
instance MonadState st m => MonadState st (WebT m) where
get = lift get
put = lift . put
instance MonadError e m => MonadError e (WebT m) where
throwError err = lift $ throwError err
catchError action handler = mkWebT $ catchError (ununWebT action) (ununWebT . handler)
instance MonadWriter w m => MonadWriter w (WebT m) where
tell = lift . tell
listen m = mkWebT $ Writer.listen (ununWebT m) >>= (return . liftWebT)
where liftWebT (Nothing, _) = Nothing
liftWebT (Just (Left x,f), _) = Just (Left x,f)
liftWebT (Just (Right x,f),w) = Just (Right (x,w),f)
pass m = mkWebT $ ununWebT m >>= liftWebT
where liftWebT Nothing = return Nothing
liftWebT (Just (Left x,f)) = return $ Just (Left x, f)
liftWebT (Just (Right x,f)) = pass (return x)>>= (\a -> return $ Just (Right a,f))
multi :: (Monad m, MonadPlus m) => [ServerPartT m a] -> ServerPartT m a
multi = msum
debugFilter :: (MonadIO m, Show a) => ServerPartT m a -> ServerPartT m a
debugFilter handle =
withRequest $ \rq -> do
r <- runServerPartT handle rq
return r
outputTraceMessage :: String -> a -> a
outputTraceMessage s c | "Pattern match failure " `isPrefixOf` s =
let w = [(k,p) | (i,p) <- zip (tails s) (inits s), Just k <- [stripPrefix " at " i]]
v = concatMap (\(k,p) -> k ++ ": " ++ p) w
in trace v c
outputTraceMessage s c = trace s c
mkFailMessage :: (FilterMonad Response m, WebMonad Response m) => String -> m b
mkFailMessage s = do
ignoreFilters
finishWith (failResponse s)
failResponse :: String -> Response
failResponse s =
setHeader "Content-Type" "text/html; charset=UTF-8" $
resultBS 500 (LU.fromString (failHtml s))
failHtml:: String->String
failHtml errString =
"<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.01//EN\" \"http://www.w3.org/TR/html4/strict.dtd\">"
++ "<html><head><title>Happstack "
++ ver ++ " Internal Server Error</title></head>"
++ "<body><h1>Happstack " ++ ver ++ "</h1>"
++ "<p>Something went wrong here<br>"
++ "Internal server error<br>"
++ "Everything has stopped</p>"
++ "<p>The error was \"" ++ (escapeString errString) ++ "\"</p></body></html>"
where ver = DV.showVersion Cabal.version
escapeString :: String -> String
escapeString str = concatMap encodeEntity str
where
encodeEntity :: Char -> String
encodeEntity '<' = "<"
encodeEntity '>' = ">"
encodeEntity '&' = "&"
encodeEntity '"' = """
encodeEntity c
| ord c > 127 = "&#" ++ show (ord c) ++ ";"
| otherwise = [c]
instance (ServerMonad m) => ServerMonad (ReaderT r m) where
askRq = lift askRq
localRq f = mapReaderT (localRq f)
instance (FilterMonad res m) => FilterMonad res (ReaderT r m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter = mapReaderT getFilter
instance (WebMonad a m) => WebMonad a (ReaderT r m) where
finishWith = lift . finishWith
instance (ServerMonad m) => ServerMonad (Lazy.StateT s m) where
askRq = lift askRq
localRq f = Lazy.mapStateT (localRq f)
instance (ServerMonad m) => ServerMonad (Strict.StateT s m) where
askRq = lift askRq
localRq f = Strict.mapStateT (localRq f)
instance (FilterMonad res m) => FilterMonad res (Lazy.StateT s m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = Lazy.mapStateT (\m' ->
do ((b,s), f) <- getFilter m'
return ((b, f), s)) m
instance (FilterMonad res m) => FilterMonad res (Strict.StateT s m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = Strict.mapStateT (\m' ->
do ((b,s), f) <- getFilter m'
return ((b, f), s)) m
instance (WebMonad a m) => WebMonad a (Lazy.StateT s m) where
finishWith = lift . finishWith
instance (WebMonad a m) => WebMonad a (Strict.StateT s m) where
finishWith = lift . finishWith
instance (ServerMonad m, Monoid w) => ServerMonad (Lazy.WriterT w m) where
askRq = lift askRq
localRq f = Lazy.mapWriterT (localRq f)
instance (ServerMonad m, Monoid w) => ServerMonad (Strict.WriterT w m) where
askRq = lift askRq
localRq f = Strict.mapWriterT (localRq f)
instance (FilterMonad res m, Monoid w) => FilterMonad res (Lazy.WriterT w m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = Lazy.mapWriterT (\m' ->
do ((b,w), f) <- getFilter m'
return ((b, f), w)) m
instance (FilterMonad res m, Monoid w) => FilterMonad res (Strict.WriterT w m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = Strict.mapWriterT (\m' ->
do ((b,w), f) <- getFilter m'
return ((b, f), w)) m
instance (WebMonad a m, Monoid w) => WebMonad a (Lazy.WriterT w m) where
finishWith = lift . finishWith
instance (WebMonad a m, Monoid w) => WebMonad a (Strict.WriterT w m) where
finishWith = lift . finishWith
instance (ServerMonad m, Monoid w) => ServerMonad (Lazy.RWST r w s m) where
askRq = lift askRq
localRq f = Lazy.mapRWST (localRq f)
instance (ServerMonad m, Monoid w) => ServerMonad (Strict.RWST r w s m) where
askRq = lift askRq
localRq f = Strict.mapRWST (localRq f)
instance (FilterMonad res m, Monoid w) => FilterMonad res (Lazy.RWST r w s m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = Lazy.mapRWST (\m' ->
do ((b,s,w), f) <- getFilter m'
return ((b, f), s, w)) m
instance (FilterMonad res m, Monoid w) => FilterMonad res (Strict.RWST r w s m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = Strict.mapRWST (\m' ->
do ((b,s,w), f) <- getFilter m'
return ((b, f), s, w)) m
instance (WebMonad a m, Monoid w) => WebMonad a (Lazy.RWST r w s m) where
finishWith = lift . finishWith
instance (WebMonad a m, Monoid w) => WebMonad a (Strict.RWST r w s m) where
finishWith = lift . finishWith
instance (Error e, ServerMonad m) => ServerMonad (ErrorT e m) where
askRq = lift askRq
localRq f = mapErrorT $ localRq f
instance (Error e, FilterMonad a m) => FilterMonad a (ErrorT e m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = mapErrorT (\m' ->
do (eb, f) <- getFilter m'
case eb of
(Left e) -> return (Left e)
(Right b) -> return $ Right (b, f)
) m
instance (Error e, WebMonad a m) => WebMonad a (ErrorT e m) where
finishWith = lift . finishWith
instance ServerMonad m => ServerMonad (ExceptT e m) where
askRq = lift askRq
localRq f = mapExceptT $ localRq f
instance (FilterMonad a m) => FilterMonad a (ExceptT e m) where
setFilter f = lift $ setFilter f
composeFilter = lift . composeFilter
getFilter m = mapExceptT (\m' ->
do (eb, f) <- getFilter m'
case eb of
(Left e) -> return (Left e)
(Right b) -> return $ Right (b, f)
) m
instance WebMonad a m => WebMonad a (ExceptT e m) where
finishWith = lift . finishWith