{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TupleSections #-}
-- {-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-|
   An e-graph efficiently represents a congruence relation over many expressions.

   Based on \"egg: Fast and Extensible Equality Saturation\" https://arxiv.org/abs/2004.03082.
 -}
module Data.Equality.Graph
    (
      -- * Definition of e-graph
      EGraph

      -- * Functions on e-graphs
    , emptyEGraph

      -- ** Transformations
    , add, merge, rebuild
    -- , repair, repairAnal

      -- ** Querying
    , find, canonicalize

      -- * Re-exports
    , module Data.Equality.Graph.Classes
    , module Data.Equality.Graph.Nodes
    , module Data.Equality.Language
    ) where

-- import GHC.Conc

import Data.Function
import Data.Bifunctor
import Data.Containers.ListUtils

import qualified Data.IntMap.Strict as IM
import qualified Data.Set    as S

import Data.Equality.Utils.SizedList

import Data.Equality.Graph.Internal
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Graph.Lens

-- ROMES:TODO: join things built in paralell?
-- instance Ord1 l => Semigroup (EGraph l) where
--     (<>) eg1 eg2 = undefined -- not so easy
-- instance Ord1 l => Monoid (EGraph l) where
--     mempty = EGraph emptyUF mempty mempty mempty


-- | Add an e-node to the e-graph
--
-- If the e-node is already represented in this e-graph, the class-id of the
-- class it's already represented in will be returned.
add :: forall l. Language l => ENode l -> EGraph l -> (ClassId, EGraph l)
add :: forall (l :: * -> *).
Language l =>
ENode l -> EGraph l -> (ClassId, EGraph l)
add ENode l
uncanon_e EGraph l
egr =
    let !new_en :: ENode l
new_en = forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize ENode l
uncanon_e EGraph l
egr

     in case forall (l :: * -> *) a. Ord1 l => ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of
      Just ClassId
canon_enode_id -> (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
canon_enode_id EGraph l
egr, EGraph l
egr)
      Maybe ClassId
Nothing ->

        let

            -- Make new equivalence class with a new id in the union-find
            (ClassId
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (ClassId, ReprUnionFind)
makeNewSet (forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr)

            -- New singleton e-class stores the e-node and its analysis data
            new_eclass :: EClass l
new_eclass       = forall (l :: * -> *).
ClassId
-> Set (ENode l)
-> Domain l
-> SList (ClassId, ENode l)
-> EClass l
EClass ClassId
new_eclass_id (forall a. a -> Set a
S.singleton ENode l
new_en) (forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
new_en EGraph l
egr) forall a. Monoid a => a
mempty

            -- TODO:Performance: All updates can be done to the map first? Parallelize?
            --
            -- Update e-classes by going through all e-node children and adding
            -- to the e-class parents the new e-node and its e-class id
            --
            -- And add new e-class to existing e-classes
            new_parents :: SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents      = ((ClassId
new_eclass_id, ENode l
new_en) forall a. a -> SList a -> SList a
|:)
            new_classes :: IntMap (EClass l)
new_classes      = forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
new_eclass_id EClass l
new_eclass forall a b. (a -> b) -> a -> b
$
                                    forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr  (forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust ((forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents forall s a. Lens' s a -> (a -> a) -> s -> s
%~ SList (ClassId, ENode l) -> SList (ClassId, ENode l)
new_parents)))
                                           (forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr)
                                           (forall (l :: * -> *). ENode l -> l ClassId
unNode ENode l
new_en)

            -- TODO: From egg: Is this needed?
            -- This is required if we want math pruning to work. Unfortunately, it
            -- also makes the invariants tests x4 slower (because they aren't using
            -- analysis) I think there might be another way to ensure math analysis
            -- pruning to work without having this line here.  Comment it out to
            -- check the result on the unit tests.
            -- 
            -- Update: I found a fix for that case: the modifyA function must add
            -- the parents of the pruned class to the worklist for them to be
            -- upward merged. I think it's a good compromise for requiring the user
            -- to do this. Adding the added node to the worklist everytime creates
            -- too much unnecessary work.
            --
            -- Actually I've found more bugs regarding this, and can't fix them
            -- there, so indeed this seems to be necessary for sanity with 'modifyA'
            --
            -- This way we also liberate the user from caring about the worklist
            --
            -- The hash cons invariants test suffer from this greatly but the
            -- saturation tests seem mostly fine?
            --
            -- And adding to the analysis worklist doesn't work, so maybe it's
            -- something else?
            --
            -- So in the end, we do need to addToWorklist to get correct results
            new_worklist :: [(ClassId, ENode l)]
new_worklist     = (ClassId
new_eclass_id, ENode l
new_en)forall a. a -> [a] -> [a]
:(forall (l :: * -> *). EGraph l -> Worklist l
worklist EGraph l
egr)

            -- Add the e-node's e-class id at the e-node's id
            new_memo :: NodeMap l ClassId
new_memo         = forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en ClassId
new_eclass_id (forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr)

         in ( ClassId
new_eclass_id

            , EGraph l
egr { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
                  , classes :: IntMap (EClass l)
classes   = IntMap (EClass l)
new_classes
                  , worklist :: [(ClassId, ENode l)]
worklist  = [(ClassId, ENode l)]
new_worklist
                  , memo :: NodeMap l ClassId
memo      = NodeMap l ClassId
new_memo
                  }

                  -- Modify created node according to analysis
                  forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Analysis l => ClassId -> EGraph l -> EGraph l
modifyA ClassId
new_eclass_id

            )
{-# INLINABLE add #-}

-- | Merge 2 e-classes by id
merge :: forall l. Language l => ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge :: forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge ClassId
a ClassId
b EGraph l
egr0 =

  -- Use canonical ids
  let
      a' :: ClassId
a' = forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
a EGraph l
egr0
      b' :: ClassId
b' = forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
b EGraph l
egr0
   in
   if ClassId
a' forall a. Eq a => a -> a -> Bool
== ClassId
b'
     then (ClassId
a', EGraph l
egr0)
     else
       let
           -- Get classes being merged
           class_a :: EClass l
class_a = EGraph l
egr0 forall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). ClassId -> Lens' (EGraph l) (EClass l)
_class ClassId
a'
           class_b :: EClass l
class_b = EGraph l
egr0 forall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). ClassId -> Lens' (EGraph l) (EClass l)
_class ClassId
b'

           -- Leader is the class with more parents
           (ClassId
leader, EClass l
leader_class, ClassId
sub, EClass l
sub_class) =
               if forall a. SList a -> ClassId
sizeSL (EClass l
class_aforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents) forall a. Ord a => a -> a -> Bool
< forall a. SList a -> ClassId
sizeSL (EClass l
class_bforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents)
                  then (ClassId
b', EClass l
class_b, ClassId
a', EClass l
class_a) -- b is leader
                  else (ClassId
a', EClass l
class_a, ClassId
b', EClass l
class_b) -- a is leader

           -- Make leader the leader in the union find
           (ClassId
new_id, ReprUnionFind
new_uf) = ClassId -> ClassId -> ReprUnionFind -> (ClassId, ReprUnionFind)
unionSets ClassId
leader ClassId
sub (forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr0)

           -- Update leader class with all e-nodes and parents from the
           -- subsumed class
           updatedLeader :: EClass l
updatedLeader = EClass l
leader_class forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents forall a. Semigroup a => a -> a -> a
<>)
                                        forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EClass l) (Set (ENode l))
