{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE NamedFieldPuns #-}
module Data.Equality.Graph
(
EGraph
, emptyEGraph
, represent, add, merge, rebuild
, find, canonicalize
, module Data.Equality.Graph.Classes
, module Data.Equality.Graph.Nodes
, module Data.Equality.Language
) where
import Prelude hiding (lookup)
import Data.Function
import Data.Bifunctor
import Data.Containers.ListUtils
import Control.Monad
import Control.Monad.Trans.State
import Control.Exception (assert)
import qualified Data.IntMap.Strict as IM
import qualified Data.Set as S
import Data.Equality.Utils.SizedList
import Data.Equality.Graph.Internal
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Graph.Lens
import Data.Equality.Utils
represent :: forall a l. (Analysis a l, Language l) => Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent = (l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l -> (ClassId, EGraph a l))
-> Fix l -> EGraph a l -> (ClassId, EGraph a l)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((EGraph a l
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((EGraph a l
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l))
-> (EGraph a l
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> EGraph a l
-> (ClassId, EGraph a l)
forall a b. (a -> b) -> a -> b
$ \EGraph a l
e -> (ENode l -> EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l) -> (ClassId, EGraph a l)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ENode l -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ((ENode l, EGraph a l) -> (ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (l ClassId -> ENode l)
-> (l ClassId, EGraph a l) -> (ENode l, EGraph a l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node ((l ClassId, EGraph a l) -> (ENode l, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
-> (l ClassId, EGraph a l))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (ENode l, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State (EGraph a l) (l ClassId)
-> EGraph a l -> (l ClassId, EGraph a l)
forall s a. State s a -> s -> (a, s)
`runState` EGraph a l
e) (State (EGraph a l) (l ClassId) -> (l ClassId, EGraph a l))
-> (l (EGraph a l -> (ClassId, EGraph a l))
-> State (EGraph a l) (l ClassId))
-> l (EGraph a l -> (ClassId, EGraph a l))
-> (l ClassId, EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity ClassId)
-> l (EGraph a l -> (ClassId, EGraph a l))
-> State (EGraph a l) (l ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity (ClassId, EGraph a l)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity (ClassId, EGraph a l))
-> ((ClassId, EGraph a l) -> StateT (EGraph a l) Identity ClassId)
-> (EGraph a l -> (ClassId, EGraph a l))
-> StateT (EGraph a l) Identity ClassId
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> \(ClassId
x,EGraph a l
e') -> ClassId
x ClassId
-> StateT (EGraph a l) Identity ()
-> StateT (EGraph a l) Identity ClassId
forall a b.
a
-> StateT (EGraph a l) Identity b -> StateT (EGraph a l) Identity a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ EGraph a l -> StateT (EGraph a l) Identity ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put EGraph a l
e'))
add :: forall a l. (Analysis a l, Language l) => ENode l -> EGraph a l -> (ClassId, EGraph a l)
add :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
add ENode l
uncanon_e EGraph a l
egr =
let !new_en :: ENode l
new_en = ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize ENode l
uncanon_e EGraph a l
egr
in case ENode l -> NodeMap l ClassId -> Maybe ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
Just ClassId
canon_enode_id -> (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
canon_enode_id EGraph a l
egr, EGraph a l
egr)
Maybe ClassId
Nothing ->
let
(ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr)
(EClass a l
new_eclass, [Fix l]
added_nodes) = EClass a l -> (EClass a l, [Fix l])
forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA (EClass a l -> (EClass a l, [Fix l]))
-> EClass a l -> (EClass a l, [Fix l])
forall a b. (a -> b) -> a -> b
$ ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) a -> a
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
i((EClass a l -> f (EClass a l)) -> EGraph a l -> f (EGraph a l))
-> ((a -> f a) -> EClass a l -> f (EClass a l))
-> (a -> f a)
-> EGraph a l
-> f (EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)) SList (ClassId, ENode l)
forall a. Monoid a => a
mempty
new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents = ((ClassId
new_eclass_id, ENode l
new_en) (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. a -> SList a -> SList a
|:)
new_classes :: IntMap (EClass a l)
new_classes = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass a l
new_eclass (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> IntMap (EClass a l)
forall a b. (a -> b) -> a -> b
$
(ClassId -> IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l) -> l ClassId -> IntMap (EClass a l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((EClass a l -> EClass a l)
-> ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall {f :: * -> *}.
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents (forall {f :: * -> *}.
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents))
(EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
(ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)
new_worklist :: [(ClassId, ENode l)]
new_worklist = (ClassId
new_eclass_id, ENode l
new_en)(ClassId, ENode l) -> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. a -> [a] -> [a]
:EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr
new_memo :: NodeMap l ClassId
new_memo = ENode l -> ClassId -> NodeMap l ClassId -> NodeMap l ClassId
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr)
egr1 :: EGraph a l
egr1 = EGraph a l
egr { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
, classes :: IntMap (EClass a l)
classes = IntMap (EClass a l)
new_classes
, worklist :: [(ClassId, ENode l)]
worklist = [(ClassId, ENode l)]
new_worklist
, memo :: NodeMap l ClassId
memo = NodeMap l ClassId
new_memo
}
egr2 :: EGraph a l
egr2 = (Fix l -> EGraph a l -> EGraph a l)
-> EGraph a l -> [Fix l] -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId -> Fix l -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
new_eclass_id) EGraph a l
egr1 [Fix l]
added_nodes
in ( ClassId
new_eclass_id
, EGraph a l
egr2
)
{-# INLINABLE add #-}
merge :: forall a l. (Analysis a l, Language l) => ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
a ClassId
b EGraph a l
egr0 =
let
a' :: ClassId
a' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
a EGraph a l
egr0
b' :: ClassId
b' = ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
b EGraph a l
egr0
in
if ClassId
a' ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
b'
then (ClassId
a', EGraph a l
egr0)
else
let
class_a :: EClass a l
class_a = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
a'
class_b :: EClass a l
class_b = EGraph a l
egr0 EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
b'
(ClassId
leader, EClass a l
leader_class, ClassId
sub, EClass a l
sub_class) =
if SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_aEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) ClassId -> ClassId -> Bool
forall a. Ord a => a -> a -> Bool
< SList (ClassId, ENode l) -> ClassId
forall a. SList a -> ClassId
sizeSL (EClass a l
class_bEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
then (ClassId
b', EClass a l
class_b, ClassId
a', EClass a l
class_a)
else (ClassId
a', EClass a l
class_a, ClassId
b', EClass a l
class_b)
(ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
egr0)
(ClassId, ReprUnionFind)
-> ((ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind))
-> (ClassId, ReprUnionFind)
forall a b. a -> (a -> b) -> b
& (ClassId -> ClassId)
-> (ClassId, ReprUnionFind) -> (ClassId, ReprUnionFind)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\ClassId
n -> Bool -> ClassId -> ClassId
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ClassId
leader ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
n) ClassId
n)
(EClass a l
updatedLeader, [Fix l]
added_nodes) = EClass a l
leader_class
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents Lens' (EClass a l) (SList (ClassId, ENode l))
-> (SList (ClassId, ENode l) -> SList (ClassId, ENode l))
-> EClass a l
-> EClass a l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<>)
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes (forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass a l -> EClass a l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l))
-> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<>)
EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data
EClass a l
-> (EClass a l -> (EClass a l, [Fix l])) -> (EClass a l, [Fix l])
forall a b. a -> (a -> b) -> b
& EClass a l -> (EClass a l, [Fix l])
forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA
new_data :: a
new_data = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data) (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
new_classes :: IntMap (EClass a l)
new_classes = (ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass a l
updatedLeader (IntMap (EClass a l) -> IntMap (EClass a l))
-> (IntMap (EClass a l) -> IntMap (EClass a l))
-> IntMap (EClass a l)
-> IntMap (EClass a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClassId -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub) (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr0)
new_worklist :: [(ClassId, ENode l)]
new_worklist = SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr0
new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
(if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
sub_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
sub_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
(if a
new_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass a l
leader_classEClass a l
-> (forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l))
-> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
forall {f :: * -> *}.
Functor f =>
(a -> f a) -> EClass a l -> f (EClass a l)
_data)
then SList (ClassId, ENode l) -> [(ClassId, ENode l)]
forall a. SList a -> [a]
toListSL (EClass a l
leader_classEClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents)
else [(ClassId, ENode l)]
forall a. Monoid a => a
mempty) [(ClassId, ENode l)]
-> [(ClassId, ENode l)] -> [(ClassId, ENode l)]
forall a. Semigroup a => a -> a -> a
<>
EGraph a l -> [(ClassId, ENode l)]
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr0
egr1 :: EGraph a l
egr1 = EGraph a l
egr0
{ unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
, classes :: IntMap (EClass a l)
classes = IntMap (EClass a l)
new_classes
, worklist :: [(ClassId, ENode l)]
worklist = [(ClassId, ENode l)]
new_worklist
, analysisWorklist :: [(ClassId, ENode l)]
analysisWorklist = [(ClassId, ENode l)]
new_analysis_worklist
}
egr2 :: EGraph a l
egr2 = (Fix l -> EGraph a l -> EGraph a l)
-> EGraph a l -> [Fix l] -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId -> Fix l -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
leader) EGraph a l
egr1 [Fix l]
added_nodes
in (ClassId
new_id, EGraph a l
egr2)
{-# INLINEABLE merge #-}
rebuild :: (Analysis a l, Language l) => EGraph a l -> EGraph a l
rebuild :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
wl Worklist l
awl) =
let
emptiedEgr :: EGraph a l
emptiedEgr = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
uf ClassIdMap (EClass a l)
cls Memo l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty
wl' :: Worklist l
wl' = Worklist l -> Worklist l
forall a. Ord a => [a] -> [a]
nubOrd (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId)
-> (ENode l -> ENode l) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
emptiedEgr) (ENode l -> EGraph a l -> ENode l
forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
`canonicalize` EGraph a l
emptiedEgr) ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl
egr' :: EGraph a l
egr' = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair EGraph a l
emptiedEgr Worklist l
wl'
awl' :: Worklist l
awl' = ((ClassId, ENode l) -> ClassId) -> Worklist l -> Worklist l
forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn (ClassId, ENode l) -> ClassId
forall a b. (a, b) -> a
fst (Worklist l -> Worklist l) -> Worklist l -> Worklist l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> (ClassId, ENode l) -> (ClassId, ENode l)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
egr') ((ClassId, ENode l) -> (ClassId, ENode l))
-> Worklist l -> Worklist l
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl
egr'' :: EGraph a l
egr'' = ((ClassId, ENode l) -> EGraph a l -> EGraph a l)
-> EGraph a l -> Worklist l -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId, ENode l) -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal EGraph a l
egr' Worklist l
awl'
in
if Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
worklist EGraph a l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr'')
then EGraph a l
egr''
else EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
rebuild EGraph a l
egr''
{-# INLINEABLE rebuild #-}
repair :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repair :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repair (ClassId
repair_id, ENode l
node) EGraph a l
egr =
case ENode l
-> ClassId
-> NodeMap l ClassId
-> (Maybe ClassId, NodeMap l ClassId)
forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr) of
(Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph a l
egr { memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo' }
(Just ClassId
existing_class, NodeMap l ClassId
memo') -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd (ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
existing_class ClassId
repair_id EGraph a l
egr{memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo'})
{-# INLINE repair #-}
repairAnal :: forall a l. (Analysis a l, Language l) => (ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
(ClassId, ENode l) -> EGraph a l -> EGraph a l
repairAnal (ClassId
repair_id, ENode l
node) EGraph a l
egr =
let
c1 :: EClass a l
c1 = EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
repair_id
new_data :: a
new_data = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l (EClass a l
c1EClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
Lens' (EClass a l) a
_data) (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((\ClassId
i -> EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
iEClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
_data @a) (ClassId -> a) -> l ClassId -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
node))
(EClass a l
c2, [Fix l]
added_nodes) = EClass a l -> (EClass a l, [Fix l])
forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA (EClass a l
c1 EClass a l -> (EClass a l -> EClass a l) -> EClass a l
forall a b. a -> (a -> b) -> b
& (a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
Lens' (EClass a l) a
_data Lens' (EClass a l) a -> a -> EClass a l -> EClass a l
forall s a. Lens' s a -> a -> s -> s
.~ a
new_data)
in
if EClass a l
c1EClass a l -> Lens' (EClass a l) a -> a
forall s a. s -> Lens' s a -> a
^.(a -> f a) -> EClass a l -> f (EClass a l)
forall domain (l :: * -> *) (f :: * -> *).
Functor f =>
(domain -> f domain) -> EClass domain l -> f (EClass domain l)
Lens' (EClass a l) a
_data a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
new_data
then
let
new_classes :: IntMap (EClass a l)
new_classes = ClassId -> EClass a l -> IntMap (EClass a l) -> IntMap (EClass a l)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
repair_id EClass a l
c2 (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
egr1 :: EGraph a l
egr1 = EGraph a l
egr { analysisWorklist :: Worklist l
analysisWorklist = SList (ClassId, ENode l) -> Worklist l
forall a. SList a -> [a]
toListSL (EClass a l
c1EClass a l
-> Lens' (EClass a l) (SList (ClassId, ENode l))
-> SList (ClassId, ENode l)
forall s a. s -> Lens' s a -> a
^.(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
Lens' (EClass a l) (SList (ClassId, ENode l))
_parents) Worklist l -> Worklist l -> Worklist l
forall a. Semigroup a => a -> a -> a
<> EGraph a l -> Worklist l
forall analysis (language :: * -> *).
EGraph analysis language -> Worklist language
analysisWorklist EGraph a l
egr
, classes :: IntMap (EClass a l)
classes = IntMap (EClass a l)
new_classes
}
egr2 :: EGraph a l
egr2 = (Fix l -> EGraph a l -> EGraph a l)
-> EGraph a l -> [Fix l] -> EGraph a l
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassId -> Fix l -> EGraph a l -> EGraph a l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
repair_id) EGraph a l
egr1 [Fix l]
added_nodes
in EGraph a l
egr2
else EGraph a l
egr
{-# INLINE repairAnal #-}
canonicalize :: Functor l => ENode l -> EGraph a l -> ENode l
canonicalize :: forall (l :: * -> *) a.
Functor l =>
ENode l -> EGraph a l -> ENode l
canonicalize (Node l ClassId
enode) EGraph a l
eg = l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node (l ClassId -> ENode l) -> l ClassId -> ENode l
forall a b. (a -> b) -> a -> b
$ (ClassId -> ClassId) -> l ClassId -> l ClassId
forall a b. (a -> b) -> l a -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassId -> EGraph a l -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
`find` EGraph a l
eg) l ClassId
enode
{-# INLINE canonicalize #-}
find :: ClassId -> EGraph a l -> ClassId
find :: forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find ClassId
cid = ClassId -> ReprUnionFind -> ClassId
findRepr ClassId
cid (ReprUnionFind -> ClassId)
-> (EGraph a l -> ReprUnionFind) -> EGraph a l -> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind
{-# INLINE find #-}
emptyEGraph :: Language l => EGraph a l
emptyEGraph :: forall (l :: * -> *) a. Language l => EGraph a l
emptyEGraph = ReprUnionFind
-> ClassIdMap (EClass a l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph a l
forall analysis (language :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass analysis language)
-> Memo language
-> Worklist language
-> Worklist language
-> EGraph analysis language
EGraph ReprUnionFind
emptyUF ClassIdMap (EClass a l)
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}
representAndMerge :: (Analysis a l, Language l) => ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge :: forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> Fix l -> EGraph a l -> EGraph a l
representAndMerge ClassId
o Fix l
f EGraph a l
g = case Fix l -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
Fix l -> EGraph a l -> (ClassId, EGraph a l)
represent Fix l
f EGraph a l
g of
(ClassId
i, EGraph a l
e) -> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a, b) -> b
snd ((ClassId, EGraph a l) -> EGraph a l)
-> (ClassId, EGraph a l) -> EGraph a l
forall a b. (a -> b) -> a -> b
$ ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
merge ClassId
o ClassId
i EGraph a l
e