module Control.Monad.Trans.Task 
  ( 
    TaskT (..)
    
  , Trace (..)
  , runTrace
    
  , taskToTrace
  , runTask
  ) where
import Control.Applicative
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.IO.Class
import Control.Monad.Trans
import Control.Monad.Trans.Cont
import Control.Monad.Task.Class
import Data.Either (partitionEithers)
data Trace m e where
  EXIT   :: Trace m e
  RET    :: Trace m e
  YIELD  :: m (Trace m e) -> Trace m e
  FORK   :: m (Trace m e) -> m (Trace m e) -> Trace m e
  WATCH  :: (e -> Maybe v) -> (v -> m (Trace m e)) -> Trace m e 
  SIGNAL :: e -> m (Trace m e) -> Trace m e
runTrace :: Monad m => m (Trace m e) -> m ()
runTrace prog = loop [prog] []
  where
    loop [] _ = return ()
    loop (m:ms) ss = m >>= step
      where
        step EXIT         = return ()
        step RET          = loop ms ss
        step (YIELD t)    = loop (ms ++ [t]) ss
        step (FORK t1 t2) = loop (t1:t2:ms) ss
        step (WATCH f g)  = loop ms (WATCH f g : ss)
        step (SIGNAL e t) = loop (ms' ++ [t] ++ ms) ss'
          where (ms', ss') = partitionEithers evs
                evs = [ maybe (Right x) (Left . g) (f e) | x@(WATCH f g) <- ss ]
newtype TaskT e m a 
  = TaskT { runTaskT :: ContT (Trace m e) m a }
  deriving (Functor, Applicative, MonadIO)
taskToTrace :: Monad m => TaskT e m a -> m (Trace m e) 
taskToTrace (TaskT (ContT f)) = f (\_ -> return RET)
runTask :: Monad m => TaskT e m a -> m ()
runTask = runTrace . taskToTrace
instance Monad m => Monad (TaskT e m) where
  return = TaskT . return
  (>>=) m f = TaskT $ runTaskT m >>= runTaskT . f
  fail _ = TaskT $ ContT $ \_ -> return EXIT 
instance MonadTrans (TaskT e) where
  lift = TaskT . lift
instance MonadReader s m => MonadReader s (TaskT e m) where
  ask = TaskT ask
  local f = TaskT . local f . runTaskT 
instance MonadState s m => MonadState s (TaskT e m) where
  get = TaskT get
  put = TaskT . put
instance Monad m => MonadTask e (TaskT e m) where
  exit     = TaskT $ ContT $ \_ -> return EXIT
  yield    = TaskT $ ContT $ return . YIELD . ($())
  fork p   = TaskT $ ContT $ return . FORK (taskToTrace p) . ($())
  watch f  = TaskT $ ContT $ return . WATCH f 
  signal e = TaskT $ ContT $ return . SIGNAL e . ($())