{-# LANGUAGE OverloadedStrings #-}

module Cookie.Secure.Middleware (secureCookies) where

import Network.Wai (Middleware
                  , Request
                  , ResponseReceived
                  , responseLBS
                  , requestHeaders
                  , responseHeaders)
import Network.Wai.Internal (Response(..))
import Network.HTTP.Types.Header (Header
                                , RequestHeaders
                                , ResponseHeaders)
import Network.HTTP.Types.Status (status200)
import qualified Data.ByteString.Char8 as BS
import Data.Maybe (catMaybes)
import Cookie.Secure (encryptAndSignIO, verifyAndDecryptIO)
import Data.List.Split (splitOn)

secureCookies :: Middleware
secureCookies :: Middleware
secureCookies Application
app Request
request Response -> IO ResponseReceived
respondWith =
  Request -> IO Request
verifyAndDecryptCookies Request
request
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip Application
app ((Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
encryptAndSignCookies Response -> IO ResponseReceived
respondWith)

verifyAndDecryptCookies :: Request -> IO Request
verifyAndDecryptCookies :: Request -> IO Request
verifyAndDecryptCookies Request
request =
  Request -> RequestHeaders -> Request
replaceRequestHeaders Request
request
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Header -> IO Header
verifyAndDecryptIfCookieHeader (Request -> RequestHeaders
requestHeaders Request
request)

encryptAndSignCookies
  :: (Response -> IO ResponseReceived)
  -> Response -> IO ResponseReceived
encryptAndSignCookies :: (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
encryptAndSignCookies Response -> IO ResponseReceived
respondWith Response
response = do
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Header -> IO Header
encryptAndSignIfSetCookieHeader (Response -> RequestHeaders
responseHeaders Response
response)
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Response -> IO ResponseReceived
respondWith forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> RequestHeaders -> Response
replaceResponseHeaders Response
response

encryptAndSignIfSetCookieHeader :: Header -> IO Header
encryptAndSignIfSetCookieHeader :: Header -> IO Header
encryptAndSignIfSetCookieHeader Header
header =
  if forall a b. (a, b) -> a
fst Header
header forall a. Eq a => a -> a -> Bool
== HeaderName
"Set-Cookie"
  then Header -> IO Header
encryptAndSignCookieHeader Header
header
  else forall (m :: * -> *) a. Monad m => a -> m a
return Header
header

encryptAndSignCookieHeader :: Header -> IO Header
encryptAndSignCookieHeader :: Header -> IO Header
encryptAndSignCookieHeader (HeaderName
name, ByteString
value) = (,)
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. Monad m => a -> m a
return HeaderName
name
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ByteString
encryptedSignedCookieHeaderValue
    where
      (ByteString
cookie, ByteString
metadata) = (Char -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break (forall a. Eq a => a -> a -> Bool
== Char
';') ByteString
value
      encryptedSignedCookieHeaderValue :: IO ByteString
encryptedSignedCookieHeaderValue =
        forall a b c. (a -> b -> c) -> b -> a -> c
flip ByteString -> ByteString -> ByteString
BS.append ByteString
metadata forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO ByteString
encryptAndSignCookie ByteString
cookie
      encryptAndSignCookie :: ByteString -> IO ByteString
encryptAndSignCookie ByteString
c = do
        let cookieNameValueList :: [ByteString]
cookieNameValueList = forall a b. (a -> b) -> [a] -> [b]
map [Char] -> ByteString
BS.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => [a] -> [a] -> [[a]]
splitOn [Char]
"=" forall a b. (a -> b) -> a -> b
$ ByteString -> [Char]
BS.unpack ByteString
c
        let cName :: ByteString
cName = forall a. [a] -> a
head [ByteString]
cookieNameValueList
        let cValue :: ByteString
cValue = forall a. [a] -> a
last [ByteString]
cookieNameValueList

        ByteString
encryptedValue <- ByteString -> IO ByteString
encryptAndSignIO ByteString
cValue

        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"=" [ByteString
cName, ByteString
encryptedValue]

replaceRequestHeaders :: Request -> RequestHeaders -> Request
replaceRequestHeaders :: Request -> RequestHeaders -> Request
replaceRequestHeaders Request
request RequestHeaders
newHeaders =
  Request
request { requestHeaders :: RequestHeaders
requestHeaders = RequestHeaders
newHeaders }

-- OPTIMIZE: Response is imported from Network.Wai.Internal, which
-- interface is not guaranteed to be stable.
replaceResponseHeaders :: Response -> ResponseHeaders -> Response
replaceResponseHeaders :: Response -> RequestHeaders -> Response
replaceResponseHeaders
  (ResponseFile Status
status RequestHeaders
headers [Char]
filepath Maybe FilePart
possibleFilepart) RequestHeaders
newHeaders =
    Status -> RequestHeaders -> [Char] -> Maybe FilePart -> Response
ResponseFile Status
status RequestHeaders
newHeaders [Char]
filepath Maybe FilePart
possibleFilepart
replaceResponseHeaders (ResponseBuilder Status
status RequestHeaders
headers Builder
builder) RequestHeaders
newHeaders =
  Status -> RequestHeaders -> Builder -> Response
ResponseBuilder Status
status RequestHeaders
newHeaders Builder
builder
replaceResponseHeaders (ResponseStream Status
status RequestHeaders
headers StreamingBody
body) RequestHeaders
newHeaders =
  Status -> RequestHeaders -> StreamingBody -> Response
ResponseStream Status
status RequestHeaders
newHeaders StreamingBody
body
replaceResponseHeaders (ResponseRaw IO ByteString -> (ByteString -> IO ()) -> IO ()
toStreaming Response
response) RequestHeaders
newHeaders =
  (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response -> Response
ResponseRaw IO ByteString -> (ByteString -> IO ()) -> IO ()
toStreaming (Response -> RequestHeaders -> Response
replaceResponseHeaders Response
response RequestHeaders
newHeaders)

verifyAndDecryptIfCookieHeader :: Header -> IO Header
verifyAndDecryptIfCookieHeader :: Header -> IO Header
verifyAndDecryptIfCookieHeader Header
header =
  if forall a b. (a, b) -> a
fst Header
header forall a. Eq a => a -> a -> Bool
== HeaderName
"Cookie"
  then Header -> IO Header
verifyAndDecryptCookieHeader Header
header
  else forall (m :: * -> *) a. Monad m => a -> m a
return Header
header

verifyAndDecryptCookieHeader :: Header -> IO Header
verifyAndDecryptCookieHeader :: Header -> IO Header
verifyAndDecryptCookieHeader (HeaderName
name, ByteString
value) = (,)
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. Monad m => a -> m a
return HeaderName
name
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> IO ByteString
verifyAndDecryptCookieHeaderValue ByteString
value
    where
      verifyAndDecryptCookieHeaderValue :: ByteString -> IO ByteString
verifyAndDecryptCookieHeaderValue ByteString
value =
        ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"; " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [Char] -> IO (Maybe ByteString)
verifyAndDecryptCookie
        (forall a. Eq a => [a] -> [a] -> [[a]]
splitOn [Char]
"; " (ByteString -> [Char]
BS.unpack ByteString
value))
      verifyAndDecryptCookie :: [Char] -> IO (Maybe ByteString)
verifyAndDecryptCookie [Char]
cookie = do
        let cookieNameValueList :: [ByteString]
cookieNameValueList = forall a b. (a -> b) -> [a] -> [b]
map [Char] -> ByteString
BS.pack forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> [a] -> [[a]]
splitOn [Char]
"=" [Char]
cookie
        let cName :: ByteString
cName = forall a. [a] -> a
head [ByteString]
cookieNameValueList
        let cValue :: ByteString
cValue = forall a. [a] -> a
last [ByteString]
cookieNameValueList

        Maybe ByteString
encryptedValue <- ByteString -> IO (Maybe ByteString)
verifyAndDecryptIO ByteString
cValue

        -- OPTIMIZE: maybe silently dropping cookies which fail to verify
        -- or decrypt isn't the best idea?
        case Maybe ByteString
encryptedValue of
          Maybe ByteString
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
          Just ByteString
encryptedValue' ->
            forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"=" [ByteString
cName, ByteString
encryptedValue']