{-# LANGUAGE AllowAmbiguousTypes #-} -- joinA
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-|

E-class analysis, which allows the concise expression of a program analysis over
the e-graph.

An e-class analysis resembles abstract interpretation lifted to the e-graph
level, attaching analysis data from a semilattice to each e-class.

The e-graph maintains and propagates this data as e-classes get merged and new
e-nodes are added.

Analysis data can be used directly to modify the e-graph, to inform how or if
rewrites apply their right-hand sides, or to determine the cost of terms during
the extraction process.

References: https://arxiv.org/pdf/2004.03082.pdf

-}
module Data.Equality.Analysis where

import Data.Kind (Type)
import Control.Arrow ((***))

import Data.Function ((&))
import Data.Equality.Graph.Lens
import Data.Equality.Language
import Data.Equality.Graph.Internal (EGraph)
import Data.Equality.Graph.Classes

-- | An e-class analysis with domain @domain@ defined for a language @l@.
--
-- The @domain@ is the type of the domain of the e-class analysis, that is, the
-- type of the data stored in an e-class according to this e-class analysis
class Eq domain => Analysis domain (l :: Type -> Type) where

    -- | When a new e-node is added into a new, singleton e-class, construct a
    -- new value of the domain to be associated with the new e-class, by
    -- accessing the associated data of the node's children
    --
    -- The argument is the e-node term populated with its children data
    --
    -- === Example
    --
    -- @
    -- -- domain = Maybe Double
    -- makeA :: Expr (Maybe Double) -> Maybe Double
    -- makeA = \case
    --     BinOp Div e1 e2 -> liftA2 (/) e1 e2
    --     BinOp Sub e1 e2 -> liftA2 (-) e1 e2
    --     BinOp Mul e1 e2 -> liftA2 (*) e1 e2
    --     BinOp Add e1 e2 -> liftA2 (+) e1 e2
    --     Const x -> Just x
    --     Sym _ -> Nothing
    -- @
    makeA :: l domain -> domain

    -- | When e-classes c1 c2 are being merged into c, join d_c1 and
    -- d_c2 into a new value d_c to be associated with the new
    -- e-class c
    joinA :: domain -> domain -> domain

    -- | Optionally modify the e-class c (based on d_c), typically by adding an
    -- e-node to c. Modify should be idempotent if no other changes occur to
    -- the e-class, i.e., modify(modify(c)) = modify(c)
    --
    -- === Example
    --
    -- Pruning an e-class with a constant value of all its nodes except for the
    -- leaf values, and adding a constant value node
    --
    -- @
    -- modifyA cl eg0
    --   = case eg0^._class cl._data of
    --       Nothing -> eg0
    --       Just d  ->
    --             -- Add constant as e-node
    --         let (new_c,eg1) = represent (Fix (Const d)) eg0
    --             (rep, eg2)  = merge cl new_c eg1
    --             -- Prune all except leaf e-nodes
    --          in eg2 & _class rep._nodes %~ S.filter (F.null .unNode)
    -- @
    modifyA :: ClassId
            -- ^ Id of class @c@ whose new data @d_c@ triggered the modify call
            -> EGraph domain l
            -- ^ E-graph where class @c@ being modified exists
            -> EGraph domain l
            -- ^ E-graph resulting from the modification
    modifyA ClassId
