{-# LANGUAGE Safe #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
-- Search for UndecidableInstances to see why this is needed
{-# LANGUAGE UndecidableInstances #-}
-- Needed because the CPSed versions of Writer and State are secretly State
-- wrappers, which don't force such constraints, even though they should legally
-- be there.
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.State.Class
-- Copyright   :  (c) Andy Gill 2001,
--                (c) Oregon Graduate Institute of Science and Technology, 2001
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  libraries@haskell.org
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)
--
-- MonadState class.
--
--      This module is inspired by the paper
--      /Functional Programming with Overloading and Higher-Order Polymorphism/,
--        Mark P Jones (<http://web.cecs.pdx.edu/~mpj/>)
--          Advanced School of Functional Programming, 1995.

-----------------------------------------------------------------------------

module Control.Monad.State.Class (
    MonadState(..),
    modify,
    modify',
    gets
  ) where

import Control.Monad.Trans.Cont (ContT)
import Control.Monad.Trans.Except (ExceptT)
import Control.Monad.Trans.Identity (IdentityT)
import Control.Monad.Trans.Maybe (MaybeT) 
import Control.Monad.Trans.Reader (ReaderT)
import qualified Control.Monad.Trans.RWS.Lazy as LazyRWS
import qualified Control.Monad.Trans.RWS.Strict as StrictRWS
import qualified Control.Monad.Trans.State.Lazy as Lazy
import qualified Control.Monad.Trans.State.Strict as Strict
import qualified Control.Monad.Trans.Writer.Lazy as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict
import Control.Monad.Trans.Accum (AccumT)
import Control.Monad.Trans.Select (SelectT)
import qualified Control.Monad.Trans.RWS.CPS as CPSRWS
import qualified Control.Monad.Trans.Writer.CPS as CPS
import Control.Monad.Trans.Class (lift)

-- ---------------------------------------------------------------------------

-- | Minimal definition is either both of @get@ and @put@ or just @state@
class Monad m => MonadState s m | m -> s where
    -- | Return the state from the internals of the monad.
    get :: m s
    get = (s -> (s, s)) -> m s
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state (\s
s -> (s
s, s
s))

    -- | Replace the state inside the monad.
    put :: s -> m ()
    put s
s = (s -> ((), s)) -> m ()
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state (\s
_ -> ((), s
s))

    -- | Embed a simple state action into the monad.
    state :: (s -> (a, s)) -> m a
    state s -> (a, s)
f = do
      s
s <- m s
forall s (m :: * -> *). MonadState s m => m s
get
      let ~(a
a, s
s') = s -> (a, s)
f s
s
      s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put s
s'
      a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
    {-# MINIMAL state | get, put #-}

-- | Monadic state transformer.
--
--      Maps an old state to a new state inside a state monad.
--      The old state is thrown away.
--
-- >      Main> :t modify ((+1) :: Int -> Int)
-- >      modify (...) :: (MonadState Int a) => a ()
--
--    This says that @modify (+1)@ acts over any
--    Monad that is a member of the @MonadState@ class,
--    with an @Int@ state.
modify :: MonadState s m => (s -> s) -> m ()
modify :: (s -> s) -> m ()
modify s -> s
f = (s -> ((), s)) -> m ()
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state (\s
s -> ((), s -> s
f s
s))

-- | A variant of 'modify' in which the computation is strict in the
-- new state.
--
-- @since 2.2
modify' :: MonadState s m => (s -> s) -> m ()
modify' :: (s -> s) -> m ()
modify' s -> s
f = do
  s
s' <- m s
forall s (m :: * -> *). MonadState s m => m s
get
  s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (s -> m ()) -> s -> m ()
forall a b. (a -> b) -> a -> b
$! s -> s
f s
s'

-- | Gets specific component of the state, using a projection function
-- supplied.
gets :: MonadState s m => (s -> a) -> m a
gets :: (s -> a) -> m a
gets s -> a
f = do
    s
s <- m s
forall s (m :: * -> *). MonadState s m => m s
get
    a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (s -> a
f s
s)

instance Monad m => MonadState s (Lazy.StateT s m) where
    get :: StateT s m s
get = StateT s m s
forall (m :: * -> *) s. Monad m => StateT s m s
Lazy.get
    put :: s -> StateT s m ()
put = s -> StateT s m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
Lazy.put
    state :: (s -> (a, s)) -> StateT s m a
state = (s -> (a, s)) -> StateT s m a
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
Lazy.state

instance Monad m => MonadState s (Strict.StateT s m) where
    get :: StateT s m s
get = StateT s m s
forall (m :: * -> *) s. Monad m => StateT s m s
Strict.get
    put :: s -> StateT s m ()
put = s -> StateT s m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
Strict.put
    state :: (s -> (a, s)) -> StateT s m a
state = (s -> (a, s)) -> StateT s m a
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
Strict.state

-- | @since 2.3
instance (Monad m, Monoid w) => MonadState s (CPSRWS.RWST r w s m) where
    get :: RWST r w s m s
get = RWST r w s m s
forall (m :: * -> *) r w s. Monad m => RWST r w s m s
CPSRWS.get
    put :: s -> RWST r w s m ()
put = s -> RWST r w s m ()
forall (m :: * -> *) s r w. Monad m => s -> RWST r w s m ()
CPSRWS.put
    state :: (s -> (a, s)) -> RWST r w s m a
state = (s -> (a, s)) -> RWST r w s m a
forall (m :: * -> *) s a r w.
Monad m =>
(s -> (a, s)) -> RWST r w s m a
CPSRWS.state

instance (Monad m, Monoid w) => MonadState s (LazyRWS.RWST r w s m) where
    get :: RWST r w s m s
get = RWST r w s m s
forall w (m :: * -> *) r s. (Monoid w, Monad m) => RWST r w s m s
LazyRWS.get
    put :: s -> RWST r w s m ()
put = s -> RWST r w s m ()
forall w (m :: * -> *) s r.
(Monoid w, Monad m) =>
s -> RWST r w s m ()
LazyRWS.put
    state :: (s -> (a, s)) -> RWST r w s m a
state = (s -> (a, s)) -> RWST r w s m a
forall w (m :: * -> *) s a r.
(Monoid w, Monad m) =>
(s -> (a, s)) -> RWST r w s m a
LazyRWS.state

instance (Monad m, Monoid w) => MonadState s (StrictRWS.RWST r w s m) where
    get :: RWST r w s m s
get = RWST r w s m s
forall w (m :: * -> *) r s. (Monoid w, Monad m) => RWST r w s m s
StrictRWS.get
    put :: s -> RWST r w s m ()
put = s -> RWST r w s m ()
forall w (m :: * -> *) s r.
(Monoid w, Monad m) =>
s -> RWST r w s m ()
StrictRWS.put
    state :: (s -> (a, s)) -> RWST r w s m a
state = (s -> (a, s)) -> RWST r w s m a
forall w (m :: * -> *) s a r.
(Monoid w, Monad m) =>
(s -> (a, s)) -> RWST r w s m a
StrictRWS.state

-- ---------------------------------------------------------------------------
-- Instances for other mtl transformers
--
-- All of these instances need UndecidableInstances,
-- because they do not satisfy the coverage condition.

instance MonadState s m => MonadState s (ContT r m) where
    get :: ContT r m s
get = m s -> ContT r m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> ContT r m ()
put = m () -> ContT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ContT r m ()) -> (s -> m ()) -> s -> ContT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> ContT r m a
state = m a -> ContT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ContT r m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> ContT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

-- | @since 2.2
instance MonadState s m => MonadState s (ExceptT e m) where
    get :: ExceptT e m s
get = m s -> ExceptT e m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> ExceptT e m ()
put = m () -> ExceptT e m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ExceptT e m ()) -> (s -> m ()) -> s -> ExceptT e m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> ExceptT e m a
state = m a -> ExceptT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ExceptT e m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> ExceptT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

instance MonadState s m => MonadState s (IdentityT m) where
    get :: IdentityT m s
get = m s -> IdentityT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> IdentityT m ()
put = m () -> IdentityT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> IdentityT m ()) -> (s -> m ()) -> s -> IdentityT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> IdentityT m a
state = m a -> IdentityT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> IdentityT m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> IdentityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

instance MonadState s m => MonadState s (MaybeT m) where
    get :: MaybeT m s
get = m s -> MaybeT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> MaybeT m ()
put = m () -> MaybeT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> MaybeT m ()) -> (s -> m ()) -> s -> MaybeT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> MaybeT m a
state = m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> MaybeT m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> MaybeT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

