module PostgREST.GucHeader
  ( GucHeader
  , unwrapGucHeader
  , addHeadersIfNotIncluded
  ) where

import qualified Data.Aeson           as JSON
import qualified Data.CaseInsensitive as CI
import qualified Data.HashMap.Strict  as M

import Network.HTTP.Types.Header (Header)

import Protolude


{-|
  Custom guc header, it's obtained by parsing the json in a:
  `SET LOCAL "response.headers" = '[{"Set-Cookie": ".."}]'
-}
newtype GucHeader = GucHeader (CI.CI ByteString, ByteString)

instance JSON.FromJSON GucHeader where
  parseJSON :: Value -> Parser GucHeader
parseJSON (JSON.Object Object
o) =
    case Object -> [(Text, Value)]
forall k v. HashMap k v -> [(k, v)]
M.toList Object
o of
      [(Text
k, JSON.String Text
s)] -> GucHeader -> Parser GucHeader
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GucHeader -> Parser GucHeader) -> GucHeader -> Parser GucHeader
forall a b. (a -> b) -> a -> b
$ (CI ByteString, ByteString) -> GucHeader
GucHeader (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk (ByteString -> CI ByteString) -> ByteString -> CI ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
forall a. ConvertText a Text => a -> ByteString
toUtf8 Text
k, Text -> ByteString
forall a. ConvertText a Text => a -> ByteString
toUtf8 Text
s)
      [(Text, Value)]
_ -> Parser GucHeader
forall (m :: * -> *) a. MonadPlus m => m a
mzero
  parseJSON Value
_ = Parser GucHeader
forall (m :: * -> *) a. MonadPlus m => m a
mzero

unwrapGucHeader :: GucHeader -> Header
unwrapGucHeader :: GucHeader -> (CI ByteString, ByteString)
unwrapGucHeader (GucHeader (CI ByteString
k, ByteString
v)) = (CI ByteString
k, ByteString
v)

-- | Add headers not already included to allow the user to override them instead of duplicating them
addHeadersIfNotIncluded :: [Header] -> [Header] -> [Header]
addHeadersIfNotIncluded :: [(CI ByteString, ByteString)]
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
addHeadersIfNotIncluded [(CI ByteString, ByteString)]
newHeaders [(CI ByteString, ByteString)]
initialHeaders =
  ((CI ByteString, ByteString) -> Bool)
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(CI ByteString
nk, ByteString
_) -> Maybe (CI ByteString, ByteString) -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe (CI ByteString, ByteString) -> Bool)
-> Maybe (CI ByteString, ByteString) -> Bool
forall a b. (a -> b) -> a -> b
$ ((CI ByteString, ByteString) -> Bool)
-> [(CI ByteString, ByteString)]
-> Maybe (CI ByteString, ByteString)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(CI ByteString
ik, ByteString
_) -> CI ByteString
ik CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== CI ByteString
nk) [(CI ByteString, ByteString)]
initialHeaders) [(CI ByteString, ByteString)]
newHeaders [(CI ByteString, ByteString)]
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a. [a] -> [a] -> [a]
++
  [(CI ByteString, ByteString)]
initialHeaders