_ = EGraph domain l -> EGraph domain l
forall a. a -> a
id
    {-# INLINE modifyA #-}


-- | The simplest analysis that defines the domain to be () and does nothing
-- otherwise
instance forall l. Analysis () l where
  makeA :: l () -> ()
makeA l ()
_ = ()
  joinA :: () -> () -> ()
joinA = () -> () -> ()
forall a. Semigroup a => a -> a -> a
(<>)


-- | This instance is not necessarily well behaved for any two analysis, so care
-- must be taken when using it.
--
-- A possible criterion is:
--
-- For any two analysis, where 'modifyA' is called @m1@ and @m2@ respectively,
-- this instance is well behaved if @m1@ and @m2@ commute, and the analysis
-- only change the e-class being modified.
--
-- That is, if @m1@ and @m2@ satisfy the following law:
-- @
-- m1 . m2 = m2 . m1
-- @
--
-- A simple criterion that should suffice for commutativity. If:
--  * The modify function only depends on the analysis value, and
--  * The modify function doesn't change the analysis value
-- Then any two such functions commute.
--
-- Note: there are weaker (or at least different) criteria for this instance to
-- be well behaved.
instance (Language l, Analysis a l, Analysis b l) => Analysis (a, b) l where

  makeA :: l (a, b) -> (a, b)
  makeA :: l (a, b) -> (a, b)
makeA l (a, b)
g = (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((a, b) -> a
forall a b. (a, b) -> a
fst ((a, b) -> a) -> l (a, b) -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> l (a, b)
g), forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @b ((a, b) -> b
forall a b. (a, b) -> b
snd ((a, b) -> b) -> l (a, b) -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> l (a, b)
g))

  joinA :: (a,b) -> (a,b) -> (a,b)
  joinA :: (a, b) -> (a, b) -> (a, b)
joinA (a
x,b
y) = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l a
x (a -> a) -> (b -> b) -> (a, b) -> (a, b)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @b @l b
y

  modifyA :: ClassId -> EGraph (a, b) l -> EGraph (a, b) l
  modifyA :: ClassId -> EGraph (a, b) l -> EGraph (a, b) l
modifyA ClassId
c EGraph (a, b) l
egr =
    let egra :: EGraph a l
egra = forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA @a ClassId
c (EGraph (a, b) l
egr EGraph (a, b) l -> (EGraph (a, b) l -> EGraph a l) -> EGraph a l
forall a b. a -> (a -> b) -> b
& (EClass (a, b) l -> Identity (EClass a l))
-> EGraph (a, b) l -> Identity (EGraph a l)
forall a (l :: * -> *) b (f :: * -> *).
Applicative f =>
(EClass a l -> f (EClass b l)) -> EGraph a l -> f (EGraph b l)
_classes((EClass (a, b) l -> Identity (EClass a l))
 -> EGraph (a, b) l -> Identity (EGraph a l))
-> (((a, b) -> Identity a)
    -> EClass (a, b) l -> Identity (EClass a l))
-> ((a, b) -> Identity a)
-> EGraph (a, b) l
-> Identity (EGraph a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a, b) -> Identity a) -> EClass (a, b) l -> Identity (EClass a l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data (((a, b) -> Identity a)
 -> EGraph (a, b) l -> Identity (EGraph a l))
-> ((a, b) -> a) -> EGraph (a, b) l -> EGraph a l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (a, b) -> a
forall a b. (a, b) -> a
fst)
        egrb :: EGraph b l
egrb = forall domain (l :: * -> *).
Analysis domain l =>
ClassId -> EGraph domain l -> EGraph domain l
modifyA @b ClassId
c (EGraph (a, b) l
egr EGraph (a, b) l -> (EGraph (a, b) l -> EGraph b l) -> EGraph b l
forall a b. a -> (a -> b) -> b
& (EClass (a, b) l -> Identity (EClass b l))
-> EGraph (a, b) l -> Identity (EGraph b l)
forall a (l :: * -> *) b (f :: * -> *).
Applicative f =>
(EClass a l -> f (EClass b l)) -> EGraph a l -> f (EGraph b l)
_classes((EClass (a, b) l -> Identity (EClass b l))
 -> EGraph (a, b) l -> Identity (EGraph b l))
-> (((a, b) -> Identity b)
    -> EClass (a, b) l -> Identity (EClass b l))
-> ((a, b) -> Identity b)
-> EGraph (a, b) l
-> Identity (EGraph b l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a, b) -> Identity b) -> EClass (a, b) l -> Identity (EClass b l)
forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data (((a, b) -> Identity b)
 -> EGraph (a, b) l -> Identity (EGraph b l))
-> ((a, b) -> b) -> EGraph (a, b) l -> EGraph b l
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (a, b) -> b
forall a b. (a, b) -> b
snd)
        ca :: EClass a l
ca = EGraph a l
egra EGraph a l -> Lens' (EGraph a l) (EClass a l) -> EClass a l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph a l) (EClass a l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
c
        cb :: EClass b l
cb = EGraph b l
egrb EGraph b l -> Lens' (EGraph b l) (EClass b l) -> EClass b l
forall s a. s -> Lens' s a -> a
^.ClassId -> Lens' (EGraph b l) (EClass b l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
c
     in
      EGraph (a, b) l
egr EGraph (a, b) l
-> (EGraph (a, b) l -> EGraph (a, b) l) -> EGraph (a, b) l
forall a b. a -> (a -> b) -> b
&
        ClassId
-> forall {f :: * -> *}.
   Functor f =>
   (EClass (a, b) l -> f (EClass (a, b) l))
   -> EGraph (a, b) l -> f (EGraph (a, b) l)
forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
c (forall {f :: * -> *}.
 Functor f =>
 (EClass (a, b) l -> f (EClass (a, b) l))
 -> EGraph (a, b) l -> f (EGraph (a, b) l))
-> EClass (a, b) l -> EGraph (a, b) l -> EGraph (a, b) l
forall s a. Lens' s a -> a -> s -> s
.~ (ClassId
-> Set (ENode l)
-> (a, b)
-> SList (ClassId, ENode l)
-> EClass (a, b) l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
c (EClass a l -> Set (ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassNodes EClass a l
ca Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<> EClass b l -> Set (ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassNodes EClass b l
cb) (EClass a l -> a
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassData EClass a l
ca, EClass b l -> b
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassData EClass b l
cb) (EClass a l -> SList (ClassId, ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> SList (ClassId, ENode language)
eClassParents EClass a l
ca SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<> EClass b l -> SList (ClassId, ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> SList (ClassId, ENode language)
eClassParents EClass b l
cb))