{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
module Network.Wai.Middleware.CombineHeaders (
combineHeaders,
CombineSettings,
defaultCombineSettings,
HeaderMap,
HandleType,
defaultHeaderMap,
setHeader,
removeHeader,
setHeaderMap,
regular,
keepOnly,
setRequestHeaders,
setResponseHeaders,
) where
import qualified Data.ByteString as B
import qualified Data.List as L (foldl', reverse)
import qualified Data.Map.Strict as M
import Data.Word8 (_comma, _space, _tab)
import Network.HTTP.Types (Header, HeaderName, RequestHeaders)
import qualified Network.HTTP.Types.Header as H
import Network.Wai (Middleware, mapResponseHeaders, requestHeaders)
import Network.Wai.Util (dropWhileEnd)
type = M.Map HeaderName HandleType
data CombineSettings = CombineSettings
{ :: HeaderMap
, :: Bool
, :: Bool
}
deriving (CombineSettings -> CombineSettings -> Bool
(CombineSettings -> CombineSettings -> Bool)
-> (CombineSettings -> CombineSettings -> Bool)
-> Eq CombineSettings
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CombineSettings -> CombineSettings -> Bool
== :: CombineSettings -> CombineSettings -> Bool
$c/= :: CombineSettings -> CombineSettings -> Bool
/= :: CombineSettings -> CombineSettings -> Bool
Eq, Int -> CombineSettings -> ShowS
[CombineSettings] -> ShowS
CombineSettings -> String
(Int -> CombineSettings -> ShowS)
-> (CombineSettings -> String)
-> ([CombineSettings] -> ShowS)
-> Show CombineSettings
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CombineSettings -> ShowS
showsPrec :: Int -> CombineSettings -> ShowS
$cshow :: CombineSettings -> String
show :: CombineSettings -> String
$cshowList :: [CombineSettings] -> ShowS
showList :: [CombineSettings] -> ShowS
Show)
defaultCombineSettings :: CombineSettings
defaultCombineSettings :: CombineSettings
defaultCombineSettings =
CombineSettings
{ combineHeaderMap :: HeaderMap
combineHeaderMap = HeaderMap
defaultHeaderMap
, combineRequestHeaders :: Bool
combineRequestHeaders = Bool
True
, combineResponseHeaders :: Bool
combineResponseHeaders = Bool
False
}
setHeaderMap :: HeaderMap -> CombineSettings -> CombineSettings
HeaderMap
mp CombineSettings
set = CombineSettings
set{combineHeaderMap = mp}
setRequestHeaders :: Bool -> CombineSettings -> CombineSettings
Bool
b CombineSettings
set = CombineSettings
set{combineRequestHeaders = b}
setResponseHeaders :: Bool -> CombineSettings -> CombineSettings
Bool
b CombineSettings
set = CombineSettings
set{combineResponseHeaders = b}
setHeader :: HeaderName -> HandleType -> CombineSettings -> CombineSettings
HeaderName
name HandleType
typ CombineSettings
settings =
CombineSettings
settings
{ combineHeaderMap = M.insert name typ $ combineHeaderMap settings
}
removeHeader :: HeaderName -> CombineSettings -> CombineSettings
HeaderName
name CombineSettings
settings =
CombineSettings
settings
{ combineHeaderMap = M.delete name $ combineHeaderMap settings
}
combineHeaders :: CombineSettings -> Middleware
CombineSettings{Bool
HeaderMap
combineHeaderMap :: CombineSettings -> HeaderMap
combineRequestHeaders :: CombineSettings -> Bool
combineResponseHeaders :: CombineSettings -> Bool
combineHeaderMap :: HeaderMap
combineRequestHeaders :: Bool
combineResponseHeaders :: Bool
..} Application
app Request
req Response -> IO ResponseReceived
resFunc =
Application
app Request
newReq ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
resFunc (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> Response
adjustRes
where
newReq :: Request
newReq
| Bool
combineRequestHeaders = Request
req{requestHeaders = mkNewHeaders oldHeaders}
| Bool
otherwise = Request
req
oldHeaders :: [(HeaderName, ByteString)]
oldHeaders = Request -> [(HeaderName, ByteString)]
requestHeaders Request
req
adjustRes :: Response -> Response
adjustRes
| Bool
combineResponseHeaders = ([(HeaderName, ByteString)] -> [(HeaderName, ByteString)])
-> Response -> Response
mapResponseHeaders [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
mkNewHeaders
| Bool
otherwise = Response -> Response
forall a. a -> a
id
mkNewHeaders :: [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
mkNewHeaders =
(HeaderName
-> HeaderHandling
-> [(HeaderName, ByteString)]
-> [(HeaderName, ByteString)])
-> [(HeaderName, ByteString)]
-> Map HeaderName HeaderHandling
-> [(HeaderName, ByteString)]
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
M.foldrWithKey' HeaderName
-> HeaderHandling
-> [(HeaderName, ByteString)]
-> [(HeaderName, ByteString)]
finishHeaders [] (Map HeaderName HeaderHandling -> [(HeaderName, ByteString)])
-> ([(HeaderName, ByteString)] -> Map HeaderName HeaderHandling)
-> [(HeaderName, ByteString)]
-> [(HeaderName, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling)
-> Map HeaderName HeaderHandling
-> [(HeaderName, ByteString)]
-> Map HeaderName HeaderHandling
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go Map HeaderName HeaderHandling
forall a. Monoid a => a
mempty
go :: Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go Map HeaderName HeaderHandling
acc hdr :: (HeaderName, ByteString)
hdr@(HeaderName
name, ByteString
_) =
(Maybe HeaderHandling -> Maybe HeaderHandling)
-> HeaderName
-> Map HeaderName HeaderHandling
-> Map HeaderName HeaderHandling
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter ((HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName, ByteString)
hdr) HeaderName
name Map HeaderName HeaderHandling
acc
checkHeader :: Header -> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader :: (HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName
name, ByteString
newVal) =
HeaderHandling -> Maybe HeaderHandling
forall a. a -> Maybe a
Just (HeaderHandling -> Maybe HeaderHandling)
-> (Maybe HeaderHandling -> HeaderHandling)
-> Maybe HeaderHandling
-> Maybe HeaderHandling
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
Maybe HeaderHandling
Nothing -> (HeaderName
name HeaderName -> HeaderMap -> Maybe HandleType
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` HeaderMap
combineHeaderMap, [ByteString
newVal])
Just (Maybe HandleType
mHandleType, [ByteString]
hdrs) -> (Maybe HandleType
mHandleType, ByteString
newVal ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
hdrs)
finishHeaders
:: HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders
HeaderName
name (Maybe HandleType
shouldCombine, [ByteString]
xs) [(HeaderName, ByteString)]
hdrs =
case Maybe HandleType
shouldCombine of
Just HandleType
typ -> (HeaderName
name, HandleType -> ByteString
combinedHeader HandleType
typ) (HeaderName, ByteString)
-> [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
forall a. a -> [a] -> [a]
: [(HeaderName, ByteString)]
hdrs
Maybe HandleType
Nothing ->
([(HeaderName, ByteString)]
-> ByteString -> [(HeaderName, ByteString)])
-> [(HeaderName, ByteString)]
-> [ByteString]
-> [(HeaderName, ByteString)]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\[(HeaderName, ByteString)]
acc ByteString
el -> (HeaderName
name, ByteString
el) (HeaderName, ByteString)
-> [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
forall a. a -> [a] -> [a]
: [(HeaderName, ByteString)]
acc) [(HeaderName, ByteString)]
hdrs [ByteString]
xs
where
combinedHeader :: HandleType -> ByteString
combinedHeader HandleType
Regular = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
combinedHeader (KeepOnly ByteString
val)
| ByteString
val ByteString -> [ByteString] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
xs = ByteString
val
| Bool
otherwise = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
combineHdrs :: [ByteString] -> ByteString
combineHdrs = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
", " ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
clean ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. [a] -> [a]
L.reverse
clean :: ByteString -> ByteString
clean = (Word8 -> Bool) -> ByteString -> ByteString
dropWhileEnd ((Word8 -> Bool) -> ByteString -> ByteString)
-> (Word8 -> Bool) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ \Word8
w -> Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_comma Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_space Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_tab
type HeaderHandling = (Maybe HandleType, [B.ByteString])
data HandleType
= Regular
| KeepOnly B.ByteString
deriving (HandleType -> HandleType -> Bool
(HandleType -> HandleType -> Bool)
-> (HandleType -> HandleType -> Bool) -> Eq HandleType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HandleType -> HandleType -> Bool
== :: HandleType -> HandleType -> Bool
$c/= :: HandleType -> HandleType -> Bool
/= :: HandleType -> HandleType -> Bool
Eq, Int -> HandleType -> ShowS
[HandleType] -> ShowS
HandleType -> String
(Int -> HandleType -> ShowS)
-> (HandleType -> String)
-> ([HandleType] -> ShowS)
-> Show HandleType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandleType -> ShowS
showsPrec :: Int -> HandleType -> ShowS
$cshow :: HandleType -> String
show :: HandleType -> String
$cshowList :: [HandleType] -> ShowS
showList :: [HandleType] -> ShowS
Show)
regular :: HandleType
regular :: HandleType
regular = HandleType
Regular
keepOnly :: B.ByteString -> HandleType
keepOnly :: ByteString -> HandleType
keepOnly = ByteString -> HandleType
KeepOnly
defaultHeaderMap :: HeaderMap
=
[(HeaderName, HandleType)] -> HeaderMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (HeaderName
H.hAccept, HandleType
Regular)
, (HeaderName
"Accept-CH", HandleType
Regular)
, (HeaderName
H.hAcceptCharset, HandleType
Regular)
, (HeaderName
H.hAcceptEncoding, HandleType
Regular)
, (HeaderName
H.hAcceptLanguage, HandleType
Regular)
, (HeaderName
"Accept-Post", HandleType
Regular)
, (HeaderName
"Access-Control-Allow-Headers", HandleType
Regular)
, (HeaderName
"Access-Control-Allow-Methods", HandleType
Regular)
, (HeaderName
"Access-Control-Expose-Headers", HandleType
Regular)
, (HeaderName
"Access-Control-Request-Headers", HandleType
Regular)
, (HeaderName
H.hAllow, HandleType
Regular)
, (HeaderName
"Alt-Svc", ByteString -> HandleType
KeepOnly ByteString
"clear")
, (HeaderName
H.hCacheControl, HandleType
Regular)
, (HeaderName
"Clear-Site-Data", ByteString -> HandleType
KeepOnly ByteString
"*")
,
(HeaderName
H.hConnection, HandleType
Regular)
, (HeaderName
H.hContentEncoding, HandleType
Regular)
, (HeaderName
H.hContentLanguage, HandleType
Regular)
, (HeaderName
"Digest", HandleType
Regular)
,
(HeaderName
H.hIfMatch, HandleType
Regular)
, (HeaderName
H.hIfNoneMatch, ByteString -> HandleType
KeepOnly ByteString
"*")
, (HeaderName
"Link", HandleType
Regular)
, (HeaderName
"Permissions-Policy", HandleType
Regular)
, (HeaderName
H.hTE, HandleType
Regular)
, (HeaderName
"Timing-Allow-Origin", ByteString -> HandleType
KeepOnly ByteString
"*")
, (HeaderName
H.hTrailer, HandleType
Regular)
, (HeaderName
H.hTransferEncoding, HandleType
Regular)
, (HeaderName
H.hUpgrade, HandleType
Regular)
, (HeaderName
H.hVia, HandleType
Regular)
, (HeaderName
H.hVary, ByteString -> HandleType
KeepOnly ByteString
"*")
, (HeaderName
"Want-Digest", HandleType
Regular)
]