module Ether.State
    (
    -- * MonadState class
      MonadState
    , get
    , put
    , state
    , modify
    , gets
    -- * The State monad
    , State
    , runState
    , evalState
    , execState
    -- * The StateT monad transformer
    , StateT
    , stateT
    , runStateT
    , evalStateT
    , execStateT
    -- * The State monad (lazy)
    , LazyState
    , runLazyState
    , evalLazyState
    , execLazyState
    -- * The StateT monad transformer (lazy)
    , LazyStateT
    , lazyStateT
    , runLazyStateT
    , evalLazyStateT
    , execLazyStateT
    -- * The State monad (flattened)
    , States
    , runStates
    -- * The StateT monad transformer (flattened)
    , StatesT
    , runStatesT
    -- * MonadState class (implicit)
    , MonadState'
    , get'
    , put'
    , state'
    , modify'
    , gets'
    -- * The State monad (implicit)
    , State'
    , runState'
    , evalState'
    , execState'
    -- * The StateT monad transformer (implicit)
    , StateT'
    , stateT'
    , runStateT'
    , evalStateT'
    , execStateT'
    -- * The State monad (lazy, implicit)
    , LazyState'
    , runLazyState'
    , evalLazyState'
    , execLazyState'
    -- * The StateT monad transformer (lazy, implicit)
    , LazyStateT'
    , lazyStateT'
    , runLazyStateT'
    , evalLazyStateT'
    , execLazyStateT'
    -- * Zoom
    , ZoomT
    , zoom
    -- * Internal labels
    , TAGGED
    , STATE
    , STATES
    , ZOOM
    ) where

import qualified Control.Monad.State.Class as T
import qualified Control.Monad.Trans as Lift
import Control.Monad.Trans.Identity
import qualified Control.Monad.Trans.State.Lazy as T.Lazy
import qualified Control.Monad.Trans.State.Strict as T.Strict
import Data.Coerce
import Data.Functor.Identity
import Data.Kind
import Data.Proxy
import Data.Reflection

import Ether.Internal
import Ether.TaggedTrans

