-- | This module provides a middleware to validate response headers.
-- [RFC 9110](https://www.rfc-editor.org/rfc/rfc9110.html#section-5) constrains the allowed octets in header names and values:
--
-- * Header names are [tokens](https://www.rfc-editor.org/rfc/rfc9110#section-5.6.2), i.e. visible ASCII characters (octets 33 to 126 inclusive) except delimiters.
-- * Header values should be limited to visible ASCII characters, the whitespace characters space and horizontal tab and octets 128 to 255. Headers values may not have trailing whitespace (see [RFC 9110 Section 5.5](https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5)). Folding is not allowed.
--
-- 'validateHeadersMiddleware' enforces these constraints for response headers by responding with a 500 Internal Server Error when an offending character is present. This is meant to catch programmer errors early on and reduce attack surface.
module Network.Wai.Middleware.ValidateHeaders
    ( -- * Middleware
      validateHeadersMiddleware
      -- * Settings
    , ValidateHeadersSettings (..)
    , defaultValidateHeadersSettings
      -- * Types
    , InvalidHeader (..)
    , InvalidHeaderReason (..)
    ) where

import Data.CaseInsensitive (original)
import Data.Char (chr)
import Data.Word (Word8)
import Network.HTTP.Types (Header, ResponseHeaders, internalServerError500)
import Network.Wai (Middleware, Response, responseHeaders, responseLBS)
import Text.Printf (printf)

import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import qualified Data.ByteString.Lazy as BSL

-- | Middleware to validate response headers.
--
-- @since 3.1.15
validateHeadersMiddleware :: ValidateHeadersSettings -> Middleware
validateHeadersMiddleware :: ValidateHeadersSettings -> Middleware
validateHeadersMiddleware ValidateHeadersSettings
settings Application
app Request
req Response -> IO ResponseReceived
respond =
    Application
app Request
req Response -> IO ResponseReceived
respond'
    where
        respond' :: Response -> IO ResponseReceived
respond' Response
response = case ResponseHeaders -> Maybe InvalidHeader
getInvalidHeader (ResponseHeaders -> Maybe InvalidHeader)
-> ResponseHeaders -> Maybe InvalidHeader
forall a b. (a -> b) -> a -> b
$ Response -> ResponseHeaders
responseHeaders Response
response of
            Just InvalidHeader
invalidHeader -> ValidateHeadersSettings -> InvalidHeader -> Middleware
onInvalidHeader ValidateHeadersSettings
settings InvalidHeader
invalidHeader Application
app Request
req Response -> IO ResponseReceived
respond
            Maybe InvalidHeader
Nothing -> Response -> IO ResponseReceived
respond Response
response

-- | Configuration for 'validateHeadersMiddleware'.
--
-- @since 3.1.15
data ValidateHeadersSettings = ValidateHeadersSettings
  { -- | Called when an invalid header is present.
    ValidateHeadersSettings -> InvalidHeader -> Middleware
onInvalidHeader :: InvalidHeader -> Middleware
  }

-- | Default configuration for 'validateHeadersMiddleware'.
-- Checks that each header meets the requirements listed at the top of this module: Allowed octets for name and value and no trailing whitespace in the value.
--
-- @since 3.1.15
defaultValidateHeadersSettings :: ValidateHeadersSettings
defaultValidateHeadersSettings :: ValidateHeadersSettings
defaultValidateHeadersSettings = ValidateHeadersSettings
  { onInvalidHeader :: InvalidHeader -> Middleware
onInvalidHeader = \InvalidHeader
invalidHeader Application
_app Request
_req Response -> IO ResponseReceived
respond -> Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ InvalidHeader -> Response
invalidHeaderResponse InvalidHeader
invalidHeader
  }

-- | Description of an invalid header.
--
-- @since 3.1.15
data InvalidHeader = InvalidHeader Header InvalidHeaderReason

-- | Reasons a header might be invalid.
--
-- @since 3.1.15
data InvalidHeaderReason
  -- | Header name contains an invalid octet.
  = InvalidOctetInHeaderName Word8
  -- | Header value contains an invalid octet.
  | InvalidOctetInHeaderValue Word8
  -- | Header value contains trailing whitespace.
  | TrailingWhitespaceInHeaderValue

-- Internal stuff.
-- 'getInvalidHeader' returns an appropriate 'InvalidHeader' for a given header if applicable.
-- 'invalidHeaderResponse' creates a 'Response' for a given 'InvalidHeader'.

getInvalidHeader :: ResponseHeaders -> Maybe InvalidHeader
getInvalidHeader :: ResponseHeaders -> Maybe InvalidHeader
getInvalidHeader = [Maybe InvalidHeader] -> Maybe InvalidHeader
forall a. [Maybe a] -> Maybe a
firstJust ([Maybe InvalidHeader] -> Maybe InvalidHeader)
-> (ResponseHeaders -> [Maybe InvalidHeader])
-> ResponseHeaders
-> Maybe InvalidHeader
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header -> Maybe InvalidHeader)
-> ResponseHeaders -> [Maybe InvalidHeader]
forall a b. (a -> b) -> [a] -> [b]
map Header -> Maybe InvalidHeader
go
    where
        firstJust :: [Maybe a] -> Maybe a
        firstJust :: forall a. [Maybe a] -> Maybe a
firstJust [] = Maybe a
forall a. Maybe a
Nothing
        firstJust (Just a
x : [Maybe a]
_) = a -> Maybe a
forall a. a -> Maybe a
Just a
x
        firstJust (Maybe a
_ : [Maybe a]
xs) = [Maybe a] -> Maybe a
forall a. [Maybe a] -> Maybe a
firstJust [Maybe a]
xs

        go :: Header -> Maybe InvalidHeader
        go :: Header -> Maybe InvalidHeader
go header :: Header
header@(HeaderName
name, ByteString
value) = Header -> InvalidHeaderReason -> InvalidHeader
InvalidHeader Header
header (InvalidHeaderReason -> InvalidHeader)
-> Maybe InvalidHeaderReason -> Maybe InvalidHeader
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Maybe InvalidHeaderReason] -> Maybe InvalidHeaderReason
forall a. [Maybe a] -> Maybe a
firstJust
          [ Word8 -> InvalidHeaderReason
InvalidOctetInHeaderName (Word8 -> InvalidHeaderReason)
-> Maybe Word8 -> Maybe InvalidHeaderReason
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Word8 -> Bool) -> ByteString -> Maybe Word8
BS.find (Bool -> Bool
not (Bool -> Bool) -> (Word8 -> Bool) -> Word8 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bool
isValidHeaderNameOctet) (HeaderName -> ByteString
forall s. CI s -> s
original HeaderName
name)
          , Word8 -> InvalidHeaderReason