_nodes   forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Set (ENode l))
_nodes forall a. Semigroup a => a -> a -> a
<>)
                                        forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data    forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data
           new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
leader_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data) (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data)

           -- Update leader in classes so that it has all nodes and parents
           -- from subsumed class, and delete the subsumed class
           new_classes :: IntMap (EClass l)
new_classes = ((forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
leader EClass l
updatedLeader) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. ClassId -> IntMap a -> IntMap a
IM.delete ClassId
sub)) (forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr0)

           -- Add all subsumed parents to worklist We can do this instead of
           -- adding the new e-class itself to the worklist because it would end
           -- up adding its parents anyway
           new_worklist :: [(ClassId, ENode l)]
new_worklist = forall a. SList a -> [a]
toListSL (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents) forall a. Semigroup a => a -> a -> a
<> (forall (l :: * -> *). EGraph l -> Worklist l
worklist EGraph l
egr0)

           -- If the new_data is different from the classes, the parents of the
           -- class whose data is different from the merged must be put on the
           -- analysisWorklist
           new_analysis_worklist :: [(ClassId, ENode l)]
new_analysis_worklist =
             (if Domain l
new_data forall a. Eq a => a -> a -> Bool
/= (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data)
                then forall a. SList a -> [a]
toListSL (EClass l
sub_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents)
                else forall a. Monoid a => a
mempty) forall a. Semigroup a => a -> a -> a
<>
             (if Domain l
new_data forall a. Eq a => a -> a -> Bool
/= (EClass l
leader_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data)
                then forall a. SList a -> [a]
toListSL (EClass l
leader_classforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents)
                else forall a. Monoid a => a
mempty) forall a. Semigroup a => a -> a -> a
<>
             (forall (l :: * -> *). EGraph l -> Worklist l
analysisWorklist EGraph l
egr0)

           -- ROMES:TODO: The code that makes the -1 * cos test pass when some other things are tweaked
           -- new_memo = foldr (`insertNM` leader) (memo egr0) (sub_class^._nodes)

           -- Build new e-graph
           new_egr :: EGraph l
