module Network.Wai.Session.PostgreSQL
( dbStore
, WithPostgreSQLConn (..)
, StoreSettings (..)
) where
import Control.Exception.Base
import Control.Exception
import Control.Monad
import Control.Monad.IO.Class
import Data.Int (Int64)
import Data.Serialize (encode, decode, Serialize)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Database.PostgreSQL.Simple
import Network.Wai.Session
import qualified Data.ByteString as B
data StoreSettings = StoreSettings
{ storeSettingsSessionTimeout :: Int64
, storeSettingsKeyGen :: IO B.ByteString
}
class WithPostgreSQLConn a where
withPostgreSQLConn :: a -> (Connection -> IO b) -> IO b
instance WithPostgreSQLConn Connection where
withPostgreSQLConn conn = bracket (return conn) (\_ -> return ())
qryCreateTable = "CREATE TABLE session (id bigserial NOT NULL, session_key character varying NOT NULL, session_created_at bigint NOT NULL, session_last_access bigint NOT NULL, session_value bytea NOT NULL, CONSTRAINT session_pkey PRIMARY KEY (id), CONSTRAINT session_session_key_key UNIQUE (session_key)) WITH (OIDS=FALSE);"
qryCreateSession = "INSERT INTO session (session_key, session_created_at, session_last_access, session_value) VALUES (?,?,?,?)"
qryUpdateSession = "UPDATE session SET session_value=?,session_last_access=? WHERE session_key=?"
qryLookupSession = "SELECT session_value FROM session WHERE session_key=? AND session_last_access>=?"
qryLookupSession' = "UPDATE session SET session_last_access=? WHERE session_key=?"
qryLookupSession'' = "SELECT session_value FROM session WHERE session_key=?"
dbStore :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> IO (SessionStore m k v)
dbStore pool stos = do
withPostgreSQLConn pool $ \ conn ->
unerror $ execute_ conn qryCreateTable
return $ dbStore' pool stos
dbStore' :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> SessionStore m k v
dbStore' pool stos Nothing = do
newKey <- storeSettingsKeyGen stos
let map = [] :: [(k, v)]
map' = ""
curtime <- round <$> liftIO getPOSIXTime
withPostgreSQLConn pool $ \ conn ->
void $ execute conn qryCreateSession (newKey, curtime :: Int64, curtime, map' :: B.ByteString)
backend pool newKey map
dbStore' pool stos (Just key) = do
let map = [] :: [(k, v)]
map' = "\"\""
curtime <- round <$> liftIO getPOSIXTime
res <- withPostgreSQLConn pool $ \ conn ->
query conn qryLookupSession (key, curtime storeSettingsSessionTimeout stos) :: IO [Only B.ByteString]
case res of
[Only _] -> backend pool key map
_ -> dbStore' pool stos Nothing
backend :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> B.ByteString -> [(k, v)] -> IO (Session m k v, IO B.ByteString)
backend pool key mappe =
return ( (
(reader pool key mappe)
, (writer pool key mappe) )
, return key )
reader :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> B.ByteString -> [(k, v)] -> k -> m (Maybe v)
reader pool key mappe k = do
curtime <- round <$> liftIO getPOSIXTime
res <- liftIO $ withPostgreSQLConn pool $ \conn -> do
void $ execute conn qryLookupSession' (curtime :: Int64, key)
query conn qryLookupSession'' (Only key)
case res of
[Only store'] -> case decode (fromBinary store') of
Right store -> return $ k `lookup` store
Left error -> return Nothing
[] -> return Nothing
writer :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> B.ByteString -> [(k, v)] -> k -> v -> m ()
writer pool key mappe k v = do
curtime <- round <$> liftIO getPOSIXTime
[Only store] <- liftIO $ withPostgreSQLConn pool $ \conn ->
query conn qryLookupSession'' (Only key)
let store' = case decode (fromBinary store) of
Right s -> s
_ -> []
store'' = ((k,v):) . filter ((/=k) . fst) $ store'
store''' = encode store''
liftIO $ withPostgreSQLConn pool $ \conn ->
void $ execute conn qryUpdateSession (Binary store''', curtime :: Int64, key)
ignoreSqlError :: SqlError -> IO ()
ignoreSqlError _ = pure ()
unerror :: IO a -> IO ()
unerror action = void action `catch` ignoreSqlError