{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Auth.ClientSession
  ( loadCookieValue
  , saveCookieValue
  , deleteCookieValue
  , Key
  , getDefaultKey
  ) where

import           Blaze.ByteString.Builder   (toByteString)
import           Control.Monad              (guard)
import           Data.Binary                (Binary, decodeOrFail, encode)
import qualified Data.ByteString            as S
import qualified Data.ByteString.Base64.URL as B64
import qualified Data.ByteString.Lazy       as L
import           Data.Int                   (Int64)
import           Data.Maybe                 (listToMaybe)
import           Data.Time.Clock            (UTCTime(UTCTime))
import           Data.Time.Calendar         (fromGregorian)
import           Foreign.C.Types            (CTime (..))
import           GHC.Generics               (Generic)
import           Network.HTTP.Types         (Header)
import           Network.Wai                (Request, requestHeaders)
import           System.PosixCompat.Time    (epochTime)
import           Web.ClientSession          (Key, decrypt, encryptIO,
                                             getDefaultKey)
import           Web.Cookie                 (def, parseCookies, renderSetCookie,
                                             sameSiteLax, setCookieExpires,
                                             setCookieHttpOnly, setCookieMaxAge,
                                             setCookieName, setCookiePath,
                                             setCookieSameSite, setCookieValue)

data Wrapper value = Wrapper
  { Wrapper value -> value
contained :: value
  , Wrapper value -> Int64
expires   :: !Int64 -- ^ should really be EpochTime or CTime, but there's no Binary instance
  } deriving ((forall x. Wrapper value -> Rep (Wrapper value) x)
-> (forall x. Rep (Wrapper value) x -> Wrapper value)
-> Generic (Wrapper value)
forall x. Rep (Wrapper value) x -> Wrapper value
forall x. Wrapper value -> Rep (Wrapper value) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall value x. Rep (Wrapper value) x -> Wrapper value
forall value x. Wrapper value -> Rep (Wrapper value) x
$cto :: forall value x. Rep (Wrapper value) x -> Wrapper value
$cfrom :: forall value x. Wrapper value -> Rep (Wrapper value) x
Generic)
instance Binary value => Binary (Wrapper value)

loadCookieValue
  :: Binary value
  => Key
  -> S.ByteString -- ^ cookie name
  -> Request
  -> IO (Maybe value)
loadCookieValue :: Key -> ByteString -> Request -> IO (Maybe value)
loadCookieValue Key
key ByteString
name Request
req = do
  CTime Int64
now <- IO CTime
epochTime
  Maybe value -> IO (Maybe value)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe value -> IO (Maybe value))
-> Maybe value -> IO (Maybe value)
forall a b. (a -> b) -> a -> b
$
    [value] -> Maybe value
forall a. [a] -> Maybe a
listToMaybe ([value] -> Maybe value) -> [value] -> Maybe value
forall a b. (a -> b) -> a -> b
$ do
      (HeaderName
k, ByteString
v) <- Request -> RequestHeaders
requestHeaders Request
req
      Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
