{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
#if !(MIN_VERSION_transformers(0,6,0))
{-# OPTIONS_GHC -Wno-deprecations #-}
#endif

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.Free.Class
-- Copyright   :  (C) 2008-2015 Edward Kmett
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  Edward Kmett <ekmett@gmail.com>
-- Stability   :  experimental
-- Portability :  non-portable (fundeps, MPTCs)
--
-- Monads for free.
----------------------------------------------------------------------------
module Control.Monad.Free.Class
  ( MonadFree(..)
  , liftF
  , wrapT
  ) where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import qualified Control.Monad.Trans.State.Strict as Strict
import qualified Control.Monad.Trans.State.Lazy as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict
import qualified Control.Monad.Trans.Writer.Lazy as Lazy
import qualified Control.Monad.Trans.RWS.Strict as Strict
import qualified Control.Monad.Trans.RWS.Lazy as Lazy
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.Except
import Control.Monad.Trans.Identity

#if !(MIN_VERSION_transformers(0,6,0))
import Control.Monad.Trans.Error
import Control.Monad.Trans.List
#endif

-- |
-- Monads provide substitution ('fmap') and renormalization ('Control.Monad.join'):
--
-- @m '>>=' f = 'Control.Monad.join' ('fmap' f m)@
--
-- A free 'Monad' is one that does no work during the normalization step beyond simply grafting the two monadic values together.
--
-- @[]@ is not a free 'Monad' (in this sense) because @'Control.Monad.join' [[a]]@ smashes the lists flat.
--
-- On the other hand, consider:
--
-- @
-- data Tree a = Bin (Tree a) (Tree a) | Tip a
-- @
--
-- @
-- instance 'Monad' Tree where
--   'return' = Tip
--   Tip a '>>=' f = f a
--   Bin l r '>>=' f = Bin (l '>>=' f) (r '>>=' f)
-- @
--
-- This 'Monad' is the free 'Monad' of Pair:
--
-- @
-- data Pair a = Pair a a
-- @
--
-- And we could make an instance of 'MonadFree' for it directly:
--
-- @
-- instance 'MonadFree' Pair Tree where
--    'wrap' (Pair l r) = Bin l r
-- @
--
-- Or we could choose to program with @'Control.Monad.Free.Free' Pair@ instead of 'Tree'
-- and thereby avoid having to define our own 'Monad' instance.
--
-- Moreover, "Control.Monad.Free.Church" provides a 'MonadFree'
-- instance that can improve the /asymptotic/ complexity of code that
-- constructs free monads by effectively reassociating the use of
-- ('>>='). You may also want to take a look at the @kan-extensions@
-- package (<http://hackage.haskell.org/package/kan-extensions>).
--
-- See 'Control.Monad.Free.Free' for a more formal definition of the free 'Monad'
-- for a 'Functor'.
class Monad m => MonadFree f m | m -> f where
  -- | Add a layer.
  --
  -- @
  -- wrap (fmap f x) ≡ wrap (fmap return x) >>= f
  -- @
  wrap :: f (m a) -> m a
  default wrap :: (m ~ t n, MonadTrans t, MonadFree f n, Functor f) => f (m a) -> m a
  wrap = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. Monad m => a -> m a
return

instance (Functor f, MonadFree f m) => MonadFree f (ReaderT e m) where
  wrap :: forall a. f (ReaderT e m a) -> ReaderT e m a
wrap f (ReaderT e m a)
fm = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT e
e forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (ReaderT e m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (Lazy.StateT s m) where
  wrap :: forall a. f (StateT s m a) -> StateT s m a
wrap f (StateT s m a)
fm = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
Lazy.StateT forall a b. (a -> b) -> a -> b
$ \s
s -> forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Lazy.runStateT s
s forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (StateT s m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (Strict.StateT s m) where
  wrap :: forall a. f (StateT s m a) -> StateT s m a
wrap f (StateT s m a)
fm = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
Strict.StateT forall a b. (a -> b) -> a -> b
$ \s
s -> forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Strict.runStateT s
s forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (StateT s m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (ContT r m) where
  wrap :: forall a. f (ContT r m a) -> ContT r m a
wrap f (ContT r m a)
t = forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ \a -> m r
h -> forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ContT r m a
p -> forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT ContT r m a
p a -> m r
h) f (ContT r m a)
t)

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Lazy.WriterT w m) where
  wrap :: forall a. f (WriterT w m a) -> WriterT w m a
wrap = forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Lazy.WriterT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Lazy.runWriterT

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Strict.WriterT w m) where
  wrap :: forall a. f (WriterT w m a) -> WriterT w m a
wrap = forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Strict.WriterT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Strict.runWriterT

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Strict.RWST r w s m) where
  wrap :: forall a. f (RWST r w s m a) -> RWST r w s m a
wrap f (RWST r w s m a)
fm = forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Strict.RWST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\RWST r w s m a
m -> forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Strict.runRWST RWST r w s m a
m r
r s
s) f (RWST r w s m a)
fm

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Lazy.RWST r w s m) where
  wrap :: forall a. f (RWST r w s m a) -> RWST r w s m a
wrap f (RWST r w s m a)
fm = forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Lazy.RWST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\RWST r w s m a
m -> forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Lazy.runRWST RWST r w s m a
m r
r s
s) f (RWST r w s m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (MaybeT m) where
  wrap :: forall a. f (MaybeT m a) -> MaybeT m a
wrap = forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT

instance (Functor f, MonadFree f m) => MonadFree f (IdentityT m) where
  wrap :: forall a. f (IdentityT m a) -> IdentityT m a
wrap = forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT

instance (Functor f, MonadFree f m) => MonadFree f (ExceptT e m) where
  wrap :: forall a. f (ExceptT e m a) -> ExceptT e m a
wrap = forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT

-- instance (Functor f, MonadFree f m) => MonadFree f (EitherT e m) where
--   wrap = EitherT . wrap . fmap runEitherT

#if !(MIN_VERSION_transformers(0,6,0))
instance (Functor f, MonadFree f m, Error e) => MonadFree f (ErrorT e m) where
  wrap :: forall a. f (ErrorT e m a) -> ErrorT e m a
wrap = forall e (m :: * -> *) a. m (Either e a) -> ErrorT e m a
ErrorT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall e (m :: * -> *) a. ErrorT e m a -> m (Either e a)
runErrorT

instance (Functor f, MonadFree f m) => MonadFree f (ListT m) where
  wrap :: forall a. f (ListT m a) -> ListT m a
wrap = forall (m :: * -> *) a. m [a] -> ListT m a
ListT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. ListT m a -> m [a]
runListT
#endif

-- | A version of lift that can be used with just a Functor for f.
liftF :: (Functor f, MonadFree f m) => f a -> m a
liftF :: forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF = forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. Monad m => a -> m a
return

-- | A version of wrap for monad transformers over a free monad.
--
-- /Note:/ that this is the default implementation for 'wrap' for
-- @MonadFree f (t m)@.
wrapT :: (Functor f, MonadFree f m, MonadTrans t, Monad (t m)) => f (t m a) -> t m a
wrapT :: forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, MonadFree f m, MonadTrans t, Monad (t m)) =>
f (t m a) -> t m a
wrapT = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF