{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE UndecidableInstances #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 704
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE TypeFamilies #-}
#endif
{-# OPTIONS_GHC -fno-warn-deprecations #-}
{-# LANGUAGE Safe #-}
#include "free-common.h"

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.Free.Class
-- Copyright   :  (C) 2008-2015 Edward Kmett
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  Edward Kmett <ekmett@gmail.com>
-- Stability   :  experimental
-- Portability :  non-portable (fundeps, MPTCs)
--
-- Monads for free.
----------------------------------------------------------------------------
module Control.Monad.Free.Class
  ( MonadFree(..)
  , liftF
  , wrapT
  ) where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import qualified Control.Monad.Trans.State.Strict as Strict
import qualified Control.Monad.Trans.State.Lazy as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict
import qualified Control.Monad.Trans.Writer.Lazy as Lazy
import qualified Control.Monad.Trans.RWS.Strict as Strict
import qualified Control.Monad.Trans.RWS.Lazy as Lazy
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.List
import Control.Monad.Trans.Error
import Control.Monad.Trans.Except
import Control.Monad.Trans.Identity

#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative
import Data.Monoid
#endif

-- |
-- Monads provide substitution ('fmap') and renormalization ('Control.Monad.join'):
--
-- @m '>>=' f = 'Control.Monad.join' ('fmap' f m)@
--
-- A free 'Monad' is one that does no work during the normalization step beyond simply grafting the two monadic values together.
--
-- @[]@ is not a free 'Monad' (in this sense) because @'Control.Monad.join' [[a]]@ smashes the lists flat.
--
-- On the other hand, consider:
--
-- @
-- data Tree a = Bin (Tree a) (Tree a) | Tip a
-- @
--
-- @
-- instance 'Monad' Tree where
--   'return' = Tip
--   Tip a '>>=' f = f a
--   Bin l r '>>=' f = Bin (l '>>=' f) (r '>>=' f)
-- @
--
-- This 'Monad' is the free 'Monad' of Pair:
--
-- @
-- data Pair a = Pair a a
-- @
--
-- And we could make an instance of 'MonadFree' for it directly:
--
-- @
-- instance 'MonadFree' Pair Tree where
--    'wrap' (Pair l r) = Bin l r
-- @
--
-- Or we could choose to program with @'Control.Monad.Free.Free' Pair@ instead of 'Tree'
-- and thereby avoid having to define our own 'Monad' instance.
--
-- Moreover, "Control.Monad.Free.Church" provides a 'MonadFree'
-- instance that can improve the /asymptotic/ complexity of code that
-- constructs free monads by effectively reassociating the use of
-- ('>>='). You may also want to take a look at the @kan-extensions@
-- package (<http://hackage.haskell.org/package/kan-extensions>).
--
-- See 'Control.Monad.Free.Free' for a more formal definition of the free 'Monad'
-- for a 'Functor'.
class Monad m => MonadFree f m | m -> f where
  -- | Add a layer.
  --
  -- @
  -- wrap (fmap f x) ≡ wrap (fmap return x) >>= f
  -- @
  wrap :: f (m a) -> m a
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 704
  default wrap :: (m ~ t n, MonadTrans t, MonadFree f n, Functor f) => f (m a) -> m a
  wrap = t n (t n a) -> t n a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (t n (t n a) -> t n a)
-> (f (t n a) -> t n (t n a)) -> f (t n a) -> t n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. n (t n a) -> t n (t n a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (n (t n a) -> t n (t n a))
-> (f (t n a) -> n (t n a)) -> f (t n a) -> t n (t n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (n (t n a)) -> n (t n a)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (n (t n a)) -> n (t n a))
-> (f (t n a) -> f (n (t n a))) -> f (t n a) -> n (t n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t n a -> n (t n a)) -> f (t n a) -> f (n (t n a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap t n a -> n (t n a)
forall (m :: * -> *) a. Monad m => a -> m a
return
#endif

instance (Functor f, MonadFree f m) => MonadFree f (ReaderT e m) where
  wrap :: f (ReaderT e m a) -> ReaderT e m a
wrap f (ReaderT e m a)
fm = (e -> m a) -> ReaderT e m a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m a) -> ReaderT e m a) -> (e -> m a) -> ReaderT e m a
forall a b. (a -> b) -> a -> b
$ \e
e -> f (m a) -> m a
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m a) -> m a) -> f (m a) -> m a
forall a b. (a -> b) -> a -> b
$ (ReaderT e m a -> e -> m a) -> e -> ReaderT e m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT e m a -> e -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT e
e (ReaderT e m a -> m a) -> f (ReaderT e m a) -> f (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (ReaderT e m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (Lazy.StateT s m) where
  wrap :: f (StateT s m a) -> StateT s m a
wrap f (StateT s m a)
fm = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
Lazy.StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> f (m (a, s)) -> m (a, s)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (a, s)) -> m (a, s)) -> f (m (a, s)) -> m (a, s)
forall a b. (a -> b) -> a -> b
$ (StateT s m a -> s -> m (a, s)) -> s -> StateT s m a -> m (a, s)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Lazy.runStateT s
s (StateT s m a -> m (a, s)) -> f (StateT s m a) -> f (m (a, s))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (StateT s m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (Strict.StateT s m) where
  wrap :: f (StateT s m a) -> StateT s m a
wrap f (StateT s m a)
fm = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
Strict.StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> f (m (a, s)) -> m (a, s)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (a, s)) -> m (a, s)) -> f (m (a, s)) -> m (a, s)
forall a b. (a -> b) -> a -> b
$ (StateT s m a -> s -> m (a, s)) -> s -> StateT s m a -> m (a, s)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Strict.runStateT s
s (StateT s m a -> m (a, s)) -> f (StateT s m a) -> f (m (a, s))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (StateT s m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (ContT r m) where
  wrap :: f (ContT r m a) -> ContT r m a
wrap f (ContT r m a)
t = ((a -> m r) -> m r) -> ContT r m a
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((a -> m r) -> m r) -> ContT r m a)
-> ((a -> m r) -> m r) -> ContT r m a
forall a b. (a -> b) -> a -> b
$ \a -> m r
h -> f (m r) -> m r
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap ((ContT r m a -> m r) -> f (ContT r m a) -> f (m r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ContT r m a
p -> ContT r m a -> (a -> m r) -> m r
forall k (r :: k) (m :: k -> *) a. ContT r m a -> (a -> m r) -> m r
runContT ContT r m a
p a -> m r
h) f (ContT r m a)
t)

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Lazy.WriterT w m) where
  wrap :: f (WriterT w m a) -> WriterT w m a
wrap = m (a, w) -> WriterT w m a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Lazy.WriterT (m (a, w) -> WriterT w m a)
-> (f (WriterT w m a) -> m (a, w))
-> f (WriterT w m a)
-> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m (a, w)) -> m (a, w)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (a, w)) -> m (a, w))
-> (f (WriterT w m a) -> f (m (a, w)))
-> f (WriterT w m a)
-> m (a, w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (WriterT w m a -> m (a, w)) -> f (WriterT w m a) -> f (m (a, w))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap WriterT w m a -> m (a, w)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Lazy.runWriterT

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Strict.WriterT w m) where
  wrap :: f (WriterT w m a) -> WriterT w m a
