{-# OPTIONS -Wno-orphans #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_HADDOCK hide #-}

module Control.Functor.Linear.Internal.State
  ( StateT (..),
    State,
    state,
    get,
    put,
    gets,
    modify,
    replace,
    runStateT,
    runState,
    mapStateT,
    mapState,
    execStateT,
    execState,
    withStateT,
    withState,
  )
where

import Control.Functor.Linear.Internal.Class
import Control.Functor.Linear.Internal.Instances (Data (..))
import Control.Functor.Linear.Internal.MonadTrans
import qualified Control.Monad as NonLinear ()
import qualified Control.Monad.Trans.State.Strict as NonLinear
import Data.Functor.Identity
import qualified Data.Functor.Linear.Internal.Applicative as Data
import qualified Data.Functor.Linear.Internal.Functor as Data
import Data.Unrestricted.Linear.Internal.Consumable
import Data.Unrestricted.Linear.Internal.Dupable
import Prelude.Linear.Internal

-- # StateT
-------------------------------------------------------------------------------

-- | A (strict) linear state monad transformer.
newtype StateT s m a = StateT (s %1 -> m (a, s))
  deriving (Functor (StateT s m)
Functor (StateT s m)
-> (forall a. a -> StateT s m a)
-> (forall a b.
    StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b)
-> (forall a b c.
    (a %1 -> b %1 -> c)
    -> StateT s m a %1 -> StateT s m b %1 -> StateT s m c)
-> Applicative (StateT s m)
forall a. a -> StateT s m a
forall a b.
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
forall a b c.
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
forall {s} {m :: * -> *}. Monad m => Functor (StateT s m)
forall s (m :: * -> *) a. Monad m => a -> StateT s m a
forall s (m :: * -> *) a b.
Monad m =>
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
forall s (m :: * -> *) a b c.
Monad m =>
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a %1 -> b) %1 -> f a %1 -> f b)
-> (forall a b c. (a %1 -> b %1 -> c) -> f a %1 -> f b %1 -> f c)
-> Applicative f
liftA2 :: forall a b c.
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
$cliftA2 :: forall s (m :: * -> *) a b c.
Monad m =>
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
<*> :: forall a b.
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
$c<*> :: forall s (m :: * -> *) a b.
Monad m =>
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
pure :: forall a. a -> StateT s m a
$cpure :: forall s (m :: * -> *) a. Monad m => a -> StateT s m a
Data.Applicative) via Data (StateT s m)

-- We derive Data.Applicative and not Data.Functor since Data.Functor can use
-- weaker constraints on m than Control.Functor, while
-- Data.Applicative needs a Monad instance just like Control.Applicative.

type State s = StateT s Identity

get :: (Applicative m, Dupable s) => StateT s m s
get :: forall (m :: * -> *) s. (Applicative m, Dupable s) => StateT s m s
get = (s %1 -> (s, s)) %1 -> StateT s m s
forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state s %1 -> (s, s)
forall a. Dupable a => a %1 -> (a, a)
dup

put :: (Applicative m, Consumable s) => s %1 -> StateT s m ()
put :: forall (m :: * -> *) s.
(Applicative m, Consumable s) =>
s %1 -> StateT s m ()
put = StateT s m s %1 -> StateT s m ()
forall (f :: * -> *) a. (Functor f, Consumable a) => f a %1 -> f ()
Data.void (StateT s m s %1 -> StateT s m ())
-> (s %1 -> StateT s m s) -> s %1 -> StateT s m ()
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> StateT s m s
forall (m :: * -> *) s. Applicative m => s %1 -> StateT s m s
replace

gets :: (Applicative m, Dupable s) => (s %1 -> a) %1 -> StateT s m a
gets :: forall (m :: * -> *) s a.
(Applicative m, Dupable s) =>
(s %1 -> a) %1 -> StateT s m a
gets s %1 -> a
f = (s %1 -> (a, s)) %1 -> StateT s m a
forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state ((\(s
s1, s
s2) -> (s %1 -> a
f s
s1, s
s2)) ((s, s) %1 -> (a, s)) %1 -> (s %1 -> (s, s)) -> s %1 -> (a, s)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> (s, s)
forall a. Dupable a => a %1 -> (a, a)
dup)

runStateT :: StateT s m a %1 -> s %1 -> m (a, s)
runStateT :: forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT (StateT s %1 -> m (a, s)
f) = s %1 -> m (a, s)
f

state :: Applicative m => (s %1 -> (a, s)) %1 -> StateT s m a
state :: forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state s %1 -> (a, s)
f = (s %1 -> m (a, s)) %1 -> StateT s m a
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT ((a, s) %1 -> m (a, s)
forall (f :: * -> *) a. Applicative f => a %1 -> f a
pure ((a, s) %1 -> m (a, s)) -> (s %1 -> (a, s)) %1 -> s %1 -> m (a, s)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> (a, s)
f)