instance MonadState s m => MonadState s (ReaderT r m) where
    get :: ReaderT r m s
get = m s -> ReaderT r m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> ReaderT r m ()
put = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ()) -> (s -> m ()) -> s -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> ReaderT r m a
state = m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ReaderT r m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> ReaderT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

-- | @since 2.3
instance (Monoid w, MonadState s m) => MonadState s (CPS.WriterT w m) where
    get :: WriterT w m s
get = m s -> WriterT w m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> WriterT w m ()
put = m () -> WriterT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WriterT w m ()) -> (s -> m ()) -> s -> WriterT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> WriterT w m a
state = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

instance (Monoid w, MonadState s m) => MonadState s (Lazy.WriterT w m) where
    get :: WriterT w m s
get = m s -> WriterT w m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> WriterT w m ()
put = m () -> WriterT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WriterT w m ()) -> (s -> m ()) -> s -> WriterT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> WriterT w m a
state = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

instance (Monoid w, MonadState s m) => MonadState s (Strict.WriterT w m) where
    get :: WriterT w m s
get = m s -> WriterT w m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> WriterT w m ()
put = m () -> WriterT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WriterT w m ()) -> (s -> m ()) -> s -> WriterT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> WriterT w m a
state = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

-- | @since 2.3
instance
  ( Monoid w
  , MonadState s m
  ) => MonadState s (AccumT w m) where
    get :: AccumT w m s
get = m s -> AccumT w m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> AccumT w m ()
put = m () -> AccumT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> AccumT w m ()) -> (s -> m ()) -> s -> AccumT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> AccumT w m a
state = m a -> AccumT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> AccumT w m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> AccumT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state

-- | @since 2.3
instance MonadState s m => MonadState s (SelectT r m) where
    get :: SelectT r m s
get = m s -> SelectT r m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> SelectT r m ()
put = m () -> SelectT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SelectT r m ()) -> (s -> m ()) -> s -> SelectT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
    state :: (s -> (a, s)) -> SelectT r m a
state = m a -> SelectT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> SelectT r m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> SelectT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state