wrap = m (a, w) -> WriterT w m a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Strict.WriterT (m (a, w) -> WriterT w m a)
-> (f (WriterT w m a) -> m (a, w))
-> f (WriterT w m a)
-> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m (a, w)) -> m (a, w)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (a, w)) -> m (a, w))
-> (f (WriterT w m a) -> f (m (a, w)))
-> f (WriterT w m a)
-> m (a, w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (WriterT w m a -> m (a, w)) -> f (WriterT w m a) -> f (m (a, w))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap WriterT w m a -> m (a, w)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Strict.runWriterT

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Strict.RWST r w s m) where
  wrap :: f (RWST r w s m a) -> RWST r w s m a
wrap f (RWST r w s m a)
fm = (r -> s -> m (a, s, w)) -> RWST r w s m a
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Strict.RWST ((r -> s -> m (a, s, w)) -> RWST r w s m a)
-> (r -> s -> m (a, s, w)) -> RWST r w s m a
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> f (m (a, s, w)) -> m (a, s, w)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (a, s, w)) -> m (a, s, w)) -> f (m (a, s, w)) -> m (a, s, w)
forall a b. (a -> b) -> a -> b
$ (RWST r w s m a -> m (a, s, w))
-> f (RWST r w s m a) -> f (m (a, s, w))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\RWST r w s m a
m -> RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Strict.runRWST RWST r w s m a
m r
r s
s) f (RWST r w s m a)
fm

instance (Functor f, MonadFree f m, Monoid w) => MonadFree f (Lazy.RWST r w s m) where
  wrap :: f (RWST r w s m a) -> RWST r w s m a
wrap f (RWST r w s m a)
fm = (r -> s -> m (a, s, w)) -> RWST r w s m a
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Lazy.RWST ((r -> s -> m (a, s, w)) -> RWST r w s m a)
-> (r -> s -> m (a, s, w)) -> RWST r w s m a
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> f (m (a, s, w)) -> m (a, s, w)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (a, s, w)) -> m (a, s, w)) -> f (m (a, s, w)) -> m (a, s, w)
forall a b. (a -> b) -> a -> b
$ (RWST r w s m a -> m (a, s, w))
-> f (RWST r w s m a) -> f (m (a, s, w))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\RWST r w s m a
m -> RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Lazy.runRWST RWST r w s m a
m r
r s
s) f (RWST r w s m a)
fm

