{-# LANGUAGE CPP #-} module Network.Wai.Session (Session, SessionStore, withSession, genSessionId) where import Data.Monoid (mconcat) import Data.String (fromString) import Control.Monad.IO.Class (liftIO) import Network.HTTP.Types (ResponseHeaders) import Network.Wai (Middleware, Request(..)) #if MIN_VERSION_wai(3,0,0) import Network.Wai.Internal (Response(ResponseBuilder,ResponseFile,ResponseStream,ResponseRaw)) #else import Network.Wai.Internal (Response(ResponseBuilder,ResponseFile,ResponseSource)) #endif import Web.Cookie (parseCookies, renderSetCookie, SetCookie(..)) #if MIN_VERSION_vault(0,3,0) import Data.Vault.Lazy (Key) import qualified Data.Vault.Lazy as Vault #else import Data.Vault (Key) import qualified Data.Vault as Vault #endif import Data.ByteString (ByteString, foldr') import Data.ByteString.Lazy (toStrict) import Data.ByteString.Builder (word8Hex, toLazyByteString) import qualified Blaze.ByteString.Builder as Builder import System.Entropy (getEntropy) -- | Type representing a single session (a lookup, insert pair) type Session m k v = ((k -> m (Maybe v)), (k -> v -> m ())) -- | A 'SessionStore' takes in the contents of the cookie (if there was one) -- and returns a ('Session', 'IO' action to get new contents for cookie) pair type SessionStore m k v = (Maybe ByteString -> IO (Session m k v, IO ByteString)) -- | Fully parameterised middleware for cookie-based sessions withSession :: SessionStore m k v -- ^ The 'SessionStore' to use for sessions -> ByteString -- ^ Name to use for the session cookie (MUST BE ASCII) -> SetCookie -- ^ Settings for the cookie (path, expiry, etc) -> Key (Session m k v) -- ^ 'Data.Vault.Vault' key to use when passing the session through -> Middleware #if MIN_VERSION_wai(3,0,0) withSession sessions cookieName cookieDefaults vkey app req respond = do #else withSession sessions cookieName cookieDefaults vkey app req = do #endif (session, getNewCookie) <- liftIO $ sessions $ lookup cookieName =<< cookies #if MIN_VERSION_wai(3,0,0) app (req {vault = Vault.insert vkey session (vault req)}) (\r -> do newCookieVal <- liftIO getNewCookie respond $ mapHeader (\hs -> (setCookie, newCookie newCookieVal):hs) r ) #else resp <- app (req {vault = Vault.insert vkey session (vault req)}) newCookieVal <- liftIO getNewCookie return $ mapHeader (\hs -> (setCookie, newCookie newCookieVal):hs) resp #endif where newCookie v = Builder.toByteString $ renderSetCookie $ cookieDefaults { setCookieName = cookieName, setCookieValue = v } cookies = fmap parseCookies $ lookup ciCookie (requestHeaders req) setCookie = fromString "Set-Cookie" ciCookie = fromString "Cookie" -- | Simple session ID generator using cryptographically strong random IDs -- -- Useful for session stores that use session IDs. genSessionId :: IO ByteString genSessionId = do randBytes <- getEntropy 32 return $ prettyPrint randBytes where prettyPrint :: ByteString -> ByteString prettyPrint = toStrict . toLazyByteString . mconcat . Data.ByteString.foldr' ( \ byte acc -> word8Hex byte:acc ) [] -- | Run a function over the headers in a 'Response' mapHeader :: (ResponseHeaders -> ResponseHeaders) -> Response -> Response mapHeader f (ResponseFile s h b1 b2) = ResponseFile s (f h) b1 b2 mapHeader f (ResponseBuilder s h b) = ResponseBuilder s (f h) b #if MIN_VERSION_wai(3,0,0) mapHeader f (ResponseStream s h b) = ResponseStream s (f h) b mapHeader f (ResponseRaw io resp) = ResponseRaw io (mapHeader f resp) #else mapHeader f (ResponseSource s h b) = ResponseSource s (f h) b #endif