{-# LANGUAGE DisambiguateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}
module Network.HTTP.Client.Headers
    ( parseStatusHeaders
    , validateHeaders
    , HeadersValidationResult (..)
    ) where

import           Control.Applicative            as A ((<$>), (<*>))
import           Control.Monad
import qualified Data.ByteString                as S
import qualified Data.ByteString.Char8          as S8
import qualified Data.CaseInsensitive           as CI
import           Data.Maybe (mapMaybe)
import           Data.Monoid
import           Data.Word (Word8)
import           Network.HTTP.Client.Connection
import           Network.HTTP.Client.Types
import           Network.HTTP.Types
import           System.Timeout                 (timeout)

charSpace, charColon, charPeriod :: Word8
charSpace :: Word8
charSpace = Word8
32
charColon :: Word8
charColon = Word8
58
charPeriod :: Word8
charPeriod = Word8
46


parseStatusHeaders :: Maybe MaxHeaderLength -> Connection -> Maybe Int -> ([Header] -> IO ()) -> Maybe (IO ()) -> IO StatusHeaders
parseStatusHeaders :: Maybe MaxHeaderLength
-> Connection
-> Maybe Int
-> (RequestHeaders -> IO ())
-> Maybe (IO ())
-> IO StatusHeaders
parseStatusHeaders Maybe MaxHeaderLength
mhl Connection
conn Maybe Int
timeout' RequestHeaders -> IO ()
onEarlyHintHeaders Maybe (IO ())
cont
    | Just IO ()
k <- Maybe (IO ())
cont = forall {a}. IO a -> IO StatusHeaders
getStatusExpectContinue IO ()
k
    | Bool
otherwise      = IO StatusHeaders
getStatus
  where
    withTimeout :: IO c -> IO c
withTimeout = case Maybe Int
timeout' of
        Maybe Int
Nothing -> forall a. a -> a
id
        Just  Int
t -> forall a. Int -> IO a -> IO (Maybe a)
timeout Int
t forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a. HttpExceptionContent -> IO a
throwHttp HttpExceptionContent
ResponseTimeout) forall (m :: * -> *) a. Monad m => a -> m a
return

    getStatus :: IO StatusHeaders
getStatus = forall {c}. IO c -> IO c
withTimeout IO StatusHeaders
next
      where
        next :: IO StatusHeaders
next = IO (Maybe StatusHeaders)
nextStatusHeaders forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO StatusHeaders
next forall (m :: * -> *) a. Monad m => a -> m a
return

    getStatusExpectContinue :: IO a -> IO StatusHeaders
getStatusExpectContinue IO a
sendBody = do
        Maybe StatusHeaders
status <- forall {c}. IO c -> IO c
withTimeout IO (Maybe StatusHeaders)
nextStatusHeaders
        case Maybe StatusHeaders
status of
            Just  StatusHeaders
s -> forall (m :: * -> *) a. Monad m => a -> m a
return StatusHeaders
s
            Maybe StatusHeaders
Nothing -> IO a
sendBody forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO StatusHeaders
getStatus

    nextStatusHeaders :: IO (Maybe StatusHeaders)
    nextStatusHeaders :: IO (Maybe StatusHeaders)
nextStatusHeaders = do
        (Status
s, HttpVersion
v) <- Maybe MaxHeaderLength -> IO (Status, HttpVersion)
nextStatusLine Maybe MaxHeaderLength
mhl
        if | Status -> Int
statusCode Status
s forall a. Eq a => a -> a -> Bool
== Int
100 -> Maybe MaxHeaderLength -> Connection -> IO ()
connectionDropTillBlankLine Maybe MaxHeaderLength
mhl Connection
conn forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
           | Status -> Int
statusCode Status
s forall a. Eq a => a -> a -> Bool
== Int
103 -> do
                 RequestHeaders
earlyHeaders <- Int -> (RequestHeaders -> RequestHeaders) -> IO RequestHeaders
parseEarlyHintHeadersUntilFailure Int
0 forall a. a -> a
id
                 RequestHeaders -> IO ()
onEarlyHintHeaders RequestHeaders
earlyHeaders
                 IO (Maybe StatusHeaders)
nextStatusHeaders forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                     Maybe StatusHeaders
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                     Just (StatusHeaders Status
s' HttpVersion
v' RequestHeaders
earlyHeaders' RequestHeaders
reqHeaders) ->
                         forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Status
-> HttpVersion -> RequestHeaders -> RequestHeaders -> StatusHeaders
StatusHeaders Status
s' HttpVersion
v' (RequestHeaders
earlyHeaders forall a. Semigroup a => a -> a -> a
<> RequestHeaders
earlyHeaders') RequestHeaders
reqHeaders
           | Bool
