-- | This module contains functions for writing webservers. These servers
--   process requests in a state monad pipeline and several useful actions are
--   provided here in. See examples/fileserver.hs for an example of how to use
--   this module.
module Network.MiniHTTP.Server
  ( -- * Sources
    Source
  , SourceResult(..)
  , bsSource
  , hSource
  , nullSource

  -- * The processing monad
  , WebMonad
  , WebState
  , getRequest
  , getReply
  , setReply
  , setHeader

  -- * WebMonad actions
  , handleConditionalRequest
  , handleHandleToSource
  , handleRangeRequests
  , handleDecoration
  , handleFromFilesystem

  -- * Running the server
  , serve
  ) where

import Prelude hiding (foldl, elem, catch)

import Control.Concurrent.STM
import Control.Monad (liftM)
import Control.Monad.State.Strict
import Control.Exception (catch)

import Data.Foldable
import Data.Word (Word16)
import Data.Bits (shiftL, shiftR, (.|.), (.&.))
import Data.Maybe (isNothing, isJust, fromJust, catMaybes, maybe)
import Data.Time.Clock.POSIX
import Data.IORef
import Data.Int (Int64)
import qualified Data.ByteString as B
import Data.ByteString.Internal (w2c, c2w)
import Data.ByteString.Char8 ()
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as Map
import qualified Data.Sequence as Seq
import System.IO
import System.Posix
import System.FilePath (combine, splitDirectories, joinPath, takeExtension)
import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString
import Text.Printf (printf)

import qualified Data.Binary.Put as P
import qualified Data.Binary.Strict.IncrementalGet as IG
import Network.MiniHTTP.Marshal
import qualified Network.MiniHTTP.Connection as C
import Network.MiniHTTP.MimeTypesParse

-- | This assumes a little endian system because I can't find a nice way to probe
--   the endianness
htons :: Word16 -> Word16
htons x = ((x .&. 255) `shiftL` 8) .|. (x `shiftR` 8)

-- | A source is a stream of data, like a lazy data structure, but without
--   some of the dangers that such entail. A source returns a @SourceResult@
--   each time you evaluate it.
type Source = IO SourceResult

data SourceResult = SourceError  -- ^ error - please don't read this source again
                  | SourceEOF  -- ^ end of data
                  | SourceData B.ByteString  -- ^ some data
                  deriving (Show)

-- | Construct a source from a ByteString
bsSource :: B.ByteString -> IO Source
bsSource bs = do
  ref <- newIORef $ SourceData bs
  return $ do
    v <- readIORef ref
    writeIORef ref SourceEOF
    return v

-- | Construct a source from a Handle
hSource :: (Int64, Int64)  -- ^ the first and last byte to include
               -> Handle  -- ^ the handle to read from
               -> IO Source
hSource (from, to) handle = do
  bytesSoFar <- newIORef (from :: Int64)
  hSeek handle AbsoluteSeek (fromIntegral from)
  return $ do
    catch
      (do done <- readIORef bytesSoFar
          bytes <- B.hGet handle $ min (128 * 1024) (fromIntegral $ (to + 1) - done)
          if B.length bytes == 0
             then do
               if to + 1 == done
                  then return SourceEOF
                  else return SourceError
             else do modifyIORef bytesSoFar ((+) (fromIntegral $ B.length bytes))
                     return $ SourceData bytes)
      (const $ return SourceError)

-- | A source with no data (e.g. /dev/null)
nullSource :: Source
nullSource = return SourceEOF

-- | Processing a request involve running a number of actions in a StateT monad
--   where the state for that monad is this record. This contains both a
--   @Source@ and a @Handle@ element. Often something will fill in the @Handle@
--   and expect later processing to convert it to a @Source@. Somehow, you have
--   to end up with a @Source@, however.
data WebState =
  WebState { wsRequest :: Request  -- ^ the original request
             -- | the system mime types db, mapping file extensions
           , wsMimeTypes :: Map.Map B.ByteString MediaType
           , wsReply :: Reply   -- ^ the current reply
           , wsSource :: Maybe Source  -- ^ the current source
           , wsHandle :: Maybe Handle  -- ^ the current handle
             -- | an action to be performed before sending the reply
           , wsAction :: Maybe (IO ())
           }

