{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
------------------------------------------------------------------------------
-- | This is a support module meant to back all session back-end
-- implementations.
--
-- It gives us an encrypted and timestamped cookie that can store an arbitrary
-- serializable payload. For security, it will:
--
--   * Encrypt its payload together with a timestamp.
--
--   * Check the timestamp for session expiration everytime you read from the
--     cookie. This will limit intercept-and-replay attacks by disallowing
--     cookies older than the timeout threshold.

module Snap.Snaplet.Session.SecureCookie
       ( SecureCookie
       , getSecureCookie
       , setSecureCookie
       , expireSecureCookie
       -- ** Helper functions
       , encodeSecureCookie
       , decodeSecureCookie
       , checkTimeout
       ) where

------------------------------------------------------------------------------
import           Control.Monad
import           Control.Monad.Trans
import           Data.ByteString       (ByteString)
import           Data.Serialize
import           Data.Time
import           Data.Time.Clock.POSIX
import           Snap.Core
import           Web.ClientSession

#if !MIN_VERSION_base(4,8,0)
import           Control.Applicative
#endif

------------------------------------------------------------------------------
-- | Arbitrary payload with timestamp.
type SecureCookie t = (UTCTime, t)


------------------------------------------------------------------------------
-- | Get the cookie payload.
getSecureCookie :: (MonadSnap m, Serialize t)
                => ByteString       -- ^ Cookie name
                -> Key              -- ^ Encryption key
                -> Maybe Int        -- ^ Timeout in seconds
                -> m (Maybe t)
getSecureCookie :: forall (m :: * -> *) t.
(MonadSnap m, Serialize t) =>
ByteString -> Key -> Maybe Int -> m (Maybe t)
getSecureCookie ByteString
name Key
key Maybe Int
timeout = do
    Maybe Cookie
rqCookie <- forall (m :: * -> *). MonadSnap m => ByteString -> m (Maybe Cookie)
getCookie ByteString
name
    Maybe Cookie
rspCookie <- ByteString -> Response -> Maybe Cookie
getResponseCookie ByteString
name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadSnap m => m Response
getResponse
    let ck :: Maybe Cookie
ck = Maybe Cookie
rspCookie forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Maybe Cookie
rqCookie
    let val :: Maybe (SecureCookie t)
val = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Cookie -> ByteString
cookieValue Maybe Cookie
ck forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a.
Serialize a =>
Key -> ByteString -> Maybe (SecureCookie a)
decodeSecureCookie Key
key
    case Maybe (SecureCookie t)
val of
      Maybe (SecureCookie t)
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
      Just (UTCTime
ts, t
t) -> do
          Bool
to <- forall (m :: * -> *). MonadSnap m => Maybe Int -> UTCTime -> m Bool
checkTimeout Maybe Int
timeout UTCTime
ts
          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Bool
to of
            Bool
True -> forall a. Maybe a
Nothing
            Bool
False -> forall a. a -> Maybe a
Just t
t


------------------------------------------------------------------------------
-- | Decode secure cookie payload wih key.
decodeSecureCookie  :: Serialize a
                     => Key                     -- ^ Encryption key
                     -> ByteString              -- ^ Encrypted payload
                     -> Maybe (SecureCookie a)
decodeSecureCookie :: forall a.
Serialize a =>
Key -> ByteString -> Maybe (SecureCookie a)
decodeSecureCookie Key
key ByteString
value = do
    ByteString
cv <- Key -> ByteString -> Maybe ByteString
decrypt Key
key ByteString
value
    (Integer
i, a
val) <- forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Serialize a => ByteString -> Either String a
decode ByteString
cv
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ (POSIXTime -> UTCTime
posixSecondsToUTCTime (forall a. Num a => Integer -> a
fromInteger Integer
i), a
val)


------------------------------------------------------------------------------
-- | Inject the payload.
setSecureCookie :: (MonadSnap m, Serialize t)
                => ByteString       -- ^ Cookie name
                -> Maybe ByteString -- ^ Cookie domain
                -> Key              -- ^ Encryption key
                -> Maybe Int        -- ^ Max age in seconds
                -> t                -- ^ Serializable payload
                -> m ()
setSecureCookie :: forall (m :: * -> *) t.
(MonadSnap m, Serialize t) =>
ByteString -> Maybe ByteString -> Key -> Maybe Int -> t -> m ()
setSecureCookie ByteString
name Maybe ByteString
domain Key
key Maybe Int
to t
val = do
    UTCTime
t <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    ByteString
val' <- forall (m :: * -> *) t.
(MonadIO m, Serialize t) =>
Key -> SecureCookie t -> m ByteString
encodeSecureCookie Key
key (UTCTime
t, t
val)
    let expire :: Maybe UTCTime
expire = Maybe Int
to forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip POSIXTime -> UTCTime -> UTCTime
addUTCTime UTCTime
t forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
    let nc :: Cookie
nc = ByteString
-> ByteString
-> Maybe UTCTime
-> Maybe ByteString
-> Maybe ByteString
-> Bool
-> Bool
-> Cookie
Cookie ByteString
name ByteString
val' Maybe UTCTime
expire Maybe ByteString
domain (forall a. a -> Maybe a
Just ByteString
"/") Bool
False Bool
True
    forall (m :: * -> *). MonadSnap m => (Response -> Response) -> m ()
modifyResponse forall a b. (a -> b) -> a -> b
$ Cookie -> Response -> Response
addResponseCookie Cookie
nc


------------------------------------------------------------------------------
-- | Encode SecureCookie with key into injectable payload
encodeSecureCookie :: (MonadIO m, Serialize t)
                    => Key            -- ^ Encryption key
                    -> SecureCookie t -- ^ Payload
                    -> m ByteString
encodeSecureCookie :: forall (m :: * -> *) t.
(MonadIO m, Serialize t) =>
Key -> SecureCookie t -> m ByteString
encodeSecureCookie Key
key (UTCTime
t, t
val) =
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Key -> ByteString -> IO ByteString
encryptIO Key
key forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Serialize a => a -> ByteString
encode forall a b. (a -> b) -> a -> b
$ (Integer
seconds, t
val)
  where
    seconds :: Integer
seconds = forall a b. (RealFrac a, Integral b) => a -> b
round (UTCTime -> POSIXTime
utcTimeToPOSIXSeconds UTCTime
t) :: Integer


------------------------------------------------------------------------------
-- | Expire secure cookie
expireSecureCookie :: MonadSnap m
                   => ByteString       -- ^ Cookie name
                   -> Maybe ByteString -- ^ Cookie domain
                   -> m ()
expireSecureCookie :: forall (m :: * -> *).
MonadSnap m =>
ByteString -> Maybe ByteString -> m ()
expireSecureCookie ByteString
name Maybe ByteString
domain = forall (m :: * -> *). MonadSnap m => Cookie -> m ()
expireCookie Cookie
cookie
  where
    cookie :: Cookie
cookie = ByteString
-> ByteString
-> Maybe UTCTime
-> Maybe ByteString
-> Maybe ByteString
-> Bool
-> Bool
-> Cookie
Cookie ByteString
name ByteString
"" forall a. Maybe a
Nothing Maybe ByteString
domain (forall a. a -> Maybe a
Just ByteString
"/") Bool
False Bool
False


------------------------------------------------------------------------------
-- | Validate session against timeout policy.
--
-- * If timeout is set to 'Nothing', never trigger a time-out.
--
-- * Otherwise, do a regular time-out check based on current time and given
--   timestamp.
checkTimeout :: (MonadSnap m) => Maybe Int -> UTCTime -> m Bool
checkTimeout :: forall (m :: * -> *). MonadSnap m => Maybe Int -> UTCTime -> m Bool
checkTimeout Maybe Int
Nothing UTCTime
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
checkTimeout (Just Int
x) UTCTime
t0 = do
    UTCTime
t1 <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ UTCTime
t1 forall a. Ord a => a -> a -> Bool
> POSIXTime -> UTCTime -> UTCTime
addUTCTime (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) UTCTime
t0