otherwise -> (forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) forall a b. (a -> b) -> a -> b
$ Status
-> HttpVersion -> RequestHeaders -> RequestHeaders -> StatusHeaders
StatusHeaders Status
s HttpVersion
v forall a. Monoid a => a
mempty forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
A.<$> Int -> (RequestHeaders -> RequestHeaders) -> IO RequestHeaders
parseHeaders Int
0 forall a. a -> a
id

    nextStatusLine :: Maybe MaxHeaderLength -> IO (Status, HttpVersion)
    nextStatusLine :: Maybe MaxHeaderLength -> IO (Status, HttpVersion)
nextStatusLine Maybe MaxHeaderLength
mhl = do
        -- Ensure that there is some data coming in. If not, we want to signal
        -- this as a connection problem and not a protocol problem.
        ByteString
bs <- Connection -> IO ByteString
connectionRead Connection
conn
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Bool
S.null ByteString
bs) forall a b. (a -> b) -> a -> b
$ forall a. HttpExceptionContent -> IO a
throwHttp HttpExceptionContent
NoResponseDataReceived
        Maybe MaxHeaderLength -> Connection -> ByteString -> IO ByteString
connectionReadLineWith Maybe MaxHeaderLength
mhl Connection
conn ByteString
bs forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe MaxHeaderLength
-> Int -> ByteString -> IO (Status, HttpVersion)
parseStatus Maybe MaxHeaderLength
mhl Int
3

    parseStatus :: Maybe MaxHeaderLength -> Int -> S.ByteString -> IO (Status, HttpVersion)
    parseStatus :: Maybe MaxHeaderLength
-> Int -> ByteString -> IO (Status, HttpVersion)
parseStatus Maybe MaxHeaderLength
mhl Int
i ByteString
bs | ByteString -> Bool
S.null ByteString
bs Bool -> Bool -> Bool
&& Int
i forall a. Ord a => a -> a -> Bool
> Int
0 = Maybe MaxHeaderLength -> Connection -> IO ByteString
connectionReadLine Maybe MaxHeaderLength
mhl Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe MaxHeaderLength
-> Int -> ByteString -> IO (Status, HttpVersion)
parseStatus Maybe MaxHeaderLength
mhl (Int
i forall a. Num a => a -> a -> a
- Int
1)
    parseStatus Maybe MaxHeaderLength
_ Int
_ ByteString
bs = do
        let (ByteString
ver, ByteString
bs2) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (forall a. Eq a => a -> a -> Bool
== Word8
charSpace) ByteString
bs
            (ByteString
code, ByteString
bs3) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (forall a. Eq a => a -> a -> Bool
== Word8
charSpace) forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile (forall a. Eq a => a -> a -> Bool
== Word8
charSpace) ByteString
bs2
            msg :: ByteString
