module Language.Lambda.SystemF.State
  ( TypecheckState(..),
    Typecheck(),
    Context(),
    runTypecheck,
    execTypecheck,
    unsafeRunTypecheck,
    unsafeExecTypecheck,
    mkTypecheckState,
    context,
    uniques,
    getContext,
    getUniques,
    modifyContext,
    modifyUniques,
    setContext,
    setUniques
  ) where

import Language.Lambda.Shared.Errors (LambdaException(..))
import Language.Lambda.SystemF.Expression (Ty(..))

import Control.Monad.Except (Except(), runExcept)
import RIO
import RIO.State
import qualified RIO.Map as Map

data TypecheckState name = TypecheckState
  { TypecheckState name -> Context name
tsContext :: Context name,
    TypecheckState name -> [name]
tsUniques :: [name]
  }

type Typecheck name
  = StateT (TypecheckState name)
      (Except LambdaException)

type Context name = Map name (Ty name)

runTypecheck
  :: Typecheck name result
  -> TypecheckState name
  -> Either LambdaException (result, TypecheckState name)
runTypecheck :: Typecheck name result
-> TypecheckState name
-> Either LambdaException (result, TypecheckState name)
runTypecheck Typecheck name result
computation = Except LambdaException (result, TypecheckState name)
-> Either LambdaException (result, TypecheckState name)
forall e a. Except e a -> Either e a
runExcept (Except LambdaException (result, TypecheckState name)
 -> Either LambdaException (result, TypecheckState name))
-> (TypecheckState name
    -> Except LambdaException (result, TypecheckState name))
-> TypecheckState name
-> Either LambdaException (result, TypecheckState name)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Typecheck name result
-> TypecheckState name
-> Except LambdaException (result, TypecheckState name)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT Typecheck name result
computation

execTypecheck
  :: Typecheck name result
  -> TypecheckState name
  -> Either LambdaException result
execTypecheck :: Typecheck name result
-> TypecheckState name -> Either LambdaException result
execTypecheck Typecheck name result
computation = Except LambdaException result -> Either LambdaException result
forall e a. Except e a -> Either e a
runExcept (Except LambdaException result -> Either LambdaException result)
-> (TypecheckState name -> Except LambdaException result)
-> TypecheckState name
-> Either LambdaException result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Typecheck name result
-> TypecheckState name -> Except LambdaException result
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT Typecheck name result
computation

unsafeRunTypecheck
  :: Typecheck name result
  -> TypecheckState name
  -> (result, TypecheckState name)
unsafeRunTypecheck :: Typecheck name result
-> TypecheckState name -> (result, TypecheckState name)
unsafeRunTypecheck Typecheck name result
computation TypecheckState name
state' = (LambdaException -> (result, TypecheckState name))
-> ((result, TypecheckState name) -> (result, TypecheckState name))
-> Either LambdaException (result, TypecheckState name)
-> (result, TypecheckState name)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either LambdaException -> (result, TypecheckState name)
forall e a. Exception e => e -> a
impureThrow (result, TypecheckState name) -> (result, TypecheckState name)
forall a. a -> a
id Either LambdaException (result, TypecheckState name)
tcResult
  where tcResult :: Either LambdaException (result, TypecheckState name)
tcResult = Typecheck name result
-> TypecheckState name
-> Either LambdaException (result, TypecheckState name)
forall name result.
Typecheck name result
-> TypecheckState name
-> Either LambdaException (result, TypecheckState name)
runTypecheck Typecheck name result
computation TypecheckState name
state'

unsafeExecTypecheck :: Typecheck name result -> TypecheckState name -> result
unsafeExecTypecheck :: Typecheck name result -> TypecheckState name -> result
unsafeExecTypecheck Typecheck name result
computation TypecheckState name
state' = (LambdaException -> result)
-> (result -> result) -> Either LambdaException result -> result
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either LambdaException -> result
forall e a. Exception e => e -> a
impureThrow result -> result
forall a. a -> a
id Either LambdaException result
tcResult
  where tcResult :: Either LambdaException result
tcResult = Typecheck name result
-> TypecheckState name -> Either LambdaException result
forall name result.
Typecheck name result
-> TypecheckState name -> Either LambdaException result
execTypecheck Typecheck name result
computation TypecheckState name
state'

mkTypecheckState :: [name] -> TypecheckState name
mkTypecheckState :: [name] -> TypecheckState name
mkTypecheckState = Context name -> [name] -> TypecheckState name
forall name. Context name -> [name] -> TypecheckState name
TypecheckState Context name
forall k a. Map k a
Map.empty

uniques :: Lens' (TypecheckState name) [name]
uniques :: ([name] -> f [name])
-> TypecheckState name -> f (TypecheckState name)
uniques [name] -> f [name]
f TypecheckState name
state' = (\[name]
uniques' -> TypecheckState name
state' { tsUniques :: [name]
tsUniques = [name]
uniques' })
  ([name] -> TypecheckState name)
