{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_GHC -ddump-to-file -ddump-simpl #-}
module Data.Equality.Graph
(
EGraph
, represent, add, merge, rebuild
, find, canonicalize
, emptyEGraph, newEClass
, representM, addM, mergeM, rebuildM
, module Data.Equality.Graph.Classes
, module Data.Equality.Graph.Nodes
, module Data.Equality.Language
) where
import Prelude hiding (lookup)
import Data.Function
import Data.Foldable (foldlM)
import Data.Bifunctor
import Data.Containers.ListUtils
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Control.Exception (assert)
import qualified Data.IntMap.Strict as IM
import qualified Data.Set as S
import Data.Equality.Utils.SizedList
import Data.Equality.Graph.Internal
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import qualified Data.Equality.Analysis.Monadic as AM
import Data.Equality.Language
import Data.Equality.Graph.Lens
import Data.Equality.Utils
represent :: forall a l. (Analysis a l, Language l) => Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent = (l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l -> (ClassId, EGraph a l))
-> Fix l -> EGraph a l -> (ClassId, EGraph a l)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((EGraph a l
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((EGraph a l
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l))
-> (EGraph a l
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ \EGraph a l
e -> (ENode l -> EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l) -> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ENode l -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ((ENode l, EGraph a l) -> (ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (l ClassId -> ENode l)
-> (l ClassId, EGraph a l) -> (ENode l, EGraph a l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node ((l ClassId, EGraph a l) -> (ENode l, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
-> (l ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State (EGraph a l) (l ClassId)
-> EGraph a l -> (l ClassId, EGraph a l)
forall s a. State s a -> s -> (a, s)
`runState` EGraph a l
e) (State (EGraph a l) (l ClassId) -> (l ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
-> State (EGraph a l) (l ClassId))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (l ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity ClassId)
-> l (EGraph a l -> (ClassId, EGraph a l))
-> State (EGraph a l) (l ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity (ClassId, EGraph a l)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity (ClassId, EGraph a l))
-> ((ClassId, EGraph a l) -> StateT (EGraph a l) Identity ClassId)
-> (EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity ClassId
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> \(ClassId
x,EGraph a l
e') -> ClassId
x ClassId
-> StateT (EGraph a l) Identity ()
-> StateT (EGraph a l) Identity ClassId
forall a b.
a
-> StateT (EGraph a l) Identity b -> StateT (EGraph a l) Identity a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ EGraph a l -> StateT (EGraph a l) Identity ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put EGraph a l
e'))
add :: forall a l. (Analysis a l, Language l) => ENode l -> EGraph a l -> (ClassId, EGraph a l)
add :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ENode l
uncanon_e EGraph a l
egr =
let !new_en :: ENode l
new_en = ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize ENode l
uncanon_e EGraph a l
egr
in case ENode l -> NodeMap l ClassId -> Maybe ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
Just ClassId
canon_enode_id -> (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
canon_enode_id EGraph a l
egr, EGraph a l
egr)
Maybe ClassId
Nothing ->
let
(ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)
new_eclass :: EClass a l
new_eclass = ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) a -> a
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
i((EClass a l -> f (EClass a l)) -> EGraph a l -> f (EGraph a l))
-> ((a -> f a) -> EClass a l -> f (EClass a l))
-> (a -> f a)
-> EGraph a l
-> f (EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)) SList (ClassId, ENode l)
forall a. Monoid a => a
mempty
new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents = ((ClassId
new_eclass_id, ENode l
new_en) (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. a -> SList a -> SList a
|:)
new_classes :: IntMap (EClass a l)
new_classes = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass a l
new_eclass (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> IntMap (EClass a l)
forall a b. (a -> b) -> a -> b
$
(ClassId -> IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> l ClassId -> IntMap (EClass a l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((EClass a l -> EClass a l)
-> ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents))
(EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
(ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)
new_worklist :: [(ClassId, ENode l)]
new_worklist = (ClassId
new_eclass_id, ENode l
new_en)(ClassId, ENode l) -> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. a -> [a] -> [a]
:EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr
new_memo :: NodeMap l ClassId
new_memo = ENode l -> ClassId -> NodeMap l ClassId -> NodeMap l ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr)
in ( ClassId
new_eclass_id
, EGraph a l
egr { unionFind = new_uf
, classes = new_classes
, worklist = new_worklist
, memo = new_memo
}
EGraph a l -> (EGraph a l -> EGraph a l) -> EGraph a l
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> EGraph a l
forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA ClassId
new_eclass_id
)
{-# INLINABLE add #-}
merge :: forall a l. (Analysis a l, Language l) => ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
a ClassId
b EGraph a l
egr0 =
let
a' :: ClassId
a' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
a EGraph a l
egr0
b' :: ClassId
b' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
b EGraph a l
egr0
in
if ClassId
a' ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
b'
then (ClassId
a', EGraph a l
egr0)
else
let
class_a :: EClass a l
class_a = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
a'
class_b :: EClass a l
class_b = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
b'
(ClassId
leader, EClass a l
leader_class, ClassId
sub, EClass a l
sub_class) =
if SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_aEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) ClassId -> ClassId -> Bool
forall a. Ord a => a -> a -> Bool
< SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_bEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
then (ClassId
b', EClass a l
class_b, ClassId
a', EClass a l
class_a)
else (ClassId
a', EClass a l
class_a, ClassId
b', EClass a l
class_b)
(ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr0)
(ClassId, ReprUnionFind)
-> ((ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind))
-> (ClassId, ReprUnionFind)
forall a b. a -> (a -> b) -> b
& (ClassId -> ClassId)
-> (ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\ClassId
n -> Bool -> ClassId -> ClassId
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ClassId
leader ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
n) ClassId
n)
updatedLeader :: EClass a l
updatedLeader = EClass a l
leader_class
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<>)
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> Identity (Set (ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes ((Set (ENode l) -> Identity (Set (ENode l)))
-> EClass a l -> Identity (EClass a l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass a l -> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l -> Lens' (EClass a l) (Set (ENode l)) -> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (Set (ENode l))
_nodes Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<>)
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data
new_data :: a
new_data = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data) (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
new_classes :: IntMap (EClass a l)
new_classes = (ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass a l
updatedLeader (IntMap (EClass a l) -> IntMap (EClass a l))
-> (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l)
-> IntMap (EClass a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub) (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr0)
new_worklist :: [(ClassId, ENode l)]
new_worklist = SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr0
new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
(if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
(if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
leader_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr0
egr1 :: EGraph a l
egr1 = EGraph a l
egr0
{ unionFind = new_uf
, classes = new_classes
, worklist = new_worklist
, analysisWorklist = new_analysis_worklist
}
EGraph a l -> (EGraph a l -> EGraph a l) -> EGraph a l
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> EGraph a l
forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA ClassId
new_id
in (ClassId
new_id, EGraph a l
egr1)
{-# INLINEABLE merge #-}
rebuild :: (Analysis a l, Language l) => EGraph a l -> EGraph a l
rebuild :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
wl Worklist l
awl) =
let
emptiedEgr :: EGraph a l
emptiedEgr = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty
wl' :: Worklist l
wl' = Worklist l -> Worklist l
forall a. Ord a => [a] -> [a]
nubOrd (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId)
-> (ENode l -> ENode l) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
emptiedEgr) (ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
`canonicalize` EGraph a l
emptiedEgr) ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl
egr' :: EGraph a l
egr' = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair EGraph a l
emptiedEgr Worklist l
wl'
awl' :: Worklist l
awl' = ((ClassId, ENode l) -> ClassId) -> Worklist l -> Worklist l
forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn (ClassId, ENode l) -> ClassId
forall a b. (a, b) -> a
fst (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
egr') ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl
egr'' :: EGraph a l
egr'' = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal EGraph a l
egr' Worklist l
awl'
in
if Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr'')
then EGraph a l
egr''
else EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild EGraph a l
egr''
{-# INLINEABLE rebuild #-}
repair :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repair :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair (ClassId
repair_id, ENode l
node) EGraph a l
egr =
case ENode l
-> ClassId
-> NodeMap l ClassId
-> (Maybe ClassId, NodeMap l ClassId)
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
(Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph a l
egr { memo = memo' }
(Just ClassId
existing_class, NodeMap l ClassId
memo') -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd (ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
existing_class ClassId
repair_id EGraph a l
egr{memo = memo'})
{-# INLINE repair #-}
repairAnal :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal (ClassId
repair_id, ENode l
node) EGraph a l
egr =
let
c :: EClass a l
c = EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
repair_id
new_data :: a
new_data = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
iEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
node))
in
if EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
new_data
then
EGraph a l
egr { analysisWorklist = toListSL (c^._parents) <> analysisWorklist egr
, classes = IM.adjust (_data .~ new_data) repair_id (classes egr)
}
EGraph a l -> (EGraph a l -> EGraph a l) -> EGraph a l
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> EGraph a l
forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA ClassId
repair_id
else EGraph a l
egr
{-# INLINE repairAnal #-}
canonicalize :: Functor l => ENode l -> EGraph a l -> ENode l
canonicalize :: forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize (Node l ClassId
enode) EGraph a l
eg = l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node (l ClassId -> ENode l) -> l ClassId -> ENode l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> l ClassId -> l ClassId
forall a b. (a -> b) -> l a -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
eg) l ClassId
enode
{-# INLINE canonicalize #-}
find :: ClassId -> EGraph a l -> ClassId
find :: forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
cid = ClassId -> ReprUnionFind -> ClassId
findRepr ClassId
cid (ReprUnionFind -> ClassId)
-> (EGraph a l -> ReprUnionFind) -> EGraph a l -> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind
{-# INLINE find #-}
emptyEGraph :: Language l => EGraph a l
emptyEGraph :: forall (l :: * -> *) a. Language l => EGraph a l
emptyEGraph = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
emptyUF ClassIdMap (EClass a l)
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}
newEClass :: (Language l) => a -> EGraph a l -> (ClassId, EGraph a l)
newEClass :: forall (l :: * -> *) a.
Language l =>
a -> EGraph a l -> (ClassId, EGraph a l)
newEClass a
adata EGraph a l
egr =
let
(ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)
new_eclass :: EClass a l
new_eclass = ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id Set (ENode l)
forall a. Set a
S.empty a
adata SList (ClassId, ENode l)
forall a. Monoid a => a
mempty
in ( ClassId
new_eclass_id
, EGraph a l
egr { unionFind = new_uf
, classes = IM.insert new_eclass_id new_eclass (classes egr)
}
)
{-# INLINE newEClass #-}
representM :: forall a l m. (AM.AnalysisM m a l, Language l) => Fix l -> EGraph a l -> m (ClassId, EGraph a l)
representM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
Fix l -> EGraph a l -> m (ClassId, EGraph a l)
representM = (l (EGraph a l -> m (ClassId, EGraph a l))
-> EGraph a l -> m (ClassId, EGraph a l))
-> Fix l -> EGraph a l -> m (ClassId, EGraph a l)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((l (EGraph a l -> m (ClassId, EGraph a l))
-> EGraph a l -> m (ClassId, EGraph a l))
-> Fix l -> EGraph a l -> m (ClassId, EGraph a l))
-> (l (EGraph a l -> m (ClassId, EGraph a l))
-> EGraph a l -> m (ClassId, EGraph a l))
-> Fix l
-> EGraph a l
-> m (ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ \l (EGraph a l -> m (ClassId, EGraph a l))
l EGraph a l
e -> do
(l ClassId
l', EGraph a l
e') <- (StateT (EGraph a l) m (l ClassId)
-> EGraph a l -> m (l ClassId, EGraph a l)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph a l
e) (StateT (EGraph a l) m (l ClassId) -> m (l ClassId, EGraph a l))
-> StateT (EGraph a l) m (l ClassId) -> m (l ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ ((EGraph a l -> m (ClassId, EGraph a l))
-> StateT (EGraph a l) m ClassId)
-> l (EGraph a l -> m (ClassId, EGraph a l))
-> StateT (EGraph a l) m (l ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse (\EGraph a l -> m (ClassId, EGraph a l)
f -> StateT (EGraph a l) m (EGraph a l)
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT (EGraph a l) m (EGraph a l)
-> (EGraph a l -> StateT (EGraph a l) m (ClassId, EGraph a l))
-> StateT (EGraph a l) m (ClassId, EGraph a l)
forall a b.
StateT (EGraph a l) m a
-> (a -> StateT (EGraph a l) m b) -> StateT (EGraph a l) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m (ClassId, EGraph a l)
-> StateT (EGraph a l) m (ClassId, EGraph a l)
forall (m :: * -> *) a. Monad m => m a -> StateT (EGraph a l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (ClassId, EGraph a l)
-> StateT (EGraph a l) m (ClassId, EGraph a l))
-> (EGraph a l -> m (ClassId, EGraph a l))
-> EGraph a l
-> StateT (EGraph a l) m (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph a l -> m (ClassId, EGraph a l)
f StateT (EGraph a l) m (ClassId, EGraph a l)
-> ((ClassId, EGraph a l) -> StateT (EGraph a l) m ClassId)
-> StateT (EGraph a l) m ClassId
forall a b.
StateT (EGraph a l) m a
-> (a -> StateT (EGraph a l) m b) -> StateT (EGraph a l) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (EGraph a l -> m (ClassId, EGraph a l))
-> StateT (EGraph a l) m ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((EGraph a l -> m (ClassId, EGraph a l))
-> StateT (EGraph a l) m ClassId)
-> ((ClassId, EGraph a l) -> EGraph a l -> m (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
-> StateT (EGraph a l) m ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (ClassId, EGraph a l) -> EGraph a l -> m (ClassId, EGraph a l)
forall a b. a -> b -> a
const (m (ClassId, EGraph a l) -> EGraph a l -> m (ClassId, EGraph a l))
-> ((ClassId, EGraph a l) -> m (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
-> EGraph a l
-> m (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) l (EGraph a l -> m (ClassId, EGraph a l))
l
ENode l -> EGraph a l -> m (ClassId, EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ENode l -> EGraph a l -> m (ClassId, EGraph a l)
addM (l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node l ClassId
l') EGraph a l
e'
addM :: forall a l m. (AM.AnalysisM m a l, Language l) => ENode l -> EGraph a l -> m (ClassId, EGraph a l)
addM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ENode l -> EGraph a l -> m (ClassId, EGraph a l)
addM ENode l
uncanon_e EGraph a l
egr =
let !new_en :: ENode l
new_en = ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize ENode l
uncanon_e EGraph a l
egr
in case ENode l -> NodeMap l ClassId -> Maybe ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
Just ClassId
canon_enode_id -> (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
canon_enode_id EGraph a l
egr, EGraph a l
egr)
Maybe ClassId
Nothing -> do
let
(ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)
a
new_data <- forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
l domain -> m domain
AM.makeA @m @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) a -> a
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
i((EClass a l -> f (EClass a l)) -> EGraph a l -> f (EGraph a l))
-> ((a -> f a) -> EClass a l -> f (EClass a l))
-> (a -> f a)
-> EGraph a l
-> f (EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)
let
new_eclass :: EClass a l
new_eclass = ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) a
new_data SList (ClassId, ENode l)
forall a. Monoid a => a
mempty
new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents = ((ClassId
new_eclass_id, ENode l
new_en) (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. a -> SList a -> SList a
|:)
new_classes :: IntMap (EClass a l)
new_classes = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass a l
new_eclass (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> IntMap (EClass a l)
forall a b. (a -> b) -> a -> b
$
(ClassId -> IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> l ClassId -> IntMap (EClass a l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((EClass a l -> EClass a l)
-> ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents))
(EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
(ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)
new_worklist :: [(ClassId, ENode l)]
new_worklist = (ClassId
new_eclass_id, ENode l
new_en)(ClassId, ENode l) -> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. a -> [a] -> [a]
:EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr
new_memo :: NodeMap l ClassId
new_memo = ENode l -> ClassId -> NodeMap l ClassId -> NodeMap l ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr)
EGraph a l
egr1 <- EGraph a l
egr { unionFind = new_uf
, classes = new_classes
, worklist = new_worklist
, memo = new_memo
}
EGraph a l -> (EGraph a l -> m (EGraph a l)) -> m (EGraph a l)
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> m (EGraph a l)
forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
ClassId -> EGraph domain l -> m (EGraph domain l)
AM.modifyA ClassId
new_eclass_id
(ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ClassId
new_eclass_id, EGraph a l
egr1 )
{-# INLINABLE addM #-}
mergeM :: forall a l m. (AM.AnalysisM m a l, Language l) => ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
mergeM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
mergeM ClassId
a ClassId
b EGraph a l
egr0 = do
let
a' :: ClassId
a' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
a EGraph a l
egr0
b' :: ClassId
b' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
b EGraph a l
egr0
in
if ClassId
a' ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
b'
then (ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
a', EGraph a l
egr0)
else do
let
class_a :: EClass a l
class_a = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
a'
class_b :: EClass a l
class_b = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
b'
(ClassId
leader, EClass a l
leader_class, ClassId
sub, EClass a l
sub_class) =
if SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_aEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) ClassId -> ClassId -> Bool
forall a. Ord a => a -> a -> Bool
< SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_bEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
then (ClassId
b', EClass a l
class_b, ClassId
a', EClass a l
class_a)
else (ClassId
a', EClass a l
class_a, ClassId
b', EClass a l
class_b)
(ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr0)
(ClassId, ReprUnionFind)
-> ((ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind))
-> (ClassId, ReprUnionFind)
forall a b. a -> (a -> b) -> b
& (ClassId -> ClassId)
-> (ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\ClassId
n -> Bool -> ClassId -> ClassId
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ClassId
leader ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
n) ClassId
n)
a
new_data <- forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
domain -> domain -> m domain
AM.joinA @m @a @l (EClass a l
leader_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data) (EClass a l
sub_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data)
let
updatedLeader :: EClass a l
updatedLeader = EClass a l
leader_class
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents ((SList (ClassId, ENode l) -> Identity (SList (ClassId, ENode l)))
-> EClass a l -> Identity (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<>)
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> Identity (Set (ENode l)))
-> EClass a l -> Identity (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes ((Set (ENode l) -> Identity (Set (ENode l)))
-> EClass a l -> Identity (EClass a l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass a l -> EClass a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (EClass a l
sub_classEClass a l -> Lens' (EClass a l) (Set (ENode l)) -> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (Set (ENode l))
_nodes Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<>)
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data Lens' (EClass a l) a -> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data
new_classes :: IntMap (EClass a l)
new_classes = (ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass a l
updatedLeader (IntMap (EClass a l) -> IntMap (EClass a l))
-> (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l)
-> IntMap (EClass a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub) (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr0)
new_worklist :: [(ClassId, ENode l)]
new_worklist = SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr0
new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
(if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
sub_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data)
then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
(if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
leader_classEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data)
then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
leader_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr0
EGraph a l
egr1 <- EGraph a l
egr0 { unionFind = new_uf
, classes = new_classes
, worklist = new_worklist
, analysisWorklist = new_analysis_worklist
}
EGraph a l -> (EGraph a l -> m (EGraph a l)) -> m (EGraph a l)
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> m (EGraph a l)
forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
ClassId -> EGraph domain l -> m (EGraph domain l)
AM.modifyA ClassId
new_id
(ClassId, EGraph a l) -> m (ClassId, EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
new_id, EGraph a l
egr1)
{-# INLINEABLE mergeM #-}
rebuildM :: forall a l m. (AM.AnalysisM m a l, Language l) => EGraph a l -> m (EGraph a l)
rebuildM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
EGraph a l -> m (EGraph a l)
rebuildM (EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
wl Worklist l
awl) = do
let
emptiedEgr :: EGraph a l
emptiedEgr = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty
wl' :: Worklist l
wl' = Worklist l -> Worklist l
forall a. Ord a => [a] -> [a]
nubOrd (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId)
-> (ENode l -> ENode l) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
emptiedEgr) (ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
`canonicalize` EGraph a l
emptiedEgr) ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl
EGraph a l
egr' <- (EGraph a l -> (ClassId, ENode l) -> m (EGraph a l))
-> EGraph a l -> Worklist l -> m (EGraph a l)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (((ClassId, ENode l) -> EGraph a l -> m (EGraph a l))
-> EGraph a l -> (ClassId, ENode l) -> m (EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairM) EGraph a l
emptiedEgr Worklist l
wl'
let awl' :: Worklist l
awl' = ((ClassId, ENode l) -> ClassId) -> Worklist l -> Worklist l
forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn (ClassId, ENode l) -> ClassId
forall a b. (a, b) -> a
fst (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
egr') ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl
EGraph a l
egr'' <- (EGraph a l -> (ClassId, ENode l) -> m (EGraph a l))
-> EGraph a l -> Worklist l -> m (EGraph a l)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (((ClassId, ENode l) -> EGraph a l -> m (EGraph a l))
-> EGraph a l -> (ClassId, ENode l) -> m (EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairAnalM) EGraph a l
egr' Worklist l
awl'
if Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr'')
then EGraph a l -> m (EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return EGraph a l
egr''
else EGraph a l -> m (EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
EGraph a l -> m (EGraph a l)
rebuildM EGraph a l
egr''
{-# INLINEABLE rebuildM #-}
repairM :: forall a l m. (AM.AnalysisM m a l, Language l) => (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairM (ClassId
repair_id, ENode l
node) EGraph a l
egr =
case ENode l
-> ClassId
-> NodeMap l ClassId
-> (Maybe ClassId, NodeMap l ClassId)
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
(Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph a l -> m (EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (EGraph a l -> m (EGraph a l)) -> EGraph a l -> m (EGraph a l)
forall a b. (a -> b) -> a -> b
$ EGraph a l
egr { memo = memo' }
(Just ClassId
existing_class, NodeMap l ClassId
memo') -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd ((ClassId, EGraph a l) -> EGraph a l)
-> m (ClassId, EGraph a l) -> m (EGraph a l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
mergeM ClassId
existing_class ClassId
repair_id EGraph a l
egr{memo = memo'})
{-# INLINE repairM #-}
repairAnalM :: forall a l m. (AM.AnalysisM m a l, Language l) => (ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairAnalM :: forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> m (EGraph a l)
repairAnalM (ClassId
repair_id, ENode l
node) EGraph a l
egr = do
let c :: EClass a l
c = EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
repair_id
a
new_data <- forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
domain -> domain -> m domain
AM.joinA @m @a @l (EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data) (a -> m a) -> m a -> m a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
l domain -> m domain
AM.makeA @m @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
iEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
node)
if EClass a l
cEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
Lens' (EClass a l) a
_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
new_data
then
EGraph a l
egr { analysisWorklist = toListSL (c^._parents) <> analysisWorklist egr
, classes = IM.adjust (_data .~ new_data) repair_id (classes egr)
}
EGraph a l -> (EGraph a l -> m (EGraph a l)) -> m (EGraph a l)
forall a b. a -> (a -> b) -> b
& ClassId -> EGraph a l -> m (EGraph a l)
forall (m :: * -> *) domain (l :: * -> *).
AnalysisM m domain l =>
ClassId -> EGraph domain l -> m (EGraph domain l)
AM.modifyA ClassId
repair_id
else
EGraph a l -> m (EGraph a l)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return EGraph a l
egr
{-# INLINE repairAnalM #-}