{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}

module Metro.Conn
  ( ConnEnv
  , ConnT
  , FromConn (..)
  , runConnT
  , initConnEnv
  , receive
  , send
  , close
  , statusTVar
  ) where

import           Control.Monad.Reader.Class (MonadReader (ask))
import           Control.Monad.Trans.Class  (MonadTrans, lift)
import           Control.Monad.Trans.Reader (ReaderT (..), runReaderT)
import           Data.ByteString            (ByteString)
import qualified Data.ByteString            as B (empty)
import           Metro.Class
import qualified Metro.Lock                 as L (Lock, new, with)
import           Metro.Utils                (recvEnough)
import           UnliftIO

data ConnEnv tp = ConnEnv
    { ConnEnv tp -> tp
transport :: tp
    , ConnEnv tp -> Lock
readLock  :: L.Lock
    , ConnEnv tp -> Lock
writeLock :: L.Lock
    , ConnEnv tp -> TVar ByteString
buffer    :: TVar ByteString
    , ConnEnv tp -> TVar Bool
status    :: TVar Bool
    }

newtype ConnT tp m a = ConnT { ConnT tp m a -> ReaderT (ConnEnv tp) m a
unConnT :: ReaderT (ConnEnv tp) m a }
  deriving
    ( a -> ConnT tp m b -> ConnT tp m a
(a -> b) -> ConnT tp m a -> ConnT tp m b
(forall a b. (a -> b) -> ConnT tp m a -> ConnT tp m b)
-> (forall a b. a -> ConnT tp m b -> ConnT tp m a)
-> Functor (ConnT tp m)
forall a b. a -> ConnT tp m b -> ConnT tp m a
forall a b. (a -> b) -> ConnT tp m a -> ConnT tp m b
forall tp (m :: * -> *) a b.
Functor m =>
a -> ConnT tp m b -> ConnT tp m a
forall tp (m :: * -> *) a b.
Functor m =>
(a -> b) -> ConnT tp m a -> ConnT tp m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ConnT tp m b -> ConnT tp m a
$c<$ :: forall tp (m :: * -> *) a b.
Functor m =>
a -> ConnT tp m b -> ConnT tp m a
fmap :: (a -> b) -> ConnT tp m a -> ConnT tp m b
$cfmap :: forall tp (m :: * -> *) a b.
Functor m =>
(a -> b) -> ConnT tp m a -> ConnT tp m b
Functor
    , Functor (ConnT tp m)
a -> ConnT tp m a
Functor (ConnT tp m) =>
(forall a. a -> ConnT tp m a)
-> (forall a b.
    ConnT tp m (a -> b) -> ConnT tp m a -> ConnT tp m b)
-> (forall a b c.
    (a -> b -> c) -> ConnT tp m a -> ConnT tp m b -> ConnT tp m c)
-> (forall a b. ConnT tp m a -> ConnT tp m b -> ConnT tp m b)
-> (forall a b. ConnT tp m a -> ConnT tp m b -> ConnT tp m a)
-> Applicative (ConnT tp m)
ConnT tp m a -> ConnT tp m b -> ConnT tp m b
ConnT tp m a -> ConnT tp m b -> ConnT tp m a
ConnT tp m (a -> b) -> ConnT tp m a -> ConnT tp m b
(a -> b -> c) -> ConnT tp m a -> ConnT tp m b -> ConnT tp m c
forall a. a -> ConnT tp m a
forall a b. ConnT tp m a -> ConnT tp m b -> ConnT tp m a
forall a b. ConnT tp m a -> ConnT tp m b -> ConnT tp m b
forall a b. ConnT tp m (a -> b) -> ConnT tp m a -> ConnT tp m b
forall a b c.
(a -> b -> c) -> ConnT tp m a -> ConnT tp m b -> ConnT tp m c
forall tp (m :: * -> *). Applicative m => Functor (ConnT tp m)
forall tp (m :: * -> *) a. Applicative m => a -> ConnT tp m a
forall tp (m :: * -> *) a b.
Applicative m =>
ConnT tp m a -> ConnT tp m b -> ConnT tp m a
forall tp (m :: * -> *) a b.
Applicative m =>
ConnT tp m a -> ConnT tp m b -> ConnT tp m b
forall tp (m :: * -> *) a b.
Applicative m =>
ConnT tp m (a -> b) -> ConnT tp m a -> ConnT tp m b
forall tp (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> ConnT tp m a -> ConnT tp m b -> ConnT tp m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ConnT tp m a -> ConnT tp m b -> ConnT tp m a
$c<* :: forall tp (m :: * -> *) a b.
Applicative m =>
ConnT tp m a -> ConnT tp m b -> ConnT tp m a
*> :: ConnT tp m a -> ConnT tp m b -> ConnT tp m b
$c*> :: forall tp (m :: * -> *) a b.
Applicative m =>
ConnT tp m a -> ConnT tp m b -> ConnT tp m b
liftA2 :: (a -> b -> c) -> ConnT tp m a -> ConnT tp m b -> ConnT tp m c
$cliftA2 :: forall tp (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> ConnT tp m a -> ConnT tp m b -> ConnT tp m c
<*> :: ConnT tp m (a -> b) -> ConnT tp m a -> ConnT tp m b
$c<*> :: forall tp (m :: * -> *) a b.
Applicative m =>
ConnT tp m (a -> b) -> ConnT tp m a -> ConnT tp m b
pure :: a -> ConnT tp m a
$cpure :: forall tp (m :: * -> *) a. Applicative m => a -> ConnT tp m a
$cp1Applicative :: forall tp (m :: * -> *). Applicative m => Functor (ConnT tp m)
Applicative
    , Applicative (ConnT tp m)
a -> ConnT tp m a
Applicative (ConnT tp m) =>
(forall a b. ConnT tp m a -> (a -> ConnT tp m b) -> ConnT tp m b)
-> (forall a b. ConnT tp m a -> ConnT tp m b -> ConnT tp m b)
-> (forall a. a -> ConnT tp m a)
-> Monad (ConnT tp m)
ConnT tp m a -> (a -> ConnT tp m b) -> ConnT tp m b
ConnT tp m a -> ConnT tp m b -> ConnT tp m b
forall a. a -> ConnT tp m a
forall a b. ConnT tp m a -> ConnT tp m b -> ConnT tp m b
forall a b. ConnT tp m a -> (a -> ConnT tp m b) -> ConnT tp m b
forall tp (m :: * -> *). Monad m => Applicative (ConnT tp m)
forall tp (m :: * -> *) a. Monad m => a -> ConnT tp m a
forall tp (m :: * -> *) a b.
Monad m =>
ConnT tp m a -> ConnT tp m b -> ConnT tp m b
forall tp (m :: * -> *) a b.
Monad m =>
ConnT tp m a -> (a -> ConnT tp m b) -> ConnT tp m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ConnT tp m a
$creturn :: forall tp (m :: * -> *) a. Monad m => a -> ConnT tp m a
>> :: ConnT tp m a -> ConnT tp m b -> ConnT tp m b
$c>> :: forall tp (m :: * -> *) a b.
Monad m =>
ConnT tp m a -> ConnT tp m b -> ConnT tp m b
>>= :: ConnT tp m a -> (a -> ConnT tp m b) -> ConnT tp m b
$c>>= :: forall tp (m :: * -> *) a b.
Monad m =>
ConnT tp m a -> (a -> ConnT tp m b) -> ConnT tp m b
$cp1Monad :: forall tp (m :: * -> *). Monad m => Applicative (ConnT tp m)
Monad
    , m a -> ConnT tp m a
(forall (m :: * -> *) a. Monad m => m a -> ConnT tp m a)
-> MonadTrans (ConnT tp)
forall tp (m :: * -> *) a. Monad m => m a -> ConnT tp m a
forall (m :: * -> *) a. Monad m => m a -> ConnT tp m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> ConnT tp m a
$clift :: forall tp (m :: * -> *) a. Monad m => m a -> ConnT tp m a
MonadTrans
    , Monad (ConnT tp m)
Monad (ConnT tp m) =>
(forall a. IO a -> ConnT tp m a) -> MonadIO (ConnT tp m)
IO a -> ConnT tp m a
forall a. IO a -> ConnT tp m a
forall tp (m :: * -> *). MonadIO m => Monad (ConnT tp m)
forall tp (m :: * -> *) a. MonadIO m => IO a -> ConnT tp m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
liftIO :: IO a -> ConnT tp m a
$cliftIO :: forall tp (m :: * -> *) a. MonadIO m => IO a -> ConnT tp m a
$cp1MonadIO :: forall tp (m :: * -> *). MonadIO m => Monad (ConnT tp m)
MonadIO
    , MonadReader (ConnEnv tp)
    )

instance MonadUnliftIO m => MonadUnliftIO (ConnT tp m) where
  withRunInIO :: ((forall a. ConnT tp m a -> IO a) -> IO b) -> ConnT tp m b
withRunInIO inner :: (forall a. ConnT tp m a -> IO a) -> IO b
inner = ReaderT (ConnEnv tp) m b -> ConnT tp m b
forall tp (m :: * -> *) a. ReaderT (ConnEnv tp) m a -> ConnT tp m a
ConnT (ReaderT (ConnEnv tp) m b -> ConnT tp m b)
-> ReaderT (ConnEnv tp) m b -> ConnT tp m b
forall a b. (a -> b) -> a -> b
$
    (ConnEnv tp -> m b) -> ReaderT (ConnEnv tp) m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((ConnEnv tp -> m b) -> ReaderT (ConnEnv tp) m b)
-> (ConnEnv tp -> m b) -> ReaderT (ConnEnv tp) m b
forall a b. (a -> b) -> a -> b
$ \r :: ConnEnv tp
r ->
      ((forall a. m a -> IO a) -> IO b) -> m b
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO b) -> m b)
-> ((forall a. m a -> IO a) -> IO b) -> m b
forall a b. (a -> b) -> a -> b
$ \run :: forall a. m a -> IO a
run ->
        (forall a. ConnT tp m a -> IO a) -> IO b