instance (Functor f, MonadFree f m) => MonadFree f (MaybeT m) where
  wrap :: f (MaybeT m a) -> MaybeT m a
wrap = m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe a) -> MaybeT m a)
-> (f (MaybeT m a) -> m (Maybe a)) -> f (MaybeT m a) -> MaybeT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m (Maybe a)) -> m (Maybe a)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (Maybe a)) -> m (Maybe a))
-> (f (MaybeT m a) -> f (m (Maybe a)))
-> f (MaybeT m a)
-> m (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MaybeT m a -> m (Maybe a)) -> f (MaybeT m a) -> f (m (Maybe a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MaybeT m a -> m (Maybe a)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT

instance (Functor f, MonadFree f m) => MonadFree f (IdentityT m) where
  wrap :: f (IdentityT m a) -> IdentityT m a
wrap = m a -> IdentityT m a
forall k (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT (m a -> IdentityT m a)
-> (f (IdentityT m a) -> m a) -> f (IdentityT m a) -> IdentityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m a) -> m a
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m a) -> m a)
-> (f (IdentityT m a) -> f (m a)) -> f (IdentityT m a) -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IdentityT m a -> m a) -> f (IdentityT m a) -> f (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap IdentityT m a -> m a
forall k (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT

instance (Functor f, MonadFree f m) => MonadFree f (ListT m) where
  wrap :: f (ListT m a) -> ListT m a
wrap = m [a] -> ListT m a
forall (m :: * -> *) a. m [a] -> ListT m a
ListT (m [a] -> ListT m a)
-> (f (ListT m a) -> m [a]) -> f (ListT m a) -> ListT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m [a]) -> m [a]
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m [a]) -> m [a])
-> (f (ListT m a) -> f (m [a])) -> f (ListT m a) -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ListT m a -> m [a]) -> f (ListT m a) -> f (m [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ListT m a -> m [a]
forall (m :: * -> *) a. ListT m a -> m [a]
runListT

instance (Functor f, MonadFree f m, Error e) => MonadFree f (ErrorT e m) where
  wrap :: f (ErrorT e m a) -> ErrorT e m a
wrap = m (Either e a) -> ErrorT e m a
forall e (m :: * -> *) a. m (Either e a) -> ErrorT e m a
ErrorT (m (Either e a) -> ErrorT e m a)
-> (f (ErrorT e m a) -> m (Either e a))
-> f (ErrorT e m a)
-> ErrorT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m (Either e a)) -> m (Either e a)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (Either e a)) -> m (Either e a))
-> (f (ErrorT e m a) -> f (m (Either e a)))
-> f (ErrorT e m a)
-> m (Either e a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ErrorT e m a -> m (Either e a))
-> f (ErrorT e m a) -> f (m (Either e a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ErrorT e m a -> m (Either e a)
forall e (m :: * -> *) a. ErrorT e m a -> m (Either e a)
runErrorT

instance (Functor f, MonadFree f m) => MonadFree f (ExceptT e m) where
  wrap :: f (ExceptT e m a) -> ExceptT e m a
wrap = m (Either e a) -> ExceptT e m a
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either e a) -> ExceptT e m a)
-> (f (ExceptT e m a) -> m (Either e a))
-> f (ExceptT e m a)
-> ExceptT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m (Either e a)) -> m (Either e a)
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m (Either e a)) -> m (Either e a))
-> (f (ExceptT e m a) -> f (m (Either e a)))
-> f (ExceptT e m a)
-> m (Either e a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ExceptT e m a -> m (Either e a))
-> f (ExceptT e m a) -> f (m (Either e a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ExceptT e m a -> m (Either e a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT

-- instance (Functor f, MonadFree f m) => MonadFree f (EitherT e m) where
--   wrap = EitherT . wrap . fmap runEitherT

-- | A version of lift that can be used with just a Functor for f.
liftF :: (Functor f, MonadFree f m) => f a -> m a
liftF :: f a -> m a
liftF = f (m a) -> m a
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (f (m a) -> m a) -> (f a -> f (m a)) -> f a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m a) -> f a -> f (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return

-- | A version of wrap for monad transformers over a free monad.
--
-- /Note:/ that this is the default implementation for 'wrap' for
-- @MonadFree f (t m)@.
wrapT :: (Functor f, MonadFree f m, MonadTrans t, Monad (t m)) => f (t m a) -> t m a
wrapT :: f (t m a) -> t m a
wrapT = t m (t m a) -> t m a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (t m (t m a) -> t m a)
-> (f (t m a) -> t m (t m a)) -> f (t m a) -> t m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (t m a) -> t m (t m a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (t m a) -> t m (t m a))
-> (f (t m a) -> m (t m a)) -> f (t m a) -> t m (t m a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (t m a) -> m (t m a)
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF