module Web.Spock
(
runSpock, runSpockNoBanner, spockAsApp
, spock, SpockM, SpockCtxM
, Path, root, Var, var, static, (<//>), wildcard
, renderRoute
, subcomponent, prehook
, RouteSpec
, get, post, getpost, head, put, delete, patch, hookRoute
, hookRouteCustom, hookAny, hookAnyCustom
, C.StdMethod (..)
, middleware
, SpockAction, SpockActionCtx
, module Web.Spock.Action
, HasSpock(..), SessionManager
, module Web.Spock.SessionActions
, getCsrfToken, getClientCsrfToken, csrfCheck
, WebStateM, WebStateT, WebState
, getSpockHeart, runSpockIO, getSpockPool
)
where
import Web.Spock.Action
import Web.Spock.Core hiding
( hookRoute', hookAny'
, get, post, getpost, head, put, delete, patch, hookRoute
, hookRouteCustom, hookAny, hookAnyCustom
)
import Web.Spock.Internal.Monad
import Web.Spock.Internal.SessionManager
import Web.Spock.Internal.Types
import Web.Spock.SessionActions
import qualified Web.Spock.Core as C
import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Trans.Resource
import Data.Pool
import Network.HTTP.Types.Status (status403)
import Prelude hiding (head)
import qualified Data.HVect as HV
import qualified Data.Text as T
import qualified Data.Vault.Lazy as V
import qualified Network.Wai as Wai
type SpockM conn sess st = SpockCtxM () conn sess st
type SpockCtxM ctx conn sess st = SpockCtxT ctx (WebStateM conn sess st)
spock :: forall conn sess st. SpockCfg conn sess st -> SpockM conn sess st () -> IO Wai.Middleware
spock spockCfg spockAppl =
do connectionPool <-
case poolOrConn of
PCNoDatabase ->
createPool (return ()) (const $ return ()) 5 60 5
PCPool p ->
return p
PCConn cb ->
let pc = cb_poolConfiguration cb
in createPool (cb_createConn cb) (cb_destroyConn cb)
(pc_stripes pc) (pc_keepOpenTime pc)
(pc_resPerStripe pc)
internalState <-
WebState
<$> pure connectionPool
<*> (createSessionManager sessionCfg $
SessionIf
{ si_queryVault = queryVault
, si_modifyVault = modifyVault
, si_setRawMultiHeader = setRawMultiHeader
, si_vaultKey = V.newKey
}
)
<*> pure initialState
<*> pure spockCfg
let coreConfig =
defaultSpockConfig
{ sc_maxRequestSize = spc_maxRequestSize spockCfg
, sc_errorHandler = spc_errorHandler spockCfg
}
spockConfigT coreConfig (\m -> runResourceT $ runReaderT (runWebStateT m) internalState) $
do middleware (sm_middleware $ web_sessionMgr internalState)
spockAppl
where
sessionCfg = spc_sessionCfg spockCfg
poolOrConn = spc_database spockCfg
initialState = spc_initialState spockCfg
getCsrfToken :: SpockActionCtx ctx conn sess st T.Text
getCsrfToken = runInContext () $ sm_getCsrfToken =<< getSessMgr
getClientCsrfToken :: SpockActionCtx ctx conn sess st (Maybe T.Text)
getClientCsrfToken =
do cfg <- getSpockCfg
mHeader <- header (spc_csrfHeaderName cfg)
mParam <- param (spc_csrfPostName cfg)
pure (mHeader <|> mParam)
csrfCheck :: SpockActionCtx ctx conn sess st ()
csrfCheck =
do csrf <- getCsrfToken
clientCsrf <- getClientCsrfToken
case clientCsrf of
Nothing -> abort
Just csrfVal
| csrfVal == csrf -> pure ()
| otherwise -> abort
where
abort =
do setStatus status403
text "Broken/Missing CSRF Token"
type RouteSpec xs ps ctx conn sess st =
Path xs ps -> HV.HVectElim xs (SpockActionCtx ctx conn sess st ()) -> SpockCtxM ctx conn sess st ()
hookRoute :: HV.HasRep xs => StdMethod -> RouteSpec xs ps ctx conn sess st
hookRoute = hookRoute' . MethodStandard . HttpMethod
get :: HV.HasRep xs => RouteSpec xs ps ctx conn sess st
get = hookRoute GET
post :: HV.HasRep xs => RouteSpec xs ps ctx conn sess st
post = hookRoute POST
getpost :: HV.HasRep xs => RouteSpec xs ps ctx conn sess st
getpost r a = hookRoute POST r a >> hookRoute GET r a
head :: HV.HasRep xs => RouteSpec xs ps ctx conn sess st
head = hookRoute HEAD
put :: HV.HasRep xs => RouteSpec xs ps ctx conn sess st
put = hookRoute PUT
delete :: HV.HasRep xs => RouteSpec xs ps ctx conn sess st
delete = hookRoute DELETE
patch :: HV.HasRep xs => RouteSpec xs ps ctx conn sess st
patch = hookRoute PATCH
hookRouteCustom :: HV.HasRep xs => T.Text -> RouteSpec xs ps ctx conn sess st
hookRouteCustom = hookRoute' . MethodCustom
hookAny :: StdMethod -> ([T.Text] -> SpockActionCtx ctx conn sess st ()) -> SpockCtxM ctx conn sess st ()
hookAny = hookAny' . MethodStandard . HttpMethod
hookAnyCustom :: T.Text -> ([T.Text] -> SpockActionCtx ctx conn sess st ()) -> SpockCtxM ctx conn sess st ()
hookAnyCustom = hookAny' . MethodCustom
hookAny' :: SpockMethod -> ([T.Text] -> SpockActionCtx ctx conn sess st ()) -> SpockCtxM ctx conn sess st ()
hookAny' m action =
getSpockCfg >>= \cfg ->
C.hookAny' m $ \t ->
case m of
MethodStandard (HttpMethod stdMethod)
| shouldCheckCsrf stdMethod && spc_csrfProtection cfg -> csrfCheck >> action t
_ -> action t
hookRoute' ::
forall xs ps ctx conn sess st.
(HV.HasRep xs)
=> SpockMethod
-> RouteSpec xs ps ctx conn sess st
hookRoute' m path action =
do cfg <- getSpockCfg
checkedAction <-
case m of
MethodStandard (HttpMethod stdMethod)
| shouldCheckCsrf stdMethod && spc_csrfProtection cfg ->
do let unpackedAction :: HV.HVect xs -> SpockActionCtx ctx conn sess st ()
unpackedAction args =
csrfCheck >> HV.uncurry action args
pure $ HV.curry unpackedAction
_ -> pure action
C.hookRoute' m path checkedAction
shouldCheckCsrf :: StdMethod -> Bool
shouldCheckCsrf m =
case m of
GET -> False
HEAD -> False
OPTIONS -> False
_ -> True