module Control.Monad.Coroutine
Coroutine(Coroutine, resume), CoroutineStepResult, suspend,
mapMonad, mapSuspension, mapFirstSuspension,
Naught, runCoroutine, bounce, pogoStick, pogoStickM, foldRun,
PairBinder, sequentialBinder, parallelBinder, liftBinder,
Weaver, WeaveStepper, weave, merge
import Control.Applicative (Applicative(..))
import Control.Monad (Monad(..), ap, liftM, (<=<))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.Either (partitionEithers)
import Control.Monad.Parallel (MonadParallel(..))
newtype Coroutine s m r = Coroutine {
resume :: m (Either (s (Coroutine s m r)) r)
type CoroutineStepResult s m r = Either (s (Coroutine s m r)) r
instance (Functor s, Functor m) => Functor (Coroutine s m) where
fmap f t = Coroutine (fmap (apply f) (resume t))
where apply fc (Right x) = Right (fc x)
apply fc (Left s) = Left (fmap (fmap fc) s)
instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
pure = return
(<*>) = ap
instance (Functor s, Monad m) => Monad (Coroutine s m) where
return x = Coroutine (return (Right x))
t >>= f = Coroutine (resume t >>= apply f)
where apply fc (Right x) = resume (fc x)
apply fc (Left s) = return (Left (fmap (>>= fc) s))
t >> f = Coroutine (resume t >>= apply f)
where apply fc (Right _) = resume fc
apply fc (Left s) = return (Left (fmap (>> fc) s))
instance (Functor s, MonadParallel m) => MonadParallel (Coroutine s m) where
bindM2 = liftBinder bindM2
instance Functor s => MonadTrans (Coroutine s) where
lift = Coroutine . liftM Right
instance (Functor s, MonadIO m) => MonadIO (Coroutine s m) where
liftIO = lift . liftIO
data Naught x
instance Functor Naught where
fmap _ _ = undefined
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Left s))
mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
(forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort = Coroutine {resume= liftM map' (f $ resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (fmap (mapMonad f) s)
mapSuspension :: (Functor s, Monad m) => (forall y. s y -> s' y) -> Coroutine s m x -> Coroutine s' m x
mapSuspension f cort = Coroutine {resume= liftM map' (resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (f $ fmap (mapSuspension f) s)
mapFirstSuspension :: forall s m x. (Functor s, Monad m) =>
(forall y. s y -> s y) -> Coroutine s m x -> Coroutine s m x
mapFirstSuspension f cort = Coroutine {resume= liftM map' (resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (f s)
runCoroutine :: Monad m => Coroutine Naught m x -> m x
runCoroutine = pogoStick (error "runCoroutine can run only a non-suspending coroutine!")
bounce :: (Monad m, Functor s) => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> Coroutine s m x
bounce spring c = lift (resume c) >>= either spring return
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
where loop c = resume c >>= either (loop . spring) return
pogoStickM :: Monad m => (s (Coroutine s m x) -> m (Coroutine s m x)) -> Coroutine s m x -> m x
pogoStickM spring = loop
where loop c = resume c >>= either (loop <=< spring) return
foldRun :: Monad m => (a -> s (Coroutine s m x) -> (a, Coroutine s m x)) -> a -> Coroutine s m x -> m (a, x)
foldRun f a c = resume c
>>= \s-> case s
of Right result -> return (a, result)
Left c' -> uncurry (foldRun f) (f a c')
type PairBinder m = forall x y r. (x -> y -> m r) -> m x -> m y -> m r
sequentialBinder :: Monad m => PairBinder m
sequentialBinder f mx my = do {x <- mx; y <- my; f x y}
parallelBinder :: MonadParallel m => PairBinder m
parallelBinder = bindM2
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 = Coroutine (binder combine (resume t1) (resume t2)) where
combine (Right x) (Right y) = resume (f x y)
combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)
type Weaver s1 s2 s3 m x y z = Coroutine s1 m x -> Coroutine s2 m y -> Coroutine s3 m z
type WeaveStepper s1 s2 s3 m x y z =
Weaver s1 s2 s3 m x y z -> CoroutineStepResult s1 m x -> CoroutineStepResult s2 m y -> Coroutine s3 m z
weave :: forall s1 s2 s3 m x y z. (Monad m, Functor s1, Functor s2, Functor s3) =>
PairBinder m -> WeaveStepper s1 s2 s3 m x y z -> Weaver s1 s2 s3 m x y z
weave runPair weaveStep c1 c2 = zipC c1 c2 where
zipC c1 c2 = Coroutine{resume= runPair (\c1' c2'-> resume $ weaveStep zipC c1' c2') (resume c1) (resume c2)}
merge :: forall s m x. (Monad m, Functor s) =>
(forall y. [m y] -> m [y]) -> (forall y. [s y] -> s [y])
-> [Coroutine s m x] -> Coroutine s m [x]
merge sequence1 sequence2 corts = Coroutine{resume= liftM step $ sequence1 (map resume corts)} where
step :: [CoroutineStepResult s m x] -> CoroutineStepResult s m [x]
step list = case partitionEithers list
of ([], ends) -> Right ends
(suspensions, ends) -> Left $ fmap (merge sequence1 sequence2 . (map return ends ++)) $
sequence2 suspensions