{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TemplateHaskell #-}

-- | An example using arithmetic expressions.
module Overeasy.Example
  ( Arith (..)
  , ArithF (..)
  , exampleGraph
  , examplePat
  , exampleMain
  ) where

import Control.DeepSeq (NFData)
import Control.Monad.State.Strict (execState)
import Data.Functor.Foldable.TH (makeBaseFunctor)
import Data.Hashable (Hashable)
import GHC.Generics (Generic)
import Overeasy.EGraph (EClassId (..), EGraph, egAddTerm, egMerge, egNew, noAnalysis)
import Overeasy.Matching (Pat, match)
import Unfree (pattern FreeEmbed, pattern FreePure)

-- | Arithmetic expressions.
-- 'ArithF' is the base functor for this type.
data Arith =
    ArithPlus Arith Arith
  | ArithTimes Arith Arith
  | ArithShiftL Arith !Int
  | ArithShiftR Arith !Int
  | ArithConst !Int
  deriving stock (Arith -> Arith -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Arith -> Arith -> Bool
$c/= :: Arith -> Arith -> Bool
== :: Arith -> Arith -> Bool
$c== :: Arith -> Arith -> Bool
Eq, Int -> Arith -> ShowS
[Arith] -> ShowS
Arith -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Arith] -> ShowS
$cshowList :: [Arith] -> ShowS
show :: Arith -> String
$cshow :: Arith -> String
showsPrec :: Int -> Arith -> ShowS
$cshowsPrec :: Int -> Arith -> ShowS
Show, forall x. Rep Arith x -> Arith
forall x. Arith -> Rep Arith x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Arith x -> Arith
$cfrom :: forall x. Arith -> Rep Arith x
Generic)
  deriving anyclass (Eq Arith
Int -> Arith -> Int
Arith -> Int
forall a. Eq a -> (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: Arith -> Int
$chash :: Arith -> Int
hashWithSalt :: Int -> Arith -> Int
$chashWithSalt :: Int -> Arith -> Int
Hashable, Arith -> ()
forall a. (a -> ()) -> NFData a
rnf :: Arith -> ()
$crnf :: Arith -> ()
NFData)

-- Generates 'ArithF' and other recursion-schemes boilerplate
makeBaseFunctor ''Arith

deriving stock instance Eq a => Eq (ArithF a)
deriving stock instance Show a => Show (ArithF a)
deriving stock instance Generic (ArithF a)
deriving anyclass instance Hashable a => Hashable (ArithF a)
deriving anyclass instance NFData a => NFData (ArithF a)

-- | Creates a simple e-graph with the equality `2 + 2 = 4`.
exampleGraph :: EGraph () ArithF
exampleGraph :: EGraph () ArithF
exampleGraph = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState forall d (f :: * -> *). EGraph d f
egNew forall a b. (a -> b) -> a -> b
$ do
  -- We don't need to analyze here.
  let ana :: EAnalysis () f
ana = forall (f :: * -> *). EAnalysis () f
noAnalysis
  -- Some simple terms:
  let termFour :: Arith
termFour = Int -> Arith
ArithConst Int
4
      termTwo :: Arith
termTwo = Int -> Arith
ArithConst Int
2
      termPlus :: Arith
termPlus = Arith -> Arith -> Arith
ArithPlus Arith
termTwo Arith
termTwo
  -- Add the term `4`
  (Changed
_, EClassId
cidFour) <- forall t (f :: * -> *) d.
(RecursiveWhole t f, Traversable f, Eq (f EClassId),
 Hashable (f EClassId), Hashable (f ())) =>
EAnalysis d f -> t -> State (EGraph d f) (Changed, EClassId)
egAddTerm forall (f :: * -> *). EAnalysis () f
ana Arith
termFour
  -- Add the term `2`
  (Changed
_, EClassId
_cidTwo) <- forall t (f :: * -> *) d.
(RecursiveWhole t f, Traversable f, Eq (f EClassId),
 Hashable (f EClassId), Hashable (f ())) =>
EAnalysis d f -> t -> State (EGraph d f) (Changed, EClassId)
egAddTerm forall (f :: * -> *). EAnalysis () f
ana Arith
termTwo
  -- Add the term `2 + 2`
  (Changed
_, EClassId
cidPlus) <- forall t (f :: * -> *) d.
(RecursiveWhole t f, Traversable f, Eq (f EClassId),
 Hashable (f EClassId), Hashable (f ())) =>
EAnalysis d f -> t -> State (EGraph d f) (Changed, EClassId)
egAddTerm forall (f :: * -> *). EAnalysis () f
ana Arith
termPlus
  -- Merge `4` and `2 + 2`
  MergeResult EClassId
_ <- forall d (f :: * -> *).
(Semigroup d, Traversable f, Eq (f EClassId),
 Hashable (f EClassId), Eq (f ()), Hashable (f ())) =>
EClassId -> EClassId -> State (EGraph d f) (MergeResult EClassId)
egMerge EClassId
cidFour EClassId
cidPlus
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Creates a simple pattern to match nodes like `x + x`.
examplePat :: Pat ArithF String
examplePat :: Pat ArithF String
examplePat = forall (f :: * -> *) a. f (Free f a) -> Free f a
FreeEmbed (forall r. r -> r -> ArithF r
ArithPlusF (forall a (f :: * -> *). a -> Free f a
FreePure String
"x") (forall a (f :: * -> *). a -> Free f a
FreePure String
"x"))

-- | Build an e-graph, e-match on it, and print the result.
exampleMain :: IO ()
exampleMain :: IO ()
exampleMain = do
  let eg :: EGraph () ArithF
eg = EGraph () ArithF
exampleGraph
  String -> IO ()
putStrLn String
"e-graph:"
  forall a. Show a => a -> IO ()
print EGraph () ArithF
eg
  let pat :: Pat ArithF String
pat = Pat ArithF String
examplePat
  String -> IO ()
putStrLn String
"pattern:"
  forall a. Show a => a -> IO ()
print Pat ArithF String
pat
  let results :: [MatchSubst EClassId ArithF String]
results = forall (f :: * -> *) v d.
(PatGraphC f v, SolGraphC f, SolveC EClassId f v) =>
Pat f v -> EGraph d f -> [MatchSubst EClassId f v]
match Pat ArithF String
pat EGraph () ArithF
eg
  String -> IO ()
putStrLn String
"e-matches:"
  forall a. Show a => a -> IO ()
print [MatchSubst EClassId ArithF String]
results