{-# LANGUAGE DisambiguateRecordFields #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}
module Network.HTTP.Client.Headers
    ( parseStatusHeaders
    , validateHeaders
    , HeadersValidationResult (..)
    ) where

import           Control.Applicative            as A ((<$>), (<*>))
import           Control.Monad
import qualified Data.ByteString                as S
import qualified Data.ByteString.Char8          as S8
import qualified Data.CaseInsensitive           as CI
import           Data.Maybe (mapMaybe)
import           Data.Monoid
import           Network.HTTP.Client.Connection
import           Network.HTTP.Client.Types
import           System.Timeout                 (timeout)
import           Network.HTTP.Types
import Data.Word (Word8)

charSpace, charColon, charPeriod :: Word8
charSpace = 32
charColon = 58
charPeriod = 46


parseStatusHeaders :: Connection -> Maybe Int -> Maybe (IO ()) -> IO StatusHeaders
parseStatusHeaders conn timeout' cont
    | Just k <- cont = getStatusExpectContinue k
    | otherwise      = getStatus
  where
    withTimeout = case timeout' of
        Nothing -> id
        Just  t -> timeout t >=> maybe (throwHttp ResponseTimeout) return

    getStatus = withTimeout next
      where
        next = nextStatusHeaders >>= maybe next return

    getStatusExpectContinue sendBody = do
        status <- withTimeout nextStatusHeaders
        case status of
            Just  s -> return s
            Nothing -> sendBody >> getStatus

    nextStatusHeaders = do
        (s, v) <- nextStatusLine
        if statusCode s == 100
            then connectionDropTillBlankLine conn >> return Nothing
            else Just . StatusHeaders s v A.<$> parseHeaders (0 :: Int) id

    nextStatusLine :: IO (Status, HttpVersion)
    nextStatusLine = do
        -- Ensure that there is some data coming in. If not, we want to signal
        -- this as a connection problem and not a protocol problem.
        bs <- connectionRead conn
        when (S.null bs) $ throwHttp NoResponseDataReceived
        connectionReadLineWith conn bs >>= parseStatus 3

    parseStatus :: Int -> S.ByteString -> IO (Status, HttpVersion)
    parseStatus i bs | S.null bs && i > 0 = connectionReadLine conn >>= parseStatus (i - 1)
    parseStatus _ bs = do
        let (ver, bs2) = S.break (== charSpace) bs
            (code, bs3) = S.break (== charSpace) $ S.dropWhile (== charSpace) bs2
            msg = S.dropWhile (== charSpace) bs3
        case (,) <$> parseVersion ver A.<*> readInt code of
            Just (ver', code') -> return (Status code' msg, ver')
            Nothing -> throwHttp $ InvalidStatusLine bs

    stripPrefixBS x y
        | x `S.isPrefixOf` y = Just $ S.drop (S.length x) y
        | otherwise = Nothing
    parseVersion bs0 = do
        bs1 <- stripPrefixBS "HTTP/" bs0
        let (num1, S.drop 1 -> num2) = S.break (== charPeriod) bs1
        HttpVersion <$> readInt num1 <*> readInt num2

    readInt bs =
        case S8.readInt bs of
            Just (i, "") -> Just i
            _ -> Nothing

    parseHeaders 100 _ = throwHttp OverlongHeaders
    parseHeaders count front = do
        line <- connectionReadLine conn
        if S.null line
            then return $ front []
            else do
                mheader <- parseHeader line
                case mheader of
                    Just header ->
                        parseHeaders (count + 1) $ front . (header:)
                    Nothing ->
                        -- Unparseable header line; rather than throwing
                        -- an exception, ignore it for robustness.
                        parseHeaders count front

    parseHeader :: S.ByteString -> IO (Maybe Header)
    parseHeader bs = do
        let (key, bs2) = S.break (== charColon) bs
        if S.null bs2
            then return Nothing
            else return (Just (CI.mk $! strip key, strip $! S.drop 1 bs2))

    strip = S.dropWhile (== charSpace) . fst . S.spanEnd (== charSpace)

data HeadersValidationResult
    = GoodHeaders
    | BadHeaders S.ByteString -- contains a message with the reason

validateHeaders :: RequestHeaders -> HeadersValidationResult
validateHeaders headers =
    case mapMaybe validateHeader headers of
        [] -> GoodHeaders
        reasons -> BadHeaders (S8.unlines reasons)
    where
    validateHeader (k, v)
        | S8.elem '\n' v = Just ("Header " <> CI.original k <> " has newlines")
        | True = Nothing