class Monad m => MonadState tag s m | m tag -> s where

    {-# MINIMAL state | get, put #-}

    -- | Return the state from the internals of the monad.
    get :: m s
    get = state @tag (\s -> (s, s))

    -- | Replace the state inside the monad.
    put :: s -> m ()
    put s = state @tag (\_ -> ((), s))

    -- | Embed a simple state action into the monad.
    state :: (s -> (a, s)) -> m a
    state f = do
      s <- get @tag
      let ~(a, s') = f s
      put @tag s'
      return a

instance {-# OVERLAPPABLE #-}
    ( Lift.MonadTrans t
    , Monad (t m)
    , MonadState tag s m
    ) => MonadState tag s (t m)
  where

    get = Lift.lift (get @tag)
    {-# INLINE get #-}

    put = Lift.lift . put @tag
    {-# INLINE put #-}

    state = Lift.lift . state @tag
    {-# INLINE state #-}

instance {-# OVERLAPPABLE #-}
    ( Monad (trans m)
    , MonadState tag s (TaggedTrans effs trans m)
    ) => MonadState tag s (TaggedTrans (eff ': effs) trans (m :: Type -> Type))
  where

    get =
      (coerce ::
        TaggedTrans         effs  trans m s ->
        TaggedTrans (eff ': effs) trans m s)
      (get @tag)
    {-# INLINE get #-}

    put =
      (coerce ::
        (s -> TaggedTrans         effs  trans m ()) ->
        (s -> TaggedTrans (eff ': effs) trans m ()))
      (put @tag)
    {-# INLINE put #-}

    state =
      (coerce :: forall a .
        ((s -> (a, s)) -> TaggedTrans         effs  trans m a) ->
        ((s -> (a, s)) -> TaggedTrans (eff ': effs) trans m a))
      (state @tag)
    {-# INLINE state #-}

-- | Modifies the state inside a state monad.
modify :: forall tag s m . MonadState tag s m => (s -> s) -> m ()
modify f = state @tag (\s -> ((), f s))
{-# INLINABLE modify #-}

-- | Gets specific component of the state, using a projection function supplied.
gets :: forall tag s m a . MonadState tag s m => (s -> a) -> m a
gets f = fmap f (get @tag)
{-# INLINABLE gets #-}

-- | Encode type-level information for 'StateT'.
data STATE

type instance HandleSuper      STATE s trans   = ()
type instance HandleConstraint STATE s trans m =
  T.MonadState s (trans m)

instance Handle STATE s (T.Strict.StateT s) where
  handling r = r
  {-# INLINE handling #-}

instance Handle STATE s (T.Lazy.StateT s) where
  handling r = r
  {-# INLINE handling #-}

instance
    ( Handle STATE s trans
    , Monad m, Monad (trans m)
    ) => MonadState tag s (TaggedTrans (TAGGED STATE tag) trans m)
  where

    get =
      handling @STATE @s @trans @m $
      coerce (T.get @s @(trans m))
    {-# INLINE get #-}

    put =
      handling @STATE @s @trans @m $
      coerce (T.put @s @(trans m))
    {-# INLINE put #-}

    state =
      handling @STATE @s @trans @m $
      coerce (T.state @s @(trans m) @a) ::
        forall eff a . (s -> (a, s)) -> TaggedTrans eff trans m a
    {-# INLINE state #-}

instance
    ( HasLens tag payload s
    , Handle STATE payload trans
    , Monad m, Monad (trans m)
    ) => MonadState tag s (TaggedTrans (TAGGED STATE tag ': effs) trans m)
  where

    get =
      handling @STATE @payload @trans @m $
      (coerce :: forall eff a .
                    trans m a ->
        TaggedTrans eff trans m a)
      (T.gets (view (lensOf @tag @payload @s)))
    {-# INLINE get #-}

    put s =
      handling @STATE @payload @trans @m $
      (coerce :: forall eff a .
                    trans m a ->
        TaggedTrans eff trans m a)
      (T.modify (over (lensOf @tag @payload @s) (const s)))
    {-# INLINE put #-}

    state f =
      handling @STATE @payload @trans @m $
      (coerce :: forall eff a .
                    trans m a ->
        TaggedTrans eff trans m a)
      (T.state (lensOf @tag @payload @s f))
    {-# INLINE state #-}

-- | The parametrizable state monad.
--
-- Computations have access to a mutable state.
--
-- The 'return' function leaves the state unchanged, while '>>=' uses
-- the final state of the first computation as the initial state of the second.
type State tag r = StateT tag r Identity

-- | The state monad transformer.
--
-- The 'return' function leaves the state unchanged, while '>>=' uses
-- the final state of the first computation as the initial state of the second.
type StateT tag s = TaggedTrans (TAGGED STATE tag) (T.Strict.StateT s)

-- | Constructor for computations in the state monad transformer.
stateT :: forall tag s m a . (s -> m (a, s)) -> StateT tag s m a
stateT = coerce (T.Strict.StateT @s @m @a)
{-# INLINE stateT #-}

-- | Runs a 'StateT' with the given initial state
-- and returns both the final value and the final state.
runStateT :: forall tag s m a . StateT tag s m a -> s -> m (a, s)
runStateT = coerce (T.Strict.runStateT @s @m @a)
{-# INLINE runStateT #-}

-- | Runs a 'StateT' with the given initial state
-- and returns the final value, discarding the final state.
evalStateT :: forall tag s m a . Monad m => StateT tag s m a -> s -> m a
evalStateT = coerce (T.Strict.evalStateT @m @s @a)
{-# INLINE evalStateT #-}

-- | Runs a 'StateT' with the given initial state
-- and returns the final state, discarding the final value.
execStateT :: forall tag s m a . Monad m => StateT tag s m a -> s -> m s
execStateT = coerce (T.Strict.execStateT @m @s @a)
{-# INLINE execStateT #-}

-- | Runs a 'State' with the given initial state
-- and returns both the final value and the final state.
runState :: forall tag s a . State tag s a -> s -> (a, s)
runState = coerce (T.Strict.runState @s @a)
{-# INLINE runState #-}

-- | Runs a 'State' with the given initial state
-- and returns the final value, discarding the final state.
evalState :: forall tag s a . State tag s a -> s -> a
evalState = coerce (T.Strict.evalState @s @a)
{-# INLINE evalState #-}

-- | Runs a 'State' with the given initial state
-- and returns the final state, discarding the final value.
execState :: forall tag s a . State tag s a -> s -> s
execState = coerce (T.Strict.execState @s @a)
{-# INLINE execState #-}

-- | The parametrizable state monad.
--
-- Computations have access to a mutable state.
--
-- The 'return' function leaves the state unchanged, while '>>=' uses
-- the final state of the first computation as the initial state of the second.
type LazyState tag r = LazyStateT tag r Identity

-- | The state monad transformer.
--
-- The 'return' function leaves the state unchanged, while '>>=' uses
-- the final state of the first computation as the initial state of the second.
type LazyStateT tag s = TaggedTrans (TAGGED STATE tag) (T.Lazy.StateT s)

-- | Constructor for computations in the state monad transformer.
lazyStateT :: forall tag s m a . (s -> m (a, s)) -> LazyStateT tag s m a
lazyStateT = coerce (T.Lazy.StateT @s @m @a)
{-# INLINE lazyStateT #-}

-- | Runs a 'StateT' with the given initial state
-- and returns both the final value and the final state.
runLazyStateT :: forall tag s m a . LazyStateT tag s m a -> s -> m (a, s)
runLazyStateT = coerce (T.Lazy.runStateT @s @m @a)
{-# INLINE runLazyStateT #-}

-- | Runs a 'StateT' with the given initial state
-- and returns the final value, discarding the final state.
evalLazyStateT :: forall tag s m a . Monad m => LazyStateT tag s m a -> s -> m a
evalLazyStateT = coerce (T.Lazy.evalStateT @m @s @a)
{-# INLINE evalLazyStateT #-}

-- | Runs a 'StateT' with the given initial state
-- and returns the final state, discarding the final value.
execLazyStateT :: forall tag s m a . Monad m => LazyStateT tag s m a -> s -> m s
execLazyStateT = coerce (T.Lazy.execStateT @m @s @a)
{-# INLINE execLazyStateT #-}

-- | Runs a 'State' with the given initial state
-- and returns both the final value and the final state.
runLazyState :: forall tag s a . LazyState tag s a -> s -> (a, s)
runLazyState = coerce (T.Lazy.runState @s @a)
{-# INLINE runLazyState #-}

-- | Runs a 'State' with the given initial state
-- and returns the final value, discarding the final state.
evalLazyState :: forall tag s a . LazyState tag s a -> s -> a
evalLazyState = coerce (T.Lazy.evalState @s @a)
{-# INLINE evalLazyState #-}

-- | Runs a 'State' with the given initial state
-- and returns the final state, discarding the final value.
execLazyState :: forall tag s a . LazyState tag s a -> s -> s
execLazyState = coerce (T.Lazy.execState @s @a)
{-# INLINE execLazyState #-}

type family STATES (ts :: HList xs) :: [Type] where
  STATES 'HNil = '[]
  STATES ('HCons t ts) = TAGGED STATE t ': STATES ts

type StatesT s = TaggedTrans (STATES (Tags s)) (T.Strict.StateT s)

type States s = StatesT s Identity

runStatesT :: forall p m a . StatesT p m a -> p -> m (a, p)
runStatesT = coerce (T.Strict.runStateT @p @m @a)
{-# INLINE runStatesT #-}

runStates :: forall p a . States p a -> p -> (a, p)
runStates = coerce (T.Strict.runState @p @a)
{-# INLINE runStates #-}

type StateT' s = StateT s s

stateT' :: (s -> m (a, s)) -> StateT' s m a
stateT' = stateT
{-# INLINE stateT' #-}

runStateT' :: StateT' s m a -> s -> m (a, s)
runStateT' = runStateT
{-# INLINE runStateT' #-}

runState' :: State' s a -> s -> (a, s)
runState' = runState
{-# INLINE runState' #-}

evalStateT' :: Monad m => StateT' s m a -> s -> m a
evalStateT' = evalStateT
{-# INLINE evalStateT' #-}

type State' s = State s s

evalState' :: State' s a -> s -> a
evalState' = evalState
{-# INLINE evalState' #-}

execStateT' :: Monad m => StateT' s m a -> s -> m s
execStateT' = execStateT
{-# INLINE execStateT' #-}

execState' :: State' s a -> s -> s
execState' = execState
{-# INLINE execState' #-}

type LazyStateT' s = LazyStateT s s

lazyStateT' :: (s -> m (a, s)) -> LazyStateT' s m a
lazyStateT' = lazyStateT
{-# INLINE lazyStateT' #-}

runLazyStateT' :: LazyStateT' s m a -> s -> m (a, s)
runLazyStateT' = runLazyStateT
{-# INLINE runLazyStateT' #-}

runLazyState' :: LazyState' s a -> s -> (a, s)
runLazyState' = runLazyState
{-# INLINE runLazyState' #-}

evalLazyStateT' :: Monad m => LazyStateT' s m a -> s -> m a
evalLazyStateT' = evalLazyStateT
{-# INLINE evalLazyStateT' #-}

type LazyState' s = LazyState s s

evalLazyState' :: LazyState' s a -> s -> a
evalLazyState' = evalLazyState
{-# INLINE evalLazyState' #-}

execLazyStateT' :: Monad m => LazyStateT' s m a -> s -> m s
execLazyStateT' = execLazyStateT
{-# INLINE execLazyStateT' #-}

execLazyState' :: LazyState' s a -> s -> s
execLazyState' = execLazyState
{-# INLINE execLazyState' #-}

type MonadState' s = MonadState s s

get' :: forall s m . MonadState' s m => m s
get' = get @s
{-# INLINE get' #-}

gets' :: forall s m a . MonadState' s m => (s -> a) -> m a
gets' = gets @s
{-# INLINE gets' #-}

put' :: forall s m . MonadState' s m => s -> m ()
put' = put @s
{-# INLINE put' #-}

state' :: forall s m a . MonadState' s m => (s -> (a, s)) -> m a
state' = state @s
{-# INLINE state' #-}

modify' :: forall s m . MonadState' s m => (s -> s) -> m ()
modify' = modify @s
{-# INLINE modify' #-}

-- | Encode type-level information for 'zoom'.
data ZOOM t z

type ZoomT t (z :: Type) = TaggedTrans (ZOOM t z) IdentityT

-- | Zoom into a part of a state using a lens.
zoom
  :: forall tag sOuter sInner m a
   . Lens' sOuter sInner
  -> (forall z . Reifies z (ReifiedLens' sOuter sInner) => ZoomT tag z m a)
  -> m a
zoom l m = reify (Lens l) (\(_ :: Proxy z) -> coerce (m @z))
{-# INLINE zoom #-}

instance
    ( MonadState tag sOuter m
    , Reifies z (ReifiedLens' sOuter sInner)
    , trans ~ IdentityT
    ) => MonadState tag sInner (TaggedTrans (ZOOM tag z) trans m)
  where
    state =
      (coerce :: forall eff r a .
        (r ->                       m a) ->
        (r -> TaggedTrans eff trans m a))
      (state @tag . l)
      where
        Lens l = reflect (Proxy :: Proxy z)
    {-# INLINE state #-}