inner (m a -> IO a
forall a. m a -> IO a
run (m a -> IO a) -> (ConnT tp m a -> m a) -> ConnT tp m a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnEnv tp -> ConnT tp m a -> m a
forall tp (m :: * -> *) a. ConnEnv tp -> ConnT tp m a -> m a
runConnT ConnEnv tp
r)

class FromConn m where
  fromConn :: Monad n => ConnT tp n a -> m tp n a

instance FromConn ConnT where
  fromConn :: ConnT tp n a -> ConnT tp n a
fromConn = ConnT tp n a -> ConnT tp n a
forall a. a -> a
id

runConnT :: ConnEnv tp -> ConnT tp m a -> m a
runConnT :: ConnEnv tp -> ConnT tp m a -> m a
runConnT connEnv :: ConnEnv tp
connEnv = (ReaderT (ConnEnv tp) m a -> ConnEnv tp -> m a)
-> ConnEnv tp -> ReaderT (ConnEnv tp) m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT (ConnEnv tp) m a -> ConnEnv tp -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ConnEnv tp
connEnv (ReaderT (ConnEnv tp) m a -> m a)
-> (ConnT tp m a -> ReaderT (ConnEnv tp) m a)
-> ConnT tp m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnT tp m a -> ReaderT (ConnEnv tp) m a
forall tp (m :: * -> *) a. ConnT tp m a -> ReaderT (ConnEnv tp) m a
unConnT

