module Language.Lambda.Untyped.State
  ( EvalState(..),
    Eval(),
    Globals(),
    runEval,
    execEval,
    unsafeExecEval,
    unsafeRunEval,
    globals,
    uniques,
    mkEvalState,
    getGlobals,
    getUniques,
    setGlobals,
    setUniques
  ) where

import Language.Lambda.Shared.Errors
import Language.Lambda.Untyped.Expression 

import Control.Monad.Except
import RIO
import RIO.State
import qualified RIO.Map as Map

-- | The evaluation state
data EvalState name = EvalState
  { EvalState name -> Globals name
esGlobals :: Globals name,
    EvalState name -> [name]
esUniques :: [name] -- ^ Unused unique names
  }

-- | A stateful computation
type Eval name
  = StateT (EvalState name)
      (Except LambdaException)

-- | A mapping of global variables to expressions
type Globals name = Map name (LambdaExpr name)

-- | Run an evalualation
runEval :: Eval name result -> EvalState name -> Either LambdaException (result, EvalState name)
runEval :: Eval name result
-> EvalState name
-> Either LambdaException (result, EvalState name)
runEval Eval name result
computation = Except LambdaException (result, EvalState name)
-> Either LambdaException (result, EvalState name)
forall e a. Except e a -> Either e a
runExcept (Except LambdaException (result, EvalState name)
 -> Either LambdaException (result, EvalState name))
-> (EvalState name
    -> Except LambdaException (result, EvalState name))
-> EvalState name
-> Either LambdaException (result, EvalState name)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Eval name result
-> EvalState name
-> Except LambdaException (result, EvalState name)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT Eval name result
computation

-- | Run an evalualation, throwing away the final state
execEval :: Eval name result -> EvalState name -> Either LambdaException result
execEval :: Eval name result -> EvalState name -> Either LambdaException result
execEval Eval name result
computation = Except LambdaException result -> Either LambdaException result
forall e a. Except e a -> Either e a
runExcept (Except LambdaException result -> Either LambdaException result)
-> (EvalState name -> Except LambdaException result)
-> EvalState name
-> Either LambdaException result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Eval name result -> EvalState name -> Except LambdaException result
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT Eval name result
computation

-- | Run an evaluation. If the result is an error, throws it
unsafeRunEval :: Eval name result -> EvalState name -> (result, EvalState name)
unsafeRunEval :: Eval name result -> EvalState name -> (result, EvalState name)
unsafeRunEval Eval name result
computation EvalState name
state'
  = case Eval name result
-> EvalState name
-> Either LambdaException (result, EvalState name)
forall name result.
Eval name result
-> EvalState name
-> Either LambdaException (result, EvalState name)
runEval Eval name result
computation EvalState name
state' of
      Left LambdaException
err -> [Char] -> (result, EvalState name)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (result, EvalState name))
-> [Char] -> (result, EvalState name)
forall a b. (a -> b) -> a -> b
$ LambdaException -> [Char]
forall a. Show a => a -> [Char]
show LambdaException
err
      Right (result, EvalState name)
res -> (result, EvalState name)
res
  
-- | Run an evaluation, throwing away the final state. If the result is an error, throws it
unsafeExecEval:: Eval name result -> EvalState name -> result
unsafeExecEval :: Eval name result -> EvalState name -> result
unsafeExecEval Eval name result
computation EvalState name
state'
  = case Eval name result -> EvalState name -> Either LambdaException result
forall name result.
Eval name result -> EvalState name -> Either LambdaException result
execEval Eval name result
computation EvalState name
state' of
      Left LambdaException
err -> LambdaException -> result
forall e a. Exception e => e -> a
impureThrow LambdaException
err
      Right result
res -> result
res

-- | Create an EvalState
mkEvalState :: [name] -> EvalState name
mkEvalState :: [name] -> EvalState name
mkEvalState = Globals name -> [name] -> EvalState name
forall name. Globals name -> [name] -> EvalState name
EvalState Globals name
forall k a. Map k a
Map.empty

globals :: Lens' (EvalState name) (Globals name)
globals :: (Globals name -> f (Globals name))
-> EvalState name -> f (EvalState name)
globals Globals name -> f (Globals name)
f EvalState name
state'
  = (\Globals name
globals' -> EvalState name
state' { esGlobals :: Globals name
esGlobals = Globals name
globals' })
  (Globals name -> EvalState name)
-> f (Globals name) -> f (EvalState name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Globals name -> f (Globals name)
f (EvalState name -> Globals name
forall name. EvalState name -> Globals name
esGlobals EvalState name
state')

uniques :: Lens' (EvalState name) [name]
uniques :: ([name] -> f [name]) -> EvalState name -> f (EvalState name)
uniques [name] -> f [name]
f EvalState name
state'
  = (\[name]
uniques' -> EvalState name
state' { esUniques :: [name]
esUniques = [name]
uniques' })
  ([name] -> EvalState name) -> f [name] -> f (EvalState name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [name] -> f [name]
f (EvalState name -> [name]
forall name. EvalState name -> [name]
esUniques EvalState name
state')

-- | Access globals from the state monad
getGlobals :: Eval name (Globals name)
getGlobals :: Eval name (Globals name)
getGlobals = (EvalState name -> Globals name) -> Eval name (Globals name)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EvalState name
-> Getting (Globals name) (EvalState name) (Globals name)
-> Globals name
forall s a. s -> Getting a s a -> a
^. Getting (Globals name) (EvalState name) (Globals name)
forall name. Lens' (EvalState name) (Globals name)
globals)

-- | Access unique supply from state monad
getUniques :: Eval name [name]
getUniques :: Eval name [name]
getUniques = (EvalState name -> [name]) -> Eval name [name]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EvalState name -> Getting [name] (EvalState name) [name] -> [name]
forall s a. s -> Getting a s a -> a
^. Getting [name] (EvalState name) [name]
forall name. Lens' (EvalState name) [name]
uniques)

setGlobals :: Globals name -> Eval name ()
setGlobals :: Globals name -> Eval name ()
setGlobals Globals name
globals' = (EvalState name -> EvalState name) -> Eval name ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (EvalState name
-> (EvalState name -> EvalState name) -> EvalState name
forall a b. a -> (a -> b) -> b
& (Globals name -> Identity (Globals name))
-> EvalState name -> Identity (EvalState name)
forall name. Lens' (EvalState name) (Globals name)
globals ((Globals name -> Identity (Globals name))
 -> EvalState name -> Identity (EvalState name))
-> Globals name -> EvalState name -> EvalState name
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Globals name
globals')

setUniques :: [name] -> Eval name ()
setUniques :: [name] -> Eval name ()
setUniques [name]
uniques' = (EvalState name -> EvalState name) -> Eval name ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (EvalState name
-> (EvalState name -> EvalState name) -> EvalState name
forall a b. a -> (a -> b) -> b
& ([name] -> Identity [name])
-> EvalState name -> Identity (EvalState name)
forall name. Lens' (EvalState name) [name]
uniques (([name] -> Identity [name])
 -> EvalState name -> Identity (EvalState name))
-> [name] -> EvalState name -> EvalState name
forall s t a b. ASetter s t a b -> b -> s -> t
.~ [name]
uniques')