module Control.Monad.Coroutine
(
Coroutine(Coroutine, resume), CoroutineStepResult, suspend,
mapMonad, mapSuspension, mapFirstSuspension,
Naught, runCoroutine, bounce, pogoStick, pogoStickM, foldRun,
PairBinder, sequentialBinder, parallelBinder, liftBinder,
Weaver, WeaveStepper, weave, merge
)
where
import Control.Applicative (Applicative(..))
import Control.Monad (Monad(..), ap, liftM, (<=<))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.Either (partitionEithers)
import Control.Monad.Parallel (MonadParallel(..))
newtype Coroutine s m r = Coroutine {
resume :: m (Either (s (Coroutine s m r)) r)
}
type CoroutineStepResult s m r = Either (s (Coroutine s m r)) r
instance (Functor s, Functor m) => Functor (Coroutine s m) where
fmap f t = Coroutine (fmap (apply f) (resume t))
where apply fc (Right x) = Right (fc x)
apply fc (Left s) = Left (fmap (fmap fc) s)
instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
pure = return
(<*>) = ap
instance (Functor s, Monad m) => Monad (Coroutine s m) where
return x = Coroutine (return (Right x))
t >>= f = Coroutine (resume t >>= apply f)
where apply fc (Right x) = resume (fc x)
apply fc (Left s) = return (Left (fmap (>>= fc) s))
t >> f = Coroutine (resume t >>= apply f)
where apply fc (Right _) = resume fc
apply fc (Left s) = return (Left (fmap (>> fc) s))
instance (Functor s, MonadParallel m) => MonadParallel (Coroutine s m) where
bindM2 = liftBinder bindM2
instance Functor s => MonadTrans (Coroutine s) where
lift = Coroutine . liftM Right
instance (Functor s, MonadIO m) => MonadIO (Coroutine s m) where
liftIO = lift . liftIO
data Naught x
instance Functor Naught where
fmap _ _ = undefined
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Left s))
mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
(forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort = Coroutine {resume= liftM map' (f $ resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (fmap (mapMonad f) s)
mapSuspension :: (Functor s, Monad m) => (forall y. s y -> s' y) -> Coroutine s m x -> Coroutine s' m x
mapSuspension f cort = Coroutine {resume= liftM map' (resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (f $ fmap (mapSuspension f) s)
mapFirstSuspension :: forall s m x. (Functor s, Monad m) =>
(forall y. s y -> s y) -> Coroutine s m x -> Coroutine s m x
mapFirstSuspension f cort = Coroutine {resume= liftM map' (resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (f s)
runCoroutine :: Monad m => Coroutine Naught m x -> m x
runCoroutine = pogoStick (error "runCoroutine can run only a non-suspending coroutine!")
bounce :: (Monad m, Functor s) => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> Coroutine s m x
bounce spring c = lift (resume c) >>= either spring return
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
where loop c = resume c >>= either (loop . spring) return
pogoStickM :: Monad m => (s (Coroutine s m x) -> m (Coroutine s m x)) -> Coroutine s m x -> m x
pogoStickM spring = loop
where loop c = resume c >>= either (loop <=< spring) return
foldRun :: Monad m => (a -> s (Coroutine s m x) -> (a, Coroutine s m x)) -> a -> Coroutine s m x -> m (a, x)
foldRun f a c = resume c
>>= \s-> case s
of Right result -> return (a, result)
Left c' -> uncurry (foldRun f) (f a c')
type PairBinder m = forall x y r. (x -> y -> m r) -> m x -> m y -> m r
sequentialBinder :: Monad m => PairBinder m
sequentialBinder f mx my = do {x <- mx; y <- my; f x y}
parallelBinder :: MonadParallel m => PairBinder m
parallelBinder = bindM2
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 = Coroutine (binder combine (resume t1) (resume t2)) where
combine (Right x) (Right y) = resume (f x y)
combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)
type Weaver s1 s2 s3 m x y z = Coroutine s1 m x -> Coroutine s2 m y -> Coroutine s3 m z
type WeaveStepper s1 s2 s3 m x y z =
Weaver s1 s2 s3 m x y z -> CoroutineStepResult s1 m x -> CoroutineStepResult s2 m y -> Coroutine s3 m z
weave :: forall s1 s2 s3 m x y z. (Monad m, Functor s1, Functor s2, Functor s3) =>
PairBinder m -> WeaveStepper s1 s2 s3 m x y z -> Weaver s1 s2 s3 m x y z
weave runPair weaveStep c1 c2 = zipC c1 c2 where
zipC c1 c2 = Coroutine{resume= runPair (\c1' c2'-> resume $ weaveStep zipC c1' c2') (resume c1) (resume c2)}
merge :: forall s m x. (Monad m, Functor s) =>
(forall y. [m y] -> m [y]) -> (forall y. [s y] -> s [y])
-> [Coroutine s m x] -> Coroutine s m [x]
merge sequence1 sequence2 corts = Coroutine{resume= liftM step $ sequence1 (map resume corts)} where
step :: [CoroutineStepResult s m x] -> CoroutineStepResult s m [x]
step list = case partitionEithers list
of ([], ends) -> Right ends
(suspensions, ends) -> Left $ fmap (merge sequence1 sequence2 . (map return ends ++)) $
sequence2 suspensions