{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
module Snap.Snaplet.Session.Backends.RedisSession
( initRedisSessionManager
) where
import Control.Monad.Reader
import Data.ByteString (ByteString)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Serialize (Serialize)
import qualified Data.Serialize as S
import Data.Text (Text)
import Data.Text.Encoding
import Data.Typeable
import Snap.Core (Snap)
import Web.ClientSession
import Database.Redis
import Snap.Snaplet
import Snap.Snaplet.RedisDB
import Snap.Snaplet.Session
import Snap.Snaplet.Session.SessionManager
type Session = HashMap Text Text
data RedisSession = RedisSession
{ rsCSRFToken :: Text
, rsSession :: Session
}
deriving (Eq, Show)
instance Serialize RedisSession where
put (RedisSession a _) =
S.put $ encodeUtf8 a
get =
let unpack a = RedisSession (decodeUtf8 a) HM.empty
in unpack <$> S.get
encodeTuple :: (Text, Text) -> (ByteString, ByteString)
encodeTuple (a,b) = (encodeUtf8 a, encodeUtf8 b)
decodeTuple :: (ByteString, ByteString) -> (Text, Text)
decodeTuple (a,b) = (decodeUtf8 a, decodeUtf8 b)
mkCookieSession :: RNG -> IO RedisSession
mkCookieSession rng = do
t <- liftIO $ mkCSRFToken rng
return $ RedisSession t HM.empty
data RedisSessionManager = RedisSessionManager {
session :: Maybe RedisSession
, siteKey :: Key
, cookieName :: ByteString
, cookieDomain :: Maybe ByteString
, timeOut :: Maybe Int
, randomNumberGenerator :: RNG
, _redisConnection :: Connection
} deriving (Typeable)
loadDefSession :: RedisSessionManager -> IO RedisSessionManager
loadDefSession mgr@(RedisSessionManager ses _ _ _ _ rng _) =
case ses of
Nothing -> do ses' <- mkCookieSession rng
return $! mgr { session = Just ses' }
Just _ -> return mgr
modSession :: (Session -> Session) -> RedisSession -> RedisSession
modSession f (RedisSession t ses) = RedisSession t (f ses)
sessionKey :: Text -> ByteString
sessionKey t = encodeUtf8 $ mappend "session:" t
initRedisSessionManager
:: FilePath
-> ByteString
-> Maybe ByteString
-> Maybe Int
-> RedisDB
-> SnapletInit b SessionManager
initRedisSessionManager fp cn cd to c =
makeSnaplet "RedisSession"
"A snaplet providing sessions via HTTP cookies with a Redis backend."
Nothing $ liftIO $ do
key <- getKey fp
rng <- liftIO mkRNG
return $! SessionManager
$ RedisSessionManager Nothing key cn cd to rng (_connection c)
instance ISessionManager RedisSessionManager where
load mgr@(RedisSessionManager r _ _ _ _ rng con) =
case r of
Just _ -> return mgr
Nothing -> do
pl <- getPayload mgr
case pl of
Nothing -> liftIO $ loadDefSession mgr
Just (Payload x) -> do
let c = S.decode x
case c of
Left _ -> liftIO $ loadDefSession mgr
Right cs -> liftIO $ do
sess <- runRedis con $ do
l <- hgetall (sessionKey $ rsCSRFToken cs)
case l of
Left _ -> liftIO $ mkCookieSession rng
Right l' -> do
let rs = cs { rsSession = HM.fromList $ map decodeTuple l'}
return rs
return mgr { session = Just sess }
commit mgr@(RedisSessionManager r _ _ _ to rng con) = do
pl <- case r of
Just r' -> liftIO $
runRedis con $ do
res <- multiExec $ do
_ <- del [sessionKey (rsCSRFToken r')]
let sess = map encodeTuple $ HM.toList (rsSession r')
res1 <- case sess of
[] -> hmset (sessionKey (rsCSRFToken r')) [("","")]
_ -> hmset (sessionKey (rsCSRFToken r')) sess
res2 <- case to of
Just i -> expire (sessionKey (rsCSRFToken r')) $ toInteger i
Nothing -> persist (sessionKey (rsCSRFToken r'))
return $ (,) <$> res1 <*> res2
case res of
TxSuccess _ -> return . Payload $ S.encode r'
TxError e -> error e
TxAborted -> error "transaction aborted"
Nothing -> liftIO $ Payload . S.encode <$> mkCookieSession rng
setPayload mgr pl
reset mgr@(RedisSessionManager r _ _ _ _ _ con) = do
case r of
Just r' -> liftIO $
runRedis con $ do
res1 <- del [sessionKey $ rsCSRFToken r']
case res1 of
Left e -> error $ show e
_ -> return ()
_ -> return ()
cs <- liftIO $ mkCookieSession (randomNumberGenerator mgr)
return $ mgr { session = Just cs }
touch = id
insert k v mgr@(RedisSessionManager r _ _ _ _ _ _) = case r of
Just r' -> mgr { session = Just $ modSession (HM.insert k v) r' }
Nothing -> mgr
lookup k (RedisSessionManager r _ _ _ _ _ _) = r >>= HM.lookup k . rsSession
delete k mgr@(RedisSessionManager r _ _ _ _ _ _) = case r of
Just r' -> mgr { session = Just $ modSession (HM.delete k) r' }
Nothing -> mgr
csrf (RedisSessionManager r _ _ _ _ _ _) = case r of
Just r' -> rsCSRFToken r'
Nothing -> ""
toList (RedisSessionManager r _ _ _ _ _ _) = case r of
Just r' -> HM.toList . rsSession $ r'
Nothing -> []
newtype Payload = Payload ByteString
deriving (Eq, Show, Ord, Serialize)
getPayload :: RedisSessionManager -> Snap (Maybe Payload)
getPayload mgr = getSecureCookie (cookieName mgr) (siteKey mgr) (timeOut mgr)
setPayload :: RedisSessionManager -> Payload -> Snap ()
setPayload mgr = setSecureCookie
(cookieName mgr)
#if MIN_VERSION_snap(1,0,0)
(cookieDomain mgr)
#endif
(siteKey mgr) (timeOut mgr)