module Network.Wai.Session.PostgreSQL
( dbStore
, clearSession
, defaultSettings
, fromSimpleConnection
, purgeOldSessions
, purger
, ratherSecureGen
, SimpleConnection
, StoreSettings (..)
, WithPostgreSQLConn (..)
) where
import Control.Applicative ((<$>))
import Control.Concurrent
import Control.Concurrent.MVar
import Control.Exception.Base
import Control.Exception
import Control.Monad
import Control.Monad.IO.Class
import Data.Default
import Data.Int (Int64)
import Data.Pool (Pool, withResource)
import Data.Serialize (encode, decode, Serialize)
import Data.String (fromString)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Database.PostgreSQL.Simple
import Network.Wai (Request, requestHeaders)
import Network.Wai.Session
import Numeric (showHex)
import System.Entropy (getEntropy)
import Web.Cookie (parseCookies)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
data StoreSettings = StoreSettings {
storeSettingsSessionTimeout :: Int64
, storeSettingsKeyGen :: IO B.ByteString
, storeSettingsCreateTable :: Bool
, storeSettingsLog :: String -> IO ()
, storeSettingsPurgeInterval :: Int
}
instance Default StoreSettings where
def = defaultSettings
class WithPostgreSQLConn a where
withPostgreSQLConn :: a -> (Connection -> IO b) -> IO b
fromSimpleConnection :: Connection -> IO SimpleConnection
fromSimpleConnection connection = do
mvar <- newMVar ()
return $ SimpleConnection (mvar, connection)
newtype SimpleConnection = SimpleConnection (MVar (), Connection)
instance WithPostgreSQLConn SimpleConnection where
withPostgreSQLConn (SimpleConnection (mvar, conn)) =
bracket (takeMVar mvar >> return conn) (\_ -> putMVar mvar ())
instance WithPostgreSQLConn (Pool Connection) where
withPostgreSQLConn = withResource
qryCreateTable1 :: Query
qryCreateTable1 = "CREATE TABLE wai_pg_sessions (id bigserial NOT NULL, session_key character varying NOT NULL, session_created_at bigint NOT NULL, session_last_access bigint NOT NULL, session_invalidate_key boolean NOT NULL DEFAULT false, CONSTRAINT wai_pg_sessions_pkey PRIMARY KEY (id), CONSTRAINT wai_pg_sessions_session_key UNIQUE (session_key)) WITH ( OIDS=FALSE );"
qryCreateIndex1 :: Query
qryCreateIndex1 = "CREATE INDEX idx_session_last_access ON public.wai_pg_sessions USING btree (session_last_access);"
qryCreateTable2 :: Query
qryCreateTable2 = "CREATE TABLE wai_pg_session_data ( id bigserial NOT NULL, wai_pg_session bigint, key bytea, value bytea, CONSTRAINT wai_pg_session_data_pkey PRIMARY KEY (id), CONSTRAINT wai_pg_session_data_wai_pg_session_fkey FOREIGN KEY (wai_pg_session) REFERENCES wai_pg_sessions (id) MATCH SIMPLE ON UPDATE RESTRICT ON DELETE CASCADE, CONSTRAINT wai_pg_session_data_wai_pg_session_key_key UNIQUE (wai_pg_session, key) ) WITH (OIDS=FALSE);"
qryCreateSession :: Query
qryCreateSession = "INSERT INTO wai_pg_sessions (session_key, session_created_at, session_last_access) VALUES (?,?,?) RETURNING id"
qryCreateSessionEntry :: Query
qryCreateSessionEntry = "INSERT INTO wai_pg_session_data (wai_pg_session,key,value) VALUES (?,?,?)"
qryUpdateSession :: Query
qryUpdateSession = "UPDATE wai_pg_sessions SET session_last_access=? WHERE id=?"
qryUpdateSessionEntry :: Query
qryUpdateSessionEntry = "UPDATE wai_pg_session_data SET value=? WHERE wai_pg_session=? AND key=?"
qryLookupSession :: Query
qryLookupSession = "SELECT id FROM wai_pg_sessions WHERE session_key=? AND session_last_access>=?"
qryLookupSession' :: Query
qryLookupSession' = "UPDATE wai_pg_sessions SET session_last_access=? WHERE id=?"
qryLookupSession'' :: Query
qryLookupSession'' = "SELECT value FROM wai_pg_session_data WHERE wai_pg_session=? AND key=?"
qryLookupSession''' :: Query
qryLookupSession''' = "SELECT id FROM wai_pg_session_data WHERE wai_pg_session=? AND key=?"
qryPurgeOldSessions :: Query
qryPurgeOldSessions = "DELETE FROM wai_pg_sessions WHERE session_last_access<?"
qryCheckNewKey :: Query
qryCheckNewKey = "SELECT session_invalidate_key FROM wai_pg_sessions WHERE session_key=?"
qryInvalidateSess1 :: Query
qryInvalidateSess1 = "UPDATE wai_pg_sessions SET session_invalidate_key=TRUE WHERE session_key=?"
qryInvalidateSess2 :: Query
qryInvalidateSess2 = "DELETE FROM wai_pg_session_data WHERE wai_pg_session=(SELECT id FROM wai_pg_sessions WHERE session_key=?)"
qryUpdateKey :: Query
qryUpdateKey = "UPDATE wai_pg_sessions SET session_key=?,session_invalidate_key=FALSE WHERE session_key=?"
dbStore :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> IO (SessionStore m k v)
dbStore pool stos = do
when (storeSettingsCreateTable stos) $
withPostgreSQLConn pool $ \ conn ->
unerror $ do
void $ execute_ conn qryCreateTable1
void $ execute_ conn qryCreateTable2
void $ execute_ conn qryCreateIndex1
storeSettingsLog stos "Created tables."
return $ dbStore' pool stos
purgeOldSessions :: WithPostgreSQLConn a => a -> StoreSettings -> IO Int64
purgeOldSessions pool stos = do
curtime <- round <$> liftIO getPOSIXTime
count <- withPostgreSQLConn pool $ \ conn ->
execute conn qryPurgeOldSessions (Only (curtime storeSettingsSessionTimeout stos))
storeSettingsLog stos $ "Purged " ++ show count ++ " session(s)."
return count
purger :: WithPostgreSQLConn a => a -> StoreSettings -> IO ThreadId
purger pool stos = forkIO . forever . unerror $ do
purgeOldSessions pool stos
threadDelay $ storeSettingsPurgeInterval stos
defaultSettings :: StoreSettings
defaultSettings = StoreSettings
{ storeSettingsSessionTimeout=3600
, storeSettingsKeyGen=ratherSecureGen 24
, storeSettingsCreateTable=True
, storeSettingsLog=putStrLn
, storeSettingsPurgeInterval=600000000
}
ratherSecureGen :: Int -> IO B.ByteString
ratherSecureGen n = TE.encodeUtf8 . prettyPrint <$> getEntropy n
prettyPrint :: B.ByteString -> T.Text
prettyPrint = T.pack . concatMap (`showHex` "") . B.unpack
dbStore' :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> SessionStore m k v
dbStore' pool stos Nothing = do
newKey <- storeSettingsKeyGen stos
curtime <- liftIO getPOSIXTime
sessionPgId <- withPostgreSQLConn pool $ \ conn -> do
[Only res] <- query conn qryCreateSession (newKey, round curtime :: Int64, round curtime :: Int64) :: IO [Only Int64]
return (res :: Int64)
backend pool stos newKey sessionPgId
dbStore' pool stos (Just key) = do
curtime <- round <$> liftIO getPOSIXTime
res <- withPostgreSQLConn pool $ \ conn ->
query conn qryLookupSession (key, curtime storeSettingsSessionTimeout stos) :: IO [Only Int64]
case res of
[Only sessionPgId] -> backend pool stos key sessionPgId
_ -> dbStore' pool stos Nothing
clearSession :: (WithPostgreSQLConn a) => a -> B.ByteString -> Request -> IO ()
clearSession pool cookieName req = do
let map = [] :: [(k, v)]
map' = ""
cookies = parseCookies <$> lookup (fromString "Cookie") (requestHeaders req)
Just key = lookup cookieName =<< cookies
withPostgreSQLConn pool $ \ conn ->
withTransaction conn $ do
void $ execute conn qryInvalidateSess1 (Only key)
void $ execute conn qryInvalidateSess2 (Only key)
backend :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> B.ByteString -> Int64 -> IO (Session m k v, IO B.ByteString)
backend pool stos key sessionPgId =
return ( (
reader pool key sessionPgId
, writer pool key sessionPgId ), withPostgreSQLConn pool $ \conn -> do
curtime <- liftIO getPOSIXTime
void $ execute conn qryLookupSession' (round curtime :: Int64, sessionPgId)
[Only shouldNewKey] <- query conn qryCheckNewKey (Only key)
if shouldNewKey then do
newKey' <- storeSettingsKeyGen stos
execute conn qryUpdateKey (newKey', key)
return newKey'
else
return key
)
reader :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> B.ByteString -> Int64 -> k -> m (Maybe v)
reader pool key sessionPgId k = do
res <- liftIO $ withPostgreSQLConn pool $ \conn -> do
query conn qryLookupSession'' (sessionPgId, Binary $ encode k)
case res of
[Only value] -> case decode (fromBinary value) of
Right value' -> return $ Just value'
Left error -> return Nothing
[] -> return Nothing
writer :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> B.ByteString -> Int64 -> k -> v -> m ()
writer pool key sessionPgId k v = do
let k' = Binary $ encode k
v' = Binary $ encode v
liftIO $ withPostgreSQLConn pool $ \conn ->
withTransaction conn $ do
res <- query conn qryLookupSession''' (sessionPgId, k') :: IO [Only Int64]
case res of
[Only id] -> void $ execute conn qryUpdateSessionEntry (v', sessionPgId, k')
_ -> void $ execute conn qryCreateSessionEntry (sessionPgId, k', v')
ignoreSqlError :: SqlError -> IO ()
ignoreSqlError _ = return ()
unerror :: IO a -> IO ()
unerror action = void action `catch` ignoreSqlError