initConnEnv :: (MonadIO m, Transport tp) => TransportConfig tp -> m (ConnEnv tp)
initConnEnv :: TransportConfig tp -> m (ConnEnv tp)
initConnEnv config :: TransportConfig tp
config = do
  Lock
readLock <- m Lock
forall (m :: * -> *). MonadIO m => m Lock
L.new
  Lock
writeLock <- m Lock
forall (m :: * -> *). MonadIO m => m Lock
L.new
  TVar Bool
status <- Bool -> m (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
True
  TVar ByteString
buffer <- ByteString -> m (TVar ByteString)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO ByteString
B.empty
  tp
transport <- IO tp -> m tp
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO tp -> m tp) -> IO tp -> m tp
forall a b. (a -> b) -> a -> b
$ TransportConfig tp -> IO tp
forall transport.
Transport transport =>
TransportConfig transport -> IO transport
newTransport TransportConfig tp
config
  ConnEnv tp -> m (ConnEnv tp)
forall (m :: * -> *) a. Monad m => a -> m a
return ConnEnv :: forall tp.
tp -> Lock -> Lock -> TVar ByteString -> TVar Bool -> ConnEnv tp
ConnEnv{..}

receive :: (MonadUnliftIO m, Transport tp, RecvPacket pkt) => ConnT tp m pkt
receive :: ConnT tp m pkt
receive = do
  ConnEnv{..} <- ConnT tp m (ConnEnv tp)
