{-# LANGUAGE BlockArguments #-}

module Hercules.Agent.Conduit where

import Data.Conduit (ConduitT, await, awaitForever, yield, (.|))
import Data.IORef (IORef, modifyIORef, newIORef, readIORef)
import Data.Sequence qualified as Seq
import Protolude hiding (pred, yield)

tailC :: (Monad m) => Int -> ConduitT i i m ()
tailC :: forall (m :: * -> *) i. Monad m => Int -> ConduitT i i m ()
tailC Int
n = do
  Seq i
buf <- Int -> ConduitT i i m (Seq i)
forall (m :: * -> *) i o. Monad m => Int -> ConduitT i o m (Seq i)
sinkTail Int
n
  Seq i -> (i -> ConduitT i i m ()) -> ConduitT i i m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Seq i
buf i -> ConduitT i i m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield

-- | Return the last @n@ items
sinkTail :: (Monad m) => Int -> ConduitT i o m (Seq i)
sinkTail :: forall (m :: * -> *) i o. Monad m => Int -> ConduitT i o m (Seq i)
sinkTail Int
n = do
  Seq i -> ConduitT i o m (Seq i)
forall {m :: * -> *} {a} {o}.
Monad m =>
Seq a -> ConduitT a o m (Seq a)
doBuffer Seq i
forall a. Monoid a => a
mempty
  where
    doBuffer :: Seq a -> ConduitT a o m (Seq a)
doBuffer Seq a
st =
      ConduitT a o m (Maybe a)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await ConduitT a o m (Maybe a)
-> (Maybe a -> ConduitT a o m (Seq a)) -> ConduitT a o m (Seq a)
forall a b.
ConduitT a o m a -> (a -> ConduitT a o m b) -> ConduitT a o m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe a
Nothing -> Seq a -> ConduitT a o m (Seq a)
forall a. a -> ConduitT a o m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Seq a
st
        Just a
item -> Seq a -> ConduitT a o m (Seq a)
doBuffer (Seq a -> ConduitT a o m (Seq a))
-> Seq a -> ConduitT a o m (Seq a)
forall a b. (a -> b) -> a -> b
$! (Int -> Seq a -> Seq a
forall a. Int -> Seq a -> Seq a
Seq.drop (Seq a -> Int
forall a. Seq a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Seq a
st Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Seq a
st Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.:|> a
item)

-- | Take at most @n@ items that satisfy the predicate, then stop consuming,
-- even if the next item does not match the predicate.
--
-- Return the number of counted messages and the total number of messages written.
takeCWhileStopEarly :: (Monad m) => (i -> Bool) -> Int -> ConduitT i i m (Int, Int)
takeCWhileStopEarly :: forall (m :: * -> *) i.
Monad m =>
(i -> Bool) -> Int -> ConduitT i i m (Int, Int)
takeCWhileStopEarly i -> Bool
counts Int
limit = Int -> Int -> ConduitT i i m (Int, Int)
forall {m :: * -> *} {t}.
(Monad m, Num t) =>
Int -> t -> ConduitT i i m (Int, t)
go Int
0 Int
0
  where
    go :: Int -> t -> ConduitT i i m (Int, t)
go Int
counted t
total | Int
counted Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
limit = (Int, t) -> ConduitT i i m (Int, t)
forall a. a -> ConduitT i i m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
counted, t
total)
    go Int
counted t
total =
      ConduitT i i m (Maybe i)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await ConduitT i i m (Maybe i)
-> (Maybe i -> ConduitT i i m (Int, t)) -> ConduitT i i m (Int, t)
forall a b.
ConduitT i i m a -> (a -> ConduitT i i m b) -> ConduitT i i m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe i
Nothing -> (Int, t) -> ConduitT i i m (Int, t)
forall a. a -> ConduitT i i m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
counted, t
total)
        Just i
item -> do
          i -> ConduitT i i m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield i
item
          if i
item i -> (i -> Bool) -> Bool
forall a b. a -> (a -> b) -> b
& i -> Bool
counts
            then Int -> t -> ConduitT i i m (Int, t)