runState :: State s a %1 -> s %1 -> (a, s)
runState :: forall s a. State s a %1 -> s %1 -> (a, s)
runState State s a
f = Identity (a, s) %1 -> (a, s)
forall a (p :: Multiplicity). Identity a %p -> a
runIdentity' (Identity (a, s) %1 -> (a, s))
-> (s %1 -> Identity (a, s)) %1 -> s %1 -> (a, s)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. State s a %1 -> s %1 -> Identity (a, s)
forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT State s a
f

mapStateT :: (m (a, s) %1 -> n (b, s)) %1 -> StateT s m a %1 -> StateT s n b
mapStateT :: forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) %1 -> n (b, s)) %1 -> StateT s m a %1 -> StateT s n b
mapStateT m (a, s) %1 -> n (b, s)
r (StateT s %1 -> m (a, s)
f) = (s %1 -> n (b, s)) %1 -> StateT s n b
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (m (a, s) %1 -> n (b, s)
r (m (a, s) %1 -> n (b, s))
%1 -> (s %1 -> m (a, s)) %1 -> s %1 -> n (b, s)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> m (a, s)
f)

withStateT :: (s %1 -> s) %1 -> StateT s m a %1 -> StateT s m a
withStateT :: forall s (m :: * -> *) a.
(s %1 -> s) %1 -> StateT s m a %1 -> StateT s m a
withStateT s %1 -> s
r (StateT s %1 -> m (a, s)
f) = (s %1 -> m (a, s)) %1 -> StateT s m a
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (s %1 -> m (a, s)
f (s %1 -> m (a, s)) %1 -> (s %1 -> s) %1 -> s %1 -> m (a, s)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> s
r)

execStateT :: Functor m => StateT s m () %1 -> s %1 -> m s
execStateT :: forall (m :: * -> *) s.
Functor m =>
StateT s m () %1 -> s %1 -> m s
execStateT StateT s m ()
f = (((), s) %1 -> s) %1 -> m ((), s) %1 -> m s
forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (\((), s
s) -> s
s) (m ((), s) %1 -> m s) -> (s %1 -> m ((), s)) %1 -> s %1 -> m s
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. (StateT s m () %1 -> s %1 -> m ((), s)
forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT StateT s m ()
f)

mapState :: ((a, s) %1 -> (b, s)) %1 -> State s a %1 -> State s b
mapState :: forall a s b. ((a, s) %1 -> (b, s)) %1 -> State s a %1 -> State s b
mapState (a, s) %1 -> (b, s)
f = (Identity (a, s) %1 -> Identity (b, s))
%1 -> StateT s Identity a %1 -> StateT s Identity b
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) %1 -> n (b, s)) %1 -> StateT s m a %1 -> StateT s n b
mapStateT ((b, s) %1 -> Identity (b, s)
forall a. a -> Identity a
Identity ((b, s) %1 -> Identity (b, s))
-> (Identity (a, s) %1 -> (b, s))
%1 -> Identity (a, s)
%1 -> Identity (b, s)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. (a, s) %1 -> (b, s)
f ((a, s) %1 -> (b, s))
%1 -> (Identity (a, s) %1 -> (a, s))
-> Identity (a, s)
%1 -> (b, s)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. Identity (a, s) %1 -> (a, s)
forall a (p :: Multiplicity). Identity a %p -> a
runIdentity')

withState :: (s %1 -> s) %1 -> State s a %1 -> State s a
withState :: forall s a. (s %1 -> s) %1 -> State s a %1 -> State s a
withState = (s %1 -> s) %1 -> StateT s Identity a %1 -> StateT s Identity a
forall s (m :: * -> *) a.
(s %1 -> s) %1 -> StateT s m a %1 -> StateT s m a
withStateT

execState :: State s () %1 -> s %1 -> s
execState :: forall s. State s () %1 -> s %1 -> s
execState State s ()
f = Identity s %1 -> s
forall a (p :: Multiplicity). Identity a %p -> a
runIdentity' (Identity s %1 -> s) -> (s %1 -> Identity s) %1 -> s %1 -> s
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. State s () %1 -> s %1 -> Identity s
forall (m :: * -> *) s.
Functor m =>
StateT s m () %1 -> s %1 -> m s
execStateT State s ()
f

modify :: Applicative m => (s %1 -> s) %1 -> StateT s m ()
modify :: forall (m :: * -> *) s.
Applicative m =>
(s %1 -> s) %1 -> StateT s m ()
modify s %1 -> s
f = (s %1 -> ((), s)) %1 -> StateT s m ()
forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state ((s %1 -> ((), s)) %1 -> StateT s m ())
-> (s %1 -> ((), s)) %1 -> StateT s m ()
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> ((), s %1 -> s
f s
s)

-- TODO: add strict version of `modify`

-- | @replace s@ will replace the current state with the new given state, and
-- return the old state.
replace :: Applicative m => s %1 -> StateT s m s
replace :: forall (m :: * -> *) s. Applicative m => s %1 -> StateT s m s
replace s
s = (s %1 -> (s, s)) %1 -> StateT s m s
forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state ((s %1 -> (s, s)) %1 -> StateT s m s)
-> (s %1 -> (s, s)) %1 -> StateT s m s
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ (\s
s' -> (s
s', s
s))

