{-# LANGUAGE TupleSections #-}
{-|
   Monadic interface to e-graph stateful computations
 -}
module Data.Equality.Graph.Monad
  ( egraph
  , represent
  , add
  , merge
  , rebuild
  , EG.canonicalize
  , EG.find
  , EG.emptyEGraph

  -- * E-graph stateful computations
  , EGraphM
  , runEGraphM

  -- * E-graph definition re-export
  , EG.EGraph

  -- * 'State' monad re-exports
  , modify, get, gets
  ) where

import Control.Monad ((>=>))
import Control.Monad.Trans.State.Strict

import Data.Equality.Utils (Fix, cata)

import Data.Equality.Graph (EGraph, ClassId, Language, ENode(..))
import qualified Data.Equality.Graph as EG

-- | E-graph stateful computation
type EGraphM s = State (EGraph s)

-- | Run EGraph computation on an empty e-graph
--
-- === Example
-- @
-- egraph $ do
--  id1 <- represent t1
--  id2 <- represent t2
--  merge id1 id2
-- @
egraph :: Language l => EGraphM l a -> (a, EGraph l)
egraph :: forall (l :: * -> *) a. Language l => EGraphM l a -> (a, EGraph l)
egraph = forall (l :: * -> *) a. EGraph l -> EGraphM l a -> (a, EGraph l)
runEGraphM forall (l :: * -> *). Language l => EGraph l
EG.emptyEGraph
{-# INLINE egraph #-}

-- | Represent an expression (@Fix l@) in an e-graph by recursively
-- representing sub expressions
represent :: Language l => Fix l -> EGraphM l ClassId
represent :: forall (l :: * -> *). Language l => Fix l -> EGraphM l ClassId
represent = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall (l :: * -> *). Language l => ENode l -> EGraphM l ClassId
add forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (l :: * -> *). l ClassId -> ENode l
Node
{-# INLINE represent #-}

-- | Add an e-node to the e-graph
add :: Language l => ENode l -> EGraphM l ClassId
add :: forall (l :: * -> *). Language l => ENode l -> EGraphM l ClassId
add = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (l :: * -> *).
Language l =>
ENode l -> EGraph l -> (ClassId, EGraph l)
EG.add
{-# INLINE add #-}

-- | Merge two e-classes by id
--
-- E-graph invariants may be broken by merging, and 'rebuild' should be used
-- /eventually/ to restore them
merge :: Language l => ClassId -> ClassId -> EGraphM l ClassId
merge :: forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraphM l ClassId
merge ClassId
a ClassId
b = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
EG.merge ClassId
a ClassId
b)
{-# INLINE merge #-}

-- | Rebuild: Restore e-graph invariants
--
-- E-graph invariants are traditionally maintained after every merge, but we
-- allow operations to temporarilly break the invariants (specifically, until we call
-- 'rebuild')
--
-- The paper describing rebuilding in detail is https://arxiv.org/abs/2004.03082
rebuild :: Language l => EGraphM l ()
rebuild :: forall (l :: * -> *). Language l => EGraphM l ()
rebuild = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((),)forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (l :: * -> *). Language l => EGraph l -> EGraph l
EG.rebuild)
{-# INLINE rebuild #-}

-- | Run 'EGraphM' computation on a given e-graph
runEGraphM :: EGraph l -> EGraphM l a -> (a, EGraph l)
runEGraphM :: forall (l :: * -> *) a. EGraph l -> EGraphM l a -> (a, EGraph l)
runEGraphM = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> (a, s)
runState
{-# INLINE runEGraphM #-}