{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE UnboxedTuples #-}
module Snap.Internal.Http.Server.Parser
( IRequest(..)
, HttpParseException(..)
, readChunkedTransferEncoding
, writeChunkedTransferEncoding
, parseRequest
, parseFromStream
, parseCookie
, parseUrlEncoded
, getStdContentLength
, getStdHost
, getStdTransferEncoding
, getStdCookie
, getStdContentType
, getStdConnection
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
import Control.Exception (Exception, throwIO)
import qualified Control.Exception as E
import Control.Monad (void, when)
import Data.Attoparsec.ByteString.Char8 (Parser, hexadecimal, skipWhile, take)
import qualified Data.ByteString.Char8 as S
import Data.ByteString.Internal (ByteString (..), c2w, memchr, w2c)
#if MIN_VERSION_bytestring(0, 10, 6)
import Data.ByteString.Internal (accursedUnutterablePerformIO)
import Data.ByteString.Internal (inlinePerformIO)
import qualified Data.ByteString.Unsafe as S
#if !MIN_VERSION_io_streams(1,2,0)
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.List (sort)
import Data.Typeable (Typeable)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (minusPtr, nullPtr, plusPtr)
import Prelude hiding (take)
import Blaze.ByteString.Builder.HTTP (chunkedTransferEncoding, chunkedTransferTerminator)
import Data.ByteString.Builder (Builder)
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
import System.IO.Streams.Attoparsec (parseFromStream)
import Snap.Internal.Http.Types (Method (..))
import Snap.Internal.Parsing (crlf, parseCookie, parseUrlEncoded, unsafeFromNat, (<?>))
import Snap.Types.Headers (Headers)
import qualified Snap.Types.Headers as H
newtype StandardHeaders = StandardHeaders (V.Vector (Maybe ByteString))
type MStandardHeaders = MV.IOVector (Maybe ByteString)
contentLengthTag, hostTag, transferEncodingTag, cookieTag, contentTypeTag,
connectionTag, nStandardHeaders :: Int
contentLengthTag = 0
hostTag = 1
transferEncodingTag = 2
cookieTag = 3
contentTypeTag = 4
connectionTag = 5
nStandardHeaders = 6
findStdHeaderIndex :: ByteString -> Int
findStdHeaderIndex "content-length" = contentLengthTag
findStdHeaderIndex "host" = hostTag
findStdHeaderIndex "transfer-encoding" = transferEncodingTag
findStdHeaderIndex "cookie" = cookieTag
findStdHeaderIndex "content-type" = contentTypeTag
findStdHeaderIndex "connection" = connectionTag
findStdHeaderIndex _ = -1
getStdContentLength, getStdHost, getStdTransferEncoding, getStdCookie,
getStdConnection, getStdContentType :: StandardHeaders -> Maybe ByteString
getStdContentLength (StandardHeaders v) = V.unsafeIndex v contentLengthTag
getStdHost (StandardHeaders v) = V.unsafeIndex v hostTag
getStdTransferEncoding (StandardHeaders v) = V.unsafeIndex v transferEncodingTag
getStdCookie (StandardHeaders v) = V.unsafeIndex v cookieTag
getStdContentType (StandardHeaders v) = V.unsafeIndex v contentTypeTag
getStdConnection (StandardHeaders v) = V.unsafeIndex v connectionTag
newMStandardHeaders :: IO MStandardHeaders
newMStandardHeaders = MV.replicate nStandardHeaders Nothing
data IRequest = IRequest
{ iMethod :: !Method
, iRequestUri :: !ByteString
, iHttpVersion :: (Int, Int)
, iRequestHeaders :: Headers
, iStdHeaders :: StandardHeaders
instance Eq IRequest where
a == b =
and [ iMethod a == iMethod b
, iRequestUri a == iRequestUri b
, iHttpVersion a == iHttpVersion b
, sort (H.toList (iRequestHeaders a))
== sort (H.toList (iRequestHeaders b))
instance Show IRequest where
show (IRequest m u (major, minor) hdrs _) =
concat [ show m
, " "
, show u
, " "
, show major
, "."
, show minor
, " "
, show hdrs
data HttpParseException = HttpParseException String deriving (Typeable, Show)
instance Exception HttpParseException
{-# INLINE parseRequest #-}
parseRequest :: InputStream ByteString -> IO IRequest
parseRequest input = do
line <- pLine input
let (!mStr, !s) = bSp line
let (!uri, !vStr) = bSp s
let method = methodFromString mStr
let !version = pVer vStr
let (host, uri') = getHost uri
let uri'' = if S.null uri' then "/" else uri'
stdHdrs <- newMStandardHeaders
MV.unsafeWrite stdHdrs hostTag host
hdrs <- pHeaders stdHdrs input
outStd <- StandardHeaders <$> V.unsafeFreeze stdHdrs
return $! IRequest method uri'' version hdrs outStd
getHost s | "http://" `S.isPrefixOf` s
= let s' = S.unsafeDrop 7 s
(!host, !uri) = breakCh '/' s'
in (Just $! host, uri)
| "https://" `S.isPrefixOf` s
= let s' = S.unsafeDrop 8 s
(!host, !uri) = breakCh '/' s'
in (Just $! host, uri)
| otherwise = (Nothing, s)
pVer s = if "HTTP/" `S.isPrefixOf` s
then pVers (S.unsafeDrop 5 s)
else (1, 0)
bSp = splitCh ' '
pVers s = (c, d)
(!a, !b) = splitCh '.' s
!c = unsafeFromNat a
!d = unsafeFromNat b
pLine :: InputStream ByteString -> IO ByteString
pLine input = go []
throwNoCRLF =
throwIO $
HttpParseException "parse error: expected line ending in crlf"
throwBadCRLF =
throwIO $
HttpParseException "parse error: got cr without subsequent lf"
go !l = do
!mb <- Streams.read input
!s <- maybe throwNoCRLF return mb
let !i = elemIndex '\r' s
if i < 0
then noCRLF l s
else case () of
!_ | i+1 >= S.length s -> lastIsCR l s i
| S.unsafeIndex s (i+1) == 10 -> foundCRLF l s i
| otherwise -> throwBadCRLF
foundCRLF l s !i1 = do
let !i2 = i1 + 2
let !a = S.unsafeTake i1 s
when (i2 < S.length s) $ do
let !b = S.unsafeDrop i2 s
Streams.unRead b input
let !out = if null l then a else S.concat (reverse (a:l))
return out
noCRLF l s = go (s:l)
lastIsCR l s !idx = do
!t <- Streams.read input >>= maybe throwNoCRLF return
if S.null t
then lastIsCR l s idx
else do
let !c = S.unsafeHead t
if c /= 10
then throwBadCRLF
else do
let !a = S.unsafeTake idx s
let !b = S.unsafeDrop 1 t
when (not $ S.null b) $ Streams.unRead b input
let !out = if null l then a else S.concat (reverse (a:l))
return out
splitCh :: Char -> ByteString -> (ByteString, ByteString)
splitCh !c !s = if idx < 0
then (s, S.empty)
else let !a = S.unsafeTake idx s
!b = S.unsafeDrop (idx + 1) s
in (a, b)
!idx = elemIndex c s
{-# INLINE splitCh #-}
breakCh :: Char -> ByteString -> (ByteString, ByteString)
breakCh !c !s = if idx < 0
then (s, S.empty)
else let !a = S.unsafeTake idx s
!b = S.unsafeDrop idx s
in (a, b)
!idx = elemIndex c s
{-# INLINE breakCh #-}
splitHeader :: ByteString -> (ByteString, ByteString)
splitHeader !s = if idx < 0
then (s, S.empty)
else let !a = S.unsafeTake idx s
in (a, skipSp (idx + 1))
!idx = elemIndex ':' s
l = S.length s
skipSp !i | i >= l = S.empty
| otherwise = let c = S.unsafeIndex s i
in if isLWS $ w2c c
then skipSp $ i + 1
else S.unsafeDrop i s
{-# INLINE splitHeader #-}
isLWS :: Char -> Bool
isLWS c = c == ' ' || c == '\t'
{-# INLINE isLWS #-}
pHeaders :: MStandardHeaders -> InputStream ByteString -> IO Headers
pHeaders stdHdrs input = do
hdrs <- H.unsafeFromCaseFoldedList <$> go []
return hdrs
go !list = do
line <- pLine input
if S.null line
then return list
else do
let (!k0,!v) = splitHeader line
let !k = toLower k0
vf <- pCont id
let vs = vf []
let !v' = S.concat (v:vs)
let idx = findStdHeaderIndex k
when (idx >= 0) $ MV.unsafeWrite stdHdrs idx $! Just v'
let l' = ((k, v'):list)
go l'
trimBegin = S.dropWhile isLWS
pCont !dlist = do
mbS <- Streams.peek input
maybe (return dlist)
(\s -> if not (S.null s)
then if not $ isLWS $ w2c $ S.unsafeHead s
then return dlist
else procCont dlist
else Streams.read input >> pCont dlist)
procCont !dlist = do
line <- pLine input
let !t = trimBegin line
pCont (dlist . (" ":) . (t:))
methodFromString :: ByteString -> Method
methodFromString "GET" = GET
methodFromString "POST" = POST
methodFromString "HEAD" = HEAD
methodFromString "PUT" = PUT
methodFromString "DELETE" = DELETE
methodFromString "TRACE" = TRACE
methodFromString "OPTIONS" = OPTIONS
methodFromString "CONNECT" = CONNECT
methodFromString "PATCH" = PATCH
methodFromString s = Method s
readChunkedTransferEncoding :: InputStream ByteString
-> IO (InputStream ByteString)
readChunkedTransferEncoding input =
Streams.makeInputStream $ parseFromStream pGetTransferChunk input
writeChunkedTransferEncoding :: OutputStream Builder
-> IO (OutputStream Builder)
#if MIN_VERSION_io_streams(1,2,0)
writeChunkedTransferEncoding os = Streams.makeOutputStream f
f Nothing = do
Streams.write (Just chunkedTransferTerminator) os
Streams.write Nothing os
f x = Streams.write (chunkedTransferEncoding `fmap` x) os
writeChunkedTransferEncoding os = do
eof <- newIORef True
Streams.makeOutputStream $ f eof
f eof Nothing = readIORef eof >>= flip when (do
writeIORef eof True
Streams.write (Just chunkedTransferTerminator) os
Streams.write Nothing os)
f _ x = Streams.write (chunkedTransferEncoding `fmap` x) os
mAX_CHUNK_SIZE = (2::Int)^(18::Int)
pGetTransferChunk :: Parser (Maybe ByteString)
pGetTransferChunk = parser <?> "pGetTransferChunk"
parser = do
!hex <- hexadecimal <?> "hexadecimal"
skipWhile (/= '\r') <?> "skipToEOL"
void crlf <?> "linefeed"
if hex >= mAX_CHUNK_SIZE
then return $! E.throw $! HttpParseException $!
"pGetTransferChunk: chunk of size " ++ show hex ++ " too long."
else if hex <= 0
then (crlf >> return Nothing) <?> "terminal crlf after 0 length"
else do
!x <- take hex <?> "reading data chunk"
void crlf <?> "linefeed after data chunk"
return $! Just x
toLower :: ByteString -> ByteString
toLower = S.map lower
lower c0 = let !c = c2w c0
in if 65 <= c && c <= 90
then w2c $! c + 32
else c0
elemIndex :: Char -> ByteString -> Int
#if MIN_VERSION_bytestring(0, 10, 6)
elemIndex c (PS !fp !start !len) = accursedUnutterablePerformIO $
elemIndex c (PS !fp !start !len) = inlinePerformIO $
withForeignPtr fp $ \p0 -> do
let !p = plusPtr p0 start
q <- memchr p w8 (fromIntegral len)
return $! if q == nullPtr then (-1) else q `minusPtr` p
!w8 = c2w c
{-# INLINE elemIndex #-}