{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
-- {-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-|
   An e-graph efficiently represents a congruence relation over many expressions.

   Based on \"egg: Fast and Extensible Equality Saturation\" https://arxiv.org/abs/2004.03082.
 -}
module Data.Equality.Graph
    (
      -- * Definition of e-graph
      EGraph

      -- * Functions on e-graphs
    , emptyEGraph

      -- ** Transformations
    , represent, add, merge, rebuild
    -- , repair, repairAnal

      -- ** Querying
    , find, canonicalize

      -- * Re-exports
    , module Data.Equality.Graph.Classes
    , module Data.Equality.Graph.Nodes
    , module Data.Equality.Language
    ) where

-- ROMES:TODO: Is the E-Graph a Monad if the analysis data were the type arg? i.e. Monad (EGraph language)?

-- import GHC.Conc
import Prelude hiding (lookup)

import Data.Function
import Data.Bifunctor
import Data.Containers.ListUtils

import Control.Monad
import Control.Monad.Trans.State
import Control.Exception (assert)

import qualified Data.IntMap.Strict as IM
import qualified Data.Set    as S

import Data.Equality.Utils.SizedList

import Data.Equality.Graph.Internal
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Graph.Lens

import Data.Equality.Utils

-- ROMES:TODO: join things built in paralell?
-- instance Ord1 l => Semigroup (EGraph l) where
--     (<>) eg1 eg2 = undefined -- not so easy
-- instance Ord1 l => Monoid (EGraph l) where
--     mempty = EGraph emptyUF mempty mempty mempty

-- | Represent an expression (in it's fixed point form) in an e-graph.
-- Returns the updated e-graph and the id from the e-class in which it was represented.
represent :: forall a l. (Analysis a l, Language l) => Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent = (l (EGraph a l -> (ClassId, EGraph a l))
 -> EGraph a l -> (ClassId, EGraph a l))
-> Fix l -> EGraph a l -> (ClassId, EGraph a l)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((EGraph a l
 -> l (EGraph a l -> (ClassId, EGraph a l))
 -> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((EGraph a l
  -> l (EGraph a l -> (ClassId, EGraph a l))
  -> (ClassId, EGraph a l))
 -> l (EGraph a l -> (ClassId, EGraph a l))
 -> EGraph a l
 -> (ClassId, EGraph a l))
-> (EGraph a l
    -> l (EGraph a l -> (ClassId, EGraph a l))
    -> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ \EGraph a l
e -> (ENode l -> EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l) -> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ENode l -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ((ENode l, EGraph a l) -> (ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
    -> (ENode l, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (l ClassId -> ENode l)
-> (l ClassId, EGraph a l) -> (ENode l, EGraph a l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node ((l ClassId, EGraph a l) -> (ENode l, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
    -> (l ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State (EGraph a l) (l ClassId)
-> EGraph a l -> (l ClassId, EGraph a l)
forall s a. State s a -> s -> (a, s)
`runState` EGraph a l
e) (State (EGraph a l) (l ClassId) -> (l ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
    -> State (EGraph a l) (l ClassId))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (l ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EGraph a l -> (ClassId, EGraph a l))
 -> StateT (EGraph a l) Identity ClassId)
-> l (EGraph a l -> (ClassId, EGraph a l))
-> State (EGraph a l) (l ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity (ClassId, EGraph a l)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((EGraph a l -> (ClassId, EGraph a l))
 -> StateT (EGraph a l) Identity (ClassId, EGraph a l))
-> ((ClassId, EGraph a l) -> StateT (EGraph a l) Identity ClassId)
-> (EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity ClassId
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> \(ClassId
x,EGraph a l
e') -> ClassId
x ClassId
-> StateT (EGraph a l) Identity ()
-> StateT (EGraph a l) Identity ClassId
forall a b.
a
-> StateT (EGraph a l) Identity b -> StateT (EGraph a l) Identity a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ EGraph a l -> StateT (EGraph a l) Identity ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put EGraph a l
e'))

-- | Add an e-node to the e-graph
--
-- If the e-node is already represented in this e-graph, the class-id of the
-- class it's already represented in will be returned.
add :: forall a l. (Analysis a l, Language l) => ENode l -> EGraph a l -> (ClassId, EGraph a l)
add :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ENode l
uncanon_e EGraph a l
egr =
    let !new_en :: ENode l
new_en = ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize ENode l
uncanon_e EGraph a l
egr

     in case ENode l -> NodeMap l ClassId -> Maybe ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
      Just ClassId
canon_enode_id -> (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
canon_enode_id EGraph a l
egr, EGraph a l
egr)
      Maybe ClassId
Nothing ->

        let

            -- Make new equivalence class with a new id in the union-find
            (ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)

            -- New singleton e-class stores the e-node and its analysis data
            -- which is modified according to analysis
            --
            -- The modification also produces a list of expressions to
            -- represent and merge with this class, which we'll do before
            -- returning from this function
            (EClass a l
new_eclass, [Fix l]
added_nodes) = EClass a l -> (EClass a l, [Fix l])
forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA (EClass a l -> (EClass a l, [Fix l]))
-> EClass a l -> (EClass a l, [Fix l])
forall a b. (a -> b) -> a -> b
$ ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) a -> a
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
i((EClass a l -> f (EClass a l)) -> EGraph a l -> f (EGraph a l))
-> ((a -> f a) -> EClass a l -> f (EClass a l))
-> (a -> f a)
-> EGraph a l
-> f (EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)) SList (ClassId, ENode l)
forall a. Monoid a => a
mempty

            -- TODO:Performance: All updates can be done to the map first? Parallelize?
            --
            -- Update e-classes by going through all e-node children and adding
            -- to the e-class parents the new e-node and its e-class id
            --
            -- And add new e-class to existing e-classes
            new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents      = ((ClassId
new_eclass_id, ENode l
new_en) (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. a -> SList a -> SList a
|:)
            new_classes :: IntMap (EClass a l)
new_classes      = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass a l
new_eclass (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> IntMap (EClass a l)
forall a b. (a -> b) -> a -> b
$
                                    (ClassId -> IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> l ClassId -> IntMap (EClass a l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr  ((EClass a l -> EClass a l)
-> ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall {f :: * -> *}.
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents (forall {f :: * -> *}.
 Functor f =>
 (SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
 -> EClass a l -> f (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents))
                                           (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
                                           (ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)

            -- TODO: From egg: Is this needed?
            -- This is required if we want math pruning to work. Unfortunately, it
            -- also makes the invariants tests x4 slower (because they aren't using
            -- analysis) I think there might be another way to ensure math analysis
            -- pruning to work without having this line here.  Comment it out to
            -- check the result on the unit tests.
            -- 
            -- Update: I found a fix for that case: the modifyA function must add
            -- the parents of the pruned class to the worklist for them to be
            -- upward merged. I think it's a good compromise for requiring the user
            -- to do this. Adding the added node to the worklist everytime creates
            -- too much unnecessary work.
            --
            -- Actually I've found more bugs regarding this, and can't fix them
            -- there, so indeed this seems to be necessary for sanity with 'modifyA'
            --
            -- This way we also liberate the user from caring about the worklist
            --
            -- The hash cons invariants test suffer from this greatly but the
            -- saturation tests seem mostly fine?
            --
            -- And adding to the analysis worklist doesn't work, so maybe it's
            -- something else?
            --
            -- So in the end, we do need to addToWorklist to get correct results
            new_worklist :: [(ClassId, ENode l)]
new_worklist     = (ClassId
new_eclass_id, ENode l
new_en)(ClassId, ENode l) -> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. a -> [a] -> [a]
:EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr

            -- Add the e-node's e-class id at the e-node's id
            new_memo :: NodeMap l ClassId
new_memo         = ENode l -> ClassId -> NodeMap l ClassId -> NodeMap l ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr)

            -- So we have our almost final e-graph. We just need to represent
            -- and merge in it all expressions which resulted from 'modifyA'
            -- above
            egr1 :: EGraph a l
egr1             = EGraph a l
egr { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
                                   , classes :: IntMap (EClass a l)
classes   = IntMap (EClass a l)
new_classes
                                   , worklist :: [(ClassId, ENode l)]
worklist  = [(ClassId, ENode l)]
new_worklist
                                   , memo :: NodeMap l ClassId
memo      = NodeMap l ClassId
new_memo
                                   }

            egr2 :: EGraph a l
egr2             = (Fix l -> EGraph a l -> EGraph a l)
-> EGraph a l -> [Fix l] -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId -> Fix l -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
new_eclass_id) EGraph a l
egr1 [Fix l]
added_nodes


         in ( ClassId
new_eclass_id
            , EGraph a l
egr2
            )
{-# INLINABLE add #-}

-- | Merge 2 e-classes by id
merge :: forall a l. (Analysis a l, Language l) => ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
a ClassId
b EGraph a l
egr0 =

  -- Use canonical ids
  let
      a' :: ClassId
a' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
a EGraph a l
egr0
      b' :: ClassId
b' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
b EGraph a l
egr0
   in
   if ClassId
a' ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
b'
     then (ClassId
a', EGraph a l
egr0)
     else
       let
           -- Get classes being merged
           class_a :: EClass a l
class_a = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
a'
           class_b :: EClass a l
class_b = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
b'

           -- Leader is the class with more parents
           (ClassId
leader, EClass a l
leader_class, ClassId
sub, EClass a l
sub_class) =
               if SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_aEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) ClassId -> ClassId -> Bool
forall a. Ord a => a -> a -> Bool
< SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_bEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                  then (ClassId
b', EClass a l
class_b, ClassId
a', EClass a l
class_a) -- b is leader
                  else (ClassId
a', EClass a l
class_a, ClassId
b', EClass a l
class_b) -- a is leader

           -- Make leader the leader in the union find
           (ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr0)
                                (ClassId, ReprUnionFind)
-> ((ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind))
-> (ClassId, ReprUnionFind)
forall a b. a -> (a -> b) -> b
& (ClassId -> ClassId)
-> (ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\ClassId
n -> Bool -> ClassId -> ClassId
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ClassId
leader ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
n) ClassId
n)

           -- Update leader class with all e-nodes and parents from the
           -- subsumed class
           (EClass a l
updatedLeader, [Fix l]
added_nodes) = EClass a l
leader_class
                                            EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents Lens' (EClass a l) (SList (ClassId, ENode l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<>)
                                            EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes   (forall {f :: * -> *}.
 Functor f =>
 (Set (ENode l) -> f (Set (ENode l)))
 -> EClass a l -> f (EClass a l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass a l -> EClass a l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (Set (ENode l) -> f (Set (ENode l)))
    -> EClass a l -> f (EClass a l))
-> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<>)
                                            EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data    (forall {f :: * -> *}.
 Functor f =>
 (a -> f a) -> EClass a l -> f (EClass a l))
-> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data
                                            EClass a l
-> (EClass a l -> (EClass a l, [Fix l])) -> (EClass a l, [Fix l])
forall a b. a -> (a -> b) -> b
& EClass a l -> (EClass a l, [Fix l])
forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA

           new_data :: a
new_data = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data) (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)

           -- Update leader in classes so that it has all nodes and parents
           -- from subsumed class, and delete the subsumed class
           --
           -- Additionally modify the e-class according to the analysis
           new_classes :: IntMap (EClass a l)
new_classes = (ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass a l
updatedLeader (IntMap (EClass a l) -> IntMap (EClass a l))
-> (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l)
-> IntMap (EClass a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub) (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr0)

           -- Add all subsumed parents to worklist We can do this instead of
           -- adding the new e-class itself to the worklist because it would end
           -- up adding its parents anyway
           new_worklist :: [(ClassId, ENode l)]
new_worklist = SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr0

           -- If the new_data is different from the classes, the parents of the
           -- class whose data is different from the merged must be put on the
           -- analysisWorklist
           new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
             (if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
                then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
             (if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
                then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
leader_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
             EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr0

           -- ROMES:TODO: The code that makes the -1 * cos test pass when some other things are tweaked
           -- new_memo = foldr (`insertNM` leader) (memo egr0) (sub_class^._nodes)

           -- Build new e-graph
           egr1 :: EGraph a l
egr1 = EGraph a l
egr0
             { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
             , classes :: IntMap (EClass a l)
classes   = IntMap (EClass a l)
new_classes
             -- , memo      = new_memo
             , worklist :: [(ClassId, ENode l)]
worklist  = [(ClassId, ENode l)]
new_worklist
             , analysisWorklist :: [(ClassId, ENode l)]
analysisWorklist = [(ClassId, ENode l)]
new_analysis_worklist
             }

           egr2 :: EGraph a l
egr2 = (Fix l -> EGraph a l -> EGraph a l)
-> EGraph a l -> [Fix l] -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId -> Fix l -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
leader) EGraph a l
egr1 [Fix l]
added_nodes

        in (ClassId
new_id, EGraph a l
egr2)
{-# INLINEABLE merge #-}
            

-- | The rebuild operation processes the e-graph's current worklist,
-- restoring the invariants of deduplication and congruence. Rebuilding is
-- similar to other approaches in how it restores congruence; but it uniquely
-- allows the client to choose when to restore invariants in the context of a
-- larger algorithm like equality saturation.
rebuild :: (Analysis a l, Language l) => EGraph a l -> EGraph a l
rebuild :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
wl Worklist l
awl) =
  -- empty worklists
  -- repair deduplicated e-classes
  let
    emptiedEgr :: EGraph a l
emptiedEgr = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty

    wl' :: Worklist l
wl'   = Worklist l -> Worklist l
forall a. Ord a => [a] -> [a]
nubOrd (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId)
-> (ENode l -> ENode l) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
emptiedEgr) (ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
`canonicalize` EGraph a l
emptiedEgr) ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl
    egr' :: EGraph a l
egr'  = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair EGraph a l
emptiedEgr Worklist l
wl'

    awl' :: Worklist l
awl'  = ((ClassId, ENode l) -> ClassId) -> Worklist l -> Worklist l
forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn (ClassId, ENode l) -> ClassId
forall a b. (a, b) -> a
fst (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
egr') ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl
    egr'' :: EGraph a l
egr'' = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal EGraph a l
egr' Worklist l
awl'
  in
  -- Loop until worklist is completely empty
  if Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr'')
     then EGraph a l
egr''
     else EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild EGraph a l
egr'' -- ROMES:TODO: Doesn't seem to be needed at all in the testsuite.
{-# INLINEABLE rebuild #-}

-- ROMES:TODO: find repair_id could be shared between repair and repairAnal?

-- | Repair a single worklist entry.
repair :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repair :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair (ClassId
repair_id, ENode l
node) EGraph a l
egr =

   -- TODO We're no longer deleting the uncanonicalized node, how much does it matter that the structure keeps growing?

   case ENode l
-> ClassId
-> NodeMap l ClassId
-> (Maybe ClassId, NodeMap l ClassId)
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of

      (Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph a l
egr { memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo' } -- new memo with inserted node

      (Just ClassId
existing_class, NodeMap l ClassId
memo') -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd (ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
existing_class ClassId
repair_id EGraph a l
egr{memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo'})
{-# INLINE repair #-}

-- | Repair a single analysis-worklist entry.
repairAnal :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal (ClassId
repair_id, ENode l
node) EGraph a l
egr =
    let
        c1 :: EClass a l
c1                = EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
repair_id
        new_data :: a
new_data          = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
c1EClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
Lens' (EClass a l) a
_data) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
iEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
node))
        (EClass a l
c2, [Fix l]
added_nodes) = EClass a l -> (EClass a l, [Fix l])
forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA (EClass a l
c1 EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
Lens' (EClass a l) a
_data Lens' (EClass a l) a -> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data)
    in
    -- Take action if the new_data is different from the existing data
    if EClass a l
c1EClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
Lens' (EClass a l) a
_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
new_data
        -- Merge result is different from original class data, update class
        -- with new_data
       then
        let
            new_classes :: IntMap (EClass a l)
new_classes = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
repair_id EClass a l
c2 (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
            egr1 :: EGraph a l
egr1 = EGraph a l
egr { analysisWorklist :: Worklist l
analysisWorklist = SList (ClassId, ENode l) -> Worklist l
forall a. SList a -> [a]
toListSL (EClass a l
c1EClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) Worklist l -> Worklist l -> Worklist l
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr
                       , classes :: IntMap (EClass a l)
classes = IntMap (EClass a l)
new_classes
                       }
            egr2 :: EGraph a l
egr2 = (Fix l -> EGraph a l -> EGraph a l)
-> EGraph a l -> [Fix l] -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId -> Fix l -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
repair_id) EGraph a l
egr1 [Fix l]
added_nodes
         in EGraph a l
egr2
       else EGraph a l
egr
{-# INLINE repairAnal #-}

-- | Canonicalize an e-node
--
-- Two e-nodes are equal when their canonical form is equal. Canonicalization
-- makes the list of e-class ids the e-node holds a list of canonical ids.
-- Meaning two seemingly different e-nodes might be equal when we figure out
-- that their e-class ids are represented by the same e-class canonical ids
--
-- canonicalize(𝑓(𝑎,𝑏,𝑐,...)) = 𝑓((find 𝑎), (find 𝑏), (find 𝑐),...)
canonicalize :: Functor l => ENode l -> EGraph a l -> ENode l
canonicalize :: forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize (Node l ClassId
enode) EGraph a l
eg = l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node (l ClassId -> ENode l) -> l ClassId -> ENode l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> l ClassId -> l ClassId
forall a b. (a -> b) -> l a -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
eg) l ClassId
enode
{-# INLINE canonicalize #-}

-- | Find the canonical representation of an e-class id in the e-graph
--
-- Invariant: The e-class id always exists.
find :: ClassId -> EGraph a l -> ClassId
find :: forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
cid = ClassId -> ReprUnionFind -> ClassId
findRepr ClassId
cid (ReprUnionFind -> ClassId)
-> (EGraph a l -> ReprUnionFind) -> EGraph a l -> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind
{-# INLINE find #-}

-- | The empty e-graph. Nothing is represented in it yet.
emptyEGraph :: Language l => EGraph a l
emptyEGraph :: forall (l :: * -> *) a. Language l => EGraph a l
emptyEGraph = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
emptyUF ClassIdMap (EClass a l)
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}

-- | Represent an expression (in fix-point form) and merge it with the e-class with the given id
representAndMerge :: (Analysis a l, Language l) => ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
o Fix l
f EGraph a l
g = case Fix l -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent Fix l
f EGraph a l
g of
                        (ClassId
i, EGraph a l
e) -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd ((ClassId, EGraph a l) -> EGraph a l)
-> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a -> b) -> a -> b
$ ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
o ClassId
i EGraph a l
e