new_egr = EGraph l
egr0
             { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
             , classes :: IntMap (EClass l)
classes   = IntMap (EClass l)
new_classes
             -- , memo      = new_memo
             , worklist :: [(ClassId, ENode l)]
worklist  = [(ClassId, ENode l)]
new_worklist
             , analysisWorklist :: [(ClassId, ENode l)]
analysisWorklist = [(ClassId, ENode l)]
new_analysis_worklist
             }

             -- Modify according to analysis
             forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Analysis l => ClassId -> EGraph l -> EGraph l
modifyA ClassId
new_id

        in (ClassId
new_id, EGraph l
new_egr)
{-# INLINEABLE merge #-}
            

-- | The rebuild operation processes the e-graph's current worklist,
-- restoring the invariants of deduplication and congruence. Rebuilding is
-- similar to other approaches in how it restores congruence; but it uniquely
-- allows the client to choose when to restore invariants in the context of a
-- larger algorithm like equality saturation.
rebuild :: Language l => EGraph l -> EGraph l
rebuild :: forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Memo l
mm Worklist l
wl Worklist l
awl) =
  -- empty worklists
  -- repair deduplicated e-classes
  let
    emptiedEgr :: EGraph l
emptiedEgr = (forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph l
EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Memo l
mm forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty)

    wl' :: Worklist l
wl'   = forall a. Ord a => [a] -> [a]
nubOrd forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
`find` EGraph l
emptiedEgr) (forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
`canonicalize` EGraph l
emptiedEgr) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
wl
    egr' :: EGraph l
egr'  = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repair EGraph l
emptiedEgr Worklist l
wl'

    awl' :: Worklist l
awl'  = forall a. (a -> ClassId) -> [a] -> [a]
nubIntOn forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
`find` EGraph l
egr') forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Worklist l
awl
    egr'' :: EGraph l
egr'' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repairAnal EGraph l
egr' Worklist l
awl'
  in
  -- Loop until worklist is completely empty
  if forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall (l :: * -> *). EGraph l -> Worklist l
worklist EGraph l
egr'') Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall (l :: * -> *). EGraph l -> Worklist l
analysisWorklist EGraph l
egr'')
     then EGraph l
egr''
     else forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild EGraph l
