{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Machine.Fanout (fanout, fanoutSteps) where
import Data.List.NonEmpty (NonEmpty (..))
import Data.Machine
import Data.Semigroup (Semigroup (sconcat))
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid (..))
import Data.Traversable (traverse)
#endif
continue :: ([b] -> r) -> [(a -> b, b)] -> Step (Is a) o r
continue _ [] = Stop
continue f ws = Await (f . traverse fst ws) Refl (f $ map snd ws)
semigroupDlist :: Semigroup a => ([a] -> [a]) -> Maybe a
semigroupDlist f = case f [] of
[] -> Nothing
x:xs -> Just $ sconcat (x:|xs)
fanout :: forall m a r. (Monad m, Semigroup r)
=> [ProcessT m a r] -> ProcessT m a r
fanout = MachineT . go id id
where
go :: ([(a -> ProcessT m a r, ProcessT m a r)]
-> [(a -> ProcessT m a r, ProcessT m a r)])
-> ([r] -> [r])
-> [ProcessT m a r]
-> m (Step (Is a) r (ProcessT m a r))
go waiting acc [] = case waiting [] of
ws -> return . maybe k (\x -> Yield x $ encased k) $ semigroupDlist acc
where k = continue fanout ws
go waiting acc (m:ms) = runMachineT m >>= \v -> case v of
Stop -> go waiting acc ms
Yield x k -> go waiting (acc . (x:)) (k:ms)
Await f Refl k -> go (waiting . ((f, k):)) acc ms
fanoutSteps :: forall m a r. (Monad m, Monoid r)
=> [ProcessT m a r] -> ProcessT m a r
fanoutSteps = MachineT . go id id
where
go :: ([(a -> ProcessT m a r, ProcessT m a r)]
-> [(a -> ProcessT m a r, ProcessT m a r)])
-> ([r] -> [r])
-> [ProcessT m a r]
-> m (Step (Is a) r (ProcessT m a r))
go waiting acc [] = case (waiting [], mconcat (acc [])) of
(ws, xs) -> return . Yield xs $ encased (continue fanoutSteps ws)
go waiting acc (m:ms) = runMachineT m >>= \v -> case v of
Stop -> go waiting acc ms
Yield x k -> go waiting (acc . (x:)) (k:ms)
Await f Refl k -> go (waiting . ((f, k):)) acc ms