{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}

module Metro.Server
  ( startServer
  , startServer_
  , ServerEnv
  , ServerT
  , Servable (..)
  , getNodeEnvList
  , getServ
  , serverEnv
  , initServerEnv

  -- server env action
  , setServerName
  , setNodeMode
  , setSessionMode
  , setDefaultSessionTimeout
  , setKeepalive

  , setOnNodeLeave

  , runServerT
  , stopServerT
  , handleConn
  ) where

import           Control.Monad              (forM_, forever, mzero, unless,
                                             void, when)
import           Control.Monad.Reader.Class (MonadReader (ask), asks)
import           Control.Monad.Trans.Class  (MonadTrans, lift)
import           Control.Monad.Trans.Maybe  (runMaybeT)
import           Control.Monad.Trans.Reader (ReaderT (..), runReaderT)
import           Data.Either                (isLeft)
import           Data.Hashable
import           Data.Int                   (Int64)
import           Metro.Class                (GetPacketId, RecvPacket,
                                             Servable (..), Transport,
                                             TransportConfig)
import           Metro.Conn                 hiding (close)
import           Metro.IOHashMap            (IOHashMap, newIOHashMap)
import qualified Metro.IOHashMap            as HM (delete, elems, insertSTM,
                                                   lookupSTM)
import           Metro.Node                 (NodeEnv1, NodeMode (..),
                                             SessionMode (..), getNodeId,
                                             getTimer, initEnv1, runNodeT1,
                                             startNodeT_, stopNodeT)
import qualified Metro.Node                 as Node
import           Metro.Session              (SessionT)
import           Metro.Utils                (getEpochTime)
import           System.Log.Logger          (errorM, infoM)
import           UnliftIO
import           UnliftIO.Concurrent        (threadDelay)

data ServerEnv serv u nid k rpkt tp = ServerEnv
    { serveServ    :: serv
    , serveState   :: TVar Bool
    , nodeEnvList  :: IOHashMap nid (NodeEnv1 u nid k rpkt tp)
    , prepare      :: SID serv -> ConnEnv tp -> IO (Maybe (nid, u))
    , gen          :: IO k
    , keepalive    :: Int64
    , defSessTout  :: Int64
    , nodeMode     :: NodeMode
    , sessionMode  :: SessionMode
    , serveName    :: String
    , onNodeLeave  :: TVar (Maybe (nid -> u -> IO ()))
    , mapTransport :: TransportConfig (STP serv) -> TransportConfig tp
    }


newtype ServerT serv u nid k rpkt tp m a = ServerT {unServerT :: ReaderT (ServerEnv serv u nid k rpkt tp) m a}
  deriving
    ( Functor
    , Applicative
    , Monad
    , MonadIO
    , MonadReader (ServerEnv serv u nid k rpkt tp)
    )

instance MonadTrans (ServerT serv u nid k rpkt tp) where
  lift = ServerT . lift

instance MonadUnliftIO m => MonadUnliftIO (ServerT serv u nid k rpkt tp m) where
  withRunInIO inner = ServerT $
    ReaderT $ \r ->
      withRunInIO $ \run ->
        inner (run . runServerT r)

runServerT :: ServerEnv serv u nid k rpkt tp -> ServerT serv u nid k rpkt tp m a -> m a
runServerT sEnv = flip runReaderT sEnv . unServerT

initServerEnv
  :: (MonadIO m, Servable serv)
  => ServerConfig serv -> IO k
  -> (TransportConfig (STP serv) -> TransportConfig tp)
  -> (SID serv -> ConnEnv tp -> IO (Maybe (nid, u)))
  -> m (ServerEnv serv u nid k rpkt tp)
initServerEnv sc gen mapTransport prepare = do
  serveServ   <- newServer sc
  serveState  <- newTVarIO True
  nodeEnvList <- newIOHashMap
  onNodeLeave <- newTVarIO Nothing
  pure ServerEnv
    { nodeMode    = Multi
    , sessionMode = SingleAction
    , serveName   = "Metro"
    , keepalive   = 0
    , defSessTout = 300
    , ..
    }

setNodeMode
  :: NodeMode -> ServerEnv serv u nid k rpkt tp -> ServerEnv serv u nid k rpkt tp
setNodeMode mode sEnv = sEnv {nodeMode = mode}

setSessionMode
  :: SessionMode -> ServerEnv serv u nid k rpkt tp -> ServerEnv serv u nid k rpkt tp
setSessionMode mode sEnv = sEnv {sessionMode = mode}

setServerName
  :: String -> ServerEnv serv u nid k rpkt tp -> ServerEnv serv u nid k rpkt tp
setServerName n sEnv = sEnv {serveName = n}

setKeepalive
  :: Int64 -> ServerEnv serv u nid k rpkt tp -> ServerEnv serv u nid k rpkt tp
setKeepalive k sEnv = sEnv {keepalive = k}

setDefaultSessionTimeout
  :: Int64 -> ServerEnv serv u nid k rpkt tp -> ServerEnv serv u nid k rpkt tp
setDefaultSessionTimeout t sEnv = sEnv {defSessTout = t}

setOnNodeLeave :: MonadIO m => ServerEnv serv u nid k rpkt tp -> (nid -> u -> IO ()) -> m ()
setOnNodeLeave sEnv =
  atomically . writeTVar (onNodeLeave sEnv) . Just