"cookie"
      (ByteString
name', ByteString
v') <- ByteString -> Cookies
parseCookies ByteString
v
      Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ ByteString
name ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
name'
      Right ByteString
v'' <- Either String ByteString -> [Either String ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String ByteString -> [Either String ByteString])
-> Either String ByteString -> [Either String ByteString]
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String ByteString
B64.decode ByteString
v'
      Just ByteString
v''' <- Maybe ByteString -> [Maybe ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> [Maybe ByteString])
-> Maybe ByteString -> [Maybe ByteString]
forall a b. (a -> b) -> a -> b
$ Key -> ByteString -> Maybe ByteString
decrypt Key
key ByteString
v''
      Right (ByteString
_, Int64
_, Wrapper value
res Int64
expi) <-
        Either
  (ByteString, Int64, String) (ByteString, Int64, Wrapper value)
-> [Either
      (ByteString, Int64, String) (ByteString, Int64, Wrapper value)]
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
   (ByteString, Int64, String) (ByteString, Int64, Wrapper value)
 -> [Either
       (ByteString, Int64, String) (ByteString, Int64, Wrapper value)])
-> Either
     (ByteString, Int64, String) (ByteString, Int64, Wrapper value)
-> [Either
      (ByteString, Int64, String) (ByteString, Int64, Wrapper value)]
forall a b. (a -> b) -> a -> b
$ ByteString
-> Either
     (ByteString, Int64, String) (ByteString, Int64, Wrapper value)
forall a.
Binary a =>
ByteString
-> Either (ByteString, Int64, String) (ByteString, Int64, a)
decodeOrFail (ByteString
 -> Either
      (ByteString, Int64, String) (ByteString, Int64, Wrapper value))
-> ByteString
-> Either
     (ByteString, Int64, String) (ByteString, Int64, Wrapper value)
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
L.fromStrict ByteString
v'''
      Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ Int64
expi Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
now
      value -> [value]
forall (m :: * -> *) a. Monad m => a -> m a
return value
res

saveCookieValue
  :: Binary value
  => Key
  -> S.ByteString -- ^ cookie name
  -> Int -- ^ age in seconds
  -> value
  -> IO Header
saveCookieValue :: Key -> ByteString -> Int -> value -> IO (HeaderName, ByteString)
saveCookieValue Key
key ByteString
name Int
age value
value = do
  CTime Int64
now <- IO CTime
epochTime
  ByteString
value' <-
    Key -> ByteString -> IO ByteString
encryptIO Key
key (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$
    ByteString -> ByteString
L.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
    Wrapper value -> ByteString
forall a. Binary a => a -> ByteString
encode
      Wrapper :: forall value. value -> Int64 -> Wrapper value
Wrapper {contained :: value
contained = value
value, expires :: Int64
expires = Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
now Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
age}
  (HeaderName, ByteString) -> IO (HeaderName, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( HeaderName
"Set-Cookie"
    , Builder -> ByteString
toByteString (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$
      SetCookie -> Builder
renderSetCookie
        SetCookie
forall a. Default a => a
def
        { setCookieName :: ByteString
setCookieName = ByteString
name
        , setCookieValue :: ByteString
setCookieValue = ByteString -> ByteString
B64.encode ByteString
value'
        , setCookiePath :: Maybe ByteString
setCookiePath = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"/"
        , setCookieHttpOnly :: Bool
setCookieHttpOnly = Bool
True
        , setCookieMaxAge :: Maybe DiffTime
setCookieMaxAge = DiffTime -> Maybe DiffTime
forall a. a -> Maybe a
Just (DiffTime -> Maybe DiffTime) -> DiffTime -> Maybe DiffTime
forall a b. (a -> b) -> a -> b
$ Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
age
        , setCookieSameSite :: Maybe SameSiteOption
setCookieSameSite = SameSiteOption -> Maybe SameSiteOption
forall a. a -> Maybe a
Just SameSiteOption
sameSiteLax
        })

deleteCookieValue
  :: S.ByteString -- ^ cookie name
  -> Header
deleteCookieValue :: ByteString -> (HeaderName, ByteString)
deleteCookieValue ByteString
name =
  ( HeaderName
"Set-Cookie"
  , Builder -> ByteString
toByteString (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$
    SetCookie -> Builder
renderSetCookie
      SetCookie
forall a. Default a => a
def
        { setCookieName :: ByteString
setCookieName = ByteString
name
        , setCookieValue :: ByteString
setCookieValue = ByteString
""
        , setCookiePath :: Maybe ByteString
setCookiePath = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"/"
        , setCookieHttpOnly :: Bool
setCookieHttpOnly = Bool
True
        , setCookieExpires :: Maybe UTCTime
setCookieExpires = UTCTime -> Maybe UTCTime
forall a. a -> Maybe a
Just (UTCTime -> Maybe UTCTime) -> UTCTime -> Maybe UTCTime
forall a b. (a -> b) -> a -> b
$ Day -> DiffTime -> UTCTime
UTCTime (Integer -> Int -> Int -> Day
fromGregorian Integer
1970 Int
01 Int
01) DiffTime
0
        , setCookieSameSite :: Maybe SameSiteOption
setCookieSameSite = SameSiteOption -> Maybe SameSiteOption
forall a. a -> Maybe a
Just SameSiteOption
sameSiteLax
        })