{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Metro.Session
( SessionEnv (..)
, SessionEnv1 (..)
, newSessionEnv
, SessionT
, runSessionT
, runSessionT1
, send
, sessionState
, feed
, receive
, readerSize
, getSessionId
, getNodeId
, getSessionEnv1
, env
, ident
, isTimeout
, makeResponse
, makeResponse_
) where
import Control.Monad.Reader.Class (MonadReader, ask, asks)
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.Trans.Reader (ReaderT (..), runReaderT)
import Data.Int (Int64)
import Metro.Class (SendPacket, SetPacketId, Transport,
setPacketId)
import Metro.Conn (ConnEnv, ConnT, FromConn (..),
runConnT, statusTVar)
import qualified Metro.Conn as Conn (send)
import Metro.Utils (getEpochTime)
import UnliftIO
data SessionEnv u nid k rpkt = SessionEnv
{ sessionData :: TVar [Maybe rpkt]
, sessionNid :: nid
, sessionId :: k
, sessionUEnv :: u
, sessionTimer :: TVar Int64
, sessionTimeout :: Int64
}
data SessionEnv1 u nid k rpkt tp = SessionEnv1
{ sessionEnv :: SessionEnv u nid k rpkt
, connEnv :: ConnEnv tp
}
newSessionEnv :: MonadIO m => u -> nid -> k -> Int64 -> [Maybe rpkt] -> m (SessionEnv u nid k rpkt)
newSessionEnv sessionUEnv sessionNid sessionId sessionTimeout rpkts = do
sessionData <- newTVarIO rpkts
sessionTimer <- newTVarIO =<< getEpochTime
pure SessionEnv {..}
newtype SessionT u nid k rpkt tp m a = SessionT { unSessionT :: ReaderT (SessionEnv u nid k rpkt) (ConnT tp m) a }
deriving (Functor, Applicative, Monad, MonadIO, MonadReader (SessionEnv u nid k rpkt))
instance MonadTrans (SessionT u nid k rpkt tp) where
lift = SessionT . lift . lift
instance MonadUnliftIO m => MonadUnliftIO (SessionT u nid k rpkt tp m) where
withRunInIO inner = SessionT $
ReaderT $ \r ->
withRunInIO $ \run ->
inner (run . runSessionT r)
instance FromConn (SessionT u nid k rpkt) where
fromConn = SessionT . lift
runSessionT :: SessionEnv u nid k rpkt -> SessionT u nid k rpkt tp m a -> ConnT tp m a
runSessionT aEnv = flip runReaderT aEnv . unSessionT
runSessionT1 :: SessionEnv1 u nid k rpkt tp -> SessionT u nid k rpkt tp m a -> m a
runSessionT1 SessionEnv1 {..} = runConnT connEnv . runSessionT sessionEnv
sessionState :: MonadIO m => SessionT u nid k rpkt tp m Bool
sessionState = readTVarIO =<< fromConn statusTVar
send
:: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt)
=> spkt -> SessionT u nid k rpkt tp m ()
send rpkt = do
mid <- getSessionId
fromConn $ Conn.send $ setPacketId mid rpkt
feed :: (MonadIO m) => Maybe rpkt -> SessionT u nid k rpkt tp m ()
feed rpkt = do
reader <- asks sessionData
setTimer =<< getEpochTime
atomically . modifyTVar' reader $ \v -> v ++ [rpkt]
receive :: (MonadIO m, Transport tp) => SessionT u nid k rpkt tp m (Maybe rpkt)
receive = do
reader <- asks sessionData
st <- fromConn statusTVar
atomically $ do
v <- readTVar reader
if null v then do
s <- readTVar st
if s then retrySTM
else pure Nothing
else do
writeTVar reader $! tail v
pure $ head v
readerSize :: MonadIO m => SessionT u nid k rpkt tp m Int
readerSize = fmap length $ readTVarIO =<< asks sessionData
getSessionId :: Monad m => SessionT u nid k rpkt tp m k
getSessionId = asks sessionId
getNodeId :: Monad m => SessionT u nid k rpkt tp m nid
getNodeId = asks sessionNid
env :: Monad m => SessionT u nid k rpkt tp m u
env = asks sessionUEnv
makeResponse
:: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt)
=> (rpkt -> m (Maybe spkt)) -> SessionT u nid k rpkt tp m ()
makeResponse f = mapM_ doSend =<< receive
where doSend spkt = mapM_ send =<< (lift . f) spkt
makeResponse_
:: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt)
=> (rpkt -> Maybe spkt) -> SessionT u nid k rpkt tp m ()
makeResponse_ f = makeResponse (pure . f)
getTimer :: MonadIO m => SessionT u nid k rpkt tp m Int64
getTimer = readTVarIO =<< asks sessionTimer
setTimer :: MonadIO m => Int64 -> SessionT u nid k rpkt tp m ()
setTimer t = do
v <- asks sessionTimer
atomically $ writeTVar v t
isTimeout :: MonadIO m => SessionT u nid k rpkt tp m Bool
isTimeout = do
t <- getTimer
tout <- asks sessionTimeout
now <- getEpochTime
if tout > 0 then return $ (t + tout) < now
else return False
getSessionEnv1 :: (Monad m, Transport tp) => SessionT u nid k rpkt tp m (SessionEnv1 u nid k rpkt tp)
getSessionEnv1 = do
connEnv <- fromConn ask
sessionEnv <- ask
pure SessionEnv1 {..}
ident :: SessionEnv1 u nid k rpkt tp -> (nid, k)
ident SessionEnv1 {..} = (sessionNid sessionEnv, sessionId sessionEnv)