-- | The processing monad
type WebMonad = StateT WebState IO

-- | Return the request
getRequest :: WebMonad Request
getRequest = get >>= return . wsRequest

-- | Return the current reply
getReply :: WebMonad Reply
getReply = get >>= return . wsReply

-- | Set the current reply to be a reply with the given status code, the
--   default message for that status code, an empty body and an empty set of
--   headers.
setReply :: Int -> WebMonad ()
setReply code = do
  s <- get
  put $ s { wsAction = Nothing, wsSource = Nothing, wsHandle = Nothing,
            wsReply = Reply 1 1 code (statusToMessage code) $
              emptyHeaders {httpContentLength = Just 0} }

-- | Set a header in the current reply. Because of the way records work, you use
--   this function like this:
--
--   > setHeader $ \h -> h { httpSomeHeader = Just value }
setHeader :: (Headers -> Headers) -> WebMonad ()
setHeader f = do
  reply <- getReply
  let h = replyHeaders reply
  s <- get
  put $ s { wsReply = reply { replyHeaders = f h } }


-- | This handles the If-*Matches and If-*Modified conditional headers. It takes
--   its information from the Last-Modified and ETag headers of the current
--   reply. Note that, for the purposes of ETag matching, a reply without
--   an ETag header is considered not to exist from the point of view of,
--   say, If-Matches: *.
handleConditionalRequest :: WebMonad ()
handleConditionalRequest = do
  req <- getRequest
  reply <- getReply
  let metag = httpETag $ replyHeaders reply
      mmtime = httpLastModified $ replyHeaders reply

  case httpIfMatch $ reqHeaders req of
       Just (Left ()) -> when (isNothing $ metag) $ setReply 412
       Just (Right tags) ->
         case metag of
              Nothing -> setReply 412
              Just (False, etag) -> when (not $ elem etag tags) $ setReply 412
              Just (True, _) -> setReply 412
       Nothing -> return ()

  case httpIfNoneMatch $ reqHeaders req of
       Just (Left ()) -> when (isJust $ metag) $ setReply 412
       Just (Right tags) ->
         case metag of
              Nothing -> return ()
              Just tag -> when (elem tag tags) $ setReply 412
       Nothing -> return ()

  case httpIfModifiedSince $ reqHeaders req of
       Just rmtime -> case mmtime of
                           Just mtime -> when (mtime <= rmtime) $ setReply 304
                           Nothing -> return ()
       Nothing -> return ()

  case httpIfUnmodifiedSince $ reqHeaders req of
       Just rmtime -> case mmtime of
                           Just mtime -> when (rmtime <= mtime) $ setReply 412
                           Nothing -> return ()
       Nothing -> return ()

-- | If the current state includes a Handle, this turns it into a Source
handleHandleToSource :: WebMonad ()
handleHandleToSource = do
  reply <- getReply
  mhandle <- liftM wsHandle get
  case mhandle of
       Just handle -> do
         source <- lift $ hSource (0, (fromJust $ httpContentLength $ replyHeaders reply) - 1) handle
         get >>= \s -> put $ s { wsHandle = Nothing, wsSource = Just source }
       Nothing -> return ()

-- | Given the length of the resource, filter any unsatisfiable ranges and
--   convert them all into RangeOf form.
satisfiableRanges :: Int64 -> [Range] -> [Range]
satisfiableRanges contentLength = catMaybes . map f where
  f (RangeFrom a)
    | a < contentLength = Just $ RangeOf a $ contentLength - 1
    | otherwise = Nothing
  f (RangeOf a b)
    | a < contentLength = Just $ RangeOf a $ min b contentLength
    | otherwise = Nothing
  f (RangeSuffix a)
    | a > 0 && contentLength > 0 = Just $ RangeOf (contentLength - a) (contentLength - 1)
    | otherwise = Nothing

