{-# LANGUAGE BangPatterns       #-}
{-# LANGUAGE CPP                #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MagicHash          #-}
{-# LANGUAGE OverloadedStrings  #-}
{-# LANGUAGE Rank2Types         #-}
{-# LANGUAGE Trustworthy        #-}
{-# LANGUAGE UnboxedTuples      #-}

module Snap.Internal.Http.Server.Parser
  ( IRequest(..)
  , HttpParseException(..)
  , readChunkedTransferEncoding
  , writeChunkedTransferEncoding
  , parseRequest
  , parseFromStream
  , parseCookie
  , parseUrlEncoded
  , getStdContentLength
  , getStdHost
  , getStdTransferEncoding
  , getStdCookie
  , getStdContentType
  , getStdConnection
  ) where

------------------------------------------------------------------------------
#if !MIN_VERSION_base(4,8,0)
import           Control.Applicative              ((<$>))
#endif
import           Control.Exception                (Exception, throwIO)
import qualified Control.Exception                as E
import           Control.Monad                    (void, when)
import           Data.Attoparsec.ByteString.Char8 (Parser, hexadecimal, skipWhile, take)
import qualified Data.ByteString.Char8            as S
import           Data.ByteString.Internal         (ByteString (..), c2w, memchr, w2c)
#if MIN_VERSION_bytestring(0, 10, 6)
import           Data.ByteString.Internal         (accursedUnutterablePerformIO)
#else
import           Data.ByteString.Internal         (inlinePerformIO)
#endif
import qualified Data.ByteString.Unsafe           as S
#if !MIN_VERSION_io_streams(1,2,0)
import           Data.IORef                       (newIORef, readIORef, writeIORef)
#endif
import           Data.List                        (sort)
import           Data.Typeable                    (Typeable)
import qualified Data.Vector                      as V
import qualified Data.Vector.Mutable              as MV
import           Foreign.ForeignPtr               (withForeignPtr)
import           Foreign.Ptr                      (minusPtr, nullPtr, plusPtr)
import           Prelude                          hiding (take)
------------------------------------------------------------------------------
import           Blaze.ByteString.Builder.HTTP    (chunkedTransferEncoding, chunkedTransferTerminator)
import           Data.ByteString.Builder          (Builder)
import           System.IO.Streams                (InputStream, OutputStream)
import qualified System.IO.Streams                as Streams
import           System.IO.Streams.Attoparsec     (parseFromStream)
------------------------------------------------------------------------------
import           Snap.Internal.Http.Types         (Method (..))
import           Snap.Internal.Parsing            (crlf, parseCookie, parseUrlEncoded, unsafeFromNat, (<?>))
import           Snap.Types.Headers               (Headers)
import qualified Snap.Types.Headers               as H


------------------------------------------------------------------------------
newtype StandardHeaders = StandardHeaders (V.Vector (Maybe ByteString))
type MStandardHeaders = MV.IOVector (Maybe ByteString)


------------------------------------------------------------------------------
contentLengthTag, hostTag, transferEncodingTag, cookieTag, contentTypeTag,
  connectionTag, nStandardHeaders :: Int
contentLengthTag    = 0
hostTag             = 1
transferEncodingTag = 2
cookieTag           = 3
contentTypeTag      = 4
connectionTag       = 5
nStandardHeaders    = 6


------------------------------------------------------------------------------
findStdHeaderIndex :: ByteString -> Int
findStdHeaderIndex "content-length"    = contentLengthTag
findStdHeaderIndex "host"              = hostTag
findStdHeaderIndex "transfer-encoding" = transferEncodingTag
findStdHeaderIndex "cookie"            = cookieTag
findStdHeaderIndex "content-type"      = contentTypeTag
findStdHeaderIndex "connection"        = connectionTag
findStdHeaderIndex _                   = -1


------------------------------------------------------------------------------
getStdContentLength, getStdHost, getStdTransferEncoding, getStdCookie,
    getStdConnection, getStdContentType :: StandardHeaders -> Maybe ByteString
getStdContentLength    (StandardHeaders v) = V.unsafeIndex v contentLengthTag
getStdHost             (StandardHeaders v) = V.unsafeIndex v hostTag
getStdTransferEncoding (StandardHeaders v) = V.unsafeIndex v transferEncodingTag
getStdCookie           (StandardHeaders v) = V.unsafeIndex v cookieTag
getStdContentType      (StandardHeaders v) = V.unsafeIndex v contentTypeTag
getStdConnection       (StandardHeaders v) = V.unsafeIndex v connectionTag


------------------------------------------------------------------------------
newMStandardHeaders :: IO MStandardHeaders
newMStandardHeaders = MV.replicate nStandardHeaders Nothing


------------------------------------------------------------------------------
-- | an internal version of the headers part of an HTTP request
data IRequest = IRequest
    { iMethod         :: !Method
    , iRequestUri     :: !ByteString
    , iHttpVersion    :: (Int, Int)
    , iRequestHeaders :: Headers
    , iStdHeaders     :: StandardHeaders
    }

------------------------------------------------------------------------------
instance Eq IRequest where
    a == b =
        and [ iMethod a      == iMethod b
            , iRequestUri a  == iRequestUri b
            , iHttpVersion a == iHttpVersion b
            , sort (H.toList (iRequestHeaders a))
                  == sort (H.toList (iRequestHeaders b))
            ]

------------------------------------------------------------------------------
instance Show IRequest where
    show (IRequest m u (major, minor) hdrs _) =
        concat [ show m
               , " "
               , show u
               , " "
               , show major
               , "."
               , show minor
               , " "
               , show hdrs
               ]


------------------------------------------------------------------------------
data HttpParseException = HttpParseException String deriving (Typeable, Show)
instance Exception HttpParseException


------------------------------------------------------------------------------
{-# INLINE parseRequest #-}
parseRequest :: InputStream ByteString -> IO IRequest
parseRequest input = do
    line <- pLine input
    let (!mStr, !s)     = bSp line
    let (!uri, !vStr)   = bSp s
    let method          = methodFromString mStr
    let !version        = pVer vStr
    let (host, uri')    = getHost uri
    let uri''           = if S.null uri' then "/" else uri'

    stdHdrs <- newMStandardHeaders
    MV.unsafeWrite stdHdrs hostTag host
    hdrs    <- pHeaders stdHdrs input
    outStd  <- StandardHeaders <$> V.unsafeFreeze stdHdrs
    return $! IRequest method uri'' version hdrs outStd

  where
    getHost s | "http://" `S.isPrefixOf` s
                  = let s'            = S.unsafeDrop 7 s
                        (!host, !uri) = breakCh '/' s'
                    in (Just $! host, uri)
              | "https://" `S.isPrefixOf` s
                  = let s'            = S.unsafeDrop 8 s
                        (!host, !uri) = breakCh '/' s'
                    in (Just $! host, uri)
              | otherwise = (Nothing, s)

    pVer s = if "HTTP/" `S.isPrefixOf` s
               then pVers (S.unsafeDrop 5 s)
               else (1, 0)

    bSp   = splitCh ' '

    pVers s = (c, d)
      where
        (!a, !b)   = splitCh '.' s
        !c         = unsafeFromNat a
        !d         = unsafeFromNat b


------------------------------------------------------------------------------
pLine :: InputStream ByteString -> IO ByteString
pLine input = go []
  where
    throwNoCRLF =
        throwIO $
        HttpParseException "parse error: expected line ending in crlf"

    throwBadCRLF =
        throwIO $
        HttpParseException "parse error: got cr without subsequent lf"

    go !l = do
        !mb <- Streams.read input
        !s  <- maybe throwNoCRLF return mb

        let !i = elemIndex '\r' s
        if i < 0
          then noCRLF l s
          else case () of
                 !_ | i+1 >= S.length s           -> lastIsCR l s i
                    | S.unsafeIndex s (i+1) == 10 -> foundCRLF l s i
                    | otherwise                   -> throwBadCRLF

    foundCRLF l s !i1 = do
        let !i2 = i1 + 2
        let !a = S.unsafeTake i1 s
        when (i2 < S.length s) $ do
            let !b = S.unsafeDrop i2 s
            Streams.unRead b input

        -- Optimize for the common case: dl is almost always "id"
        let !out = if null l then a else S.concat (reverse (a:l))
        return out

    noCRLF l s = go (s:l)

    lastIsCR l s !idx = do
        !t <- Streams.read input >>= maybe throwNoCRLF return
        if S.null t
          then lastIsCR l s idx
          else do
            let !c = S.unsafeHead t
            if c /= 10
              then throwBadCRLF
              else do
                  let !a = S.unsafeTake idx s
                  let !b = S.unsafeDrop 1 t
                  when (not $ S.null b) $ Streams.unRead b input
                  let !out = if null l then a else S.concat (reverse (a:l))
                  return out


------------------------------------------------------------------------------
splitCh :: Char -> ByteString -> (ByteString, ByteString)
splitCh !c !s = if idx < 0
                  then (s, S.empty)
                  else let !a = S.unsafeTake idx s
                           !b = S.unsafeDrop (idx + 1) s
                       in (a, b)
  where
    !idx = elemIndex c s
{-# INLINE splitCh #-}


------------------------------------------------------------------------------
breakCh :: Char -> ByteString -> (ByteString, ByteString)
breakCh !c !s = if idx < 0
                  then (s, S.empty)
                  else let !a = S.unsafeTake idx s
                           !b = S.unsafeDrop idx s
                       in (a, b)
  where
    !idx = elemIndex c s
{-# INLINE breakCh #-}


------------------------------------------------------------------------------
splitHeader :: ByteString -> (ByteString, ByteString)
splitHeader !s = if idx < 0
                   then (s, S.empty)
                   else let !a = S.unsafeTake idx s
                        in (a, skipSp (idx + 1))
  where
    !idx = elemIndex ':' s
    l    = S.length s

    skipSp !i | i >= l    = S.empty
              | otherwise = let c = S.unsafeIndex s i
                            in if isLWS $ w2c c
                                 then skipSp $ i + 1
                                 else S.unsafeDrop i s

{-# INLINE splitHeader #-}



------------------------------------------------------------------------------
isLWS :: Char -> Bool
isLWS c = c == ' ' || c == '\t'
{-# INLINE isLWS #-}


------------------------------------------------------------------------------
pHeaders :: MStandardHeaders -> InputStream ByteString -> IO Headers
pHeaders stdHdrs input = do
    hdrs    <- H.unsafeFromCaseFoldedList <$> go []
    return hdrs

  where
    go !list = do
        line <- pLine input
        if S.null line
          then return list
          else do
            let (!k0,!v) = splitHeader line
            let !k = toLower k0
            vf <- pCont id
            let vs = vf []
            let !v' = S.concat (v:vs)
            let idx = findStdHeaderIndex k
            when (idx >= 0) $ MV.unsafeWrite stdHdrs idx $! Just v'

            let l' = ((k, v'):list)
            go l'

    trimBegin = S.dropWhile isLWS

    pCont !dlist = do
        mbS  <- Streams.peek input
        maybe (return dlist)
              (\s -> if not (S.null s)
                       then if not $ isLWS $ w2c $ S.unsafeHead s
                              then return dlist
                              else procCont dlist
                       else Streams.read input >> pCont dlist)
              mbS

    procCont !dlist = do
        line <- pLine input
        let !t = trimBegin line
        pCont (dlist . (" ":) . (t:))


------------------------------------------------------------------------------
methodFromString :: ByteString -> Method
methodFromString "GET"     = GET
methodFromString "POST"    = POST
methodFromString "HEAD"    = HEAD
methodFromString "PUT"     = PUT
methodFromString "DELETE"  = DELETE
methodFromString "TRACE"   = TRACE
methodFromString "OPTIONS" = OPTIONS
methodFromString "CONNECT" = CONNECT
methodFromString "PATCH"   = PATCH
methodFromString s         = Method s


------------------------------------------------------------------------------
readChunkedTransferEncoding :: InputStream ByteString
                            -> IO (InputStream ByteString)
readChunkedTransferEncoding input =
    Streams.makeInputStream $ parseFromStream pGetTransferChunk input


------------------------------------------------------------------------------
writeChunkedTransferEncoding :: OutputStream Builder
                             -> IO (OutputStream Builder)
#if MIN_VERSION_io_streams(1,2,0)
writeChunkedTransferEncoding os = Streams.makeOutputStream f
  where
    f Nothing = do
        Streams.write (Just chunkedTransferTerminator) os
        Streams.write Nothing os
    f x = Streams.write (chunkedTransferEncoding `fmap` x) os

#else
writeChunkedTransferEncoding os = do
    -- make sure we only send the terminator once.
    eof <- newIORef True
    Streams.makeOutputStream $ f eof
  where
    f eof Nothing = readIORef eof >>= flip when (do
        writeIORef eof True
        Streams.write (Just chunkedTransferTerminator) os
        Streams.write Nothing os)
    f _ x = Streams.write (chunkedTransferEncoding `fmap` x) os
#endif


                             ---------------------
                             -- parse functions --
                             ---------------------

------------------------------------------------------------------------------
-- We treat chunks larger than this from clients as a denial-of-service attack.
-- 256kB should be enough buffer.
mAX_CHUNK_SIZE :: Int
mAX_CHUNK_SIZE = (2::Int)^(18::Int)


------------------------------------------------------------------------------
pGetTransferChunk :: Parser (Maybe ByteString)
pGetTransferChunk = parser <?> "pGetTransferChunk"
  where
    parser = do
        !hex <- hexadecimal <?> "hexadecimal"
        skipWhile (/= '\r') <?> "skipToEOL"
        void crlf <?> "linefeed"
        if hex >= mAX_CHUNK_SIZE
          then return $! E.throw $! HttpParseException $!
               "pGetTransferChunk: chunk of size " ++ show hex ++ " too long."
          else if hex <= 0
            then (crlf >> return Nothing) <?> "terminal crlf after 0 length"
            else do
                -- now safe to take this many bytes.
                !x <- take hex <?> "reading data chunk"
                void crlf <?> "linefeed after data chunk"
                return $! Just x


------------------------------------------------------------------------------
toLower :: ByteString -> ByteString
toLower = S.map lower
  where
    lower c0 = let !c = c2w c0
               in if 65 <= c && c <= 90
                    then w2c $! c + 32
                    else c0


------------------------------------------------------------------------------
-- | A version of elemIndex that doesn't allocate a Maybe. (It returns -1 on
-- not found.)
elemIndex :: Char -> ByteString -> Int
#if MIN_VERSION_bytestring(0, 10, 6)
elemIndex c (PS !fp !start !len) = accursedUnutterablePerformIO $
#else
elemIndex c (PS !fp !start !len) = inlinePerformIO $
#endif
                                   withForeignPtr fp $ \p0 -> do
    let !p = plusPtr p0 start
    q <- memchr p w8 (fromIntegral len)
    return $! if q == nullPtr then (-1) else q `minusPtr` p
  where
    !w8 = c2w c
{-# INLINE elemIndex #-}