{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
-- {-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_GHC -ddump-to-file -ddump-simpl #-}
{-|
   An e-graph efficiently represents a congruence relation over many expressions.

   Based on \"egg: Fast and Extensible Equality Saturation\" https://arxiv.org/abs/2004.03082.
 -}
module Data.Equality.Graph
    (
      -- ** Definition of e-graph
      EGraph

      -- ** E-graph transformations
    , represent, add, merge, rebuild
    -- , repair, repairAnal

      -- ** Querying
    , find, canonicalize

      -- ** Functions on e-graphs
    , emptyEGraph, newEClass

      -- * E-graph transformations for monadic analysis
      -- | These are the same operations over e-graphs as above but over a monad in which the analysis is defined.
      -- It is common to only have a valid 'Analysis' under a monadic context.
      -- In that case, these are the functions to use -- they are just like the
      -- non-monadic ones, but have require an 'Analysis' defined in a
      -- monadic context ('AnalysisM').
    , representM, addM, mergeM, rebuildM

      -- * Re-exports
    , module Data.Equality.Graph.Classes
    , module Data.Equality.Graph.Nodes
    , module Data.Equality.Language
    ) where

-- ROMES:TODO: Is the E-Graph a Monad if the analysis data were the type arg? i.e. instance Monad (EGraph language)?

-- import GHC.Conc
import Prelude hiding (lookup)

import Data.Function
import Data.Foldable (foldlM)
import Data.Bifunctor
import Data.Containers.ListUtils

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Control.Exception (assert)

import qualified Data.IntMap.Strict as IM
import qualified Data.Set    as S

import Data.Equality.Utils.SizedList

import Data.Equality.Graph.Internal
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import qualified Data.Equality.Analysis.Monadic as AM
import Data.Equality.Language
import Data.Equality.Graph.Lens

import Data.Equality.Utils

-- ROMES:TODO: join things built in paralell?
-- instance Ord1 l => Semigroup (EGraph l) where
--     (<>) eg1 eg2 = undefined -- not so easy
-- instance Ord1 l => Monoid (EGraph l) where
--     mempty = EGraph emptyUF mempty mempty mempty

-- | Represent an expression (in it's fixed point form) in an e-graph.
-- Returns the updated e-graph and the id from the e-class in which it was represented.
represent :: forall a l. (Analysis a l, Language l) => Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent = (l (EGraph a l -> (ClassId, EGraph a l))
 -> EGraph a l -> (ClassId, EGraph a l))
-> Fix l -> EGraph a l -> (ClassId, EGraph a l)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((EGraph a l
 -> l (EGraph a l -> (ClassId, EGraph a l))
 -> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((EGraph a l
  -> l (EGraph a l -> (ClassId, EGraph a l))
  -> (ClassId, EGraph a l))
 -> l (EGraph a l -> (ClassId, EGraph a l))
 -> EGraph a l
 -> (ClassId, EGraph a l))
-> (EGraph a l
    -> l (EGraph a l -> (ClassId, EGraph a l))
    -> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ \EGraph a l
e -> (ENode l -> EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l) -> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ENode l -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ((ENode l, EGraph a l) -> (ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
    -> (ENode l, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (l ClassId -> ENode l)
-> (l ClassId, EGraph a l) -> (ENode l, EGraph a l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node ((l ClassId, EGraph a l) -> (ENode l, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
    -> (l ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State (EGraph a l) (l ClassId)
-> EGraph a l -> (l ClassId, EGraph a l)
forall s a. State s a -> s -> (a, s)
`runState` EGraph a l
e) (State (EGraph a l) (l ClassId) -> (l ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
    -> State (EGraph a l) (l ClassId))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (l ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EGraph a l -> (ClassId, EGraph a l))
 -> StateT (EGraph a l) Identity ClassId)
-> l (EGraph a l -> (ClassId, EGraph a l))
-> State (EGraph a l) (l ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity (ClassId, EGraph a l)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((EGraph a l -> (ClassId, EGraph a l))
 -> StateT (EGraph a l) Identity (ClassId, EGraph a l))
-> ((ClassId, EGraph a l) -> StateT (EGraph a l) Identity ClassId)
-> (EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity ClassId
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> \(ClassId
x,EGraph a l
e') -> ClassId
x ClassId
-> StateT (EGraph a l) Identity ()
-> StateT (EGraph a l) Identity ClassId
forall a b.
a
-> StateT (EGraph a l) Identity b -> StateT (EGraph a l) Identity a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ EGraph a l -> StateT (EGraph a l) Identity ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put EGraph a l
e'))

-- | Add an e-node to the e-graph
--
-- If the e-node is already represented in this e-graph, the class-id of the
-- class it's already represented in will be returned.
add :: forall a l. (Analysis a l, Language l) => ENode l -> EGraph a l -> (ClassId, EGraph a l)
add :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ENode l
uncanon_e EGraph a l
egr =
    let !new_en :: ENode l
new_en = ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize ENode l
uncanon_e EGraph a l
egr

     in case ENode l -> NodeMap l ClassId -> Maybe ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
      Just ClassId
canon_enode_id -> (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
canon_enode_id EGraph a l
egr, EGraph a l
egr)
      Maybe ClassId
Nothing ->

        let

            -- Make new equivalence class with a new id in the union-find
            (ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)

            -- New singleton e-class stores the e-node and its analysis data
            new_eclass :: EClass a l
new_eclass = ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) a -> a
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
i((EClass a l -> f (EClass a l)) -> EGraph a l -> f (EGraph a l))
-> ((a -> f a) -> EClass a l -> f (EClass a l))
-> (a -> f a)
-> EGraph a l
-> f (EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)) SList (ClassId, ENode l)
forall a. Monoid a => a
mempty

            -- TODO:Performance: All updates can be done to the map first? Parallelize?
            --
            -- Update e-classes by going through all e-node children and adding
            -- to the e-class parents the new e-node and its e-class id
            --
            -- And add new e-class to existing e-classes
            new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents      = ((ClassId
new_eclass_id, ENode l
new_en) (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. a -> SList a -> SList a
|:)
            new_classes :: IntMap (EClass a l)
new_classes      = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass a l
new_eclass (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> IntMap (EClass a l)
forall a b. (a -> b) -> a -> b
$
                                    (ClassId -> IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> l ClassId -> IntMap (EClass a l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr  ((EClass a l -> EClass a l)
-> ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
 -> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents))
                                           (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
                                           (ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)

            -- TODO: From egg: Is this needed?
            -- This is required if we want math pruning to work. Unfortunately, it
            -- also makes the invariants tests x4 slower (because they aren't using
            -- analysis) I think there might be another way to ensure math analysis
            -- pruning to work without having this line here.  Comment it out to
            -- check the result on the unit tests.
            -- 
            -- Update: I found a fix for that case: the modifyA function must add
            -- the parents of the pruned class to the worklist for them to be
            -- upward merged. I think it's a good compromise for requiring the user
            -- to do this. Adding the added node to the worklist everytime creates
            -- too much unnecessary work.
            --
            -- Actually I've found more bugs regarding this, and can't fix them
            -- there, so indeed this seems to be necessary for sanity with 'modifyA'
            --
            -- This way we also liberate the user from caring about the worklist
            --
            -- The hash cons invariants test suffer from this greatly but the
            -- saturation tests seem mostly fine?
            --
            -- And adding to the analysis worklist doesn't work, so maybe it's
            -- something else?
            --
            -- So in the end, we do need to addToWorklist to get correct results
            new_worklist :: [(ClassId, ENode l)]
new_worklist     = (ClassId
new_eclass_id, ENode l
new_en)(ClassId, ENode l) -> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. a -> [a] -> [a]
:EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr

            -- Add the e-node's e-class id at the e-node's id
            new_memo :: NodeMap l ClassId
new_memo         = ENode l -> ClassId -> NodeMap l ClassId -> NodeMap l ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr)

         in ( ClassId
new_eclass_id
            , EGraph a l
egr { unionFind = new_uf
                  , classes   = new_classes
                  , worklist  = new_worklist
                  , memo      = new_memo
                  }
                  -- Modify created node according to analysis
                  EGraph a l -> (EGraph a l -> EGraph a l) -> EGraph a l
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> EGraph a l
forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA ClassId
new_eclass_id
            )
{-# INLINABLE add #-}

-- | Merge 2 e-classes by id
merge :: forall a l. (Analysis a l, Language l) => ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
a ClassId
b EGraph a l
egr0 =

  -- Use canonical ids
  let
      a' :: ClassId
a' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
a EGraph a l
egr0
      b' :: ClassId
b' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
b EGraph a l
egr0
   in
   if ClassId
a' ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
b'
     then (ClassId
a', EGraph a l
egr0)
     else
       let
           -- Get classes being merged
           class_a :: EClass a l
class_a = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
a'
           class_b :: EClass a l
class_b = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
b'

           -- Leader is the class with more parents
           (ClassId
leader, EClass a l
leader_class, ClassId
sub, EClass a l
sub_class) =
               if SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_aEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) ClassId -> ClassId -> Bool
forall a. Ord a => a -> a -> Bool
< SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_bEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                  then (ClassId
b', EClass a l
class_b, ClassId
a', EClass a l
class_a) -- b is leader
                  else (ClassId
a', EClass a l
class_a, ClassId
b', EClass a l
class_b) -- a is leader

           -- Make leader the leader in the union find
           (ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr0)
                                (ClassId, ReprUnionFind)
-> ((ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind))
-> (ClassId, ReprUnionFind)
forall a b. a -> (a -> b) -> b
& (ClassId -> ClassId)
-> (ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\ClassId
n -> Bool -> ClassId -> ClassId
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ClassId
leader ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
n) ClassId
n)

           -- Update leader class with all e-nodes and parents from the
           -- subsumed class
           updatedLeader :: EClass a l
updatedLeader = EClass a l
leader_class
                             EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
 -> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<>)
                             EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> Identity (Set (ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes   ((Set (ENode l) -> Identity (Set (ENode l)))
 -> EClass a l -> Identity (EClass a l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass a l -> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l -> Lens' (EClass a l) (Set (ENode l)) -> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (Set (ENode l))
_nodes Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<>)
                             EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data    (forall {f :: * -> *}.
 Functor f =>
 (a -> f a) -> EClass a l -> f (EClass a l))
-> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data

           new_data :: a
new_data = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data) (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)

           -- Update leader in classes so that it has all nodes and parents
           -- from subsumed class, and delete the subsumed class
           --
           -- Additionally modify the e-class according to the analysis
           new_classes :: IntMap (EClass a l)
new_classes = (ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass a l
updatedLeader (IntMap (EClass a l) -> IntMap (EClass a l))
-> (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l)
-> IntMap (EClass a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub) (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr0)

           -- Add all subsumed parents to worklist We can do this instead of
           -- adding the new e-class itself to the worklist because it would end
           -- up adding its parents anyway
           new_worklist :: [(ClassId, ENode l)]
new_worklist = SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr0

           -- If the new_data is different from the classes, the parents of the
           -- class whose data is different from the merged must be put on the
           -- analysisWorklist
           new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
             (if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
                then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
             (if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
    Functor f =>
    (a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
                then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
leader_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
             EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr0

           -- ROMES:TODO: The code that makes the -1 * cos test pass when some other things are tweaked
           -- new_memo = foldr (`insertNM` leader) (memo egr0) (sub_class^._nodes)

           -- Build new e-graph
           egr1 :: EGraph a l
egr1 = EGraph a l
egr0
             { unionFind = new_uf
             , classes   = new_classes
             -- , memo      = new_memo
             , worklist  = new_worklist
             , analysisWorklist = new_analysis_worklist
             }
             -- Modify according to analysis
             EGraph a l -> (EGraph a l -> EGraph a l) -> EGraph a l
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> EGraph a l
forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA ClassId
new_id

        in (ClassId
new_id, EGraph a l
egr1)
{-# INLINEABLE merge #-}
            

-- | The rebuild operation processes the e-graph's current worklist,
-- restoring the invariants of deduplication and congruence. Rebuilding is
-- similar to other approaches in how it restores congruence; but it uniquely
-- allows the client to choose when to restore invariants in the context of a
-- larger algorithm like equality saturation.
rebuild :: (Analysis a l, Language l) => EGraph a l -> EGraph a l
rebuild :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
wl Worklist l
awl) =
  -- empty worklists
  -- repair deduplicated e-classes
  let
    emptiedEgr :: EGraph a l
emptiedEgr = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty

    wl' :: Worklist l
wl'   = Worklist l -> Worklist l
forall a. Ord a => [a] -> [a]
nubOrd (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId)
-> (ENode l -> ENode l) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
emptiedEgr) (ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
`canonicalize` EGraph a l
emptiedEgr) ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl
    egr' :: EGraph a l
egr'  = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair EGraph a l
emptiedEgr Worklist l
wl'

    awl' :: Worklist l
awl'  = ((ClassId, ENode l) -> ClassId) -> Worklist l -> Worklist l
forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn (ClassId, ENode l) -> ClassId
forall a b. (a, b) -> a
fst (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
egr') ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl
    egr'' :: EGraph a l
egr'' = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal EGraph a l
egr' Worklist l
awl'
  in
  -- Loop until worklist is completely empty
  if Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr'')
     then EGraph a l
egr''
     else EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild EGraph a l
egr'' -- ROMES:TODO: Doesn't seem to be needed at all in the testsuite.
{-# INLINEABLE rebuild #-}

-- ROMES:TODO: find repair_id could be shared between repair and repairAnal?

-- | Repair a single worklist entry.
repair :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repair :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair (ClassId
repair_id, ENode l
node) EGraph a l
egr =

   -- TODO We're no longer deleting the uncanonicalized node, how much does it matter that the structure keeps growing?

   case ENode l
-> ClassId
-> NodeMap l ClassId
-> (Maybe ClassId, NodeMap l ClassId)
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of

      (Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph a l
egr { memo = memo' } -- new memo with inserted node

      (Just ClassId
existing_class, NodeMap l ClassId
memo') -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd (ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
existing_class ClassId
repair_id EGraph a l
egr{memo = memo'})
{-# INLINE repair #-}

-- | Repair a single analysis-worklist entry.
repairAnal :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal (ClassId
repair_id, ENode l
node) EGraph a l
egr =
    let
        c :: EClass a l
c                = EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
repair_id
        new_data :: a
new_data          = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
iEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
node))
    in
    -- Take action if the new_data is different from the existing data
    if EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
new_data
        -- Merge result is different from original class data, update class
        -- with new_data
       then
         EGraph a l
egr { analysisWorklist = toListSL (c^._parents) <> analysisWorklist egr
             , classes = IM.adjust (_data .~ new_data) repair_id (classes egr)
             }
             EGraph a l -> (EGraph a l -> EGraph a l) -> EGraph a l
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> EGraph a l
forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA ClassId
repair_id
       else EGraph a l
egr
{-# INLINE repairAnal #-}

-- | Canonicalize an e-node
--
-- Two e-nodes are equal when their canonical form is equal. Canonicalization
-- makes the list of e-class ids the e-node holds a list of canonical ids.
-- Meaning two seemingly different e-nodes might be equal when we figure out
-- that their e-class ids are represented by the same e-class canonical ids
--
-- canonicalize(𝑓(𝑎,𝑏,𝑐,...)) = 𝑓((find 𝑎), (find 𝑏), (find 𝑐),...)
canonicalize :: Functor l => ENode l -> EGraph a l -> ENode l
canonicalize :: forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize (Node l ClassId
enode) EGraph a l
eg = l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node (l ClassId -> ENode l) -> l ClassId -> ENode l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> l ClassId -> l ClassId
forall a b. (a -> b) -> l a -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
eg) l ClassId
enode
{-# INLINE canonicalize #-}

-- | Find the canonical representation of an e-class id in the e-graph
--
-- Invariant: The e-class id always exists.
find :: ClassId -> EGraph a l -> ClassId
find :: forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
cid = ClassId -> ReprUnionFind -> ClassId
findRepr ClassId
cid (ReprUnionFind -> ClassId)
-> (EGraph a l -> ReprUnionFind) -> EGraph a l -> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind
{-# INLINE find #-}

-- | The empty e-graph. Nothing is represented in it yet.
emptyEGraph :: Language l => EGraph a l
emptyEGraph :: forall (l :: * -> *) a. Language l => EGraph a l
emptyEGraph = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
emptyUF ClassIdMap (EClass a l)
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}

-- | Creates an empty e-class in an e-graph, with the explicitly given domain analysis data.
-- (That is, an e-class with no e-nodes)
newEClass :: (Language l) => a -> EGraph a l -> (ClassId, EGraph a l)
newEClass :: forall (l :: * -> *) a.
Language l =>
a -> EGraph a l -> (ClassId, EGraph a l)
newEClass a
adata EGraph a l
egr =
  let
    -- Make new equivalence class with a new id in the union-find
    (ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)

    -- New empty e-class stores just the analysis data
    new_eclass :: EClass a l
new_eclass = ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id Set (ENode l)
forall a. Set a
S.empty a
adata SList (ClassId, ENode l)
forall a. Monoid a => a
mempty
   in ( ClassId
new_eclass_id
      , EGraph a l
egr { unionFind = new_uf
            , classes   = IM.insert new_eclass_id new_eclass (classes egr)
            }
      )
{-# INLINE newEClass #-}

----- Monadic operations on e-graphs
-- Unfortunately, these cannot be defined in terms of the primary ones.
-- This could almost be done by defining the domain of the standard Analysis to
-- be (m a), for some Monad m, but this would require an instance Eq (m a),
-- which often won't exist.
--
-- Anyway, this allows us to have a better story for monadic analysis because
-- the type-class functions arguments don't need to be of the same monadic type
-- as the result (unlike if we were using a normal analysis with a monadic domain).
--
-- Be sure to update these functions if the above "canonical" versions are changed.

-- TODO: Move these to new module?

-- * E-graph operations for analysis defined monadically ('AM.AnalysisM')

-- | Like 'represent', but for a monadic analysis
representM :: forall a l m. (AM.AnalysisM m a l, Language l) => Fix l -> EGraph a l -> m (ClassId, EGraph a l)
representM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
Fix l -> EGraph a l -> m (ClassId, EGraph a l)
representM = (l (EGraph a l -> m (ClassId, EGraph a l))
 -> EGraph a l -> m (ClassId, EGraph a l))
-> Fix l -> EGraph a l -> m (ClassId, EGraph a l)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((l (EGraph a l -> m (ClassId, EGraph a l))
  -> EGraph a l -> m (ClassId, EGraph a l))
 -> Fix l -> EGraph a l -> m (ClassId, EGraph a l))
-> (l (EGraph a l -> m (ClassId, EGraph a l))
    -> EGraph a l -> m (ClassId, EGraph a l))
-> Fix l
-> EGraph a l
-> m (ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ \l (EGraph a l -> m (ClassId, EGraph a l))
l EGraph a l
e -> do
  -- Canonical implementation is represent, this is just the monadic variant of it
  (l ClassId
l', EGraph a l
e') <- (StateT (EGraph a l) m (l ClassId)
-> EGraph a l -> m (l ClassId, EGraph a l)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph a l
e) (StateT (EGraph a l) m (l ClassId) -> m (l ClassId, EGraph a l))
-> StateT (EGraph a l) m (l ClassId) -> m (l ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ ((EGraph a l -> m (ClassId, EGraph a l))
 -> StateT (EGraph a l) m ClassId)
-> l (EGraph a l -> m (ClassId, EGraph a l))
-> StateT (EGraph a l) m (l ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse (\EGraph a l -> m (ClassId, EGraph a l)
f -> StateT (EGraph a l) m (EGraph a l)
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT (EGraph a l) m (EGraph a l)
-> (EGraph a l -> StateT (EGraph a l) m (ClassId, EGraph a l))
-> StateT (EGraph a l) m (ClassId, EGraph a l)
forall a b.
StateT (EGraph a l) m a
-> (a -> StateT (EGraph a l) m b) -> StateT (EGraph a l) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m (ClassId, EGraph a l)
-> StateT (EGraph a l) m (ClassId, EGraph a l)
forall (m :: * -> *) a. Monad m => m a -> StateT (EGraph a l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (ClassId, EGraph a l)
 -> StateT (EGraph a l) m (ClassId, EGraph a l))
-> (EGraph a l -> m (ClassId, EGraph a l))
-> EGraph a l
-> StateT (EGraph a l) m (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph a l -> m (ClassId, EGraph a l)
f StateT (EGraph a l) m (ClassId, EGraph a l)
-> ((ClassId, EGraph a l) -> StateT (EGraph a l) m ClassId)
-> StateT (EGraph a l) m ClassId
forall a b.
StateT (EGraph a l) m a
-> (a -> StateT (EGraph a l) m b) -> StateT (EGraph a l) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (EGraph a l -> m (ClassId, EGraph a l))
-> StateT (EGraph a l) m ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((EGraph a l -> m (ClassId, EGraph a l))
 -> StateT (EGraph a l) m ClassId)
-> ((ClassId, EGraph a l) -> EGraph a l -> m (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
-> StateT (EGraph a l) m ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (ClassId, EGraph a l) -> EGraph a l -> m (ClassId, EGraph a l)
forall a b. a -> b -> a
const (m (ClassId, EGraph a l) -> EGraph a l -> m (ClassId, EGraph a l))
-> ((ClassId, EGraph a l) -> m (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
-> EGraph a l
-> m (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) l (EGraph a l -> m (ClassId, EGraph a l))
l
  ENode l -> EGraph a l -> m (ClassId, EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ENode l -> EGraph a l -> m (ClassId, EGraph a l)
addM (l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node l ClassId
l') EGraph a l
e'

-- | Like 'add', but for a monadic analysis
addM :: forall a l m. (AM.AnalysisM m a l, Language l) => ENode l -> EGraph a l -> m (ClassId, EGraph a l)
addM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ENode l -> EGraph a l -> m (ClassId, EGraph a l)
addM ENode l
uncanon_e EGraph a l
egr =
  -- Canonical implementation is add, this is just the monadic variant of it
    let !new_en :: ENode l
new_en = ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize ENode l
uncanon_e EGraph a l
egr

     in case ENode l -> NodeMap l ClassId -> Maybe ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
      Just ClassId
canon_enode_id -> (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
canon_enode_id EGraph a l
egr, EGraph a l
egr)
      Maybe ClassId
Nothing -> do
        let
            (ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)

        a
new_data <- forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
l domain -> m domain
AM.makeA @m @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) a -> a
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
i((EClass a l -> f (EClass a l)) -> EGraph a l -> f (EGraph a l))
-> ((a -> f a) -> EClass a l -> f (EClass a l))
-> (a -> f a)
-> EGraph a l
-> f (EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)

        let
            new_eclass :: EClass a l
new_eclass   =  ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) a
new_data SList (ClassId, ENode l)
forall a. Monoid a => a
mempty
            new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents  = ((ClassId
new_eclass_id, ENode l
new_en) (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. a -> SList a -> SList a
|:)
            new_classes :: IntMap (EClass a l)
new_classes  = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass a l
new_eclass (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> IntMap (EClass a l)
forall a b. (a -> b) -> a -> b
$
                                (ClassId -> IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> l ClassId -> IntMap (EClass a l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr  ((EClass a l -> EClass a l)
-> ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
 -> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents))
                                       (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
                                       (ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)

            new_worklist :: [(ClassId, ENode l)]
new_worklist = (ClassId
new_eclass_id, ENode l
new_en)(ClassId, ENode l) -> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. a -> [a] -> [a]
:EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr

            new_memo :: NodeMap l ClassId
new_memo     = ENode l -> ClassId -> NodeMap l ClassId -> NodeMap l ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr)

        EGraph a l
egr1 <- EGraph a l
egr { unionFind = new_uf
                    , classes   = new_classes
                    , worklist  = new_worklist
                    , memo      = new_memo
                    }
                    EGraph a l -> (EGraph a l -> m (EGraph a l)) -> m (EGraph a l)
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> m (EGraph a l)
forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
ClassId -> EGraph domain l -> m (EGraph domain l)
AM.modifyA ClassId
new_eclass_id

        (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ClassId
new_eclass_id, EGraph a l
egr1 )
{-# INLINABLE addM #-}

-- | Like 'merge', but for a monadic analysis
mergeM :: forall a l m. (AM.AnalysisM m a l, Language l) => ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
mergeM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
mergeM ClassId
a ClassId
b EGraph a l
egr0 = do
  -- Canonical implementation is merge, this is just the monadic variant of it
  let
      a' :: ClassId
a' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
a EGraph a l
egr0
      b' :: ClassId
b' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
b EGraph a l
egr0
   in
   if ClassId
a' ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
b'
     then (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
a', EGraph a l
egr0)
     else do
       let
           class_a :: EClass a l
class_a = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
a'
           class_b :: EClass a l
class_b = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
b'

           (ClassId
leader, EClass a l
leader_class, ClassId
sub, EClass a l
sub_class) =
               if SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_aEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) ClassId -> ClassId -> Bool
forall a. Ord a => a -> a -> Bool
< SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_bEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                  then (ClassId
b', EClass a l
class_b, ClassId
a', EClass a l
class_a) -- b is leader
                  else (ClassId
a', EClass a l
class_a, ClassId
b', EClass a l
class_b) -- a is leader

           (ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr0)
                                (ClassId, ReprUnionFind)
-> ((ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind))
-> (ClassId, ReprUnionFind)
forall a b. a -> (a -> b) -> b
& (ClassId -> ClassId)
-> (ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\ClassId
n -> Bool -> ClassId -> ClassId
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ClassId
leader ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
n) ClassId
n)

       a
new_data <- forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
domain -> domain -> m domain
AM.joinA @m @a @l (EClass a l
leader_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data) (EClass a l
sub_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data)

       let
           updatedLeader :: EClass a l
updatedLeader = EClass a l
leader_class
                             EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
 -> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<>)
                             EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> Identity (Set (ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes   ((Set (ENode l) -> Identity (Set (ENode l)))
 -> EClass a l -> Identity (EClass a l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass a l -> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l -> Lens' (EClass a l) (Set (ENode l)) -> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (Set (ENode l))
_nodes Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<>)
                             EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data    Lens' (EClass a l) a -> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data

           new_classes :: IntMap (EClass a l)
new_classes = (ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass a l
updatedLeader (IntMap (EClass a l) -> IntMap (EClass a l))
-> (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l)
-> IntMap (EClass a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub) (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr0)

           -- Add all subsumed parents to worklist We can do this instead of
           -- adding the new e-class itself to the worklist because it would end
           -- up adding its parents anyway
           new_worklist :: [(ClassId, ENode l)]
new_worklist = SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr0

           -- If the new_data is different from the classes, the parents of the
           -- class whose data is different from the merged must be put on the
           -- analysisWorklist
           new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
             (if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
sub_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data)
                then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
             (if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
leader_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data)
                then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
leader_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
                else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
             EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr0

       -- ROMES:TODO: The code that makes the -1 * cos test pass when some other things are tweaked
       -- new_memo = foldr (`insertNM` leader) (memo egr0) (sub_class^._nodes)

       -- Build new e-graph
       EGraph a l
egr1 <- EGraph a l
egr0 { unionFind = new_uf
                    , classes   = new_classes
                    -- , memo      = new_memo
                    , worklist  = new_worklist
                    , analysisWorklist = new_analysis_worklist
                    }
                    EGraph a l -> (EGraph a l -> m (EGraph a l)) -> m (EGraph a l)
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> m (EGraph a l)
forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
ClassId -> EGraph domain l -> m (EGraph domain l)
AM.modifyA ClassId
new_id

       (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
new_id, EGraph a l
egr1)
{-# INLINEABLE mergeM #-}

-- | Like 'rebuild', but for a monadic analysis
rebuildM :: forall a l m. (AM.AnalysisM m a l, Language l) => EGraph a l -> m (EGraph a l)
rebuildM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
EGraph a l -> m (EGraph a l)
rebuildM (EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
wl Worklist l
awl) = do
  -- Canonical implementation is rebuild, this is just the monadic variant of it
  let
    emptiedEgr :: EGraph a l
emptiedEgr = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty

    wl' :: Worklist l
wl' = Worklist l -> Worklist l
forall a. Ord a => [a] -> [a]
nubOrd (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId)
-> (ENode l -> ENode l) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
emptiedEgr) (ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
`canonicalize` EGraph a l
emptiedEgr) ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl

  EGraph a l
egr'  <- (EGraph a l -> (ClassId, ENode l) -> m (EGraph a l))
-> EGraph a l -> Worklist l -> m (EGraph a l)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (((ClassId, ENode l) -> EGraph a l -> m (EGraph a l))
-> EGraph a l -> (ClassId, ENode l) -> m (EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairM) EGraph a l
emptiedEgr Worklist l
wl'

  let awl' :: Worklist l
awl' = ((ClassId, ENode l) -> ClassId) -> Worklist l -> Worklist l
forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn (ClassId, ENode l) -> ClassId
forall a b. (a, b) -> a
fst (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
egr') ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl

  EGraph a l
egr'' <- (EGraph a l -> (ClassId, ENode l) -> m (EGraph a l))
-> EGraph a l -> Worklist l -> m (EGraph a l)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (((ClassId, ENode l) -> EGraph a l -> m (EGraph a l))
-> EGraph a l -> (ClassId, ENode l) -> m (EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairAnalM) EGraph a l
egr' Worklist l
awl'

  if Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr'')
     then EGraph a l -> m (EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return EGraph a l
egr''
     else EGraph a l -> m (EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
EGraph a l -> m (EGraph a l)
rebuildM EGraph a l
egr''
{-# INLINEABLE rebuildM #-}

-- | Like 'repair', but for a monadic analysis
repairM :: forall a l m. (AM.AnalysisM m a l, Language l) => (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairM (ClassId
repair_id, ENode l
node) EGraph a l
egr =
  -- Canonical implementation is repair, this is just the monadic variant of it
   case ENode l
-> ClassId
-> NodeMap l ClassId
-> (Maybe ClassId, NodeMap l ClassId)
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of

      (Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph a l -> m (EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (EGraph a l -> m (EGraph a l)) -> EGraph a l -> m (EGraph a l)
forall a b. (a -> b) -> a -> b
$ EGraph a l
egr { memo = memo' }

      (Just ClassId
existing_class, NodeMap l ClassId
memo') -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd ((ClassId, EGraph a l) -> EGraph a l)
-> m (ClassId, EGraph a l) -> m (EGraph a l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
mergeM ClassId
existing_class ClassId
repair_id EGraph a l
egr{memo = memo'})
{-# INLINE repairM #-}

-- | Like 'repairAnal', but for a monadic analysis
repairAnalM :: forall a l m. (AM.AnalysisM m a l, Language l) => (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairAnalM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairAnalM (ClassId
repair_id, ENode l
node) EGraph a l
egr = do
  -- Canonical implementation is repairAnal, this is just the monadic variant of it
    let c :: EClass a l
c = EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
repair_id

    a
new_data <- forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
domain -> domain -> m domain
AM.joinA @m @a @l (EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data) (a -> m a) -> m a -> m a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
l domain -> m domain
AM.makeA @m @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
iEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
node)

    if EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
new_data
       then
         EGraph a l
egr { analysisWorklist = toListSL (c^._parents) <> analysisWorklist egr
             , classes = IM.adjust (_data .~ new_data) repair_id (classes egr)
             }
             EGraph a l -> (EGraph a l -> m (EGraph a l)) -> m (EGraph a l)
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> m (EGraph a l)
forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
ClassId -> EGraph domain l -> m (EGraph domain l)
AM.modifyA ClassId
repair_id
       else
        EGraph a l -> m (EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return EGraph a l
egr
{-# INLINE repairAnalM #-}