{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
module Web.Spock.Config
    ( SpockCfg (..), defaultSpockCfg
      -- * Database
    , PoolOrConn (..), ConnBuilder (..), PoolCfg (..)
      -- * Sessions
    , defaultSessionCfg, SessionCfg (..)
    , defaultSessionHooks, SessionHooks (..)
    , SessionStore(..), SessionStoreInstance(..)
    , SV.newStmSessionStore
    )
where

import Web.Spock.Action
import Web.Spock.Internal.Types
import qualified Web.Spock.Internal.SessionVault as SV

#if MIN_VERSION_base(4,11,0)
#elif MIN_VERSION_base(4,9,0)
import Data.Semigroup
#else
import Data.Monoid
#endif
import Network.HTTP.Types.Status
import System.IO
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.IO as T

-- | NOP session hooks
defaultSessionHooks :: SessionHooks a
defaultSessionHooks :: SessionHooks a
defaultSessionHooks =
    SessionHooks :: forall a. (HashMap SessionId a -> IO ()) -> SessionHooks a
SessionHooks
    { sh_removed :: HashMap SessionId a -> IO ()
sh_removed = IO () -> HashMap SessionId a -> IO ()
forall a b. a -> b -> a
const (IO () -> HashMap SessionId a -> IO ())
-> IO () -> HashMap SessionId a -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    }

-- | Session configuration with reasonable defaults and an
-- stm based session store
defaultSessionCfg :: a -> IO (SessionCfg conn a st)
defaultSessionCfg :: a -> IO (SessionCfg conn a st)
defaultSessionCfg a
emptySession =
  do SessionStoreInstance (Session conn a st)
store <- IO (SessionStoreInstance (Session conn a st))
forall conn sess st.
IO (SessionStoreInstance (Session conn sess st))
SV.newStmSessionStore
     SessionCfg conn a st -> IO (SessionCfg conn a st)
forall (m :: * -> *) a. Monad m => a -> m a
return
       SessionCfg :: forall conn a st.
SessionId
-> CookieEOL
-> NominalDiffTime
-> Int
-> Bool
-> a
-> SessionStoreInstance (Session conn a st)
-> NominalDiffTime
-> SessionHooks a
-> SessionCfg conn a st
SessionCfg
       { sc_cookieName :: SessionId
sc_cookieName = SessionId
"spockcookie"
       , sc_cookieEOL :: CookieEOL
sc_cookieEOL = CookieEOL
CookieValidForever
       , sc_sessionTTL :: NominalDiffTime
sc_sessionTTL = NominalDiffTime
3600
       , sc_sessionIdEntropy :: Int
sc_sessionIdEntropy = Int
64
       , sc_sessionExpandTTL :: Bool
sc_sessionExpandTTL = Bool
True
       , sc_emptySession :: a
sc_emptySession = a
emptySession
       , sc_store :: SessionStoreInstance (Session conn a st)
sc_store = SessionStoreInstance (Session conn a st)
store
       , sc_housekeepingInterval :: NominalDiffTime
sc_housekeepingInterval = NominalDiffTime
60 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
10
       , sc_hooks :: SessionHooks a
sc_hooks = SessionHooks a
forall a. SessionHooks a
defaultSessionHooks
       }

-- | Spock configuration with reasonable defaults such as a basic error page
-- and 5MB request body limit. IMPORTANT: CSRF Protection is turned off by
-- default for now to not break any existing Spock applications. Consider
-- turning it on manually as it will become the default in the future.
defaultSpockCfg :: sess -> PoolOrConn conn -> st -> IO (SpockCfg conn sess st)
defaultSpockCfg :: sess -> PoolOrConn conn -> st -> IO (SpockCfg conn sess st)
defaultSpockCfg sess
sess PoolOrConn conn
conn st
st =
  do SessionCfg conn sess st
defSess <- sess -> IO (SessionCfg conn sess st)
forall a conn st. a -> IO (SessionCfg conn a st)
defaultSessionCfg sess
sess
     SpockCfg conn sess st -> IO (SpockCfg conn sess st)
forall (m :: * -> *) a. Monad m => a -> m a
return
       SpockCfg :: forall conn sess st.
st
-> PoolOrConn conn
-> SessionCfg conn sess st
-> Maybe Word64
-> (Status -> ActionCtxT () IO ())
-> (SessionId -> IO ())
-> Bool
-> SessionId
-> SessionId
-> SpockCfg conn sess st
SpockCfg
       { spc_initialState :: st
spc_initialState = st
st
       , spc_database :: PoolOrConn conn
spc_database = PoolOrConn conn
conn
       , spc_sessionCfg :: SessionCfg conn sess st
spc_sessionCfg = SessionCfg conn sess st
defSess
       , spc_maxRequestSize :: Maybe Word64
spc_maxRequestSize = Word64 -> Maybe Word64
forall a. a -> Maybe a
Just (Word64
5 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
1024 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
1024)
       , spc_logError :: SessionId -> IO ()
spc_logError = Handle -> SessionId -> IO ()
T.hPutStrLn Handle
stderr
       , spc_errorHandler :: Status -> ActionCtxT () IO ()
spc_errorHandler = Status -> ActionCtxT () IO ()
errorHandler
       , spc_csrfProtection :: Bool
spc_csrfProtection = Bool
False
       , spc_csrfHeaderName :: SessionId
spc_csrfHeaderName = SessionId
"X-Csrf-Token"
       , spc_csrfPostName :: SessionId
spc_csrfPostName = SessionId
"__csrf_token"
       }

errorHandler :: Status -> ActionCtxT () IO ()
errorHandler :: Status -> ActionCtxT () IO ()
errorHandler Status
status = SessionId -> ActionCtxT () IO ()
forall (m :: * -> *) ctx a.
MonadIO m =>
SessionId -> ActionCtxT ctx m a
html (SessionId -> ActionCtxT () IO ())
-> SessionId -> ActionCtxT () IO ()
forall a b. (a -> b) -> a -> b
$ Status -> SessionId
errorTemplate Status
status

-- Danger! This should better be done using combinators, but we do not
-- want Spock depending on a specific html combinator framework
errorTemplate :: Status -> T.Text
errorTemplate :: Status -> SessionId
errorTemplate Status
s =
    SessionId
"<html><head>"
    SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"<title>" SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
message SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"</title>"
    SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"</head>"
    SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"<body>"
    SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"<h1>" SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
message SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"</h1>"
    SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"<a href='https://www.spock.li'>powered by Spock</a>"
    SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
"</body>"
    where
      message :: SessionId
message =
          Int -> SessionId
showT (Status -> Int
statusCode Status
s) SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> SessionId
" - " SessionId -> SessionId -> SessionId
forall a. Semigroup a => a -> a -> a
<> ByteString -> SessionId
T.decodeUtf8 (Status -> ByteString
statusMessage Status
s)
      showT :: Int -> SessionId
showT = String -> SessionId
T.pack (String -> SessionId) -> (Int -> String) -> Int -> SessionId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show