{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
module Network.Wai.Handler.Warp.Run where
import "iproute" Data.IP (toHostAddress, toHostAddress6)
import Control.Arrow (first)
import qualified Control.Concurrent as Conc (yield)
import Control.Exception as E
import qualified Data.ByteString as S
import Data.Char (chr)
import Data.IORef (IORef, newIORef, readIORef, writeIORef, atomicModifyIORef')
import Data.Streaming.Network (bindPortTCP)
import Foreign.C.Error (Errno(..), eCONNABORTED)
import GHC.IO.Exception (IOException(..))
import Network.Socket (Socket, close, accept, withSocketsDo, SockAddr(SockAddrInet, SockAddrInet6), setSocketOption, SocketOption(..))
import qualified Network.Socket.ByteString as Sock
import Network.Wai
import Network.Wai.Internal (ResponseReceived (ResponseReceived))
import System.Environment (getEnvironment)
import System.Timeout (timeout)
import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Counter
import qualified Network.Wai.Handler.Warp.Date as D
import qualified Network.Wai.Handler.Warp.FdCache as F
import qualified Network.Wai.Handler.Warp.FileInfoCache as I
import Network.Wai.Handler.Warp.HTTP2 (http2, isHTTP2)
import Network.Wai.Handler.Warp.Header
import Network.Wai.Handler.Warp.Imports hiding (readInt)
import Network.Wai.Handler.Warp.ReadInt
import Network.Wai.Handler.Warp.Recv
import Network.Wai.Handler.Warp.Request
import Network.Wai.Handler.Warp.Response
import Network.Wai.Handler.Warp.SendFile
import Network.Wai.Handler.Warp.Settings
import qualified Network.Wai.Handler.Warp.Timeout as T
import Network.Wai.Handler.Warp.Types
#if WINDOWS
import Network.Wai.Handler.Warp.Windows
#else
import Network.Socket (fdSocket)
#endif
socketConnection :: Socket -> IO Connection
socketConnection s = do
    bufferPool <- newBufferPool
    writeBuf <- allocateBuffer bufferSize
    let sendall = Sock.sendAll s
    return Connection {
        connSendMany = Sock.sendMany s
      , connSendAll = sendall
      , connSendFile = sendFile s writeBuf bufferSize sendall
      , connClose = close s
      , connFree = freeBuffer writeBuf
      , connRecv = receive s bufferPool
      , connRecvBuf = receiveBuf s
      , connWriteBuffer = writeBuf
      , connBufferSize = bufferSize
      }
run :: Port -> Application -> IO ()
run p = runSettings defaultSettings { settingsPort = p }
runEnv :: Port -> Application -> IO ()
runEnv p app = do
    mp <- lookup "PORT" <$> getEnvironment
    maybe (run p app) runReadPort mp
  where
    runReadPort :: String -> IO ()
    runReadPort sp = case reads sp of
        ((p', _):_) -> run p' app
        _ -> fail $ "Invalid value in $PORT: " ++ sp
runSettings :: Settings -> Application -> IO ()
runSettings set app = withSocketsDo $
    bracket
        (bindPortTCP (settingsPort set) (settingsHost set))
        close
        (\socket -> do
            setSocketCloseOnExec socket
            runSettingsSocket set socket app)
runSettingsSocket :: Settings -> Socket -> Application -> IO ()
runSettingsSocket set socket app = do
    settingsInstallShutdownHandler set closeListenSocket
    runSettingsConnection set getConn app
  where
    getConn = do
#if WINDOWS
        (s, sa) <- windowsThreadBlockHack $ accept socket
#else
        (s, sa) <- accept socket
#endif
        setSocketCloseOnExec s
        
        setSocketOption s NoDelay 1 `E.catch` \(E.SomeException _) -> return ()
        conn <- socketConnection s
        return (conn, sa)
    closeListenSocket = close socket
runSettingsConnection :: Settings -> IO (Connection, SockAddr) -> Application -> IO ()
runSettingsConnection set getConn app = runSettingsConnectionMaker set getConnMaker app
  where
    getConnMaker = do
      (conn, sa) <- getConn
      return (return conn, sa)
runSettingsConnectionMaker :: Settings -> IO (IO Connection, SockAddr) -> Application -> IO ()
runSettingsConnectionMaker x y =
    runSettingsConnectionMakerSecure x (toTCP <$> y)
  where
    toTCP = first ((, TCP) <$>)
runSettingsConnectionMakerSecure :: Settings -> IO (IO (Connection, Transport), SockAddr) -> Application -> IO ()
runSettingsConnectionMakerSecure set getConnMaker app = do
    settingsBeforeMainLoop set
    counter <- newCounter
    withII0 $ acceptConnection set getConnMaker app counter
  where
    withII0 action =
        withTimeoutManager $ \tm ->
        D.withDateCache $ \dc ->
        F.withFdCache fdCacheDurationInSeconds $ \fdc ->
        I.withFileInfoCache fdFileInfoDurationInSeconds $ \fic -> do
            let ii0 = InternalInfo0 tm dc fdc fic
            action ii0
    !fdCacheDurationInSeconds = settingsFdCacheDuration set * 1000000
    !fdFileInfoDurationInSeconds = settingsFileInfoCacheDuration set * 1000000
    !timeoutInSeconds = settingsTimeout set * 1000000
    withTimeoutManager f = case settingsManager set of
        Just tm -> f tm
        Nothing -> bracket
                   (T.initialize timeoutInSeconds)
                   T.stopManager
                   f
acceptConnection :: Settings
                 -> IO (IO (Connection, Transport), SockAddr)
                 -> Application
                 -> Counter
                 -> InternalInfo0
                 -> IO ()
acceptConnection set getConnMaker app counter ii0 = do
    
    
    
    
    
    void $ mask_ acceptLoop
    
    
    
    gracefulShutdown set counter
  where
    acceptLoop = do
        
        allowInterrupt
        
        
        
        
        
        
        
        mx <- acceptNewConnection
        case mx of
            Nothing             -> return ()
            Just (mkConn, addr) -> do
                fork set mkConn addr app counter ii0
                acceptLoop
    acceptNewConnection = do
        ex <- try getConnMaker
        case ex of
            Right x -> return $ Just x
            Left e -> do
                let eConnAborted = getErrno eCONNABORTED
                    getErrno (Errno cInt) = cInt
                if ioe_errno e == Just eConnAborted
                    then acceptNewConnection
                    else do
                        settingsOnException set Nothing $ toException e
                        return Nothing
fork :: Settings
     -> IO (Connection, Transport)
     -> SockAddr
     -> Application
     -> Counter
     -> InternalInfo0
     -> IO ()
fork set mkConn addr app counter ii0 = settingsFork set $ \unmask ->
    
    
    handle (settingsOnException set Nothing) .
    
    
    withClosedRef $ \ref ->
        
        
        
        
        
        
        
        
        
        
        bracket mkConn (cleanUp ref) (serve unmask ref)
  where
    withClosedRef inner = newIORef False >>= inner
    closeConn ref conn = do
        isClosed <- atomicModifyIORef' ref $ \x -> (True, x)
        unless isClosed $ connClose conn
    cleanUp ref (conn, _) = closeConn ref conn `finally` connFree conn
    
    
    
    
    serve unmask ref (conn, transport) = bracket register cancel $ \th -> do
        let ii1 = toInternalInfo1 ii0 th
        
        
        
        unmask .
            
            
           bracket (onOpen addr) (onClose addr) $ \goingon ->
           
           
           when goingon $ serveConnection conn ii1 addr transport set app
      where
        register = T.registerKillThread (timeoutManager0 ii0)
                                        (closeConn ref conn)
        cancel   = T.cancel
    onOpen adr    = increase counter >> settingsOnOpen  set adr
    onClose adr _ = decrease counter >> settingsOnClose set adr
serveConnection :: Connection
                -> InternalInfo1
                -> SockAddr
                -> Transport
                -> Settings
                -> Application
                -> IO ()
serveConnection conn ii1 origAddr transport settings app = do
    
    (h2,bs) <- if isHTTP2 transport then
                   return (True, "")
                 else do
                   bs0 <- connRecv conn
                   if S.length bs0 >= 4 && "PRI " `S.isPrefixOf` bs0 then
                       return (True, bs0)
                     else
                       return (False, bs0)
    istatus <- newIORef False
    if settingsHTTP2Enabled settings && h2 then do
        rawRecvN <- makeReceiveN bs (connRecv conn) (connRecvBuf conn)
        let recvN = wrappedRecvN th istatus (settingsSlowlorisSize settings) rawRecvN
        
        http2 conn ii1 origAddr transport settings recvN app
      else do
        src <- mkSource (wrappedRecv conn th istatus (settingsSlowlorisSize settings))
        writeIORef istatus True
        leftoverSource src bs
        addr <- getProxyProtocolAddr src
        http1 True addr istatus src `E.catch` \e ->
          case fromException e of
            
            
            Just NoKeepAliveRequest -> return ()
            Nothing -> do
              sendErrorResponse (dummyreq addr) istatus e
              throwIO e
  where
    getProxyProtocolAddr src =
        case settingsProxyProtocol settings of
            ProxyProtocolNone ->
                return origAddr
            ProxyProtocolRequired -> do
                seg <- readSource src
                parseProxyProtocolHeader src seg
            ProxyProtocolOptional -> do
                seg <- readSource src
                if S.isPrefixOf "PROXY " seg
                    then parseProxyProtocolHeader src seg
                    else do leftoverSource src seg
                            return origAddr
    parseProxyProtocolHeader src seg = do
        let (header,seg') = S.break (== 0x0d) seg 
            maybeAddr = case S.split 0x20 header of 
                ["PROXY","TCP4",clientAddr,_,clientPort,_] ->
                    case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of
                        [a] -> Just (SockAddrInet (readInt clientPort)
                                                       (toHostAddress a))
                        _ -> Nothing
                ["PROXY","TCP6",clientAddr,_,clientPort,_] ->
                    case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of
                        [a] -> Just (SockAddrInet6 (readInt clientPort)
                                                        0
                                                        (toHostAddress6 a)
                                                        0)
                        _ -> Nothing
                ("PROXY":"UNKNOWN":_) ->
                    Just origAddr
                _ ->
                    Nothing
        case maybeAddr of
            Nothing -> throwIO (BadProxyHeader (decodeAscii header))
            Just a -> do leftoverSource src (S.drop 2 seg') 
                         return a
    decodeAscii = map (chr . fromEnum) . S.unpack
    th = threadHandle1 ii1
    shouldSendErrorResponse se
        | Just ConnectionClosedByPeer <- fromException se = False
        | otherwise                                       = True
    sendErrorResponse req istatus e = do
        status <- readIORef istatus
        when (shouldSendErrorResponse e && status) $ do
           let ii = toInternalInfo ii1 0 
           void $ sendResponse settings conn ii req defaultIndexRequestHeader (return S.empty) (errorResponse e)
    dummyreq addr = defaultRequest { remoteHost = addr }
    errorResponse e = settingsOnExceptionResponse settings e
    http1 firstRequest addr istatus src = do
        (req', mremainingRef, idxhdr, nextBodyFlush, ii) <- recvRequest firstRequest settings conn ii1 addr src
        let req = req' { isSecure = isTransportSecure transport }
        keepAlive <- processRequest istatus src req mremainingRef idxhdr nextBodyFlush ii
            `E.catch` \e -> do
                
                sendErrorResponse req istatus e
                settingsOnException settings (Just req) e
                
                return False
        
        
        
        
        
        
        
        
        when keepAlive $ http1 False addr istatus src
    processRequest istatus src req mremainingRef idxhdr nextBodyFlush ii = do
        
        T.pause th
        
        
        
        keepAliveRef <- newIORef $ error "keepAliveRef not filled"
        _ <- app req $ \res -> do
            T.resume th
            
            
            
            writeIORef istatus False
            keepAlive <- sendResponse settings conn ii req idxhdr (readSource src) res
            writeIORef keepAliveRef keepAlive
            return ResponseReceived
        keepAlive <- readIORef keepAliveRef
        
        
        
        
        
        
        
        
        Conc.yield
        if not keepAlive then
            return False
          else
            
            
            
            case settingsMaximumBodyFlush settings of
                Nothing -> do
                    flushEntireBody nextBodyFlush
                    T.resume th
                    return True
                Just maxToRead -> do
                    let tryKeepAlive = do
                            
                            isComplete <- flushBody nextBodyFlush maxToRead
                            if isComplete then do
                                T.resume th
                                return True
                              else
                                return False
                    case mremainingRef of
                        Just ref -> do
                            remaining <- readIORef ref
                            if remaining <= maxToRead then
                                tryKeepAlive
                              else
                                return False
                        Nothing -> tryKeepAlive
flushEntireBody :: IO ByteString -> IO ()
flushEntireBody src =
    loop
  where
    loop = do
        bs <- src
        unless (S.null bs) loop
flushBody :: IO ByteString 
          -> Int 
          -> IO Bool 
flushBody src =
    loop
  where
    loop toRead = do
        bs <- src
        let toRead' = toRead - S.length bs
        case () of
            ()
                | S.null bs -> return True
                | toRead' >= 0 -> loop toRead'
                | otherwise -> return False
wrappedRecv :: Connection -> T.Handle -> IORef Bool -> Int -> IO ByteString
wrappedRecv Connection { connRecv = recv } th istatus slowlorisSize = do
    bs <- recv
    unless (S.null bs) $ do
        writeIORef istatus True
        when (S.length bs >= slowlorisSize) $ T.tickle th
    return bs
wrappedRecvN :: T.Handle -> IORef Bool -> Int -> (BufSize -> IO ByteString) -> (BufSize -> IO ByteString)
wrappedRecvN th istatus slowlorisSize readN bufsize = do
    bs <- readN bufsize
    unless (S.null bs) $ do
        writeIORef istatus True
    
    
    
    
    
        when (S.length bs >= slowlorisSize || bufsize <= slowlorisSize) $ T.tickle th
    return bs
setSocketCloseOnExec :: Socket -> IO ()
#if WINDOWS
setSocketCloseOnExec _ = return ()
#else
setSocketCloseOnExec socket = do
#if MIN_VERSION_network(3,0,0)
    fd <- fdSocket socket
#else
    let fd = fdSocket socket
#endif
    F.setFileCloseOnExec $ fromIntegral fd
#endif
gracefulShutdown :: Settings -> Counter -> IO ()
gracefulShutdown set counter =
    case settingsGracefulShutdownTimeout set of
        Nothing ->
            waitForZero counter
        (Just seconds) ->
            void (timeout (seconds * microsPerSecond) (waitForZero counter))
            where microsPerSecond = 1000000