{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Snap.Internal.Http.Server.Socket
  ( bindSocket
  , bindSocketImpl
  , bindUnixSocket
  , httpAcceptFunc
  , haProxyAcceptFunc
  , sendFileFunc
  , acceptAndInitialize
  ) where

------------------------------------------------------------------------------
import           Control.Exception                 (bracketOnError, finally, throwIO)
import           Control.Monad                     (when)
import           Data.Bits                         (complement, (.&.))
import           Data.ByteString.Char8             (ByteString)
import           Network.Socket                    (Socket, SocketOption (NoDelay, ReuseAddr), accept, close, getSocketName, setSocketOption, socket)
import qualified Network.Socket                    as N
#ifdef HAS_SENDFILE
import           Network.Socket                    (fdSocket)
import           System.Posix.IO                   (OpenMode (..), closeFd, defaultFileFlags, openFd)
import           System.Posix.Types                (Fd (..))
import           System.SendFile                   (sendFile, sendHeaders)
#else
import           Data.ByteString.Builder           (byteString)
import           Data.ByteString.Builder.Extra     (flush)
import           Network.Socket.ByteString         (sendAll)
#endif
#ifdef HAS_UNIX_SOCKETS
import           Control.Exception                 (bracket)
import qualified Control.Exception                 as E (catch)
import           System.FilePath                   (isRelative)
import           System.IO.Error                   (isDoesNotExistError)
import           System.Posix.Files                (accessModes, removeLink, setFileCreationMask)
#endif

------------------------------------------------------------------------------
import qualified System.IO.Streams                 as Streams
------------------------------------------------------------------------------
import           Snap.Internal.Http.Server.Address (AddressNotSupportedException (..), getAddress, getSockAddr)
import           Snap.Internal.Http.Server.Types   (AcceptFunc (..), SendFileHandler)
import qualified System.IO.Streams.Network.HAProxy as HA


------------------------------------------------------------------------------
bindSocket :: ByteString -> Int -> IO Socket
bindSocket = bindSocketImpl setSocketOption bind N.listen
  where
#if MIN_VERSION_network(2,7,0)
    bind = N.bind
#else
    bind = N.bindSocket
#endif
{-# INLINE bindSocket #-}


------------------------------------------------------------------------------
bindSocketImpl
    :: (Socket -> SocketOption -> Int -> IO ()) -- ^ mock setSocketOption
    -> (Socket -> N.SockAddr -> IO ())          -- ^ bindSocket
    -> (Socket -> Int -> IO ())                 -- ^ listen
    -> ByteString
    -> Int
    -> IO Socket
bindSocketImpl _setSocketOption _bindSocket _listen bindAddr bindPort = do
    (family, addr) <- getSockAddr bindPort bindAddr
    bracketOnError (socket family N.Stream 0) N.close $ \sock -> do
        _setSocketOption sock ReuseAddr 1
        _setSocketOption sock NoDelay 1
        _bindSocket sock addr
        _listen sock 150
        return $! sock

bindUnixSocket :: Maybe Int -> String -> IO Socket
#if HAS_UNIX_SOCKETS
bindUnixSocket mode path = do
   when (isRelative path) $
      throwIO $ AddressNotSupportedException
                $! "Refusing to bind unix socket to non-absolute path: " ++ path

   bracketOnError (socket N.AF_UNIX N.Stream 0) N.close $ \sock -> do
      E.catch (removeLink path) $ \e -> when (not $ isDoesNotExistError e) $ throwIO e
      case mode of
         Nothing -> bind sock (N.SockAddrUnix path)
         Just mode' -> bracket (setFileCreationMask $ modeToMask mode')
                              setFileCreationMask
                              (const $ bind sock (N.SockAddrUnix path))
      N.listen sock 150
      return $! sock
   where
#if MIN_VERSION_network(2,7,0)
     bind = N.bind
#else
     bind = N.bindSocket
#endif
     modeToMask p = accessModes .&. complement (fromIntegral p)
#else
bindUnixSocket _ path = throwIO (AddressNotSupportedException $ "unix:" ++ path)
#endif

------------------------------------------------------------------------------
-- TODO(greg): move buffer size configuration into config
bUFSIZ :: Int
bUFSIZ = 4064


------------------------------------------------------------------------------
acceptAndInitialize :: Socket        -- ^ bound socket
                    -> (forall b . IO b -> IO b)
                    -> ((Socket, N.SockAddr) -> IO a)
                    -> IO a
acceptAndInitialize boundSocket restore f =
    bracketOnError (restore $ accept boundSocket)
                   (close . fst)
                   f


------------------------------------------------------------------------------
haProxyAcceptFunc :: Socket     -- ^ bound socket
                  -> AcceptFunc
haProxyAcceptFunc boundSocket =
    AcceptFunc $ \restore ->
    acceptAndInitialize boundSocket restore $ \(sock, saddr) -> do
        (readEnd, writeEnd)      <- Streams.socketToStreamsWithBufferSize
                                        bUFSIZ sock
        localPInfo               <- HA.socketToProxyInfo sock saddr
        pinfo                    <- HA.decodeHAProxyHeaders localPInfo readEnd
        (localPort, localHost)   <- getAddress $ HA.getDestAddr pinfo
        (remotePort, remoteHost) <- getAddress $ HA.getSourceAddr pinfo
        let cleanup              =  Streams.write Nothing writeEnd
                                        `finally` close sock
        return $! ( sendFileFunc sock
                  , localHost
                  , localPort
                  , remoteHost
                  , remotePort
                  , readEnd
                  , writeEnd
                  , cleanup
                  )


------------------------------------------------------------------------------
httpAcceptFunc :: Socket                     -- ^ bound socket
               -> AcceptFunc
httpAcceptFunc boundSocket =
    AcceptFunc $ \restore ->
    acceptAndInitialize boundSocket restore $ \(sock, remoteAddr) -> do
        localAddr                <- getSocketName sock
        (localPort, localHost)   <- getAddress localAddr
        (remotePort, remoteHost) <- getAddress remoteAddr
        (readEnd, writeEnd)      <- Streams.socketToStreamsWithBufferSize bUFSIZ
                                                                          sock
        let cleanup              =  Streams.write Nothing writeEnd
                                      `finally` close sock
        return $! ( sendFileFunc sock
                  , localHost
                  , localPort
                  , remoteHost
                  , remotePort
                  , readEnd
                  , writeEnd
                  , cleanup
                  )


------------------------------------------------------------------------------
sendFileFunc :: Socket -> SendFileHandler
#ifdef HAS_SENDFILE
sendFileFunc sock !_ builder fPath offset nbytes = bracket acquire closeFd go
  where
    acquire   = openFd fPath ReadOnly Nothing defaultFileFlags
#if MIN_VERSION_network(3,0,0)
    go fileFd = do sockFd <- Fd `fmap` fdSocket sock
                   sendHeaders builder sockFd
                   sendFile sockFd fileFd offset nbytes
#else
    go fileFd = do let sockFd = Fd $ fdSocket sock
                   sendHeaders builder sockFd
                   sendFile sockFd fileFd offset nbytes
#endif

#else
sendFileFunc sock buffer builder fPath offset nbytes =
    Streams.unsafeWithFileAsInputStartingAt (fromIntegral offset) fPath $
            \fileInput0 -> do
        fileInput <- Streams.takeBytes (fromIntegral nbytes) fileInput0 >>=
                     Streams.map byteString
        input     <- Streams.fromList [builder] >>=
                     flip Streams.appendInputStream fileInput
        output    <- Streams.makeOutputStream sendChunk >>=
                     Streams.unsafeBuilderStream (return buffer)
        Streams.supply input output
        Streams.write (Just flush) output

  where
    sendChunk (Just s) = sendAll sock s
    sendChunk Nothing  = return $! ()
#endif