module Language.Lambda.Untyped.Eval
  ( EvalState(..),
    evalExpr,
    subGlobals,
    betaReduce,
    alphaConvert,
    etaConvert,
    freeVarsOf
  ) where

import Control.Monad.Except
import Prettyprinter
import RIO
import RIO.List (find)
import qualified RIO.Map as Map

import Language.Lambda.Shared.Errors
import Language.Lambda.Untyped.Expression
import Language.Lambda.Untyped.State

-- | Evaluate an expression
evalExpr :: (Pretty name, Ord name) => LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr :: forall name.
(Pretty name, Ord name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr (Let name
name LambdaExpr name
expr) = do
  Globals name
globals' <- forall name. Eval name (Globals name)
getGlobals
  LambdaExpr name
result <- forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' forall a b. (a -> b) -> a -> b
$ forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Globals name
globals' LambdaExpr name
expr

  forall name. Globals name -> Eval name ()
setGlobals forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert name
name LambdaExpr name
result Globals name
globals'

  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall name. name -> LambdaExpr name -> LambdaExpr name
Let name
name LambdaExpr name
result

evalExpr LambdaExpr name
expr = do
  Globals name
globals' <- forall name. Eval name (Globals name)
getGlobals
  forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' forall a b. (a -> b) -> a -> b
$ forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Globals name
globals' LambdaExpr name
expr

-- | Evaluate an expression; does not support `let`
evalExpr' :: (Eq name, Pretty name) => LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' :: forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' expr :: LambdaExpr name
expr@(Var name
_) = forall (m :: * -> *) a. Monad m => a -> m a
return LambdaExpr name
expr
evalExpr' (Abs name
name LambdaExpr name
expr) = forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' LambdaExpr name
expr
evalExpr' (App LambdaExpr name
e1 LambdaExpr name
e2) = do
  LambdaExpr name
e1' <- forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' LambdaExpr name
e1
  LambdaExpr name
e2' <- forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' LambdaExpr name
e2
  forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
betaReduce LambdaExpr name
e1' LambdaExpr name
e2'
evalExpr' expr :: LambdaExpr name
expr@(Let name
_ LambdaExpr name
_) = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> LambdaException
InvalidLet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name. Pretty name => LambdaExpr name -> Text
prettyPrint forall a b. (a -> b) -> a -> b
$ LambdaExpr name
expr

-- | Look up free vars that have global bindings and substitute them
subGlobals
  :: Ord name
  => Map name (LambdaExpr name)
  -> LambdaExpr name
  -> LambdaExpr name
subGlobals :: forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' expr :: LambdaExpr name
expr@(Var name
x) = forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault LambdaExpr name
expr name
x Map name (LambdaExpr name)
globals'
subGlobals Map name (LambdaExpr name)
globals' (App LambdaExpr name
e1 LambdaExpr name
e2) = forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' LambdaExpr name
e1) (forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' LambdaExpr name
e2)
subGlobals Map name (LambdaExpr name)
globals' (Abs name
name LambdaExpr name
expr) = forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name LambdaExpr name
expr'
  where expr' :: LambdaExpr name
expr'
          | forall k a. Ord k => k -> Map k a -> Bool
Map.member name
name Map name (LambdaExpr name)
globals' = LambdaExpr name
expr
          | Bool
otherwise = forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' LambdaExpr name
expr
subGlobals Map name (LambdaExpr name)
_ LambdaExpr name
expr = LambdaExpr name
expr

-- | Function application
betaReduce
  :: (Eq name, Pretty name)
  => LambdaExpr name
  -> LambdaExpr name
  -> Eval name (LambdaExpr name)
