{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Equality.Graph
(
EGraph
, emptyEGraph
, add, merge, rebuild
, find, canonicalize
, module Data.Equality.Graph.Classes
, module Data.Equality.Graph.Nodes
, module Data.Equality.Language
) where
import Data.Function
import Data.Bifunctor
import Data.Containers.ListUtils
import qualified Data.IntMap.Strict as IM
import qualified Data.Set as S
import Data.Equality.Utils.SizedList
import Data.Equality.Graph.Internal
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Graph.Lens
add :: forall l. Language l => ENode l -> EGraph l -> (ClassId, EGraph l)
add :: forall (l :: * -> *).
Language l =>
ENode l -> EGraph l -> (ClassId, EGraph l)
add ENode l
uncanon_e EGraph l
egr =
let !new_en :: ENode l
new_en = forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize ENode l
uncanon_e EGraph l
egr
in case forall (l :: * -> *) a. Ord1 l => ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of
Just ClassId
canon_enode_id -> (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
canon_enode_id EGraph l
egr, EGraph l
egr)
Maybe ClassId
Nothing ->
let
(ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr)
new_eclass :: EClass l
new_eclass = forall (l :: * -> *).
ClassId
-> Set (ENode l)
-> Domain l
-> SList (ClassId, ENode l)
-> EClass l
EClass ClassId
new_eclass_id (forall a. a -> Set a
S.singleton ENode l
new_en) (forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
new_en EGraph l
egr) forall a. Monoid a => a
mempty
new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents = ((ClassId
new_eclass_id, ENode l
new_en) forall a. a -> SList a -> SList a
|:)
new_classes :: IntMap (EClass l)
new_classes = forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass l
new_eclass forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents forall s a. Lens' s a -> (a -> a) -> s -> s
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents)))
(forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr)
(forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)
new_worklist :: [(ClassId, ENode l)]
new_worklist = (ClassId
new_eclass_id, ENode l
new_en)forall a. a -> [a] -> [a]
:(forall (l :: * -> *). EGraph l -> Worklist l
worklist EGraph l
egr)
new_memo :: NodeMap l ClassId
new_memo = forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr)
in ( ClassId
new_eclass_id
, EGraph l
egr { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
, classes :: IntMap (EClass l)
classes = IntMap (EClass l)
new_classes
, worklist :: [(ClassId, ENode l)]
worklist = [(ClassId, ENode l)]
new_worklist
, memo :: NodeMap l ClassId
memo = NodeMap l ClassId
new_memo
}
forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Analysis l => ClassId -> EGraph l -> EGraph l
modifyA ClassId
new_eclass_id
)
{-# INLINABLE add #-}
merge :: forall l. Language l => ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge :: forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge ClassId
a ClassId
b EGraph l
egr0 =
let
a' :: ClassId
a' = forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
a EGraph l
egr0
b' :: ClassId
b' = forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
b EGraph l
egr0
in
if ClassId
a' forall a. Eq a => a -> a -> Bool
== ClassId
b'
then (ClassId
a', EGraph l
egr0)
else
let
class_a :: EClass l
class_a = EGraph l
egr0 forall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). ClassId -> Lens' (EGraph l) (EClass l)
_class ClassId
a'
class_b :: EClass l
class_b = EGraph l
egr0 forall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). ClassId -> Lens' (EGraph l) (EClass l)
_class ClassId
b'
(ClassId
leader, EClass l
leader_class, ClassId
sub, EClass l
sub_class) =
if forall a. SList a -> ClassId
sizeSL (EClass l
class_aforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents) forall a. Ord a => a -> a -> Bool
< forall a. SList a -> ClassId
sizeSL (EClass l
class_bforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents)
then (ClassId
b', EClass l
class_b, ClassId
a', EClass l
class_a)
else (ClassId
a', EClass l
class_a, ClassId
b', EClass l
class_b)
(ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr0)
updatedLeader :: EClass l
updatedLeader = EClass l
leader_class forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents forall a. Semigroup a => a -> a -> a
<>)
forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EClass l) (Set (ENode l))
_nodes forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Set (ENode l))
_nodes forall a. Semigroup a => a -> a -> a
<>)
forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data
new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
leader_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data) (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data)
new_classes :: IntMap (EClass l)
new_classes = ((forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass l
updatedLeader) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub)) (forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr0)
new_worklist :: [(ClassId, ENode l)]
new_worklist = forall a. SList a -> [a]
toListSL (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents) forall a. Semigroup a => a -> a -> a
<> (forall (l :: * -> *). EGraph l -> Worklist l
worklist EGraph l
egr0)
new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
(if Domain l
new_data forall a. Eq a => a -> a -> Bool
/= (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data)
then forall a. SList a -> [a]
toListSL (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents)
else forall a. Monoid a => a
mempty) forall a. Semigroup a => a -> a -> a
<>
(if Domain l
new_data forall a. Eq a => a -> a -> Bool
/= (EClass l
leader_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data)
then forall a. SList a -> [a]
toListSL (EClass l
leader_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents)
else forall a. Monoid a => a
mempty) forall a. Semigroup a => a -> a -> a
<>
(forall (l :: * -> *). EGraph l -> Worklist l
analysisWorklist EGraph l
egr0)
new_egr :: EGraph l
new_egr = EGraph l
egr0
{ unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
, classes :: IntMap (EClass l)
classes = IntMap (EClass l)
new_classes
, worklist :: [(ClassId, ENode l)]
worklist = [(ClassId, ENode l)]
new_worklist
, analysisWorklist :: [(ClassId, ENode l)]
analysisWorklist = [(ClassId, ENode l)]
new_analysis_worklist
}
forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Analysis l => ClassId -> EGraph l -> EGraph l
modifyA ClassId
new_id
in (ClassId
new_id, EGraph l
new_egr)
{-# INLINEABLE merge #-}
rebuild :: Language l => EGraph l -> EGraph l
rebuild :: forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Memo l
mm Worklist l
wl Worklist l
awl) =
let
emptiedEgr :: EGraph l
emptiedEgr = (forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph l
EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Memo l
mm forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty)
wl' :: Worklist l
wl' = forall a. Ord a => [a] -> [a]
nubOrd forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
`find` EGraph l
emptiedEgr) (forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
`canonicalize` EGraph l
emptiedEgr) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl
egr' :: EGraph l
egr' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repair EGraph l
emptiedEgr Worklist l
wl'
awl' :: Worklist l
awl' = forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
`find` EGraph l
egr') forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl
egr'' :: EGraph l
egr'' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repairAnal EGraph l
egr' Worklist l
awl'
in
if forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall (l :: * -> *). EGraph l -> Worklist l
worklist EGraph l
egr'') Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall (l :: * -> *). EGraph l -> Worklist l
analysisWorklist EGraph l
egr'')
then EGraph l
egr''
else forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild EGraph l
egr''
{-# INLINEABLE rebuild #-}
repair :: forall l. Language l => (ClassId, ENode l) -> EGraph l -> EGraph l
repair :: forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repair (ClassId
repair_id, ENode l
node) EGraph l
egr =
case forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of
(Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph l
egr { memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo' }
(Just ClassId
existing_class, NodeMap l ClassId
memo') -> forall a b. (a, b) -> b
snd (forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge ClassId
existing_class ClassId
repair_id EGraph l
egr{memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo'})
{-# INLINE repair #-}
repairAnal :: forall l. Language l => (ClassId, ENode l) -> EGraph l -> EGraph l
repairAnal :: forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repairAnal (ClassId
repair_id, ENode l
node) EGraph l
egr =
let
c :: EClass l
c = (EGraph l
egrforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (ClassIdMap (EClass l))
_classes) forall a. IntMap a -> ClassId -> a
IM.! ClassId
repair_id
new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
cforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data) (forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
node EGraph l
egr)
in
if EClass l
cforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data forall a. Eq a => a -> a -> Bool
/= Domain l
new_data
then EGraph l
egr { analysisWorklist :: Worklist l
analysisWorklist = forall a. SList a -> [a]
toListSL (EClass l
cforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents) forall a. Semigroup a => a -> a -> a
<> forall (l :: * -> *). EGraph l -> Worklist l
analysisWorklist EGraph l
egr
}
forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EGraph l) (ClassIdMap (EClass l))
_classes forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust (forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data) ClassId
repair_id)
forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Analysis l => ClassId -> EGraph l -> EGraph l
modifyA ClassId
repair_id
else EGraph l
egr
{-# INLINE repairAnal #-}
canonicalize :: Functor l => ENode l -> EGraph l -> ENode l
canonicalize :: forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize (Node l ClassId
enode) EGraph l
eg = forall (l :: * -> *). l ClassId -> ENode l
Node forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
`find` EGraph l
eg) l ClassId
enode
{-# INLINE canonicalize #-}
find :: ClassId -> EGraph l -> ClassId
find :: forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
cid = ClassId -> ReprUnionFind -> ClassId
findRepr ClassId
cid forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind
{-# INLINE find #-}
emptyEGraph :: Language l => EGraph l
emptyEGraph :: forall (l :: * -> *). Language l => EGraph l
emptyEGraph = forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph l
EGraph ReprUnionFind
emptyUF forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}