{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
module Network.Wai.Middleware.CombineHeaders
( combineHeaders
, CombineSettings
, defaultCombineSettings
, HeaderMap
, HandleType
, defaultHeaderMap
, setHeader
, removeHeader
, setHeaderMap
, regular
, keepOnly
, setRequestHeaders
, setResponseHeaders
) where
import qualified Data.ByteString as B
import qualified Data.List as L (foldl', reverse)
import qualified Data.Map.Strict as M
import Data.Word8 (_comma, _space, _tab)
import Network.HTTP.Types (Header, HeaderName, RequestHeaders)
import qualified Network.HTTP.Types.Header as H
import Network.Wai (Middleware, requestHeaders, mapResponseHeaders)
import Network.Wai.Util (dropWhileEnd)
type = M.Map HeaderName HandleType
data CombineSettings = CombineSettings {
:: HeaderMap,
:: Bool,
:: Bool
} deriving (CombineSettings -> CombineSettings -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CombineSettings -> CombineSettings -> Bool
$c/= :: CombineSettings -> CombineSettings -> Bool
== :: CombineSettings -> CombineSettings -> Bool
$c== :: CombineSettings -> CombineSettings -> Bool
Eq, Int -> CombineSettings -> ShowS
[CombineSettings] -> ShowS
CombineSettings -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CombineSettings] -> ShowS
$cshowList :: [CombineSettings] -> ShowS
show :: CombineSettings -> String
$cshow :: CombineSettings -> String
showsPrec :: Int -> CombineSettings -> ShowS
$cshowsPrec :: Int -> CombineSettings -> ShowS
Show)
defaultCombineSettings :: CombineSettings
defaultCombineSettings :: CombineSettings
defaultCombineSettings = CombineSettings {
combineHeaderMap :: HeaderMap
combineHeaderMap = HeaderMap
defaultHeaderMap,
combineRequestHeaders :: Bool
combineRequestHeaders = Bool
True,
combineResponseHeaders :: Bool
combineResponseHeaders = Bool
False
}
setHeaderMap :: HeaderMap -> CombineSettings -> CombineSettings
HeaderMap
mp CombineSettings
set = CombineSettings
set{combineHeaderMap :: HeaderMap
combineHeaderMap = HeaderMap
mp}
setRequestHeaders :: Bool -> CombineSettings -> CombineSettings
Bool
b CombineSettings
set = CombineSettings
set{combineRequestHeaders :: Bool
combineRequestHeaders = Bool
b}
setResponseHeaders :: Bool -> CombineSettings -> CombineSettings
Bool
b CombineSettings
set = CombineSettings
set{combineResponseHeaders :: Bool
combineResponseHeaders = Bool
b}
setHeader :: HeaderName -> HandleType -> CombineSettings -> CombineSettings
HeaderName
name HandleType
typ CombineSettings
settings =
CombineSettings
settings {
combineHeaderMap :: HeaderMap
combineHeaderMap = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert HeaderName
name HandleType
typ forall a b. (a -> b) -> a -> b
$ CombineSettings -> HeaderMap
combineHeaderMap CombineSettings
settings
}
removeHeader :: HeaderName -> CombineSettings -> CombineSettings
HeaderName
name CombineSettings
settings =
CombineSettings
settings {
combineHeaderMap :: HeaderMap
combineHeaderMap = forall k a. Ord k => k -> Map k a -> Map k a
M.delete HeaderName
name forall a b. (a -> b) -> a -> b
$ CombineSettings -> HeaderMap
combineHeaderMap CombineSettings
settings
}
combineHeaders :: CombineSettings -> Middleware
CombineSettings{Bool
HeaderMap
combineResponseHeaders :: Bool
combineRequestHeaders :: Bool
combineHeaderMap :: HeaderMap
combineResponseHeaders :: CombineSettings -> Bool
combineRequestHeaders :: CombineSettings -> Bool
combineHeaderMap :: CombineSettings -> HeaderMap
..} Application
app Request
req Response -> IO ResponseReceived
resFunc =
Application
app Request
newReq forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
resFunc forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> Response
adjustRes
where
newReq :: Request
newReq
| Bool
combineRequestHeaders = Request
req { requestHeaders :: RequestHeaders
requestHeaders = RequestHeaders -> RequestHeaders
mkNewHeaders RequestHeaders
oldHeaders }
| Bool
otherwise = Request
req
oldHeaders :: RequestHeaders
oldHeaders = Request -> RequestHeaders
requestHeaders Request
req
adjustRes :: Response -> Response
adjustRes
| Bool
combineResponseHeaders = (RequestHeaders -> RequestHeaders) -> Response -> Response
mapResponseHeaders RequestHeaders -> RequestHeaders
mkNewHeaders
| Bool
otherwise = forall a. a -> a
id
mkNewHeaders :: RequestHeaders -> RequestHeaders
mkNewHeaders =
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
M.foldrWithKey' HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders
finishHeaders [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go forall a. Monoid a => a
mempty
go :: Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go Map HeaderName HeaderHandling
acc hdr :: (HeaderName, ByteString)
hdr@(HeaderName
name, ByteString
_) =
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter ((HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName, ByteString)
hdr) HeaderName
name Map HeaderName HeaderHandling
acc
checkHeader :: Header -> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader :: (HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName
name, ByteString
newVal) = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
Maybe HeaderHandling
Nothing -> (HeaderName
name forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` HeaderMap
combineHeaderMap, [ByteString
newVal])
Just (Maybe HandleType
mHandleType, [ByteString]
hdrs) -> (Maybe HandleType
mHandleType, ByteString
newVal forall a. a -> [a] -> [a]
: [ByteString]
hdrs)
finishHeaders :: HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders
HeaderName
name (Maybe HandleType
shouldCombine, [ByteString]
xs) RequestHeaders
hdrs =
case Maybe HandleType
shouldCombine of
Just HandleType
typ -> (HeaderName
name, HandleType -> ByteString
combinedHeader HandleType
typ) forall a. a -> [a] -> [a]
: RequestHeaders
hdrs
Maybe HandleType
Nothing ->
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\RequestHeaders
acc ByteString
el -> (HeaderName
name, ByteString
el) forall a. a -> [a] -> [a]
: RequestHeaders
acc) RequestHeaders
hdrs [ByteString]
xs
where
combinedHeader :: HandleType -> ByteString
combinedHeader HandleType
Regular = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
combinedHeader (KeepOnly ByteString
val)
| ByteString
val forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
xs = ByteString
val
| Bool
otherwise = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
combineHdrs :: [ByteString] -> ByteString
combineHdrs = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
", " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
clean forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
L.reverse
clean :: ByteString -> ByteString
clean = (Word8 -> Bool) -> ByteString -> ByteString
dropWhileEnd forall a b. (a -> b) -> a -> b
$ \Word8
w -> Word8
w forall a. Eq a => a -> a -> Bool
== Word8
_comma Bool -> Bool -> Bool
|| Word8
w forall a. Eq a => a -> a -> Bool
== Word8
_space Bool -> Bool -> Bool
|| Word8
w forall a. Eq a => a -> a -> Bool
== Word8
_tab
type HeaderHandling = (Maybe HandleType, [B.ByteString])
data HandleType
= Regular
| KeepOnly B.ByteString
deriving (HandleType -> HandleType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HandleType -> HandleType -> Bool
$c/= :: HandleType -> HandleType -> Bool
== :: HandleType -> HandleType -> Bool
$c== :: HandleType -> HandleType -> Bool
Eq, Int -> HandleType -> ShowS
[HandleType] -> ShowS
HandleType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HandleType] -> ShowS
$cshowList :: [HandleType] -> ShowS
show :: HandleType -> String
$cshow :: HandleType -> String
showsPrec :: Int -> HandleType -> ShowS
$cshowsPrec :: Int -> HandleType -> ShowS
Show)
regular :: HandleType
regular :: HandleType
regular = HandleType
Regular
keepOnly :: B.ByteString -> HandleType
keepOnly :: ByteString -> HandleType
keepOnly = ByteString -> HandleType
KeepOnly
defaultHeaderMap :: HeaderMap
= forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (HeaderName
H.hAccept, HandleType
Regular)
, (HeaderName
"Accept-CH", HandleType
Regular)
, (HeaderName
H.hAcceptCharset, HandleType
Regular)
, (HeaderName
H.hAcceptEncoding, HandleType
Regular)
, (HeaderName
H.hAcceptLanguage, HandleType
Regular)
, (HeaderName
"Accept-Post", HandleType
Regular)
, (HeaderName
"Access-Control-Allow-Headers" , HandleType
Regular)
, (HeaderName
"Access-Control-Allow-Methods" , HandleType
Regular)
, (HeaderName
"Access-Control-Expose-Headers" , HandleType
Regular)
, (HeaderName
"Access-Control-Request-Headers", HandleType
Regular)
, (HeaderName
H.hAllow, HandleType
Regular)
, (HeaderName
"Alt-Svc", ByteString -> HandleType
KeepOnly ByteString
"clear")
, (HeaderName
H.hCacheControl, HandleType
Regular)
, (HeaderName
"Clear-Site-Data", ByteString -> HandleType
KeepOnly ByteString
"*")
, (HeaderName
H.hConnection, HandleType
Regular)
, (HeaderName
H.hContentEncoding, HandleType
Regular)
, (HeaderName
H.hContentLanguage, HandleType
Regular)
, (HeaderName
"Digest", HandleType
Regular)
, (HeaderName
H.hIfMatch, HandleType
Regular)
, (HeaderName
H.hIfNoneMatch, ByteString -> HandleType
KeepOnly ByteString
"*")
, (HeaderName
"Link", HandleType
Regular)
, (HeaderName
"Permissions-Policy", HandleType
Regular)
, (HeaderName
H.hTE, HandleType
Regular)
, (HeaderName
"Timing-Allow-Origin", ByteString -> HandleType
KeepOnly ByteString
"*")
, (HeaderName
H.hTrailer, HandleType
Regular)
, (HeaderName
H.hTransferEncoding, HandleType
Regular)
, (HeaderName
H.hUpgrade, HandleType
Regular)
, (HeaderName
H.hVia, HandleType
Regular)
, (HeaderName
H.hVary, ByteString -> HandleType
KeepOnly ByteString
"*")
, (HeaderName
"Want-Digest", HandleType
Regular)
]