-> f [name] -> f (TypecheckState name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [name] -> f [name]
f (TypecheckState name -> [name]
forall name. TypecheckState name -> [name]
tsUniques TypecheckState name
state')

context :: Lens' (TypecheckState name) (Context name)
context :: (Context name -> f (Context name))
-> TypecheckState name -> f (TypecheckState name)
context Context name -> f (Context name)
f TypecheckState name
state' = (\Context name
context' -> TypecheckState name
state' { tsContext :: Context name
tsContext = Context name
context' })
  (Context name -> TypecheckState name)
-> f (Context name) -> f (TypecheckState name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context name -> f (Context name)
f (TypecheckState name -> Context name
forall name. TypecheckState name -> Context name
tsContext TypecheckState name
state')

getUniques :: Typecheck name [name]
getUniques :: Typecheck name [name]
getUniques = (TypecheckState name -> [name]) -> Typecheck name [name]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (TypecheckState name
-> Getting [name] (TypecheckState name) [name] -> [name]
forall s a. s -> Getting a s a -> a
^. Getting [name] (TypecheckState name) [name]
forall name. Lens' (TypecheckState name) [name]
uniques)

getContext :: Typecheck name (Context name)
getContext :: Typecheck name (Context name)
getContext = (TypecheckState name -> Context name)
-> Typecheck name (Context name)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (TypecheckState name
-> Getting (Context name) (TypecheckState name) (Context name)
-> Context name
forall s a. s -> Getting a s a -> a
^. Getting (Context name) (TypecheckState name) (Context name)
forall name. Lens' (TypecheckState name) (Context name)
context)

modifyContext :: (Context name -> Context name) -> Typecheck name ()
modifyContext :: (Context name -> Context name) -> Typecheck name ()
modifyContext Context name -> Context name
f = (TypecheckState name -> TypecheckState name) -> Typecheck name ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TypecheckState name -> TypecheckState name) -> Typecheck name ())
-> (TypecheckState name -> TypecheckState name)
-> Typecheck name ()
forall a b. (a -> b) -> a -> b
$ (Context name -> Identity (Context name))
-> TypecheckState name -> Identity (TypecheckState name)
forall name. Lens' (TypecheckState name) (Context name)
context ((Context name -> Identity (Context name))
 -> TypecheckState name -> Identity (TypecheckState name))
-> (Context name -> Context name)
-> TypecheckState name
-> TypecheckState name
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Context name -> Context name
f

modifyUniques :: ([name] -> [name]) -> Typecheck name ()
modifyUniques :: ([name] -> [name]) -> Typecheck name ()
modifyUniques [name] -> [name]
f = (TypecheckState name -> TypecheckState name) -> Typecheck name ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TypecheckState name -> TypecheckState name) -> Typecheck name ())
-> (TypecheckState name -> TypecheckState name)
-> Typecheck name ()
forall a b. (a -> b) -> a -> b
$ ([name] -> Identity [name])
-> TypecheckState name -> Identity (TypecheckState name)
forall name. Lens' (TypecheckState name) [name]
uniques (([name] -> Identity [name])
 -> TypecheckState name -> Identity (TypecheckState name))
-> ([name] -> [name]) -> TypecheckState name -> TypecheckState name
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ [name] -> [name]
f

setUniques :: [name] -> Typecheck name ()
setUniques :: [name] -> Typecheck name ()
setUniques [name]
uniques' = (TypecheckState name -> TypecheckState name) -> Typecheck name ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (TypecheckState name
-> (TypecheckState name -> TypecheckState name)
-> TypecheckState name
forall a b. a -> (a -> b) -> b
& ([name] -> Identity [name])
-> TypecheckState name -> Identity (TypecheckState name)
forall name. Lens' (TypecheckState name) [name]
uniques (([name] -> Identity [name])
 -> TypecheckState name -> Identity (TypecheckState name))
-> [name] -> TypecheckState name -> TypecheckState name
forall s t a b. ASetter s t a b -> b -> s -> t
.~ [name]
uniques')

setContext :: Context name -> Typecheck name ()
setContext :: Context name -> Typecheck name ()
setContext Context name
context' = (TypecheckState name -> TypecheckState name) -> Typecheck name ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (TypecheckState name
-> (TypecheckState name -> TypecheckState name)
-> TypecheckState name
forall a b. a -> (a -> b) -> b
& (Context name -> Identity (Context name))
-> TypecheckState name -> Identity (TypecheckState name)
forall name. Lens' (TypecheckState name) (Context name)
context ((Context name -> Identity (Context name))
 -> TypecheckState name -> Identity (TypecheckState name))
-> Context name -> TypecheckState name -> TypecheckState name
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Context name
context')