egr'' -- ROMES:TODO: Doesn't seem to be needed at all in the testsuite.
{-# INLINEABLE rebuild #-}

-- ROMES:TODO: find repair_id could be shared between repair and repairAnal?

-- | Repair a single worklist entry.
repair :: forall l. Language l => (ClassId, ENode l) -> EGraph l -> EGraph l
repair :: forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repair (ClassId
repair_id, ENode l
node) EGraph l
egr =

   -- TODO We're no longer deleting the uncanonicalized node, how much does it matter that the structure keeps growing?

   case forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
node ClassId
repair_id (forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of

      (Maybe ClassId
Nothing, NodeMap l ClassId
memo') -> EGraph l
egr { memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo' } -- new memo with inserted node

      (Just ClassId
existing_class, NodeMap l ClassId
memo') -> forall a b. (a, b) -> b
snd (forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge ClassId
existing_class ClassId
repair_id EGraph l
egr{memo :: NodeMap l ClassId
memo = NodeMap l ClassId
memo'})
{-# INLINE repair #-}

-- | Repair a single analysis-worklist entry.
repairAnal :: forall l. Language l => (ClassId, ENode l) -> EGraph l -> EGraph l
repairAnal :: forall (l :: * -> *).
Language l =>
(ClassId, ENode l) -> EGraph l -> EGraph l
repairAnal (ClassId
repair_id, ENode l
node) EGraph l
egr =
    let
        c :: EClass l
c        = (EGraph l
egrforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (ClassIdMap (EClass l))
_classes) forall a. IntMap a -> ClassId -> a
IM.! ClassId
repair_id
        new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
cforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data) (forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
node EGraph l
egr)
    in
    -- Take action if the new_data is different from the existing data
    if EClass l
cforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data forall a. Eq a => a -> a -> Bool
/= Domain l
new_data
        -- Merge result is different from original class data, update class
        -- with new_data
       then EGraph l
egr { analysisWorklist :: Worklist l
analysisWorklist = forall a. SList a -> [a]
toListSL (EClass l
cforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EClass l) (SList (ClassId, ENode l))
_parents) forall a. Semigroup a => a -> a -> a
<> forall (l :: * -> *). EGraph l -> Worklist l
analysisWorklist EGraph l
egr
                }
                forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Lens' (EGraph l) (ClassIdMap (EClass l))
_classes forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (forall a. (a -> a) -> ClassId -> IntMap a -> IntMap a
IM.adjust (forall (l :: * -> *). Lens' (EClass l) (Domain l)
_data forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data) ClassId
repair_id)
                forall a b. a -> (a -> b) -> b
& forall (l :: * -> *). Analysis l => ClassId -> EGraph l -> EGraph l
modifyA ClassId
repair_id
       else EGraph l
egr
{-# INLINE repairAnal #-}

-- | Canonicalize an e-node
--
-- Two e-nodes are equal when their canonical form is equal. Canonicalization
-- makes the list of e-class ids the e-node holds a list of canonical ids.
-- Meaning two seemingly different e-nodes might be equal when we figure out
-- that their e-class ids are represented by the same e-class canonical ids
--
-- canonicalize(𝑓(𝑎,𝑏,𝑐,...)) = 𝑓((find 𝑎), (find 𝑏), (find 𝑐),...)
canonicalize :: Functor l => ENode l -> EGraph l -> ENode l
canonicalize :: forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize (Node l ClassId
enode) EGraph l
eg = forall (l :: * -> *). l ClassId -> ENode l
Node forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (l :: * -> *). ClassId -> EGraph l -> ClassId
`find` EGraph l
eg) l ClassId
enode
{-# INLINE canonicalize #-}

-- | Find the canonical representation of an e-class id in the e-graph
-- Invariant: The e-class id always exists.
find :: ClassId -> EGraph l -> ClassId
find :: forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find ClassId
cid = ClassId -> ReprUnionFind -> ClassId
findRepr ClassId
cid forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind
{-# INLINE find #-}

-- | The empty e-graph. Nothing is represented in it yet.
emptyEGraph :: Language l => EGraph l
emptyEGraph :: forall (l :: * -> *). Language l => EGraph l
emptyEGraph = forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l)
-> Memo l
-> Worklist l
-> Worklist l
-> EGraph l
EGraph ReprUnionFind
emptyUF forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}