{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE UnboxedTuples #-}
module Snap.Internal.Http.Server.Parser
( IRequest(..)
, HttpParseException(..)
, readChunkedTransferEncoding
, writeChunkedTransferEncoding
, parseRequest
, parseFromStream
, parseCookie
, parseUrlEncoded
, getStdContentLength
, getStdHost
, getStdTransferEncoding
, getStdCookie
, getStdContentType
, getStdConnection
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Control.Exception (Exception, throwIO)
import qualified Control.Exception as E
import Control.Monad (void, when)
import Data.Attoparsec.ByteString.Char8 (Parser, hexadecimal, skipWhile, take)
import qualified Data.ByteString.Char8 as S
import Data.ByteString.Internal (ByteString (..), c2w, memchr, w2c)
#if MIN_VERSION_bytestring(0, 10, 6)
import Data.ByteString.Internal (accursedUnutterablePerformIO)
#else
import Data.ByteString.Internal (inlinePerformIO)
#endif
import qualified Data.ByteString.Unsafe as S
#if !MIN_VERSION_io_streams(1,2,0)
import Data.IORef (newIORef, readIORef, writeIORef)
#endif
import Data.List (sort)
import Data.Typeable (Typeable)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (minusPtr, nullPtr, plusPtr)
import Prelude hiding (take)
import Blaze.ByteString.Builder.HTTP (chunkedTransferEncoding, chunkedTransferTerminator)
import Data.ByteString.Builder (Builder)
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
import System.IO.Streams.Attoparsec (parseFromStream)
import Snap.Internal.Http.Types (Method (..))
import Snap.Internal.Parsing (crlf, parseCookie, parseUrlEncoded, unsafeFromNat, (<?>))
import Snap.Types.Headers (Headers)
import qualified Snap.Types.Headers as H
newtype StandardHeaders = StandardHeaders (V.Vector (Maybe ByteString))
type MStandardHeaders = MV.IOVector (Maybe ByteString)
contentLengthTag, hostTag, transferEncodingTag, cookieTag, contentTypeTag,
connectionTag, nStandardHeaders :: Int
contentLengthTag = 0
hostTag = 1
transferEncodingTag = 2
cookieTag = 3
contentTypeTag = 4
connectionTag = 5
nStandardHeaders = 6
findStdHeaderIndex :: ByteString -> Int
findStdHeaderIndex "content-length" = contentLengthTag
findStdHeaderIndex "host" = hostTag
findStdHeaderIndex "transfer-encoding" = transferEncodingTag
findStdHeaderIndex "cookie" = cookieTag
findStdHeaderIndex "content-type" = contentTypeTag
findStdHeaderIndex "connection" = connectionTag
findStdHeaderIndex _ = -1
getStdContentLength, getStdHost, getStdTransferEncoding, getStdCookie,
getStdConnection, getStdContentType :: StandardHeaders -> Maybe ByteString
getStdContentLength (StandardHeaders v) = V.unsafeIndex v contentLengthTag
getStdHost (StandardHeaders v) = V.unsafeIndex v hostTag
getStdTransferEncoding (StandardHeaders v) = V.unsafeIndex v transferEncodingTag
getStdCookie (StandardHeaders v) = V.unsafeIndex v cookieTag
getStdContentType (StandardHeaders v) = V.unsafeIndex v contentTypeTag
getStdConnection (StandardHeaders v) = V.unsafeIndex v connectionTag
newMStandardHeaders :: IO MStandardHeaders
newMStandardHeaders = MV.replicate nStandardHeaders Nothing
data IRequest = IRequest
{ iMethod :: !Method
, iRequestUri :: !ByteString
, iHttpVersion :: (Int, Int)
, iRequestHeaders :: Headers
, iStdHeaders :: StandardHeaders
}
instance Eq IRequest where
a == b =
and [ iMethod a == iMethod b
, iRequestUri a == iRequestUri b
, iHttpVersion a == iHttpVersion b
, sort (H.toList (iRequestHeaders a))
== sort (H.toList (iRequestHeaders b))
]
instance Show IRequest where
show (IRequest m u (major, minor) hdrs _) =
concat [ show m
, " "
, show u
, " "
, show major
, "."
, show minor
, " "
, show hdrs
]
data HttpParseException = HttpParseException String deriving (Typeable, Show)
instance Exception HttpParseException
{-# INLINE parseRequest #-}
parseRequest :: InputStream ByteString -> IO IRequest
parseRequest input = do
line <- pLine input
let (!mStr, !s) = bSp line
let (!uri, !vStr) = bSp s
let method = methodFromString mStr
let !version = pVer vStr
let (host, uri') = getHost uri
let uri'' = if S.null uri' then "/" else uri'
stdHdrs <- newMStandardHeaders
MV.unsafeWrite stdHdrs hostTag host
hdrs <- pHeaders stdHdrs input
outStd <- StandardHeaders <$> V.unsafeFreeze stdHdrs
return $! IRequest method uri'' version hdrs outStd
where
getHost s | "http://" `S.isPrefixOf` s
= let s' = S.unsafeDrop 7 s
(!host, !uri) = breakCh '/' s'
in (Just $! host, uri)
| "https://" `S.isPrefixOf` s
= let s' = S.unsafeDrop 8 s
(!host, !uri) = breakCh '/' s'
in (Just $! host, uri)
| otherwise = (Nothing, s)
pVer s = if "HTTP/" `S.isPrefixOf` s
then pVers (S.unsafeDrop 5 s)
else (1, 0)
bSp = splitCh ' '
pVers s = (c, d)
where
(!a, !b) = splitCh '.' s
!c = unsafeFromNat a
!d = unsafeFromNat b
pLine :: InputStream ByteString -> IO ByteString
pLine input = go []
where
throwNoCRLF =
throwIO $
HttpParseException "parse error: expected line ending in crlf"
throwBadCRLF =
throwIO $
HttpParseException "parse error: got cr without subsequent lf"
go !l = do
!mb <- Streams.read input
!s <- maybe throwNoCRLF return mb
let !i = elemIndex '\r' s
if i < 0
then noCRLF l s
else case () of
!_ | i+1 >= S.length s -> lastIsCR l s i
| S.unsafeIndex s (i+1) == 10 -> foundCRLF l s i
| otherwise -> throwBadCRLF
foundCRLF l s !i1 = do
let !i2 = i1 + 2
let !a = S.unsafeTake i1 s
when (i2 < S.length s) $ do
let !b = S.unsafeDrop i2 s
Streams.unRead b input
let !out = if null l then a else S.concat (reverse (a:l))
return out
noCRLF l s = go (s:l)
lastIsCR l s !idx = do
!t <- Streams.read input >>= maybe throwNoCRLF return
if S.null t
then lastIsCR l s idx
else do
let !c = S.unsafeHead t
if c /= 10
then throwBadCRLF
else do
let !a = S.unsafeTake idx s
let !b = S.unsafeDrop 1 t
when (not $ S.null b) $ Streams.unRead b input
let !out = if null l then a else S.concat (reverse (a:l))
return out
splitCh :: Char -> ByteString -> (ByteString, ByteString)
splitCh !c !s = if idx < 0
then (s, S.empty)
else let !a = S.unsafeTake idx s
!b = S.unsafeDrop (idx + 1) s
in (a, b)
where
!idx = elemIndex c s
{-# INLINE splitCh #-}
breakCh :: Char -> ByteString -> (ByteString, ByteString)
breakCh !c !s = if idx < 0
then (s, S.empty)
else let !a = S.unsafeTake idx s
!b = S.unsafeDrop idx s
in (a, b)
where
!idx = elemIndex c s
{-# INLINE breakCh #-}
splitHeader :: ByteString -> (ByteString, ByteString)
splitHeader !s = if idx < 0
then (s, S.empty)
else let !a = S.unsafeTake idx s
in (a, skipSp (idx + 1))
where
!idx = elemIndex ':' s
l = S.length s
skipSp !i | i >= l = S.empty
| otherwise = let c = S.unsafeIndex s i
in if isLWS $ w2c c
then skipSp $ i + 1
else S.unsafeDrop i s
{-# INLINE splitHeader #-}
isLWS :: Char -> Bool
isLWS c = c == ' ' || c == '\t'
{-# INLINE isLWS #-}
pHeaders :: MStandardHeaders -> InputStream ByteString -> IO Headers
pHeaders stdHdrs input = do
hdrs <- H.unsafeFromCaseFoldedList <$> go []
return hdrs
where
go !list = do
line <- pLine input
if S.null line
then return list
else do
let (!k0,!v) = splitHeader line
let !k = toLower k0
vf <- pCont id
let vs = vf []
let !v' = S.concat (v:vs)
let idx = findStdHeaderIndex k
when (idx >= 0) $ MV.unsafeWrite stdHdrs idx $! Just v'
let l' = ((k, v'):list)
go l'
trimBegin = S.dropWhile isLWS
pCont !dlist = do
mbS <- Streams.peek input
maybe (return dlist)
(\s -> if not (S.null s)
then if not $ isLWS $ w2c $ S.unsafeHead s
then return dlist
else procCont dlist
else Streams.read input >> pCont dlist)
mbS
procCont !dlist = do
line <- pLine input
let !t = trimBegin line
pCont (dlist . (" ":) . (t:))
methodFromString :: ByteString -> Method
methodFromString "GET" = GET
methodFromString "POST" = POST
methodFromString "HEAD" = HEAD
methodFromString "PUT" = PUT
methodFromString "DELETE" = DELETE
methodFromString "TRACE" = TRACE
methodFromString "OPTIONS" = OPTIONS
methodFromString "CONNECT" = CONNECT
methodFromString "PATCH" = PATCH
methodFromString s = Method s
readChunkedTransferEncoding :: InputStream ByteString
-> IO (InputStream ByteString)
readChunkedTransferEncoding input =
Streams.makeInputStream $ parseFromStream pGetTransferChunk input
writeChunkedTransferEncoding :: OutputStream Builder
-> IO (OutputStream Builder)
#if MIN_VERSION_io_streams(1,2,0)
writeChunkedTransferEncoding os = Streams.makeOutputStream f
where
f Nothing = do
Streams.write (Just chunkedTransferTerminator) os
Streams.write Nothing os
f x = Streams.write (chunkedTransferEncoding `fmap` x) os
#else
writeChunkedTransferEncoding os = do
eof <- newIORef True
Streams.makeOutputStream $ f eof
where
f eof Nothing = readIORef eof >>= flip when (do
writeIORef eof True
Streams.write (Just chunkedTransferTerminator) os
Streams.write Nothing os)
f _ x = Streams.write (chunkedTransferEncoding `fmap` x) os
#endif
mAX_CHUNK_SIZE :: Int
mAX_CHUNK_SIZE = (2::Int)^(18::Int)
pGetTransferChunk :: Parser (Maybe ByteString)
pGetTransferChunk = parser <?> "pGetTransferChunk"
where
parser = do
!hex <- hexadecimal <?> "hexadecimal"
skipWhile (/= '\r') <?> "skipToEOL"
void crlf <?> "linefeed"
if hex >= mAX_CHUNK_SIZE
then return $! E.throw $! HttpParseException $!
"pGetTransferChunk: chunk of size " ++ show hex ++ " too long."
else if hex <= 0
then (crlf >> return Nothing) <?> "terminal crlf after 0 length"
else do
!x <- take hex <?> "reading data chunk"
void crlf <?> "linefeed after data chunk"
return $! Just x
toLower :: ByteString -> ByteString
toLower = S.map lower
where
lower c0 = let !c = c2w c0
in if 65 <= c && c <= 90
then w2c $! c + 32
else c0
elemIndex :: Char -> ByteString -> Int
#if MIN_VERSION_bytestring(0, 10, 6)
elemIndex c (PS !fp !start !len) = accursedUnutterablePerformIO $
#else
elemIndex c (PS !fp !start !len) = inlinePerformIO $
#endif
withForeignPtr fp $ \p0 -> do
let !p = plusPtr p0 start
q <- memchr p w8 (fromIntegral len)
return $! if q == nullPtr then (-1) else q `minusPtr` p
where
!w8 = c2w c
{-# INLINE elemIndex #-}