{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}

module Metro.Session
  ( SessionEnv (..)
  , SessionEnv1 (..)
  , newSessionEnv
  , SessionT
  , runSessionT
  , runSessionT1
  , send
  , sessionState
  , feed
  , receive
  , readerSize
  , getSessionId
  , getNodeId
  , getSessionEnv1
  , env
  , ident

  , isTimeout

  , makeResponse
  , makeResponse_
  ) where

import           Control.Monad.Reader.Class (MonadReader, ask, asks)
import           Control.Monad.Trans.Class  (MonadTrans (..))
import           Control.Monad.Trans.Reader (ReaderT (..), runReaderT)
import           Data.Int                   (Int64)
import           Metro.Class                (SendPacket, SetPacketId, Transport,
                                             setPacketId)
import           Metro.Conn                 (ConnEnv, ConnT, FromConn (..),
                                             runConnT, statusTVar)
import qualified Metro.Conn                 as Conn (send)
import           Metro.Utils                (getEpochTime)
import           UnliftIO

data SessionEnv u nid k rpkt = SessionEnv
    { sessionData    :: TVar [Maybe rpkt]
    , sessionNid     :: nid
    , sessionId      :: k
    , sessionUEnv    :: u
    , sessionTimer   :: TVar Int64
    , sessionTimeout :: Int64
    }

data SessionEnv1 u nid k rpkt tp = SessionEnv1
    { sessionEnv :: SessionEnv u nid k rpkt
    , connEnv    :: ConnEnv tp
    }

newSessionEnv :: MonadIO m => u -> nid -> k -> Int64 -> [Maybe rpkt] -> m (SessionEnv u nid k rpkt)
newSessionEnv sessionUEnv sessionNid sessionId sessionTimeout rpkts = do
  sessionData <- newTVarIO rpkts
  sessionTimer <- newTVarIO =<< getEpochTime
  pure SessionEnv {..}

newtype SessionT u nid k rpkt tp m a = SessionT { unSessionT :: ReaderT (SessionEnv u nid k rpkt) (ConnT tp m) a }
  deriving (Functor, Applicative, Monad, MonadIO, MonadReader (SessionEnv u nid k rpkt))

instance MonadTrans (SessionT u nid k rpkt tp) where
  lift = SessionT . lift . lift

instance MonadUnliftIO m => MonadUnliftIO (SessionT u nid k rpkt tp m) where
  withRunInIO inner = SessionT $
    ReaderT $ \r ->
      withRunInIO $ \run ->
        inner (run . runSessionT r)

instance FromConn (SessionT u nid k rpkt) where
  fromConn = SessionT . lift

runSessionT :: SessionEnv u nid k rpkt -> SessionT u nid k rpkt tp m a -> ConnT tp m a
runSessionT aEnv = flip runReaderT aEnv . unSessionT

runSessionT1 :: SessionEnv1 u nid k rpkt tp -> SessionT u nid k rpkt tp m a -> m a
runSessionT1 SessionEnv1 {..} = runConnT connEnv . runSessionT sessionEnv

sessionState :: MonadIO m => SessionT u nid k rpkt tp m Bool
sessionState = readTVarIO =<< fromConn statusTVar

send
  :: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt)
  => spkt -> SessionT u nid k rpkt tp m ()
send rpkt = do
  mid <- getSessionId
  fromConn $ Conn.send $ setPacketId mid rpkt

feed :: (MonadIO m) => Maybe rpkt -> SessionT u nid k rpkt tp m ()
feed rpkt = do
  reader <- asks sessionData
  setTimer =<< getEpochTime
  atomically . modifyTVar' reader $ \v -> v ++ [rpkt]

receive :: (MonadIO m, Transport tp) => SessionT u nid k rpkt tp m (Maybe rpkt)
receive = do
  reader <- asks sessionData
  st <- fromConn statusTVar
  atomically $ do
    v <- readTVar reader
    if null v then do
      s <- readTVar st
      if s then retrySTM
           else pure Nothing
    else do
      writeTVar reader $! tail v
      pure $ head v

readerSize :: MonadIO m => SessionT u nid k rpkt tp m Int
readerSize = fmap length $ readTVarIO =<< asks sessionData

getSessionId :: Monad m => SessionT u nid k rpkt tp m k
getSessionId = asks sessionId

getNodeId :: Monad m => SessionT u nid k rpkt tp m nid
getNodeId = asks sessionNid

env :: Monad m => SessionT u nid k rpkt tp m u
env = asks sessionUEnv

-- makeResponse if Nothing ignore
makeResponse
  :: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt)
  => (rpkt -> m (Maybe spkt)) -> SessionT u nid k rpkt tp m ()
makeResponse f = mapM_ doSend =<< receive

  where doSend spkt = mapM_ send =<< (lift . f) spkt

makeResponse_
  :: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt)
  => (rpkt -> Maybe spkt) -> SessionT u nid k rpkt tp m ()
makeResponse_ f = makeResponse (pure . f)

getTimer :: MonadIO m => SessionT u nid k rpkt tp m Int64
getTimer = readTVarIO =<< asks sessionTimer

setTimer :: MonadIO m => Int64 -> SessionT u nid k rpkt tp m ()
setTimer t = do
  v <- asks sessionTimer
  atomically $ writeTVar v t

isTimeout :: MonadIO m => SessionT u nid k rpkt tp m Bool
isTimeout = do
  t <- getTimer
  tout <- asks sessionTimeout
  now <- getEpochTime
  if tout > 0 then return $ (t + tout) < now
              else return False

getSessionEnv1 :: (Monad m, Transport tp) => SessionT u nid k rpkt tp m (SessionEnv1 u nid k rpkt tp)
getSessionEnv1 = do
  connEnv <- fromConn ask
  sessionEnv <- ask
  pure SessionEnv1 {..}

ident :: SessionEnv1 u nid k rpkt tp -> (nid, k)
ident SessionEnv1 {..} = (sessionNid sessionEnv, sessionId sessionEnv)