{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies    #-}

module Metro.TP.BS
  ( BSTransport
  , BSHandle
  , newBSHandle
  , newBSHandle_
  , feed
  , closeBSHandle
  , bsTransportConfig

  , makePipe
  ) where

import           Control.Monad   (when)
import           Data.ByteString (ByteString, empty)
import qualified Data.ByteString as B (drop, length, take)
import           Metro.Class     (Transport (..))
import           UnliftIO

data BSHandle = BSHandle Int (TVar Bool) (TVar ByteString)

newBSHandle :: MonadIO m => ByteString -> m BSHandle
newBSHandle :: ByteString -> m BSHandle
newBSHandle = Int -> ByteString -> m BSHandle
forall (m :: * -> *). MonadIO m => Int -> ByteString -> m BSHandle
newBSHandle_ Int
41943040 -- 40M

newBSHandle_ :: MonadIO m => Int -> ByteString -> m BSHandle
newBSHandle_ :: Int -> ByteString -> m BSHandle
newBSHandle_ Int
size ByteString
bs = do
  TVar Bool
state <- Bool -> m (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
False
  Int -> TVar Bool -> TVar ByteString -> BSHandle
BSHandle Int
size TVar Bool
state (TVar ByteString -> BSHandle) -> m (TVar ByteString) -> m BSHandle
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> m (TVar ByteString)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO ByteString
bs

feed :: MonadIO m => BSHandle -> ByteString -> m ()
feed :: BSHandle -> ByteString -> m ()
feed (BSHandle Int
size TVar Bool
state TVar ByteString
h) ByteString
bs = STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Bool
st <- TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
state
  Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
st (STM () -> STM ()) -> STM () -> STM ()
forall a b. (a -> b) -> a -> b
$ do
    ByteString
bs0 <- TVar ByteString -> STM ByteString
forall a. TVar a -> STM a
readTVar TVar ByteString
h
    Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
B.length ByteString
bs0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
size) STM ()
forall a. STM a
retrySTM
    TVar ByteString -> ByteString -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ByteString
h (ByteString -> STM ()) -> ByteString -> STM ()
forall a b. (a -> b) -> a -> b
$ ByteString
bs0 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs

closeBSHandle :: MonadIO m => BSHandle -> m ()
closeBSHandle :: BSHandle -> m ()
closeBSHandle (BSHandle Int
_ TVar Bool
state TVar ByteString
_) = STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
state Bool
False

data BSTransport = BS
    { BSTransport -> TVar ByteString
bsHandle :: TVar ByteString
    , BSTransport -> ByteString -> IO ()
bsWriter :: ByteString -> IO ()
    , BSTransport -> TVar Bool
bsState  :: TVar Bool
    }

instance Transport BSTransport where
  data TransportConfig BSTransport = BSConfig BSHandle (ByteString -> IO ())
  newTransport :: TransportConfig BSTransport -> IO BSTransport
newTransport (BSConfig (BSHandle _ bsState bsHandle) bsWriter) = do
    STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
bsState Bool
True
    BSTransport -> IO BSTransport
forall (m :: * -> *) a. Monad m => a -> m a
return BS :: TVar ByteString
-> (ByteString -> IO ()) -> TVar Bool -> BSTransport
BS {TVar Bool
TVar ByteString
ByteString -> IO ()
bsWriter :: ByteString -> IO ()
bsHandle :: TVar ByteString
bsState :: TVar Bool
bsState :: TVar Bool
bsWriter :: ByteString -> IO ()
bsHandle :: TVar ByteString
..}
  recvData :: BSTransport -> Int -> IO ByteString
recvData BS {TVar Bool
TVar ByteString
ByteString -> IO ()
bsState :: TVar Bool
bsWriter :: ByteString -> IO ()
bsHandle :: TVar ByteString
bsState :: BSTransport -> TVar Bool
bsWriter :: BSTransport -> ByteString -> IO ()
bsHandle :: BSTransport -> TVar ByteString
..} Int
nbytes = STM ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM ByteString -> IO ByteString)
-> STM ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
    ByteString
bs <- TVar ByteString -> STM ByteString
forall a. TVar a -> STM a
readTVar TVar ByteString
bsHandle
    if ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
empty then do
      Bool
status <- TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
bsState
      if Bool
status then STM ByteString
forall a. STM a
retrySTM
                else ByteString -> STM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
    else do
      TVar ByteString -> ByteString -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ByteString
bsHandle (ByteString -> STM ()) -> ByteString -> STM ()
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.drop Int
nbytes ByteString
bs
      ByteString -> STM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> STM ByteString) -> ByteString -> STM ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Int
nbytes ByteString
bs
  sendData :: BSTransport -> ByteString -> IO ()
sendData BS {TVar Bool
TVar ByteString
ByteString -> IO ()
bsState :: TVar Bool
bsWriter :: ByteString -> IO ()
bsHandle :: TVar ByteString
bsState :: BSTransport -> TVar Bool
bsWriter :: BSTransport -> ByteString -> IO ()
bsHandle :: BSTransport -> TVar ByteString
..} ByteString
bs = do
    Bool
status <- TVar Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Bool
bsState
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
status (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ()
bsWriter ByteString
bs
  closeTransport :: BSTransport -> IO ()
closeTransport BS {TVar Bool
TVar ByteString
ByteString -> IO ()
bsState :: TVar Bool
bsWriter :: ByteString -> IO ()
bsHandle :: TVar ByteString
bsState :: BSTransport -> TVar Bool
bsWriter :: BSTransport -> ByteString -> IO ()
bsHandle :: BSTransport -> TVar ByteString
..} = STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
bsState Bool
False
    TVar ByteString -> ByteString -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ByteString
bsHandle ByteString
empty

bsTransportConfig :: BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
bsTransportConfig :: BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
bsTransportConfig = BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
BSConfig

makePipe :: MonadIO m => m (TransportConfig BSTransport, TransportConfig BSTransport)
makePipe :: m (TransportConfig BSTransport, TransportConfig BSTransport)
makePipe = do
  BSHandle
rHandle <- ByteString -> m BSHandle
forall (m :: * -> *). MonadIO m => ByteString -> m BSHandle
newBSHandle ByteString
empty
  BSHandle
wHandle <- ByteString -> m BSHandle
forall (m :: * -> *). MonadIO m => ByteString -> m BSHandle
newBSHandle ByteString
empty

  (TransportConfig BSTransport, TransportConfig BSTransport)
-> m (TransportConfig BSTransport, TransportConfig BSTransport)
forall (m :: * -> *) a. Monad m => a -> m a
return (BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
bsTransportConfig BSHandle
rHandle (BSHandle -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => BSHandle -> ByteString -> m ()
feed BSHandle
wHandle), BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
bsTransportConfig BSHandle
wHandle (BSHandle -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => BSHandle -> ByteString -> m ()
feed BSHandle
rHandle))