serveForever
  :: (MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid, Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt, Servable serv)
  => (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> ServerT serv u nid k rpkt tp m ()
serveForever preprocess sess = do
  name <- asks serveName
  liftIO $ infoM "Metro.Server" $ name ++ "Server started"
  state <- asks serveState
  void . runMaybeT . forever $ do
    e <- lift $ tryServeOnce preprocess sess
    when (isLeft e) mzero
    alive <- readTVarIO state
    unless alive mzero
  liftIO $ infoM "Metro.Server" $ name ++ "Server closed"

tryServeOnce
  :: (MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid, Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt, Servable serv)
  => (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> ServerT serv u nid k rpkt tp m (Either SomeException ())
tryServeOnce preprocess sess = tryAny (serveOnce preprocess sess)

serveOnce
  :: ( MonadUnliftIO m
     , Transport tp
     , Show nid, Eq nid, Hashable nid
     , Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt
     , Servable serv)
  => (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> ServerT serv u nid k rpkt tp m ()
serveOnce preprocess sess = do
  ServerEnv {..} <- ask
  servOnce serveServ $ doServeOnce preprocess sess

doServeOnce
  :: ( MonadUnliftIO m
     , Transport tp
     , Show nid, Eq nid, Hashable nid
     , Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt
     , Servable serv)
  => (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> Maybe (SID serv, TransportConfig (STP serv))
  -> ServerT serv u nid k rpkt tp m ()
doServeOnce _ _ Nothing = return ()
doServeOnce preprocess sess (Just (servID, stp)) = do
  ServerEnv {..} <- ask
  connEnv <- initConnEnv $ mapTransport stp
  mnid <- liftIO $ prepare servID connEnv
  forM_ mnid $ \(nid, uEnv) -> do
    (_, io) <- handleConn "Client" servID connEnv nid uEnv preprocess sess
    r <- waitCatch io
    case r of
      Left e  -> liftIO $ errorM "Metro.Server" $ "Handle connection error " ++ show e
      Right _ -> return ()

handleConn
  :: (MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid, Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt, Servable serv)
  => String
  -> SID serv
  -> ConnEnv tp
  -> nid
  -> u
  -> (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> ServerT serv u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp, Async ())
handleConn n servID connEnv nid uEnv preprocess sess = do
    ServerEnv {..} <- ask

    liftIO $ infoM "Metro.Server" (serveName ++ n ++ ": " ++ show nid ++ " connected")
    env0 <- initEnv1
      (Node.setNodeMode nodeMode
      . Node.setSessionMode sessionMode
      . Node.setDefaultSessionTimeout defSessTout) connEnv uEnv nid gen

    env1 <- atomically $ do
      v <- HM.lookupSTM nodeEnvList nid
      HM.insertSTM nodeEnvList nid env0
      pure v

    mapM_ (`runNodeT1` stopNodeT) env1

    io <- async $ do
      onConnEnter serveServ servID
      lift . runNodeT1 env0 $ startNodeT_ preprocess sess
      onConnLeave serveServ servID
      nodeLeave <- readTVarIO onNodeLeave
      case nodeLeave of
        Nothing -> pure ()
        Just f  -> liftIO $ f nid uEnv
      liftIO $ infoM "Metro.Server" (serveName ++ n ++ ": " ++ show nid ++ " disconnected")

    return (env0, io)

startServer
  :: (MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid, Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt, Servable serv)
  => ServerEnv serv u nid k rpkt tp
  -> SessionT u nid k rpkt tp m ()
  -> m ()
startServer sEnv = startServer_ sEnv (const $ return True)

startServer_
  :: (MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid, Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt, Servable serv)
  => ServerEnv serv u nid k rpkt tp
  -> (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> m ()
startServer_ sEnv preprocess sess = do
  when (keepalive sEnv > 0) $ runCheckNodeState (keepalive sEnv) (nodeEnvList sEnv)
  runServerT sEnv $ serveForever preprocess sess
  liftIO $ servClose $ serveServ sEnv

stopServerT :: (MonadIO m, Servable serv) => ServerT serv u nid k rpkt tp m ()
stopServerT = do
  ServerEnv {..} <- ask
  atomically $ writeTVar serveState False
  liftIO $ servClose serveServ

runCheckNodeState
  :: (MonadUnliftIO m, Eq nid, Hashable nid, Transport tp)
  => Int64 -> IOHashMap nid (NodeEnv1 u nid k rpkt tp) -> m ()
runCheckNodeState alive envList = void . async . forever $ do
  threadDelay $ fromIntegral alive * 1000 * 1000
  mapM_ (checkAlive envList) =<< HM.elems envList

  where checkAlive
          :: (MonadUnliftIO m, Eq nid, Hashable nid, Transport tp)
          => IOHashMap nid (NodeEnv1 u nid k rpkt tp)
          -> NodeEnv1 u nid k rpkt tp -> m ()
        checkAlive ref env1 = runNodeT1 env1 $ do
              expiredAt <- (alive +) <$> getTimer
              now <- getEpochTime
              when (now > expiredAt) $ do
                nid <- getNodeId
                stopNodeT
                HM.delete ref nid

serverEnv :: Monad m => ServerT serv u nid k rpkt tp m (ServerEnv serv u nid k rpkt tp)
serverEnv = ask

getNodeEnvList :: ServerEnv serv u nid k rpkt tp -> IOHashMap nid (NodeEnv1 u nid k rpkt tp)
getNodeEnvList = nodeEnvList

getServ :: ServerEnv serv u nid k rpkt tp -> serv
getServ = serveServ