module Ether.TaggedTrans ( TaggedTrans(..) ) where import Control.Applicative import Control.Monad (MonadPlus) import Control.Monad.Fix (MonadFix) import Control.Monad.Trans.Class (MonadTrans, lift) import Control.Monad.IO.Class (MonadIO) import Control.Monad.Morph (MFunctor, MMonad) import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import qualified Control.Monad.Base as MB import qualified Control.Monad.Trans.Control as MC import qualified Control.Monad.Trans.Lift.StT as Lift import qualified Control.Monad.Trans.Lift.Local as Lift import qualified Control.Monad.Trans.Lift.Catch as Lift import qualified Control.Monad.Trans.Lift.Listen as Lift import qualified Control.Monad.Trans.Lift.Pass as Lift import qualified Control.Monad.Trans.Lift.CallCC as Lift import qualified Control.Monad.Cont.Class as Mtl import qualified Control.Monad.Reader.Class as Mtl import qualified Control.Monad.State.Class as Mtl import qualified Control.Monad.Writer.Class as Mtl import qualified Control.Monad.Error.Class as Mtl import GHC.Generics (Generic) import Data.Coerce (coerce) newtype TaggedTrans tag trans m a = TaggedTrans (trans m a) deriving ( Generic , Functor, Applicative, Alternative, Monad, MonadPlus , MonadFix, MonadTrans, MonadIO, MFunctor, MMonad , MonadThrow, MonadCatch, MonadMask ) type Pack tag trans m a = trans m a -> TaggedTrans tag trans m a type Unpack tag trans m a = TaggedTrans tag trans m a -> trans m a instance ( MB.MonadBase b (trans m) ) => MB.MonadBase b (TaggedTrans tag trans m) where liftBase = (coerce :: forall a . (b a -> trans m a) -> (b a -> TaggedTrans tag trans m a)) MB.liftBase {-# INLINE liftBase #-} instance ( MC.MonadTransControl trans ) => MC.MonadTransControl (TaggedTrans tag trans) where type StT (TaggedTrans tag trans) a = MC.StT trans a liftWith = MC.defaultLiftWith (coerce :: Pack tag trans m a) (coerce :: Unpack tag trans m a) {-# INLINE liftWith #-} restoreT = MC.defaultRestoreT (coerce :: Pack tag trans m a) {-# INLINE restoreT #-} type LiftBaseWith b m a = (MC.RunInBase m b -> b a) -> m a newtype LiftBaseWith' b m a = LBW { unLBW :: LiftBaseWith b m a } coerceLiftBaseWith :: LiftBaseWith b (trans m) a -> LiftBaseWith b (TaggedTrans tag trans m) a coerceLiftBaseWith lbw = unLBW (coerce (LBW lbw)) {-# INLINE coerceLiftBaseWith #-} instance ( MC.MonadBaseControl b (trans m) ) => MC.MonadBaseControl b (TaggedTrans tag trans m) where type StM (TaggedTrans tag trans m) a = MC.StM (trans m) a liftBaseWith = coerceLiftBaseWith MC.liftBaseWith {-# INLINE liftBaseWith #-} restoreM = (coerce :: forall a . (MC.StM (trans m) a -> trans m a) -> (MC.StM (trans m) a -> TaggedTrans tag trans m a)) MC.restoreM {-# INLINE restoreM #-} type instance Lift.StT (TaggedTrans tag trans) a = Lift.StT trans a instance Lift.LiftLocal trans => Lift.LiftLocal (TaggedTrans tag trans) where liftLocal = Lift.defaultLiftLocal (coerce :: Pack tag trans m a) (coerce :: Unpack tag trans m a) instance Lift.LiftCatch trans => Lift.LiftCatch (TaggedTrans tag trans) where liftCatch = Lift.defaultLiftCatch (coerce :: Pack tag trans m a) (coerce :: Unpack tag trans m a) instance Lift.LiftListen trans => Lift.LiftListen (TaggedTrans tag trans) where liftListen = Lift.defaultLiftListen (coerce :: Pack tag trans m a) (coerce :: Unpack tag trans m a) instance Lift.LiftPass trans => Lift.LiftPass (TaggedTrans tag trans) where liftPass = Lift.defaultLiftPass (coerce :: Pack tag trans m a) (coerce :: Unpack tag trans m a) instance Lift.LiftCallCC trans => Lift.LiftCallCC (TaggedTrans tag trans) where liftCallCC = Lift.defaultLiftCallCC (coerce :: Pack tag trans m a) (coerce :: Unpack tag trans m a) liftCallCC' = Lift.defaultLiftCallCC' (coerce :: Pack tag trans m a) (coerce :: Unpack tag trans m a) -- Instances for mtl classes instance ( Mtl.MonadCont m , Lift.LiftCallCC trans , Monad (trans m) ) => Mtl.MonadCont (TaggedTrans tag trans m) where callCC = Lift.liftCallCC' Mtl.callCC instance ( Mtl.MonadReader r m , Lift.LiftLocal trans , Monad (trans m) ) => Mtl.MonadReader r (TaggedTrans tag trans m) where ask = lift Mtl.ask local = Lift.liftLocal Mtl.ask Mtl.local reader = lift . Mtl.reader instance ( Mtl.MonadState s m , MonadTrans trans , Monad (trans m) ) => Mtl.MonadState s (TaggedTrans tag trans m) where get = lift Mtl.get put = lift . Mtl.put state = lift . Mtl.state instance ( Mtl.MonadWriter w m , Lift.LiftListen trans , Lift.LiftPass trans , Monad (trans m) ) => Mtl.MonadWriter w (TaggedTrans tag trans m) where writer = lift . Mtl.writer tell = lift . Mtl.tell listen = Lift.liftListen Mtl.listen pass = Lift.liftPass Mtl.pass instance ( Mtl.MonadError e m , Lift.LiftCatch trans , Monad (trans m) ) => Mtl.MonadError e (TaggedTrans tag trans m) where throwError = lift . Mtl.throwError catchError = Lift.liftCatch Mtl.catchError