{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{-|
Module       : ATP.FirstOrder.Alpha
Description  : Monads and monad transformers for computations with free and
               bound variables.
Copyright    : (c) Evgenii Kotelnikov, 2019-2021
License      : GPL-3
Maintainer   : evgeny.kotelnikov@gmail.com
Stability    : experimental
-}

module ATP.FirstOrder.Alpha (
  AlphaT,
  evalAlphaT,
  Alpha,
  evalAlpha,
  lookup,
  scope,
  enter,
  share,
  MonadAlpha(..)
) where

import Prelude hiding (lookup)
import Control.Applicative ((<|>))
import Control.Monad.Trans (MonadTrans(..))
import Control.Monad.Reader (MonadReader(..), ReaderT(..), asks)
import Control.Monad.State (MonadState(..), StateT(..), modify, gets)
import Data.Functor.Identity (Identity(..))
import qualified Data.List as L (lookup)
import qualified Data.Map as M (empty, lookup, insert, elems)
import Data.Map (Map)

import ATP.FirstOrder.Core


-- | The stack of renamings for the bound variables in the expression.
type Stack = [(Var, Var)]

-- | The rename mapping for the free variables in the expression.
type Global = Map Var Var

-- | The monad transformer that adds to the given monad @m@ the ability to track
-- a renaming of free and bound variables in a first-order expression.
newtype AlphaT m a = AlphaT (ReaderT Stack (StateT Global m) a)
  deriving (a -> AlphaT m b -> AlphaT m a
(a -> b) -> AlphaT m a -> AlphaT m b
(forall a b. (a -> b) -> AlphaT m a -> AlphaT m b)
-> (forall a b. a -> AlphaT m b -> AlphaT m a)
-> Functor (AlphaT m)
forall a b. a -> AlphaT m b -> AlphaT m a
forall a b. (a -> b) -> AlphaT m a -> AlphaT m b
forall (m :: * -> *) a b.
Functor m =>
a -> AlphaT m b -> AlphaT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> AlphaT m a -> AlphaT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> AlphaT m b -> AlphaT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> AlphaT m b -> AlphaT m a
fmap :: (a -> b) -> AlphaT m a -> AlphaT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> AlphaT m a -> AlphaT m b
Functor, Functor (AlphaT m)
a -> AlphaT m a
Functor (AlphaT m)
-> (forall a. a -> AlphaT m a)
-> (forall a b. AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b)
-> (forall a b c.
    (a -> b -> c) -> AlphaT m a -> AlphaT m b -> AlphaT m c)
-> (forall a b. AlphaT m a -> AlphaT m b -> AlphaT m b)
-> (forall a b. AlphaT m a -> AlphaT m b -> AlphaT m a)
-> Applicative (AlphaT m)
AlphaT m a -> AlphaT m b -> AlphaT m b
AlphaT m a -> AlphaT m b -> AlphaT m a
AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b
(a -> b -> c) -> AlphaT m a -> AlphaT m b -> AlphaT m c
forall a. a -> AlphaT m a
forall a b. AlphaT m a -> AlphaT m b -> AlphaT m a
forall a b. AlphaT m a -> AlphaT m b -> AlphaT m b
forall a b. AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b
forall a b c.
(a -> b -> c) -> AlphaT m a -> AlphaT m b -> AlphaT m c
forall (m :: * -> *). Monad m => Functor (AlphaT m)
forall (m :: * -> *) a. Monad m => a -> AlphaT m a
forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> AlphaT m b -> AlphaT m a
forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> AlphaT m b -> AlphaT m b
forall (m :: * -> *) a b.
Monad m =>
AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> AlphaT m a -> AlphaT m b -> AlphaT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: AlphaT m a -> AlphaT m b -> AlphaT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> AlphaT m b -> AlphaT m a
*> :: AlphaT m a -> AlphaT m b -> AlphaT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> AlphaT m b -> AlphaT m b
liftA2 :: (a -> b -> c) -> AlphaT m a -> AlphaT m b -> AlphaT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> AlphaT m a -> AlphaT m b -> AlphaT m c
<*> :: AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b
pure :: a -> AlphaT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> AlphaT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (AlphaT m)
Applicative, Applicative (AlphaT m)
a -> AlphaT m a
Applicative (AlphaT m)
-> (forall a b. AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b)
-> (forall a b. AlphaT m a -> AlphaT m b -> AlphaT m b)
-> (forall a. a -> AlphaT m a)
-> Monad (AlphaT m)
AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b
AlphaT m a -> AlphaT m b -> AlphaT m b
forall a. a -> AlphaT m a
forall a b. AlphaT m a -> AlphaT m b -> AlphaT m b
forall a b. AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b
forall (m :: * -> *). Monad m => Applicative (AlphaT m)
forall (m :: * -> *) a. Monad m => a -> AlphaT m a
forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> AlphaT m b -> AlphaT m b
forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> AlphaT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> AlphaT m a
>> :: AlphaT m a -> AlphaT m b -> AlphaT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> AlphaT m b -> AlphaT m b
>>= :: AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (AlphaT m)
Monad, MonadReader Stack, MonadState Global)

instance MonadTrans AlphaT where
  lift :: m a -> AlphaT m a
lift = ReaderT Stack (StateT Global m) a -> AlphaT m a
forall (m :: * -> *) a.
ReaderT Stack (StateT Global m) a -> AlphaT m a
AlphaT (ReaderT Stack (StateT Global m) a -> AlphaT m a)
-> (m a -> ReaderT Stack (StateT Global m) a) -> m a -> AlphaT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT Global m a -> ReaderT Stack (StateT Global m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT Global m a -> ReaderT Stack (StateT Global m) a)
-> (m a -> StateT Global m a)
-> m a
-> ReaderT Stack (StateT Global m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT Global m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

runAlphaT :: AlphaT m a -> m (a, Global)
runAlphaT :: AlphaT m a -> m (a, Global)
runAlphaT (AlphaT ReaderT Stack (StateT Global m) a
r) = StateT Global m a -> Global -> m (a, Global)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT Stack (StateT Global m) a -> Stack -> StateT Global m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Stack (StateT Global m) a
r []) Global
forall k a. Map k a
M.empty

-- | Evaluate an alpha computation and return the final value,
-- discarding the global scope.
evalAlphaT :: Monad m => AlphaT m a -> m a
evalAlphaT :: AlphaT m a -> m a
evalAlphaT = ((a, Global) -> a) -> m (a, Global) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Global) -> a
forall a b. (a, b) -> a
fst (m (a, Global) -> m a)
-> (AlphaT m a -> m (a, Global)) -> AlphaT m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AlphaT m a -> m (a, Global)
forall (m :: * -> *) a. AlphaT m a -> m (a, Global)
runAlphaT


