{-# LANGUAGE FlexibleContexts #-}

{- |Description: This module provides access to cookie data, in the
 form of a SessionMap.
-}
module Servant.API.Cookies where

import Data.ByteString (ByteString)
import Data.ByteString.Lazy (toStrict)
import Data.ByteString.Builder (toLazyByteString)
import Data.Functor ((<&>))
import Data.Kind (Type)
import Data.Map.Strict (Map)
import Data.Time.Clock (getCurrentTime, secondsToDiffTime)
import Network.Wai
import Servant
import Servant.Server.Internal.Delayed (addHeaderCheck)
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFailFatal, withRequest)
import Web.ClientSession
import Web.Cookie

import qualified Data.Map.Strict as Map
import qualified Data.Vault.Lazy as Vault
import qualified Network.HTTP.Types.Header as NTH

-- |A SessionMap is a hash map of session data from a request.
type SessionMap = Map ByteString ByteString

{- |
  A SetCookieHeader is a convenience type for adding a "Set-Cookie"
  header that expects a SetCookie record type.

  I wanted to have the header name be NTH.hSetCookie for extra "use
  the known correct value" goodness, but that breaks the type magic
  Servant relies upon.
-}
type SetCookieHeader a = Headers '[Servant.Header "Set-Cookie" SetCookie] a

{- |
  The @ProvideCookies@ and @WithCookies@ combinator work in tandem
  together -- the @ProvideCookies@ combinator parses the cookies from
  the request and stores them in the WAI request Vault, the
  @WithCookies@ combinator provides the cookies as a hash map to the
  handler.
-}
data ProvideCookies (mods :: [Type])

{- |
  As mentioned above, the @WithCookies@ combinator provides
  already-parsed cookies to the handler as a SessionMap.

  The cookie values are assumed to be encrypted with a
  @Web.ClientSession.Key@. Likewise, @updateCookies@ encrypts the
  cookies on the outbound side via this mechanism.

  Example:

@
import Control.Monad.IO.Class (liftIO)
import Servant
import ServantExtras.Cookies

import qualified Data.Map.Strict as Map

type MyAPI = "my-cookie-enabled-endpoint"
           :> ProvideCookies '[Required]
           :> WithCookies '[Required]
           :> Get '[JSON] NoContent

myServer :: Server MyAPI
myServer = cookieEndpointHandler
 where
   cookieEndpointHandler :: SessionMap -> Handler NoContent
   cookieEndpointHandler sMap =
      let mCookieValue = lookup "MerlinWasHere" $ Map.toList sMap in
      case mCookieValue of
       Nothing -> do
         liftIO $ print "Merlin was *NOT* here!"
         throwError err400 { errBody = "Clearly you've missed something." }
       Just message -> do
         liftIO $ do
           print "Merlin WAS here, and he left us a message!"
           print message
         pure NoContent
@
-}
data WithCookies (mods :: [Type])

{- |
  @HasCookies@ and @HasCookiesMaybe@ are internal utitily types. You should only need to use @ProvideCookies@ and @WithCookies@.

  As an aside, they're separate types (rather than a single type with
  a (mods :: [Type]) ) phantom type because the term-level values show up
  in the instances, and I didn't see a clean way to separate them out
  by case, and only covering one value from the sum type made Haskell
  (rightly) complain.
-}
data HasCookies = HasCookies

{- |
  @HasCookies@ and @HasCookiesMaybe@ are internal utitily types. You should only need to use @ProvideCookies@ and @WithCookies@.
-}
data HasCookiesMaybe = HasCookiesMaybe

