{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
module Network.Wai.Handler.Warp.Request (
recvRequest
, headerLines
, pauseTimeoutKey
, getFileInfoKey
) where
import qualified Control.Concurrent as Conc (yield)
import Control.Exception (throwIO)
import Data.Array ((!))
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as SU
import qualified Data.CaseInsensitive as CI
import qualified Data.IORef as I
import qualified Network.HTTP.Types as H
import Network.Socket (SockAddr)
import Network.Wai
import Network.Wai.Handler.Warp.Conduit
import Network.Wai.Handler.Warp.FileInfoCache
import Network.Wai.Handler.Warp.HashMap (hashByteString)
import Network.Wai.Handler.Warp.Header
import Network.Wai.Handler.Warp.ReadInt
import Network.Wai.Handler.Warp.RequestHeader
import Network.Wai.Handler.Warp.Settings (Settings, settingsNoParsePath)
import qualified Network.Wai.Handler.Warp.Timeout as Timeout
import Network.Wai.Handler.Warp.Types
import Network.Wai.Internal
import Prelude hiding (lines)
import Control.Monad (when)
import qualified Data.Vault.Lazy as Vault
import System.IO.Unsafe (unsafePerformIO)
maxTotalHeaderLength :: Int
maxTotalHeaderLength = 50 * 1024
recvRequest :: Settings
-> Connection
-> InternalInfo1
-> SockAddr
-> Source
-> IO (Request
,Maybe (I.IORef Int)
,IndexedHeader
,IO ByteString
,InternalInfo)
recvRequest settings conn ii1 addr src = do
hdrlines <- headerLines src
(method, unparsedPath, path, query, httpversion, hdr) <- parseHeaderLines hdrlines
let idxhdr = indexRequestHeader hdr
expect = idxhdr ! fromEnum ReqExpect
cl = idxhdr ! fromEnum ReqContentLength
te = idxhdr ! fromEnum ReqTransferEncoding
handle100Continue = handleExpect conn httpversion expect
rawPath = if settingsNoParsePath settings then unparsedPath else path
h = hashByteString rawPath
ii = toInternalInfo ii1 h
th = threadHandle ii
vaultValue = Vault.insert pauseTimeoutKey (Timeout.pause th)
$ Vault.insert getFileInfoKey (getFileInfo ii)
Vault.empty
(rbody, remainingRef, bodyLength) <- bodyAndSource src cl te
rbody' <- timeoutBody remainingRef th rbody handle100Continue
rbodyFlush <- timeoutBody remainingRef th rbody (return ())
let req = Request {
requestMethod = method
, httpVersion = httpversion
, pathInfo = H.decodePathSegments path
, rawPathInfo = rawPath
, rawQueryString = query
, queryString = H.parseQuery query
, requestHeaders = hdr
, isSecure = False
, remoteHost = addr
, requestBody = rbody'
, vault = vaultValue
, requestBodyLength = bodyLength
, requestHeaderHost = idxhdr ! fromEnum ReqHost
, requestHeaderRange = idxhdr ! fromEnum ReqRange
, requestHeaderReferer = idxhdr ! fromEnum ReqReferer
, requestHeaderUserAgent = idxhdr ! fromEnum ReqUserAgent
}
return (req, remainingRef, idxhdr, rbodyFlush, ii)
headerLines :: Source -> IO [ByteString]
headerLines src = do
bs <- readSource src
if S.null bs
then throwIO ConnectionClosedByPeer
else push src (THStatus 0 id id) bs
handleExpect :: Connection
-> H.HttpVersion
-> Maybe HeaderValue
-> IO ()
handleExpect conn ver (Just "100-continue") = do
connSendAll conn continue
Conc.yield
where
continue
| ver == H.http11 = "HTTP/1.1 100 Continue\r\n\r\n"
| otherwise = "HTTP/1.0 100 Continue\r\n\r\n"
handleExpect _ _ _ = return ()
bodyAndSource :: Source
-> Maybe HeaderValue
-> Maybe HeaderValue
-> IO (IO ByteString
,Maybe (I.IORef Int)
,RequestBodyLength
)
bodyAndSource src cl te
| chunked = do
csrc <- mkCSource src
return (readCSource csrc, Nothing, ChunkedBody)
| otherwise = do
isrc@(ISource _ remaining) <- mkISource src len
return (readISource isrc, Just remaining, bodyLen)
where
len = toLength cl
bodyLen = KnownLength $ fromIntegral len
chunked = isChunked te
toLength :: Maybe HeaderValue -> Int
toLength Nothing = 0
toLength (Just bs) = readInt bs
isChunked :: Maybe HeaderValue -> Bool
isChunked (Just bs) = CI.foldCase bs == "chunked"
isChunked _ = False
timeoutBody :: Maybe (I.IORef Int)
-> Timeout.Handle
-> IO ByteString
-> IO ()
-> IO (IO ByteString)
timeoutBody remainingRef timeoutHandle rbody handle100Continue = do
isFirstRef <- I.newIORef True
let checkEmpty =
case remainingRef of
Nothing -> return . S.null
Just ref -> \bs -> if S.null bs
then return True
else do
x <- I.readIORef ref
return $! x <= 0
return $ do
isFirst <- I.readIORef isFirstRef
when isFirst $ do
handle100Continue
Timeout.resume timeoutHandle
I.writeIORef isFirstRef False
bs <- rbody
isEmpty <- checkEmpty bs
when isEmpty (Timeout.pause timeoutHandle)
return bs
type BSEndo = ByteString -> ByteString
type BSEndoList = [ByteString] -> [ByteString]
data THStatus = THStatus
{-# UNPACK #-} !Int
BSEndoList
BSEndo
push :: Source -> THStatus -> ByteString -> IO [ByteString]
push src (THStatus len lines prepend) bs'
| len > maxTotalHeaderLength = throwIO OverLargeHeader
| otherwise = push' mnl
where
bs = prepend bs'
bsLen = S.length bs
mnl = do
nl <- S.elemIndex 10 bs
if bsLen > nl + 1 then
let c = S.index bs (nl + 1)
b = case nl of
0 -> True
1 -> S.index bs 0 == 13
_ -> False
in Just (nl, not b && (c == 32 || c == 9))
else
Just (nl, False)
{-# INLINE push' #-}
push' :: Maybe (Int, Bool) -> IO [ByteString]
push' Nothing = do
bst <- readSource' src
when (S.null bst) $ throwIO IncompleteHeaders
push src status bst
where
len' = len + bsLen
prepend' = S.append bs
status = THStatus len' lines prepend'
push' (Just (end, True)) = push src status rest
where
rest = S.drop (end + 1) bs
prepend' = S.append (SU.unsafeTake (checkCR bs end) bs)
len' = len + end
status = THStatus len' lines prepend'
push' (Just (end, False))
| S.null line = do
when (start < bsLen) $ leftoverSource src (SU.unsafeDrop start bs)
return (lines [])
| otherwise = let len' = len + start
lines' = lines . (line:)
status = THStatus len' lines' id
in if start < bsLen then
let bs'' = SU.unsafeDrop start bs
in push src status bs''
else do
bst <- readSource' src
when (S.null bs) $ throwIO IncompleteHeaders
push src status bst
where
start = end + 1
line = SU.unsafeTake (checkCR bs end) bs
{-# INLINE checkCR #-}
checkCR :: ByteString -> Int -> Int
checkCR bs pos = if pos > 0 && 13 == S.index bs p then p else pos
where
!p = pos - 1
pauseTimeoutKey :: Vault.Key (IO ())
pauseTimeoutKey = unsafePerformIO Vault.newKey
{-# NOINLINE pauseTimeoutKey #-}
getFileInfoKey :: Vault.Key (FilePath -> IO FileInfo)
getFileInfoKey = unsafePerformIO Vault.newKey
{-# NOINLINE getFileInfoKey #-}