-- # Instances of StateT
-------------------------------------------------------------------------------

instance Functor m => Functor (NonLinear.StateT s m) where
  fmap :: forall a b. (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
fmap a %1 -> b
f (NonLinear.StateT s -> m (a, s)
x) = (s -> m (b, s)) %1 -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
NonLinear.StateT ((s -> m (b, s)) %1 -> StateT s m b)
-> (s -> m (b, s)) %1 -> StateT s m b
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> ((a, s) %1 -> (b, s)) %1 -> m (a, s) %1 -> m (b, s)
forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (\(a
a, s
s') -> (a %1 -> b
f a
a, s
s')) (m (a, s) %1 -> m (b, s)) %1 -> m (a, s) %1 -> m (b, s)
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ s -> m (a, s)
x s
s

instance Data.Functor m => Data.Functor (StateT s m) where
  fmap :: forall a b. (a %1 -> b) -> StateT s m a %1 -> StateT s m b
fmap a %1 -> b
f (StateT s %1 -> m (a, s)
x) = (s %1 -> m (b, s)) %1 -> StateT s m b
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> ((a, s) %1 -> (b, s)) -> m (a, s) %1 -> m (b, s)
forall (f :: * -> *) a b. Functor f => (a %1 -> b) -> f a %1 -> f b
Data.fmap (\(a
a, s
s') -> (a %1 -> b
f a
a, s
s')) (s %1 -> m (a, s)
x s
s))

instance Functor m => Functor (StateT s m) where
  fmap :: forall a b. (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
fmap a %1 -> b
f (StateT s %1 -> m (a, s)
x) = (s %1 -> m (b, s)) %1 -> StateT s m b
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> ((a, s) %1 -> (b, s)) %1 -> m (a, s) %1 -> m (b, s)
forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (\(a
a, s
s') -> (a %1 -> b
f a
a, s
s')) (s %1 -> m (a, s)
x s
s))

instance Monad m => Applicative (StateT s m) where
  pure :: forall a. a %1 -> StateT s m a
pure a
x = (s %1 -> m (a, s)) %1 -> StateT s m a
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> (a, s) %1 -> m (a, s)
forall (m :: * -> *) a. Monad m => a %1 -> m a
return (a
x, s
s))
  StateT s %1 -> m (a %1 -> b, s)
mf <*> :: forall a b.
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
<*> StateT s %1 -> m (a, s)
mx = (s %1 -> m (b, s)) %1 -> StateT s m b
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT ((s %1 -> m (b, s)) %1 -> StateT s m b)
-> (s %1 -> m (b, s)) %1 -> StateT s m b
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> do
    (a %1 -> b
f, s
s') <- s %1 -> m (a %1 -> b, s)
mf s
s
    (a
x, s
s'') <- s %1 -> m (a, s)
mx s
s'
    (b, s) %1 -> m (b, s)
forall (m :: * -> *) a. Monad m => a %1 -> m a
return (a %1 -> b
f a
x, s
s'')

instance Monad m => Monad (StateT s m) where
  StateT s %1 -> m (a, s)
mx >>= :: forall a b.
StateT s m a %1 -> (a %1 -> StateT s m b) %1 -> StateT s m b
>>= a %1 -> StateT s m b
f = (s %1 -> m (b, s)) %1 -> StateT s m b
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT ((s %1 -> m (b, s)) %1 -> StateT s m b)
-> (s %1 -> m (b, s)) %1 -> StateT s m b
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> do
    (a
x, s
s') <- s %1 -> m (a, s)
mx s
s
    StateT s m b %1 -> s %1 -> m (b, s)
forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT (a %1 -> StateT s m b
f a
x) s
s'

instance MonadTrans (StateT s) where
  lift :: forall (m :: * -> *) a. Monad m => m a %1 -> StateT s m a
lift m a
x = (s %1 -> m (a, s)) %1 -> StateT s m a
forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> (a %1 -> (a, s)) %1 -> m a %1 -> m (a, s)
forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (,s
s) m a
x)