betaReduce :: forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
betaReduce expr :: LambdaExpr name
expr@(Var name
_) LambdaExpr name
e2 = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App LambdaExpr name
expr LambdaExpr name
e2
betaReduce (App LambdaExpr name
e1 LambdaExpr name
e1') LambdaExpr name
e2 = do
  LambdaExpr name
reduced <- forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
betaReduce LambdaExpr name
e1 LambdaExpr name
e1'
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App LambdaExpr name
reduced LambdaExpr name
e2
betaReduce (Abs name
n LambdaExpr name
e1) LambdaExpr name
e2 = do
  LambdaExpr name
e1' <- forall name.
Eq name =>
[name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert (forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf LambdaExpr name
e2) LambdaExpr name
e1
  forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' forall a b. (a -> b) -> a -> b
$ forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
e1' name
n LambdaExpr name
e2
betaReduce LambdaExpr name
_ LambdaExpr name
_ = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError LambdaException
ImpossibleError

-- | Rename abstraction parameters to avoid name captures
alphaConvert :: Eq name => [name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert :: forall name.
Eq name =>
[name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert [name]
freeVars (Abs name
name LambdaExpr name
body) = do
  [name]
uniques' <- forall name. Eval name [name]
getUniques
  let nextVar :: name
nextVar = forall a. a -> Maybe a -> a
fromMaybe name
name forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [name]
freeVars) [name]
uniques'

  if name
name forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [name]
freeVars
    then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
nextVar (forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
body name
name (forall name. name -> LambdaExpr name
Var name
nextVar))
    else forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall name.
Eq name =>
[name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert [name]
freeVars LambdaExpr name
body

alphaConvert [name]
_ LambdaExpr name
expr = forall (m :: * -> *) a. Monad m => a -> m a
return LambdaExpr name
expr

-- | Eliminite superfluous abstractions
etaConvert :: Eq n => LambdaExpr n -> LambdaExpr n
etaConvert :: forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert (Abs n
n (App LambdaExpr n
e1 (Var n
n')))
  | n
n forall a. Eq a => a -> a -> Bool
== n
n'   = forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e1
  | Bool
otherwise = forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n (forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e1) (forall name. name -> LambdaExpr name
Var n
n'))
etaConvert (Abs n
n e :: LambdaExpr n
e@(Abs n
_ LambdaExpr n
_)) 
  -- If `etaConvert e == e` then etaConverting it will create an infinite loop
  | LambdaExpr n
e forall a. Eq a => a -> a -> Bool
== LambdaExpr n
e'   = forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n LambdaExpr n
e'
  | Bool
otherwise = forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert (forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n LambdaExpr n
e')
  where e' :: LambdaExpr n
e' = forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e
etaConvert (Abs n
n LambdaExpr n
expr) = forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n (forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
expr)
etaConvert (App LambdaExpr n
e1 LambdaExpr n
e2)  = forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e1) (forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e2)
etaConvert LambdaExpr n
expr = LambdaExpr n
expr

-- | Substitute an expression for a variable name in another expression
substitute :: Eq name => LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute :: forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute subExpr :: LambdaExpr name
subExpr@(Var name
name) name
subName LambdaExpr name
inExpr
  | name
name forall a. Eq a => a -> a -> Bool
== name
subName = LambdaExpr name
inExpr
  | Bool
otherwise = LambdaExpr name
subExpr

substitute subExpr :: LambdaExpr name
subExpr@(Abs name
name LambdaExpr name
expr) name
subName LambdaExpr name
inExpr
  | name
name forall a. Eq a => a -> a -> Bool
== name
subName = LambdaExpr name
subExpr
  | Bool
otherwise = forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name (forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
expr name
subName LambdaExpr name
inExpr)

substitute (App LambdaExpr name
e1 LambdaExpr name
e2) name
subName LambdaExpr name
inExpr
  = forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (LambdaExpr name -> LambdaExpr name
sub LambdaExpr name
e1) (LambdaExpr name -> LambdaExpr name
sub LambdaExpr name
e2)
  where sub :: LambdaExpr name -> LambdaExpr name
sub LambdaExpr name
expr = forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
expr name
subName LambdaExpr name
inExpr

substitute LambdaExpr name
_ name
_ LambdaExpr name
expr = LambdaExpr name
expr

-- | Find the free variables in an expression
freeVarsOf :: Eq n => LambdaExpr n -> [n]
freeVarsOf :: forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf (Abs n
n LambdaExpr n
expr) = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/=n
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf forall a b. (a -> b) -> a -> b
$ LambdaExpr n
expr
freeVarsOf (App LambdaExpr n
e1 LambdaExpr n
e2)  = forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf LambdaExpr n
e1 forall a. [a] -> [a] -> [a]
++ forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf LambdaExpr n
e2
freeVarsOf (Var n
n)      = [n
n]
freeVarsOf LambdaExpr n
_ = []