msg = (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile (forall a. Eq a => a -> a -> Bool
== Word8
charSpace) ByteString
bs3
        case (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Maybe HttpVersion
parseVersion ByteString
ver forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
A.<*> ByteString -> Maybe Int
readInt ByteString
code of
            Just (HttpVersion
ver', Int
code') -> forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> ByteString -> Status
Status Int
code' ByteString
msg, HttpVersion
ver')
            Maybe (HttpVersion, Int)
Nothing -> forall a. HttpExceptionContent -> IO a
throwHttp forall a b. (a -> b) -> a -> b
$ ByteString -> HttpExceptionContent
InvalidStatusLine ByteString
bs

    stripPrefixBS :: ByteString -> ByteString -> Maybe ByteString
stripPrefixBS ByteString
x ByteString
y
        | ByteString
x ByteString -> ByteString -> Bool
`S.isPrefixOf` ByteString
y = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
S.drop (ByteString -> Int
S.length ByteString
x) ByteString
y
        | Bool
otherwise = forall a. Maybe a
Nothing
    parseVersion :: ByteString -> Maybe HttpVersion
parseVersion ByteString
bs0 = do
        ByteString
bs1 <- ByteString -> ByteString -> Maybe ByteString
stripPrefixBS ByteString
"HTTP/" ByteString
bs0
        let (ByteString
num1, Int -> ByteString -> ByteString
S.drop Int
1 -> ByteString
num2) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (forall a. Eq a => a -> a -> Bool
== Word8
charPeriod) ByteString
bs1
        Int -> Int -> HttpVersion
HttpVersion forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Maybe Int
readInt ByteString
num1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> Maybe Int
readInt ByteString
num2

    readInt :: ByteString -> Maybe Int
readInt ByteString
bs =
        case ByteString -> Maybe (Int, ByteString)
S8.readInt ByteString
bs of
            Just (Int
i, ByteString
"") -> forall a. a -> Maybe a
Just Int
i
            Maybe (Int, ByteString)
_ -> forall a. Maybe a
Nothing

    parseHeaders :: Int -> ([Header] -> [Header]) -> IO [Header]
    parseHeaders :: Int -> (RequestHeaders -> RequestHeaders) -> IO RequestHeaders
parseHeaders Int
100 RequestHeaders -> RequestHeaders
_ = forall a. HttpExceptionContent -> IO a
throwHttp HttpExceptionContent
OverlongHeaders
    parseHeaders Int
count RequestHeaders -> RequestHeaders
front = do
        ByteString
line <- Maybe MaxHeaderLength -> Connection -> IO ByteString
connectionReadLine Maybe MaxHeaderLength
mhl Connection
conn
        if ByteString -> Bool
S.null ByteString
line
            then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ RequestHeaders -> RequestHeaders
front []
            else
                ByteString -> IO (Maybe Header)
parseHeader ByteString
line forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                    Just Header
header ->
                        Int -> (RequestHeaders -> RequestHeaders) -> IO RequestHeaders
parseHeaders (Int
count forall a. Num a => a -> a -> a
+ Int
1) forall a b. (a -> b) -> a -> b
$ RequestHeaders -> RequestHeaders
front forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header
headerforall a. a -> [a] -> [a]
:)
                    Maybe Header
Nothing ->
                        -- Unparseable header line; rather than throwing
                        -- an exception, ignore it for robustness.
                        Int -> (RequestHeaders -> RequestHeaders) -> IO RequestHeaders
parseHeaders Int
count RequestHeaders -> RequestHeaders
front

    parseEarlyHintHeadersUntilFailure :: Int -> ([Header] -> [Header]) -> IO [Header]
    parseEarlyHintHeadersUntilFailure :: Int -> (RequestHeaders -> RequestHeaders) -> IO RequestHeaders
parseEarlyHintHeadersUntilFailure Int
100 RequestHeaders -> RequestHeaders
_ = forall a. HttpExceptionContent -> IO a
throwHttp HttpExceptionContent
OverlongHeaders
    parseEarlyHintHeadersUntilFailure Int
count RequestHeaders -> RequestHeaders
front = do
        ByteString
line <- Maybe MaxHeaderLength -> Connection -> IO ByteString
connectionReadLine Maybe MaxHeaderLength
mhl Connection
conn
        if ByteString -> Bool
S.null ByteString
line
            then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ RequestHeaders -> RequestHeaders
front []
            else
                ByteString -> IO (Maybe Header)
parseHeader ByteString
line forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                    Just Header
header ->
                      Int -> (RequestHeaders -> RequestHeaders) -> IO RequestHeaders
parseEarlyHintHeadersUntilFailure (Int
count forall a. Num a => a -> a -> a
+ Int
1) forall a b. (a -> b) -> a -> b
$ RequestHeaders -> RequestHeaders
front forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header
headerforall a. a -> [a] -> [a]
:)
                    Maybe Header
Nothing -> do
                      Connection -> ByteString -> IO ()
connectionUnreadLine Connection
conn ByteString
line
                      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ RequestHeaders -> RequestHeaders
front []

    parseHeader :: S.ByteString -> IO (Maybe Header)
    parseHeader :: ByteString -> IO (Maybe Header)
parseHeader ByteString
bs = do
        let (ByteString
key, ByteString
bs2) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (forall a. Eq a => a -> a -> Bool
== Word8
charColon) ByteString
bs
        if ByteString -> Bool
S.null ByteString
bs2
            then forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
            else forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (forall s. FoldCase s => s -> CI s
CI.mk forall a b. (a -> b) -> a -> b
$! ByteString -> ByteString
strip ByteString
key, ByteString -> ByteString
strip forall a b. (a -> b) -> a -> b
$! Int -> ByteString -> ByteString
S.drop Int
1 ByteString
bs2))

    strip :: ByteString -> ByteString
strip = (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile (forall a. Eq a => a -> a -> Bool
== Word8
charSpace) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.spanEnd (forall a. Eq a => a -> a -> Bool
== Word8
charSpace)

data HeadersValidationResult
    = GoodHeaders
    | BadHeaders S.ByteString -- contains a message with the reason

validateHeaders :: RequestHeaders -> HeadersValidationResult
validateHeaders :: RequestHeaders -> HeadersValidationResult
validateHeaders RequestHeaders
headers =
    case forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {a}.
(Semigroup a, IsString a) =>
(CI a, ByteString) -> Maybe a
validateHeader RequestHeaders
headers of
        [] -> HeadersValidationResult
GoodHeaders
        [ByteString]
reasons -> ByteString -> HeadersValidationResult
BadHeaders ([ByteString] -> ByteString
S8.unlines [ByteString]
reasons)
    where
    validateHeader :: (CI a, ByteString) -> Maybe a
validateHeader (CI a
k, ByteString
v)
        | Char -> ByteString -> Bool
S8.elem Char
'\n' ByteString
v = forall a. a -> Maybe a
Just (a
"Header " forall a. Semigroup a => a -> a -> a
<> forall s. CI s -> s
CI.original CI a
k forall a. Semigroup a => a -> a -> a
<> a
" has newlines")
        | Bool
True = forall a. Maybe a
Nothing