-- | This handles Range requests and also translates from Handles to Sources.
--   If the WebMonad has a Handle at this point, then we can construct sources
--   from any subrange of the file. (We also assume that Content-Length is
--   correctly set.)
--
--   See RFC 2616, section 14.35
handleRangeRequests :: WebMonad ()
handleRangeRequests = do
  mhandle <- get >>= return . wsHandle
  req <- getRequest
  reply <- getReply
  case mhandle of
       Nothing -> return ()
       Just handle ->
         case httpContentLength $ replyHeaders reply of
              Nothing -> handleHandleToSource
              Just contentLength -> do
                setHeader (\h -> h { httpAcceptRanges = True })
                case httpRange $ reqHeaders req of
                     Nothing -> handleHandleToSource
                     Just ranges -> do
                       let ranges' = satisfiableRanges contentLength ranges
                       case ranges' of
                          [] -> do
                            setReply 416
                            setHeader (\h -> h { httpContentRange = Just (Nothing, Just contentLength) })
                          [RangeOf a b] -> do
                            s <- get
                            source <- lift $ hSource (a, b) handle
                            put $ s { wsReply = (wsReply s) { replyStatus = 206
                                                            , replyMessage = "Partial Content" }
                                    , wsHandle = Nothing
                                    , wsSource = Just source }
                            setHeader (\h -> h { httpContentRange = Just (Just (a, b), Just contentLength)})
                            setHeader (\h -> h { httpContentLength = Just ((b - a) + 1)})
                          -- We don't support multiple ranges
                          _ -> return ()

-- | At the moment, this just adds the header Server: Network.MiniHTTP
handleDecoration :: WebMonad ()
handleDecoration = setHeader (\h -> h { httpServer = Just "Network.MiniHTTP" })

-- | If a source is missing, install a null source. If this was a HEAD request,
--   remove the current source and set the content length to 0
handleFinal :: StateT WebState IO ()
handleFinal = do
  s <- get
  case wsSource s of
       Nothing -> do setHeader (\h -> h { httpContentLength = Just 0 })
                     s <- get
                     put $ s { wsSource = Just nullSource }
       _ -> return ()

  s <- get
  req <- getRequest
  if reqMethod req == HEAD
     then do
       setHeader $ \h -> h { httpContentLength = Just 0
                           , httpTransferEncoding = [] }
       put $ s { wsSource = Just nullSource }
     else return ()

-- | This is a very simple handler which deals with requests by returning the
--   requested file from the filesystem. It sets a Handle in the state and sets
--   the Content-Type, Content-Length and Last-Modified headers
handleFromFilesystem :: FilePath -- ^ the root of the filesystem to serve from
                     -> WebMonad ()
handleFromFilesystem docroot = do
  req <- getRequest
  when (not $ reqMethod req `elem` [GET, HEAD]) $
    fail "Can only handle GET and HEAD from the filesystem"

  -- stopping directory traversal needs to be done a little carefully.
  -- Hopefully this is all correct
  let path = reqUrl req
      -- First, make sure that there aren't any NULs in the path
      path' = B.takeWhile (/= 0) path
      path'' = map w2c $ B.unpack path'
      elems = splitDirectories path''
      -- Remove any '..'
      elems' = filter (\x -> x /= ".." && x /= "/") elems
      ext = takeExtension path''
      filepath = combine docroot $ joinPath elems'
  mimeTypes <- get >>= return . wsMimeTypes
  s <- get
  s' <- lift $ catch
    (do fd <- openFd filepath ReadOnly Nothing (OpenFileFlags False False True False False)
        stat <- getFdStatus fd
        let size = fromIntegral $ fileSize stat
            mtime = posixSecondsToUTCTime $ fromRational $ toRational $ modificationTime stat
        handle <- fdToHandle fd
        return $ s { wsHandle = Just handle
                   , wsSource = Nothing
                   , wsReply = Reply 1 1 200 "Ok" $ emptyHeaders
                      { httpLastModified = Just mtime
                      , httpContentLength = Just size
                      , httpContentType = Map.lookup (B.pack $ map c2w ext) mimeTypes } } )
    (const $ return $ s { wsReply = Reply 1 1 404 "Not found" $ emptyHeaders })
  put s'

pipeline :: Map.Map B.ByteString MediaType
         -> WebMonad ()
         -> Request
         -> IO (Reply, Source)
