module Network.Wai.Middleware.ValidateHeaders
(
validateHeadersMiddleware
, ValidateHeadersSettings (..)
, defaultValidateHeadersSettings
, InvalidHeader (..)
, InvalidHeaderReason (..)
) where
import Data.CaseInsensitive (original)
import Data.Char (chr)
import Data.Word (Word8)
import Network.HTTP.Types (Header, ResponseHeaders, internalServerError500)
import Network.Wai (Middleware, Response, responseHeaders, responseLBS)
import Text.Printf (printf)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import qualified Data.ByteString.Lazy as BSL
validateHeadersMiddleware :: ValidateHeadersSettings -> Middleware
ValidateHeadersSettings
settings Application
app Request
req Response -> IO ResponseReceived
respond =
Application
app Request
req Response -> IO ResponseReceived
respond'
where
respond' :: Response -> IO ResponseReceived
respond' Response
response = case ResponseHeaders -> Maybe InvalidHeader
getInvalidHeader (ResponseHeaders -> Maybe InvalidHeader)
-> ResponseHeaders -> Maybe InvalidHeader
forall a b. (a -> b) -> a -> b
$ Response -> ResponseHeaders
responseHeaders Response
response of
Just InvalidHeader
invalidHeader -> ValidateHeadersSettings -> InvalidHeader -> Middleware
onInvalidHeader ValidateHeadersSettings
settings InvalidHeader
invalidHeader Application
app Request
req Response -> IO ResponseReceived
respond
Maybe InvalidHeader
Nothing -> Response -> IO ResponseReceived
respond Response
response
data =
{
:: InvalidHeader -> Middleware
}
defaultValidateHeadersSettings :: ValidateHeadersSettings
= ValidateHeadersSettings
{ onInvalidHeader :: InvalidHeader -> Middleware
onInvalidHeader = \InvalidHeader
invalidHeader Application
_app Request
_req Response -> IO ResponseReceived
respond -> Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ InvalidHeader -> Response
invalidHeaderResponse InvalidHeader
invalidHeader
}
data = Header InvalidHeaderReason
data
= Word8
| Word8
|
getInvalidHeader :: ResponseHeaders -> Maybe InvalidHeader
= [Maybe InvalidHeader] -> Maybe InvalidHeader
forall a. [Maybe a] -> Maybe a
firstJust ([Maybe InvalidHeader] -> Maybe InvalidHeader)
-> (ResponseHeaders -> [Maybe InvalidHeader])
-> ResponseHeaders
-> Maybe InvalidHeader
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header -> Maybe InvalidHeader)
-> ResponseHeaders -> [Maybe InvalidHeader]
forall a b. (a -> b) -> [a] -> [b]
map Header -> Maybe InvalidHeader
go
where
firstJust :: [Maybe a] -> Maybe a
firstJust :: forall a. [Maybe a] -> Maybe a
firstJust [] = Maybe a
forall a. Maybe a
Nothing
firstJust (Just a
x : [Maybe a]
_) = a -> Maybe a
forall a. a -> Maybe a
Just a
x
firstJust (Maybe a
_ : [Maybe a]
xs) = [Maybe a] -> Maybe a
forall a. [Maybe a] -> Maybe a
firstJust [Maybe a]
xs
go :: Header -> Maybe InvalidHeader
go :: Header -> Maybe InvalidHeader
go header :: Header
header@(HeaderName
name, ByteString
value) = Header -> InvalidHeaderReason -> InvalidHeader
InvalidHeader Header
header (InvalidHeaderReason -> InvalidHeader)
-> Maybe InvalidHeaderReason -> Maybe InvalidHeader
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Maybe InvalidHeaderReason] -> Maybe InvalidHeaderReason
forall a. [Maybe a] -> Maybe a
firstJust
[ Word8 -> InvalidHeaderReason
InvalidOctetInHeaderName (Word8 -> InvalidHeaderReason)
-> Maybe Word8 -> Maybe InvalidHeaderReason
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Word8 -> Bool) -> ByteString -> Maybe Word8
BS.find (Bool -> Bool
not (Bool -> Bool) -> (Word8 -> Bool) -> Word8 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bool
isValidHeaderNameOctet) (HeaderName -> ByteString
forall s. CI s -> s
original HeaderName
name)
, Word8 -> InvalidHeaderReason
InvalidOctetInHeaderValue (Word8 -> InvalidHeaderReason)
-> Maybe Word8 -> Maybe InvalidHeaderReason
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Word8 -> Bool) -> ByteString -> Maybe Word8
BS.find (Bool -> Bool
not (Bool -> Bool) -> (Word8 -> Bool) -> Word8 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bool
isValidHeaderValueOctet) ByteString
value
, if ByteString -> Bool
hasTrailingWhitespace ByteString
value then InvalidHeaderReason -> Maybe InvalidHeaderReason
forall a. a -> Maybe a
Just InvalidHeaderReason
TrailingWhitespaceInHeaderValue else Maybe InvalidHeaderReason
forall a. Maybe a
Nothing
]
isValidHeaderNameOctet :: Word8 -> Bool
Word8
octet =
Word8 -> Bool
isVisibleASCII Word8
octet Bool -> Bool -> Bool
&& Bool -> Bool
not (Word8 -> Bool
isDelimiter Word8
octet)
isValidHeaderValueOctet :: Word8 -> Bool
Word8
octet =
Word8 -> Bool
isVisibleASCII Word8
octet Bool -> Bool -> Bool
|| Word8 -> Bool
isWhitespace Word8
octet Bool -> Bool -> Bool
|| Word8 -> Bool
isObsText Word8
octet
isVisibleASCII :: Word8 -> Bool
isVisibleASCII :: Word8 -> Bool
isVisibleASCII Word8
octet = Word8
octet Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
33 Bool -> Bool -> Bool
&& Word8
octet Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
126
isDelimiter :: Word8 -> Bool
isDelimiter :: Word8 -> Bool
isDelimiter Word8
octet = Int -> Char
chr (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
octet) Char -> [Char] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ([Char]
"\"(),/:;<=>?@[\\]{}" :: String)
isWhitespace :: Word8 -> Bool
isWhitespace :: Word8 -> Bool
isWhitespace Word8
octet = Word8
octet Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x9 Bool -> Bool -> Bool
|| Word8
octet Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x20
isObsText :: Word8 -> Bool
isObsText :: Word8 -> Bool
isObsText Word8
octet = Word8
octet Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x80
hasTrailingWhitespace :: BS.ByteString -> Bool
hasTrailingWhitespace :: ByteString -> Bool
hasTrailingWhitespace ByteString
bs
| ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bool
False
| Bool
otherwise = Word8 -> Bool
isWhitespace (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
0) Bool -> Bool -> Bool
|| Word8 -> Bool
isWhitespace (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
invalidHeaderResponse :: InvalidHeader -> Response
(InvalidHeader (HeaderName
headerName, ByteString
headerValue) InvalidHeaderReason
reason) =
Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
internalServerError500 [(HeaderName
"Content-Type", ByteString
"text/plain")] (ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BSL.concat
[ ByteString
"Invalid response header found:\n"
, ByteString
"In header '"
, ByteString -> ByteString
BSL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ HeaderName -> ByteString
forall s. CI s -> s
original HeaderName
headerName
, ByteString
"' with value '"
, ByteString -> ByteString
BSL.fromStrict ByteString
headerValue
, ByteString
"': "
, InvalidHeaderReason -> ByteString
showReason InvalidHeaderReason
reason
, ByteString
"\nYou are seeing this error message because validateHeadersMiddleware is enabled."
]
where
showReason :: InvalidHeaderReason -> ByteString
showReason (InvalidOctetInHeaderName Word8
octet) = ByteString
"Name contains invalid octet " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
showOctet Word8
octet
showReason (InvalidOctetInHeaderValue Word8
octet) = ByteString
"Value contains invalid octet " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
showOctet Word8
octet
showReason InvalidHeaderReason
TrailingWhitespaceInHeaderValue = ByteString
"Value contains trailing whitespace."
showOctet :: Word8 -> ByteString
showOctet Word8
octet
| Word8 -> Bool
isVisibleASCII Word8
octet = ByteString -> ByteString
BSL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> ByteString
BS8.pack ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> Char -> Word8 -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"'%c' (0x%02X)" (Int -> Char
chr (Int -> Char) -> Int -> Char
forall a b. (a -> b) -> a -> b
$ Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
octet) Word8
octet
| Bool
otherwise = ByteString -> ByteString
BSL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> ByteString
BS8.pack ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> Word8 -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"0x%02X" Word8
octet