instance
  ( HasServer api (HasCookies ': ctx)
  , HasContextEntry ctx (Vault.Key SessionMap)
  , HasContextEntry ctx (Key) -- for encrypting/decrypting the cookie
  ) =>
  HasServer (ProvideCookies '[Required] :> api) ctx
  where
  type ServerT (ProvideCookies '[Required] :> api) m = ServerT api m

  hoistServerWithContext :: forall (m :: * -> *) (n :: * -> *).
Proxy (ProvideCookies '[Required] :> api)
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT (ProvideCookies '[Required] :> api) m
-> ServerT (ProvideCookies '[Required] :> api) n
hoistServerWithContext Proxy (ProvideCookies '[Required] :> api)
_ Proxy ctx
_ forall x. m x -> n x
nt ServerT (ProvideCookies '[Required] :> api) m
server =
    forall {k} (api :: k) (context :: [*]) (m :: * -> *) (n :: * -> *).
HasServer api context =>
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
hoistServerWithContext (forall {k} (t :: k). Proxy t
Proxy @api) (forall {k} (t :: k). Proxy t
Proxy @(HasCookies ': ctx)) forall x. m x -> n x
nt ServerT (ProvideCookies '[Required] :> api) m
server

  route :: forall env.
Proxy (ProvideCookies '[Required] :> api)
-> Context ctx
-> Delayed env (Server (ProvideCookies '[Required] :> api))
-> Router env
route Proxy (ProvideCookies '[Required] :> api)
_ Context ctx
ctx Delayed env (Server (ProvideCookies '[Required] :> api))
server =
    forall {k} (api :: k) (context :: [*]) env.
HasServer api context =>
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
route (forall {k} (t :: k). Proxy t
Proxy @api) (HasCookies
HasCookies forall x (xs :: [*]). x -> Context xs -> Context (x : xs)
:. Context ctx
ctx) Delayed env (Server (ProvideCookies '[Required] :> api))
server forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \RoutingApplication
app Request
req RouteResult Response -> IO ResponseReceived
respK -> do
      let
        mCookie :: Maybe ByteString
mCookie = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
NTH.hCookie (Request -> RequestHeaders
requestHeaders Request
req)
        key :: Key SessionMap
key = forall (context :: [*]) val.
HasContextEntry context val =>
Context context -> val
getContextEntry Context ctx
ctx :: Vault.Key SessionMap
        encKey :: Key
encKey = forall (context :: [*]) val.
HasContextEntry context val =>
Context context -> val
getContextEntry Context ctx
ctx :: Key
        mCookie' :: Maybe ByteString
mCookie' = Maybe ByteString
mCookie forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Key -> ByteString -> Maybe ByteString
decrypt Key
encKey)
        cookies :: SessionMap
cookies = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall k a. Map k a
Map.empty (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Cookies
parseCookies) Maybe ByteString
mCookie'
        req' :: Request
req' = Request
req {vault :: Vault
vault = forall a. Key a -> a -> Vault -> Vault
Vault.insert Key SessionMap
key SessionMap
cookies (Request -> Vault
vault Request
req)}
      RoutingApplication
app Request
req' RouteResult Response -> IO ResponseReceived
respK

instance
  ( HasServer api (HasCookiesMaybe ': ctx)
  , HasContextEntry ctx (Vault.Key (Maybe SessionMap))
  , HasContextEntry ctx (Key) -- for encrypting/decrypting the cookie
  ) =>
  HasServer (ProvideCookies '[Optional] :> api) ctx
  where
  type ServerT (ProvideCookies '[Optional] :> api) m = ServerT api m

  hoistServerWithContext :: forall (m :: * -> *) (n :: * -> *).
Proxy (ProvideCookies '[Optional] :> api)
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT (ProvideCookies '[Optional] :> api) m
-> ServerT (ProvideCookies '[Optional] :> api) n
hoistServerWithContext Proxy (ProvideCookies '[Optional] :> api)
_ Proxy ctx
_ forall x. m x -> n x
nt ServerT (ProvideCookies '[Optional] :> api) m
server =
    forall {k} (api :: k) (context :: [*]) (m :: * -> *) (n :: * -> *).
HasServer api context =>
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
hoistServerWithContext (forall {k} (t :: k). Proxy t
Proxy @api) (forall {k} (t :: k). Proxy t
Proxy @(HasCookiesMaybe ': ctx)) forall x. m x -> n x
nt ServerT (ProvideCookies '[Optional] :> api) m
server

  route :: forall env.
Proxy (ProvideCookies '[Optional] :> api)
-> Context ctx
-> Delayed env (Server (ProvideCookies '[Optional] :> api))
-> Router env
route Proxy (ProvideCookies '[Optional] :> api)
_ Context ctx
ctx Delayed env (Server (ProvideCookies '[Optional] :> api))
server =
    forall {k} (api :: k) (context :: [*]) env.
HasServer api context =>
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
route (forall {k} (t :: k). Proxy t
Proxy @api) ((HasCookiesMaybe
HasCookiesMaybe) forall x (xs :: [*]). x -> Context xs -> Context (x : xs)
:. Context ctx
ctx) Delayed env (Server (ProvideCookies '[Optional] :> api))
server forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \RoutingApplication
app Request
req RouteResult Response -> IO ResponseReceived
respK -> do
      let
        mCookie :: Maybe SessionMap
mCookie = (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Cookies
parseCookies) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
NTH.hCookie (Request -> RequestHeaders
requestHeaders Request
req)
        key :: Key (Maybe SessionMap)
key = forall (context :: [*]) val.
HasContextEntry context val =>
Context context -> val
getContextEntry Context ctx
ctx :: Vault.Key (Maybe SessionMap)
        req' :: Request
req' = Request
req {vault :: Vault
vault = forall a. Key a -> a -> Vault -> Vault
Vault.insert Key (Maybe SessionMap)
key Maybe SessionMap
mCookie (Request -> Vault
vault Request
req)}
      RoutingApplication
app Request
req' RouteResult Response -> IO ResponseReceived
respK

instance
  ( HasServer api ctx
  , HasContextEntry ctx HasCookies
  , HasContextEntry ctx (Vault.Key SessionMap)
  ) =>
  HasServer (WithCookies '[Required] :> api) ctx
  where
  type ServerT (WithCookies '[Required] :> api) m = SessionMap -> ServerT api m

  hoistServerWithContext :: forall (m :: * -> *) (n :: * -> *).
Proxy (WithCookies '[Required] :> api)
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT (WithCookies '[Required] :> api) m
-> ServerT (WithCookies '[Required] :> api) n
hoistServerWithContext Proxy (WithCookies '[Required] :> api)
_ Proxy ctx
ctx forall x. m x -> n x
nt ServerT (WithCookies '[Required] :> api) m
server =
    forall {k} (api :: k) (context :: [*]) (m :: * -> *) (n :: * -> *).
HasServer api context =>
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
hoistServerWithContext (forall {k} (t :: k). Proxy t
Proxy @api) Proxy ctx
ctx forall x. m x -> n x
nt forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerT (WithCookies '[Required] :> api) m
server

  route :: forall env.
Proxy (WithCookies '[Required] :> api)
-> Context ctx
-> Delayed env (Server (WithCookies '[Required] :> api))
-> Router env
route Proxy (WithCookies '[Required] :> api)
_ Context ctx
ctx Delayed env (Server (WithCookies '[Required] :> api))
server =
    forall {k} (api :: k) (context :: [*]) env.
HasServer api context =>
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
route (forall {k} (t :: k). Proxy t
Proxy @api) Context ctx
ctx forall a b. (a -> b) -> a -> b
$
      Delayed env (Server (WithCookies '[Required] :> api))
server forall env a b.
Delayed env (a -> b) -> DelayedIO a -> Delayed env b
`addHeaderCheck` DelayedIO SessionMap
retrieveCookies
    where
      retrieveCookies :: DelayedIO SessionMap
      retrieveCookies :: DelayedIO SessionMap
retrieveCookies = forall a. (Request -> DelayedIO a) -> DelayedIO a
withRequest forall a b. (a -> b) -> a -> b
$ \Request
req -> do
        let key :: Key SessionMap
key = forall (context :: [*]) val.
HasContextEntry context val =>
Context context -> val
getContextEntry Context ctx
ctx :: Vault.Key SessionMap
        case forall a. Key a -> Vault -> Maybe a
Vault.lookup Key SessionMap
key (Request -> Vault
vault Request
req) of
          Just SessionMap
cookies -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SessionMap
cookies
          Maybe SessionMap
Nothing ->
            forall a. ServerError -> DelayedIO a
delayedFailFatal forall a b. (a -> b) -> a -> b
$
              ServerError
err500
                { errBody :: ByteString
errBody = ByteString
"Something has gone horribly wrong; could not find cached cookies."
                }

instance
  ( HasServer api ctx
  , HasContextEntry ctx (HasCookiesMaybe)
  , HasContextEntry ctx (Vault.Key (Maybe SessionMap))
  ) =>
  HasServer (WithCookies '[Optional] :> api) ctx
  where
  type ServerT (WithCookies '[Optional] :> api) m = Maybe SessionMap -> ServerT api m

  hoistServerWithContext :: forall (m :: * -> *) (n :: * -> *).
Proxy (WithCookies '[Optional] :> api)
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT (WithCookies '[Optional] :> api) m
-> ServerT (WithCookies '[Optional] :> api) n
hoistServerWithContext Proxy (WithCookies '[Optional] :> api)
_ Proxy ctx
ctx forall x. m x -> n x
nt ServerT (WithCookies '[Optional] :> api) m
server =
    forall {k} (api :: k) (context :: [*]) (m :: * -> *) (n :: * -> *).
HasServer api context =>
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
hoistServerWithContext (forall {k} (t :: k). Proxy t
Proxy @api) Proxy ctx
ctx forall x. m x -> n x
nt forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerT (WithCookies '[Optional] :> api) m
server

  route :: forall env.
Proxy (WithCookies '[Optional] :> api)
-> Context ctx
-> Delayed env (Server (WithCookies '[Optional] :> api))
-> Router env
route Proxy (WithCookies '[Optional] :> api)
_ Context ctx
ctx Delayed env (Server (WithCookies '[Optional] :> api))
server =
    forall {k} (api :: k) (context :: [*]) env.
HasServer api context =>
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
route (forall {k} (t :: k). Proxy t
Proxy @api) Context ctx
ctx forall a b. (a -> b) -> a -> b
$
      Delayed env (Server (WithCookies '[Optional] :> api))
server forall env a b.
Delayed env (a -> b) -> DelayedIO a -> Delayed env b
`addHeaderCheck` DelayedIO (Maybe SessionMap)
retrieveCookies
    where
      retrieveCookies :: DelayedIO (Maybe SessionMap)
      retrieveCookies :: DelayedIO (Maybe SessionMap)
retrieveCookies = forall a. (Request -> DelayedIO a) -> DelayedIO a
withRequest forall a b. (a -> b) -> a -> b
$ \Request
req -> do
        let key :: Key (Maybe SessionMap)
key = forall (context :: [*]) val.
HasContextEntry context val =>
Context context -> val
getContextEntry Context ctx
ctx :: Vault.Key (Maybe SessionMap)
        case forall a. Key a -> Vault -> Maybe a
Vault.lookup Key (Maybe SessionMap)
key (Request -> Vault
vault Request
req) of
          Just Maybe SessionMap
cookies -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe SessionMap
cookies
          Maybe (Maybe SessionMap)
Nothing ->
            forall a. ServerError -> DelayedIO a
delayedFailFatal forall a b. (a -> b) -> a -> b
$
              ServerError
err500
                { -- TODO: Maybe the error message should be pulled from
                  -- the Context?
                  errBody :: ByteString
errBody = ByteString
"Something has gone horribly wrong; could not find cached cookies."
                }

{- |
  This function takes a SessionMap and provides a "Set-Cookie" header
  to set the SessionData to a newly minted value of your choice.
-}
updateCookies ::
  Key ->
  SessionMap ->
  SetCookie ->
  ByteString ->
  a ->
  IO (SetCookieHeader a)
updateCookies :: forall a.
Key
-> SessionMap
-> SetCookie
-> ByteString
-> a
-> IO (SetCookieHeader a)
updateCookies Key
cookieEncryptKey SessionMap
sessionMap SetCookie
setCookieDefaults ByteString
cookieName a
value = do
  -- let newCookies = newMap `Map.difference` oldMap
  --     changedCookies = Map.filterWithKey (checkIfMapValueChanged oldMap) oldMap
  --     setCookieList = fmap snd  $ Map.toList $ Map.mapWithKey (keyValueToSetCookie setCookieDefaults) changedCookies
  let
    -- We use renderCookies with a long laborious function chain to
    -- avoid depending on the version of Web.Cookie that has the
    -- @renderCookiesBS@ function, which was introduced in a very
    -- recent of the cookies library. The prod code I'm writing this
    -- library for is still on lts-18.27, so I take some extra pains
    -- to still support that release.
    cookieBS :: ByteString
    cookieBS :: ByteString
cookieBS = ByteString -> ByteString
toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
toLazyByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cookies -> Builder
renderCookies forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
Map.toList SessionMap
sessionMap

  ByteString
sessionMapEncrypted <- Key -> ByteString -> IO ByteString
encryptIO Key
cookieEncryptKey ByteString
cookieBS

  let
    setCookie :: SetCookie
setCookie =
      SetCookie
setCookieDefaults
        { setCookieName :: ByteString
setCookieName = ByteString
cookieName
        , setCookieValue :: ByteString
setCookieValue = ByteString
sessionMapEncrypted
        }

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (h :: Symbol) v orig new.
AddHeader h v orig new =>
v -> orig -> new
addHeader SetCookie
setCookie a
value

{- |
  This function clears session data, for a fresh, minty-clean
  experience. The archetypal use case is when a user logs out from
  your server.
-}
clearSession :: SetCookie -> a -> IO (SetCookieHeader a)
clearSession :: forall a. SetCookie -> a -> IO (SetCookieHeader a)
clearSession SetCookie
setCookieDefaults a
value = do
  UTCTime
now <- IO UTCTime
getCurrentTime
  let
    immediateMaxAge :: DiffTime
immediateMaxAge = Integer -> DiffTime
secondsToDiffTime Integer
0
    setCookie :: SetCookie
setCookie =
      SetCookie
setCookieDefaults
        { setCookieName :: ByteString
setCookieName = ByteString
""
        , setCookieValue :: ByteString
setCookieValue = ByteString
""
        , setCookieExpires :: Maybe UTCTime
setCookieExpires = forall a. a -> Maybe a
Just UTCTime
now
        , setCookieMaxAge :: Maybe DiffTime
setCookieMaxAge = forall a. a -> Maybe a
Just DiffTime
immediateMaxAge
        }
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (h :: Symbol) v orig new.
AddHeader h v orig new =>
v -> orig -> new
addHeader SetCookie
setCookie a
value