go (Int
counted Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (t
total t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)
            else Int -> t -> ConduitT i i m (Int, t)
go Int
counted (t
total t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)

countProduction :: (Num n, MonadIO m) => (i -> Bool) -> IORef n -> ConduitT i i m ()
countProduction :: forall n (m :: * -> *) i.
(Num n, MonadIO m) =>
(i -> Bool) -> IORef n -> ConduitT i i m ()
countProduction i -> Bool
pred IORef n
counter = (i -> ConduitT i i m ()) -> ConduitT i i m ()
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever (\i
i -> i -> ConduitT i i m ()
forall {m :: * -> *}. MonadIO m => i -> m ()
increment i
i ConduitT i i m () -> ConduitT i i m () -> ConduitT i i m ()
forall a b.
ConduitT i i m a -> ConduitT i i m b -> ConduitT i i m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> i -> ConduitT i i m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield i
i)
  where
    increment :: i -> m ()
increment i
i | i -> Bool
pred i
i = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef n -> (n -> n) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef n
counter (n -> n -> n
forall a. Num a => a -> a -> a
+ n
1)
    increment i
_ = m ()
forall (f :: * -> *). Applicative f => f ()
pass

withInputProductionCount :: (MonadIO m) => (i -> Bool) -> ConduitT i o m a -> ConduitT i o m (Int, a)
withInputProductionCount :: forall (m :: * -> *) i o a.
MonadIO m =>
(i -> Bool) -> ConduitT i o m a -> ConduitT i o m (Int, a)
withInputProductionCount i -> Bool
pred ConduitT i o m a
conduit = do
  IORef Int
counter <- IO (IORef Int) -> ConduitT i o m (IORef Int)
forall a. IO a -> ConduitT i o m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IORef Int) -> ConduitT i o m (IORef Int))
-> IO (IORef Int) -> ConduitT i o m (IORef Int)
forall a b. (a -> b) -> a -> b
$ Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0
  a
r <- (i -> Bool) -> IORef Int -> ConduitT i i m ()
forall n (m :: * -> *) i.
(Num n, MonadIO m) =>
(i -> Bool) -> IORef n -> ConduitT i i m ()
countProduction i -> Bool
pred IORef Int
counter ConduitT i i m () -> ConduitT i o m a -> ConduitT i o m a
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| ConduitT i o m a
conduit
  (,a
r) (Int -> (Int, a)) -> ConduitT i o m Int -> ConduitT i o m (Int, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int -> ConduitT i o m Int
forall a. IO a -> ConduitT i o m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
counter)

withMessageLimit ::
  (MonadIO m) =>
  (a -> Bool) ->
  -- | First limit
  Int ->
  -- | Max tail part limit
  Int ->
  -- | What to do when truncatable output starts (waiting starts)
  ConduitT a a m () ->
  -- | What to do before yielding a truncated tail
  (Int -> ConduitT a a m ()) ->
  -- | What to do after yielding a truncated tail
  (Int -> ConduitT a a m ()) ->
  ConduitT a a m ()
withMessageLimit :: forall (m :: * -> *) a.
MonadIO m =>
(a -> Bool)
-> Int
-> Int
-> ConduitT a a m ()
-> (Int -> ConduitT a a m ())
-> (Int -> ConduitT a a m ())
-> ConduitT a a m ()
withMessageLimit a -> Bool
pred Int
firstLimit Int
tailLimit ConduitT a a m ()
afterFirst Int -> ConduitT a a m ()
beforeTail Int -> ConduitT a a m ()
afterTail = do
  (Int
c, Int
_inclFlush) <- (a -> Bool) -> Int -> ConduitT a a m (Int, Int)
forall (m :: * -> *) i.
Monad m =>
(i -> Bool) -> Int -> ConduitT i i m (Int, Int)
takeCWhileStopEarly a -> Bool
pred Int
firstLimit
  Bool -> ConduitT a a m () -> ConduitT a a m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
firstLimit) ConduitT a a m ()
afterFirst
  (Int
n, Seq a
x) <- (a -> Bool)
-> ConduitT a a m (Seq a) -> ConduitT a a m (Int, Seq a)
forall (m :: * -> *) i o a.
MonadIO m =>
(i -> Bool) -> ConduitT i o m a -> ConduitT i o m (Int, a)
withInputProductionCount a -> Bool
pred do
    Int -> ConduitT a a m (Seq a)
forall (m :: * -> *) i o. Monad m => Int -> ConduitT i o m (Seq i)
sinkTail Int
tailLimit
  let between :: Int
between = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Seq a -> Int
forall a. Seq a -> Int
Seq.length Seq a
x
  Bool -> ConduitT a a m () -> ConduitT a a m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
between Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) do
    Int -> ConduitT a a m ()
beforeTail Int
between
  Seq a -> (a -> ConduitT a a m ()) -> ConduitT a a m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Seq a
x a -> ConduitT a a m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield
  Bool -> ConduitT a a m () -> ConduitT a a m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
between Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) do
    Int -> ConduitT a a m ()
afterTail Int
between