InvalidOctetInHeaderValue (Word8 -> InvalidHeaderReason)
-> Maybe Word8 -> Maybe InvalidHeaderReason
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Word8 -> Bool) -> ByteString -> Maybe Word8
BS.find (Bool -> Bool
not (Bool -> Bool) -> (Word8 -> Bool) -> Word8 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bool
isValidHeaderValueOctet) ByteString
value
          , if ByteString -> Bool
hasTrailingWhitespace ByteString
value then InvalidHeaderReason -> Maybe InvalidHeaderReason
forall a. a -> Maybe a
Just InvalidHeaderReason
TrailingWhitespaceInHeaderValue else Maybe InvalidHeaderReason
forall a. Maybe a
Nothing
          ]

isValidHeaderNameOctet :: Word8 -> Bool
isValidHeaderNameOctet :: Word8 -> Bool
isValidHeaderNameOctet Word8
octet =
    Word8 -> Bool
isVisibleASCII Word8
octet Bool -> Bool -> Bool
&& Bool -> Bool
not (Word8 -> Bool
isDelimiter Word8
octet)

isValidHeaderValueOctet :: Word8 -> Bool
isValidHeaderValueOctet :: Word8 -> Bool
isValidHeaderValueOctet Word8
octet =
    Word8 -> Bool
isVisibleASCII Word8
octet Bool -> Bool -> Bool
|| Word8 -> Bool
isWhitespace Word8
octet Bool -> Bool -> Bool
|| Word8 -> Bool
isObsText Word8
octet

isVisibleASCII :: Word8 -> Bool
isVisibleASCII :: Word8 -> Bool
isVisibleASCII Word8
octet = Word8
octet Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
33 Bool -> Bool -> Bool
&& Word8
octet Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
126

isDelimiter :: Word8 -> Bool
isDelimiter :: Word8 -> Bool
isDelimiter Word8
octet = Int -> Char
chr (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
octet) Char -> [Char] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ([Char]
"\"(),/:;<=>?@[\\]{}" :: String)

-- Whitespace characters are only horizontal tab and space here.
isWhitespace :: Word8 -> Bool
isWhitespace :: Word8 -> Bool
isWhitespace Word8
octet = Word8
octet Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x9 Bool -> Bool -> Bool
|| Word8
octet Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x20

isObsText :: Word8 -> Bool
isObsText :: Word8 -> Bool
isObsText Word8
octet = Word8
octet Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x80

hasTrailingWhitespace :: BS.ByteString -> Bool
hasTrailingWhitespace :: ByteString -> Bool
hasTrailingWhitespace ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bool
False
  | Bool
otherwise = Word8 -> Bool
isWhitespace (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
0) Bool -> Bool -> Bool
|| Word8 -> Bool
isWhitespace (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

invalidHeaderResponse :: InvalidHeader -> Response
invalidHeaderResponse :: InvalidHeader -> Response
invalidHeaderResponse (InvalidHeader (HeaderName
headerName, ByteString
headerValue) InvalidHeaderReason
reason) =
    Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
internalServerError500 [(HeaderName
"Content-Type", ByteString
"text/plain")] (ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BSL.concat
        [ ByteString
"Invalid response header found:\n"
        , ByteString
"In header '"
        , ByteString -> ByteString
BSL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ HeaderName -> ByteString
forall s. CI s -> s
original HeaderName
headerName
        , ByteString
"' with value '"
        , ByteString -> ByteString
BSL.fromStrict ByteString
headerValue
        , ByteString
"': "
        , InvalidHeaderReason -> ByteString
showReason InvalidHeaderReason
reason
        , ByteString
"\nYou are seeing this error message because validateHeadersMiddleware is enabled."
        ]
    where
        showReason :: InvalidHeaderReason -> ByteString
showReason (InvalidOctetInHeaderName Word8
octet) = ByteString
"Name contains invalid octet " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
showOctet Word8
octet
        showReason (InvalidOctetInHeaderValue Word8
octet) = ByteString
"Value contains invalid octet " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
showOctet Word8
octet
        showReason InvalidHeaderReason
TrailingWhitespaceInHeaderValue = ByteString
"Value contains trailing whitespace."

        showOctet :: Word8 -> ByteString
showOctet Word8
octet
            | Word8 -> Bool
isVisibleASCII Word8
octet = ByteString -> ByteString
BSL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> ByteString
BS8.pack ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> Char -> Word8 -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"'%c' (0x%02X)" (Int -> Char
chr (Int -> Char) -> Int -> Char
forall a b. (a -> b) -> a -> b
$ Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
octet) Word8
octet
            | Bool
otherwise = ByteString -> ByteString
BSL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> ByteString
BS8.pack ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> Word8 -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"0x%02X" Word8
octet