-- | The alpha monad parametrized by the type @a@ of the return value.
type Alpha a = AlphaT Identity a

-- | Evaluate an 'Alpha' computation and return the final value,
-- discarding the final variable renaming.
evalAlpha :: Alpha a -> a
evalAlpha :: Alpha a -> a
evalAlpha = Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a) -> (Alpha a -> Identity a) -> Alpha a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alpha a -> Identity a
forall (m :: * -> *) a. Monad m => AlphaT m a -> m a
evalAlphaT


-- | Lookup a variable, first in the stack of bound variables,
-- then in the global scope.
lookup :: Monad m => Var -> AlphaT m (Maybe Var)
lookup :: Var -> AlphaT m (Maybe Var)
lookup Var
v = do
  Maybe Var
bv <- (Stack -> Maybe Var) -> AlphaT m (Maybe Var)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Var -> Stack -> Maybe Var
forall a b. Eq a => a -> [(a, b)] -> Maybe b
L.lookup Var
v)
  Maybe Var
fv <- (Global -> Maybe Var) -> AlphaT m (Maybe Var)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Var -> Global -> Maybe Var
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Var
v)
  Maybe Var -> AlphaT m (Maybe Var)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Var
bv Maybe Var -> Maybe Var -> Maybe Var
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Maybe Var
fv)

-- | Read the set of free and bound variables of the given 'AlphaT' computation.
scope :: Monad m => AlphaT m [Var]
scope :: AlphaT m [Var]
scope = do
  [Var]
bv <- (Stack -> [Var]) -> AlphaT m [Var]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((Var, Var) -> Var) -> Stack -> [Var]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Var, Var) -> Var
forall a b. (a, b) -> b
snd)
  [Var]
fv <- (Global -> [Var]) -> AlphaT m [Var]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Global -> [Var]
forall k a. Map k a -> [a]
M.elems
  [Var] -> AlphaT m [Var]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Var]
bv [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++ [Var]
fv)

-- | Run a computation inside 'AlphaT' with the variable renaming.
enter :: Monad m => Var -> Var -> AlphaT m a -> AlphaT m a
enter :: Var -> Var -> AlphaT m a -> AlphaT m a
enter Var
v Var
w = (Stack -> Stack) -> AlphaT m a -> AlphaT m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Var
v,Var
w)(Var, Var) -> Stack -> Stack
forall a. a -> [a] -> [a]
:)

-- | Update the global scope with a variable renaming.
share :: Monad m => Var -> Var -> AlphaT m ()
share :: Var -> Var -> AlphaT m ()
share Var
v Var
w = (Global -> Global) -> AlphaT m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Var -> Var -> Global -> Global
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Var
v Var
w)


-- | A helper monad for computations on free and bound occurrences of variables.
class Monad m => MonadAlpha m where
  -- | A monadic action to perform on a variable under a quantifier.
  binding :: Var -> AlphaT m Var

  -- | A monadic action to perform on a variable occurrence.
  occurrence :: Var -> AlphaT m Var