forall r (m :: * -> *). MonadReader r m => m r
ask
  Lock -> ConnT tp m pkt -> ConnT tp m pkt
forall (m :: * -> *) a. MonadUnliftIO m => Lock -> m a -> m a
L.with Lock
readLock (ConnT tp m pkt -> ConnT tp m pkt)
-> ConnT tp m pkt -> ConnT tp m pkt
forall a b. (a -> b) -> a -> b
$ m pkt -> ConnT tp m pkt
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m pkt -> ConnT tp m pkt) -> m pkt -> ConnT tp m pkt
forall a b. (a -> b) -> a -> b
$ (Int -> m ByteString) -> m pkt
forall rpkt (m :: * -> *).
(RecvPacket rpkt, MonadIO m) =>
(Int -> m ByteString) -> m rpkt
recvPacket (TVar ByteString -> tp -> Int -> m ByteString
forall (m :: * -> *) tp.
(MonadIO m, Transport tp) =>
TVar ByteString -> tp -> Int -> m ByteString
recvEnough TVar ByteString
buffer tp
transport)

send :: (MonadUnliftIO m, Transport tp, SendPacket pkt) => pkt -> ConnT tp m ()
send :: pkt -> ConnT tp m ()
send pkt :: pkt
pkt = do
  ConnEnv{..} <- ConnT tp m (ConnEnv tp)
forall r (m :: * -> *). MonadReader r m => m r
ask
  Lock -> ConnT tp m () -> ConnT tp m ()
forall (m :: * -> *) a. MonadUnliftIO m => Lock -> m a -> m a
L.with Lock
writeLock (ConnT tp m () -> ConnT tp m ()) -> ConnT tp m () -> ConnT tp m ()
forall a b. (a -> b) -> a -> b
$ m () -> ConnT tp m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ConnT tp m ()) -> m () -> ConnT tp m ()
forall a b. (a -> b) -> a -> b
$ pkt -> (ByteString -> m ()) -> m ()
forall spkt (m :: * -> *).
(SendPacket spkt, MonadIO m) =>
spkt -> (ByteString -> m ()) -> m ()
sendPacket pkt
pkt (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (ByteString -> IO ()) -> ByteString -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. tp -> ByteString -> IO ()
forall transport.
Transport transport =>
transport -> ByteString -> IO ()
sendData tp
transport)

close :: (MonadIO m, Transport tp) => ConnT tp m ()
close :: ConnT tp m ()
close = do
  ConnEnv{..} <- ConnT tp m (ConnEnv tp)
forall r (m :: * -> *). MonadReader r m => m r
ask
  STM () -> ConnT tp m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> ConnT tp m ()) -> STM () -> ConnT tp m ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
status Bool
False
  IO () -> ConnT tp m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConnT tp m ()) -> IO () -> ConnT tp m ()
forall a b. (a -> b) -> a -> b
$ tp -> IO ()
forall transport. Transport transport => transport -> IO ()
closeTransport tp
transport

statusTVar :: Monad m => ConnT tp m (TVar Bool)
statusTVar :: ConnT tp m (TVar Bool)
statusTVar = ConnEnv tp -> TVar Bool
forall tp. ConnEnv tp -> TVar Bool
status (ConnEnv tp -> TVar Bool)
-> ConnT tp m (ConnEnv tp) -> ConnT tp m (TVar Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConnT tp m (ConnEnv tp)
forall r (m :: * -> *). MonadReader r m => m r
ask