pipeline mimetypes action req = do
  let initState = (WebState req mimetypes (Reply 1 1 500 "Server error" emptyHeaders)
                   Nothing Nothing Nothing)
  (_, s) <- runStateT (do
    action
    handleFinal) initState

  return (wsReply s, fromJust $ wsSource s)

-- | Block until the given queue has less than the given number of bytes in it
--   then enqueue a new ByteString.
waitForLowWaterAndEnqueue :: Int  -- ^ the max number of bytes in the queue before we enqueue anything
                          -> C.Connection  -- ^ the connection to write to
                          -> B.ByteString  -- ^ the data to enqueue
                          -> IO ()
waitForLowWaterAndEnqueue lw conn bs = do
  atomically $ do
    q <- readTVar $ C.connoutq conn
    let size = foldl (\sz bs -> sz + B.length bs) 0 q
    if size > lw
       then retry
       else writeTVar (C.connoutq conn) $ bs Seq.<| q

-- | Read a single request from a socket
readRequest :: Socket
            -> B.ByteString  -- ^ data which has already been read from the socket
            -> IO (B.ByteString, Request)
readRequest socket initbs = do
  let f result = do
        case result of
             IG.Failed _ -> fail "Parse failed"
             IG.Partial cont -> recv socket 256 >>= (\x -> (f $ cont x))
             IG.Finished rest result -> return (rest, result)
  f $ IG.runGet parseRequest initbs

-- | Loop, reading and processing requests
readRequests :: (Request -> IO (Reply, IO SourceResult))
             -> C.Connection
             -> B.ByteString  -- ^ previously read data
             -> IO ()
readRequests handler conn initbs = do
  (rest, result) <- readRequest (C.connsocket conn) initbs
  (reply, source) <- handler result
  let lowWater = 32 * 1024
      stream = do
        next <- source
        case next of
             SourceEOF -> return True
             SourceError -> return False
             SourceData bs -> waitForLowWaterAndEnqueue lowWater conn bs >> stream
      streamChunked = do
        next <- source
        case next of
             SourceEOF -> do
               waitForLowWaterAndEnqueue lowWater conn "0\r\n\r\n"
               return True
             SourceError -> return False
             SourceData bs -> do
               waitForLowWaterAndEnqueue lowWater conn $ B.pack $ map c2w $
                 printf "%d\r\n\r\n" $ B.length bs
               waitForLowWaterAndEnqueue lowWater conn bs
               waitForLowWaterAndEnqueue lowWater conn "\r\n"
               streamChunked
  waitForLowWaterAndEnqueue (32 * 1024) conn $ B.concat $ BL.toChunks $ P.runPut $ putReply reply
  success <- if isNothing $ httpContentLength $ replyHeaders reply
                then streamChunked
                else stream
  if not success
     then C.close conn
     else readRequests handler conn rest

acceptLoop :: Socket -> (Request -> IO (Reply, Source)) -> IO ()
acceptLoop acceptingSocket handler = do
  (newsock, addr) <- accept acceptingSocket
  setSocketOption newsock NoDelay 1
  putStrLn $ "Connection from " ++ show addr

  c <- atomically $ C.new newsock (return ())
  C.forkThreads c $ readRequests handler c B.empty
  acceptLoop acceptingSocket handler

serve :: Int  -- ^ the port number to listen on
      -> WebMonad ()  -- ^ the processing action
      -> IO ()
serve portno action = do
  -- Switch these two lines to use IPv6 (which works for IPv4 clients too)
  --acceptingSocket <- socket AF_INET6 Stream 0
  --let sockaddr = SockAddrInet6 (PortNum $ htons $ fromIntegral portno) 0 iN6ADDR_ANY 0

  acceptingSocket <- socket AF_INET Stream 0
  let sockaddr = SockAddrInet (PortNum $ htons $ fromIntegral portno) iNADDR_ANY
  setSocketOption acceptingSocket ReuseAddr 1
  bindSocket acceptingSocket sockaddr
  listen acceptingSocket 1
  mimetypes <- parseMimeTypesTotal "/etc/mime.types" >>= return . maybe Map.empty id

  catch (acceptLoop acceptingSocket $ pipeline mimetypes action)
        (const $ sClose acceptingSocket)