-- | This module defines nestable suspension functors for use with the 'Coroutine' monad transformer, as well as
-- functions for running nested coroutines of this sort.
-- Coroutines can be run from within another coroutine. In this case, the nested coroutines always suspend to their
-- invoker. If a function from this module, such as 'pogoStickNested', is used to run a nested coroutine, the parent
-- coroutine can be automatically suspended as well. A single suspension can thus suspend an entire chain of nested
-- coroutines.
-- Nestable coroutines of this kind should group their suspension functors into an 'EitherFunctor'. You can adjust a
-- normal suspension, such as the one produced by 'yield', using functions 'mapSuspension' and 'liftAncestor'. To run nested
-- coroutines, use functions 'pogoStickNested', 'seesawNested', and 'coupleNested'.

{-# LANGUAGE ScopedTypeVariables, Rank2Types, MultiParamTypeClasses, TypeFamilies,
             FlexibleContexts, FlexibleInstances, OverlappingInstances, UndecidableInstances

module Control.Monad.Coroutine.Nested
    pogoStickNested, coupleNested, seesawNested, seesawNestedSteps,
    ChildFunctor(..), AncestorFunctor(..),
    liftParent, liftAncestor

import Control.Monad (join, liftM)
import Control.Monad.Trans.Class (lift)

import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors (EitherFunctor(..))

import Data.Functor.Compose (Compose(..))

-- | Run a nested 'Coroutine' that can suspend both itself and the current 'Coroutine'.
pogoStickNested :: forall s1 s2 m x. (Functor s1, Functor s2, Monad m) => 
                   (s2 (Coroutine (EitherFunctor s1 s2) m x) -> Coroutine (EitherFunctor s1 s2) m x)
                   -> Coroutine (EitherFunctor s1 s2) m x -> Coroutine s1 m x
pogoStickNested reveal t = 
   Coroutine{resume= resume t
                      >>= \s-> case s
                               of Right result -> return (Right result)
                                  Left (LeftF s) -> return (Left (fmap (pogoStickNested reveal) s))
                                  Left (RightF c) -> resume (pogoStickNested reveal (reveal c))}

-- | Much like 'couple', but with two nested coroutines.
coupleNested :: forall s0 s1 s2 m x y r. (Monad m, Functor s0, Monad s0, Functor s1, Functor s2) => 
                PairBinder m
             -> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y
             -> Coroutine (EitherFunctor s0 (SomeFunctor s1 s2)) m (x, y)
coupleNested runPair = coupleNested' where
   coupleNested' t1 t2 = Coroutine{resume= runPair (\ st1 st2 -> return (proceed st1 st2)) (resume t1) (resume t2)}
   proceed (Right x) (Right y) = Right (x, y)
   proceed (Left (RightF s)) (Right y) = Left $ RightF $ fmap (flip coupleNested' (return y)) (LeftSome s)
   proceed (Right x) (Left (RightF s)) = Left $ RightF $ fmap (coupleNested' (return x)) (RightSome s)
   proceed (Left (RightF s1)) (Left (RightF s2)) = Left $ RightF $ fmap (uncurry coupleNested') (Both $ composePair s1 s2)
   proceed l (Left (LeftF s)) = Left $ LeftF $ fmap (coupleNested' (Coroutine $ return l)) s
   proceed (Left (LeftF s)) r = Left $ LeftF $ fmap (flip coupleNested' (Coroutine $ return r)) s

-- | Like 'seesaw', but for nested coroutines that are allowed to suspend the current coroutine as well as themselves.
-- If both coroutines try to suspend the current coroutine in the same step, the left coroutine's suspension will have
-- precedence.
seesawNested :: (Monad m, Functor s0, Functor s1, Functor s2) =>
                PairBinder m
             -> SeesawResolver s1 s2 (EitherFunctor s0 s1) (EitherFunctor s0 s2)
             -> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y -> Coroutine s0 m (x, y)
seesawNested runPair resolver t1 t2 = seesawNestedSteps runPair proceed t1 t2 where
   proceed cont (Left s1) (Left s2) = resumeBoth resolver cont s1 s2
   proceed _ (Left s) (Right y) = liftM (flip (,) y) $ pogoStickNested (resumeLeft resolver) (resumeLeft resolver s)
   proceed _ (Right x) (Left s) = liftM ((,) x) $ pogoStickNested (resumeRight resolver) (resumeRight resolver s)
   proceed _ (Right x) (Right y) = return (x, y)

-- | Like 'seesawSteps', but for nested coroutines that are allowed to suspend the current coroutine as well
-- as themselves.  If both coroutines try to suspend the current coroutine in the same step, the left coroutine's
-- suspension will have precedence.
seesawNestedSteps :: forall m c1 c2 s0 s1 s2 s1' s2' x y. 
                     (Monad m, Functor s0, Functor s1, Functor s2, 
                      s1' ~ EitherFunctor s0 s1, s2' ~ EitherFunctor s0 s2,
                      c1 ~ Coroutine s1' m x, c2 ~ Coroutine s2' m y) =>
                     PairBinder m
                  -> ((c1 -> c2 -> Coroutine s0 m (x, y)) 
                      -> Either (s1 c1) x -> Either (s2 c2) y -> Coroutine s0 m (x, y))
                  -> c1 -> c2 -> Coroutine s0 m (x, y)
seesawNestedSteps runPair proceed t1 t2 = seesaw' t1 t2 where
   seesaw' t1 t2 = Coroutine{resume= bouncePair t1 t2}
   bouncePair t1 t2 = runPair proceed' (resume t1) (resume t2)
   proceed' :: CoroutineStepResult s1' m x -> CoroutineStepResult s2' m y -> m (CoroutineStepResult s0 m (x, y))
   proceed' (Left (LeftF s1)) step2 = return $ Left $ fmap ((flip seesaw' (Coroutine $ return step2))) s1
   proceed' step1 (Left (LeftF s2)) = return $ Left $ fmap (seesaw' (Coroutine $ return step1)) s2
   proceed' step1 step2 = resume $ proceed seesaw' (local step1) (local step2)
   local :: forall s x. 
            CoroutineStepResult (EitherFunctor s0 s) m x -> Either (s (Coroutine (EitherFunctor s0 s) m x)) x
   local (Left (RightF s)) = Left s
   local (Right r) = Right r

-- | Class of functors that can contain another functor.
class Functor c => ChildFunctor c where
   type Parent c :: * -> *
   wrap :: Parent c x -> c x
instance (Functor p, Functor s) => ChildFunctor (EitherFunctor p s) where
   type Parent (EitherFunctor p s) = p
   wrap = LeftF

-- | Class of functors that can be lifted.
class (Functor a, Functor d) => AncestorFunctor a d where
   -- | Convert the ancestor functor into its descendant. The descendant functor typically contains the ancestor.
   liftFunctor :: a x -> d x

instance Functor a => AncestorFunctor a a where
   liftFunctor = id
instance (Functor a, ChildFunctor d, d' ~ Parent d, AncestorFunctor a d') => AncestorFunctor a d where
   liftFunctor = wrap . (liftFunctor :: a x -> d' x)

-- | Converts a coroutine into a child nested coroutine.
liftParent :: forall m p c x. (Monad m, Functor p, ChildFunctor c, p ~ Parent c) => Coroutine p m x -> Coroutine c m x
liftParent cort = mapSuspension wrap cort

-- | Converts a coroutine into a descendant nested coroutine.
liftAncestor :: forall m a d x. (Monad m, Functor a, AncestorFunctor a d) => Coroutine a m x -> Coroutine d m x
liftAncestor cort = mapSuspension liftFunctor cort