{-# language DeriveTraversable #-}
{-# language StandaloneDeriving #-}
{-# language LambdaCase #-}
{-# language TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}

module Data.SRTree.EqSat ( simplifyEqSat ) where

import Control.Applicative (liftA2)
import Control.Monad (unless)
import Data.AEq ( AEq((~==)) )
import Data.Eq.Deriving ( deriveEq1 )
import Data.Equality.Analysis ( Analysis(..) )
import Data.Equality.Graph ( ClassId, Language, ENode(unNode) )
import Data.Equality.Graph.Lens hiding ((^.))
import Data.Equality.Graph.Lens qualified as L
import Data.Equality.Matching
import Data.Equality.Matching.Database ( Subst )
import Data.Equality.Saturation
import Data.Equality.Saturation.Scheduler ( BackoffScheduler(BackoffScheduler) )
import Data.Foldable qualified as F
import Data.IntMap.Strict qualified as IM
import Data.Maybe (isJust, isNothing)
import Data.Ord.Deriving ( deriveOrd1 )
import Data.SRTree hiding (Fix(..))
import Data.SRTree.Recursion qualified as R
import Data.Set qualified as S
import Text.Show.Deriving ( deriveShow1 )

deriving instance Foldable SRTree
deriving instance Traversable SRTree

deriveEq1 ''SRTree
deriveOrd1 ''SRTree
deriveShow1 ''SRTree

instance Num (Pattern SRTree) where
  Pattern SRTree
l + :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
+ Pattern SRTree
r = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Add Pattern SRTree
l Pattern SRTree
r
  Pattern SRTree
l - :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
- Pattern SRTree
r = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Pattern SRTree
l Pattern SRTree
r
  Pattern SRTree
l * :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
* Pattern SRTree
r = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Pattern SRTree
l Pattern SRTree
r
  abs :: Pattern SRTree -> Pattern SRTree
abs   = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Abs

  negate :: Pattern SRTree -> Pattern SRTree
negate Pattern SRTree
t    = forall a. Num a => Integer -> a
fromInteger (-Integer
1) forall a. Num a => a -> a -> a
* Pattern SRTree
t
  signum :: Pattern SRTree -> Pattern SRTree
signum Pattern SRTree
_    = forall a. HasCallStack => a
undefined
  fromInteger :: Integer -> Pattern SRTree
fromInteger = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger

instance Fractional (Pattern SRTree) where
    / :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
(/) Pattern SRTree
a Pattern SRTree
b      = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Div Pattern SRTree
a Pattern SRTree
b
    fromRational :: Rational -> Pattern SRTree
fromRational = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => Rational -> a
fromRational

instance Floating (Pattern SRTree) where
  pi :: Pattern SRTree
pi      = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const forall a. Floating a => a
pi
  exp :: Pattern SRTree -> Pattern SRTree
exp     = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Exp
  log :: Pattern SRTree -> Pattern SRTree
log     = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Log
  sqrt :: Pattern SRTree -> Pattern SRTree
sqrt    = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sqrt
  sin :: Pattern SRTree -> Pattern SRTree
sin     = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sin
  cos :: Pattern SRTree -> Pattern SRTree
cos     = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cos
  tan :: Pattern SRTree -> Pattern SRTree
tan     = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tan
  asin :: Pattern SRTree -> Pattern SRTree
asin    = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASin
  acos :: Pattern SRTree -> Pattern SRTree
acos    = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACos
  atan :: Pattern SRTree -> Pattern SRTree
atan    = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATan
  sinh :: Pattern SRTree -> Pattern SRTree
sinh    = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sinh
  cosh :: Pattern SRTree -> Pattern SRTree
cosh    = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cosh
  tanh :: Pattern SRTree -> Pattern SRTree
tanh    = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tanh
  asinh :: Pattern SRTree -> Pattern SRTree
asinh   = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASinh
  acosh :: Pattern SRTree -> Pattern SRTree
acosh   = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACosh
  atanh :: Pattern SRTree -> Pattern SRTree
atanh   = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATanh

  Pattern SRTree
l ** :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
** Pattern SRTree
r      = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern (forall val. Op -> val -> val -> SRTree val
Bin Op
Power Pattern SRTree
l Pattern SRTree
r)
  logBase :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
logBase Pattern SRTree
l Pattern SRTree
r = forall a. HasCallStack => a
undefined

instance Analysis (Maybe Double) SRTree where
    -- type Domain SRTreeF = Maybe Double
    makeA :: SRTree (Maybe Double) -> Maybe Double
makeA = SRTree (Maybe Double) -> Maybe Double
evalConstant -- ((\c -> egr L.^._class c._data) <$> e)
    joinA :: Maybe Double -> Maybe Double -> Maybe Double
joinA Maybe Double
ma Maybe Double
mb = do
        Double
a <- Maybe Double
ma
        Double
b <- Maybe Double
mb
        !()
_ <- forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. Num a => a -> a
abs (Double
aforall a. Num a => a -> a -> a
-Double
b) forall a. Ord a => a -> a -> Bool
<= Double
1e-6 Bool -> Bool -> Bool
|| Double
a forall a. AEq a => a -> a -> Bool
~== Double
b Bool -> Bool -> Bool
|| (Double
a forall a. Eq a => a -> a -> Bool
== Double
0 Bool -> Bool -> Bool
&& Double
b forall a. Eq a => a -> a -> Bool
== (-Double
0)) Bool -> Bool -> Bool
|| (Double
a forall a. Eq a => a -> a -> Bool
== (-Double
0) Bool -> Bool -> Bool
&& Double
b forall a. Eq a => a -> a -> Bool
== Double
0)) (forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Merged non-equal constants!" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Double
a forall a. Semigroup a => a -> a -> a
<> String
" " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Double
b forall a. Semigroup a => a -> a -> a
<> String
" " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (Double
aforall a. Eq a => a -> a -> Bool
==Double
b))
        forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
a
    modifyA :: EClass (Maybe Double) SRTree
-> (EClass (Maybe Double) SRTree, [Fix SRTree])
modifyA EClass (Maybe Double) SRTree
cl = case EClass (Maybe Double) SRTree
cl forall s a. s -> Lens' s a -> a
L.^.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data of
                 Maybe Double
Nothing -> (EClass (Maybe Double) SRTree
cl, [])
                 Just Double
d -> ((forall a (l :: * -> *). Lens' (EClass a l) (Set (ENode l))
_nodes forall s a. Lens' s a -> (a -> a) -> s -> s
%~ forall a. (a -> Bool) -> Set a -> Set a
S.filter (forall (t :: * -> *) a. Foldable t => t a -> Bool
F.null forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall (l :: * -> *). ENode l -> l Int
unNode)) EClass (Maybe Double) SRTree
cl, [forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
d)])

evalConstant :: SRTree (Maybe Double) -> Maybe Double
evalConstant :: SRTree (Maybe Double) -> Maybe Double
evalConstant = \case
    -- Exception: Negative exponent: BinOp Pow e1 e2 -> liftA2 (^) e1 (round <$> e2 :: Maybe Integer)
    Bin Op
Div Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Fractional a => a -> a -> a
(/) Maybe Double
e1 Maybe Double
e2
    Bin Op
Sub Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-) Maybe Double
e1 Maybe Double
e2
    Bin Op
Mul Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(*) Maybe Double
e1 Maybe Double
e2
    Bin Op
Add Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(+) Maybe Double
e1 Maybe Double
e2
    Bin Op
Power Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Floating a => a -> a -> a
(**) Maybe Double
e1 Maybe Double
e2
    Uni Function
f Maybe Double
e1 -> forall a. Floating a => Function -> a -> a
evalFun Function
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Double
e1
    Var Int
_ -> forall a. Maybe a
Nothing
    Const Double
x -> forall a. a -> Maybe a
Just Double
x -- TODO: investigate why it cannot handle NaN
    Param Int
_ -> forall a. Maybe a
Nothing

instance Language SRTree

cost :: CostFunction SRTree Int
cost :: CostFunction SRTree Int
cost = \case
  Const Double
_ -> Int
5
  Var Int
_ -> Int
1
  Bin Op
_ Int
c1 Int
c2 -> Int
c1 forall a. Num a => a -> a -> a
+ Int
c2 forall a. Num a => a -> a -> a
+ Int
1
  Uni Function
_ Int
c -> Int
c forall a. Num a => a -> a -> a
+ Int
1
  Param Int
_ -> Int
5

unsafeGetSubst :: Pattern SRTree -> Subst -> ClassId
unsafeGetSubst :: Pattern SRTree -> Subst -> Int
unsafeGetSubst (NonVariablePattern SRTree (Pattern SRTree)
_) Subst
_ = forall a. HasCallStack => String -> a
error String
"unsafeGetSubst: NonVariablePattern; expecting VariablePattern"
unsafeGetSubst (VariablePattern Int
v) Subst
subst = case forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
v Subst
subst of
      Maybe Int
Nothing -> forall a. HasCallStack => String -> a
error String
"Searching for non existent bound var in conditional"
      Just Int
class_id -> Int
class_id

is_not_zero :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
    EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data forall a. Eq a => a -> a -> Bool
/= forall a. a -> Maybe a
Just Double
0

is_not_neg_consts :: Pattern SRTree -> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts :: Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
v1 Pattern SRTree
v2 Subst
subst EGraph (Maybe Double) SRTree
egr =
    (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Ord a => a -> a -> Bool
>=Double
0) (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v1 Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data) forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Bool
True) Bool -> Bool -> Bool
||
    (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Ord a => a -> a -> Bool
>=Double
0) (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v2 Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data) forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Bool
True)

is_negative :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Ord a => a -> a -> Bool
<Double
0) (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data) forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Bool
True

is_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
    forall a. Maybe a -> Bool
isJust (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data)

is_not_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
    forall a. Maybe a -> Bool
isNothing (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data)

rewritesBasic :: [Rewrite (Maybe Double) SRTree]
rewritesBasic :: [Rewrite (Maybe Double) SRTree]
rewritesBasic =
    [   -- commutativity
        Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"y" forall a. Num a => a -> a -> a
* Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
2
      , (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"a") forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
1)
      , (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"b") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b")
      -- associativity
      , (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y") forall a. Num a => a -> a -> a
+ Pattern SRTree
"z" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"z")
      , (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Num a => a -> a -> a
* Pattern SRTree
"z" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
"y" forall a. Num a => a -> a -> a
* Pattern SRTree
"z")
      -- , "x" * ("y" / "z") := ("x" * "y") / "z"
      , (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"z" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"z")
      -- distributive and factorization
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"z") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y") forall a. Num a => a -> a -> a
- Pattern SRTree
"z"
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"y" forall a. Num a => a -> a -> a
- Pattern SRTree
"z") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y") forall a. Num a => a -> a -> a
+ Pattern SRTree
"z"
      , forall a. Num a => a -> a
negate (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Num a => a -> a
negate Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y"
      , (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"a") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ forall a. Num a => a -> a
negate Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"y")) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (forall a. Num a => a -> a
negate Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x") forall a. Num a => a -> a -> a
* (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y")
      -- AQ
      , (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (Pattern SRTree
1 forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Floating a => a -> a -> a
** Pattern SRTree
2) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (Pattern SRTree
1 forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall a. Floating a => a -> a -> a
** Pattern SRTree
2) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b"
   ]

-- Rules for nonlinear functions
rewritesFun :: [Rewrite (Maybe Double) SRTree]
rewritesFun :: [Rewrite (Maybe Double) SRTree]
rewritesFun = [
        forall a. Floating a => a -> a
log (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
log Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ forall a. Floating a => a -> a
log Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"x" Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x" 
      , Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b")
      , forall a. Floating a => a -> a
log (Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
log Pattern SRTree
"x" forall a. Num a => a -> a -> a
- forall a. Floating a => a -> a
log Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"x" Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x" 
      , forall a. Floating a => a -> a
log (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"y" forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"y" Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"y"
      , forall a. Floating a => a -> a
log (forall a. Floating a => a -> a
sqrt Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0.5 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , forall a. Floating a => a -> a
log (forall a. Floating a => a -> a
exp Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , forall a. Floating a => a -> a
exp (forall a. Floating a => a -> a
log Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
1forall a. Fractional a => a -> a -> a
/Pattern SRTree
2) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt Pattern SRTree
"x"
      , forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt Pattern SRTree
"a" forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"a" Pattern SRTree
"x"
      , forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y")) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt (forall a. Num a => a -> a
negate Pattern SRTree
"a") forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Pattern SRTree
"y" forall a. Num a => a -> a -> a
- Pattern SRTree
"x") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
"a"
      , forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"b" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y")) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt (forall a. Num a => a -> a
negate Pattern SRTree
"a") forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Pattern SRTree
"b" forall a. Num a => a -> a -> a
- Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
"b"
      , forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"a" Pattern SRTree
"x"
      , forall a. Num a => a -> a
abs (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Num a => a -> a
abs Pattern SRTree
"x" forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
abs Pattern SRTree
"y" -- :| is_const "x"
    ]

-- Rules that reduces redundant parameters
constReduction :: [Rewrite (Maybe Double) SRTree]
constReduction :: [Rewrite (Maybe Double) SRTree]
constReduction = [
      -- identities
        Pattern SRTree
0 forall a. Num a => a -> a -> a
+ Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
0 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
0 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
      , Pattern SRTree
1 forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
1 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
      , Pattern SRTree
0 forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
0 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
      , Pattern SRTree
0 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
      -- cancellations 
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1 forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
1 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
      , Pattern SRTree
0 forall a. Floating a => a -> a -> a
** Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
      , Pattern SRTree
1 forall a. Floating a => a -> a -> a
** Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1
      -- multiplication of inverse
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1 forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x"
      , (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Num a => a -> a -> a
+ (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"z") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"z")
      -- negate 
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
- ( (-Pattern SRTree
1) forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ forall a. Num a => a -> a
negate Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
0 forall a. Num a => a -> a -> a
- Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Num a => a -> a
negate Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" 
      -- constant fusion
      , (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall a. Num a => a -> a -> a
* (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"b") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
    ]

-- Rules that moves parameters to the outside and to the left
constFusion :: [Rewrite (Maybe Double) SRTree]
constFusion :: [Rewrite (Maybe Double) SRTree]
constFusion = [
        Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b") forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , Pattern SRTree
"b" forall a. Num a => a -> a -> a
- Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* ((Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a") forall a. Num a => a -> a -> a
- Pattern SRTree
"x") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a") forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a") forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
      , Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
    ]

rewriteTree :: (Analysis a l, Language l, Ord cost) => [Rewrite a l] -> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree :: forall a (l :: * -> *) cost.
(Analysis a l, Language l, Ord cost) =>
[Rewrite a l]
-> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree [Rewrite a l]
rules Int
n Int
coolOff CostFunction l cost
c Fix l
t = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a (l :: * -> *) schd cost.
(Analysis a l, Language l, Scheduler schd, Ord cost) =>
schd
-> Fix l
-> [Rewrite a l]
-> CostFunction l cost
-> (Fix l, EGraph a l)
equalitySaturation' (Int -> Int -> BackoffScheduler
BackoffScheduler Int
n Int
coolOff) Fix l
t [Rewrite a l]
rules CostFunction l cost
c

rewriteAll, rewriteConst :: Fix SRTree -> Fix SRTree
rewriteAll :: Fix SRTree -> Fix SRTree
rewriteAll   = forall a (l :: * -> *) cost.
(Analysis a l, Language l, Ord cost) =>
[Rewrite a l]
-> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree  ([Rewrite (Maybe Double) SRTree]
rewritesBasic forall a. Semigroup a => a -> a -> a
<> [Rewrite (Maybe Double) SRTree]
constReduction forall a. Semigroup a => a -> a -> a
<> [Rewrite (Maybe Double) SRTree]
constFusion forall a. Semigroup a => a -> a -> a
<> [Rewrite (Maybe Double) SRTree]
rewritesFun) Int
2500 Int
30 CostFunction SRTree Int
cost
rewriteConst :: Fix SRTree -> Fix SRTree
rewriteConst = forall a (l :: * -> *) cost.
(Analysis a l, Language l, Ord cost) =>
[Rewrite a l]
-> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree [Rewrite (Maybe Double) SRTree]
constReduction Int
100 Int
10 CostFunction SRTree Int
cost

rewriteUntilNoChange :: [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange :: [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange [Fix SRTree -> Fix SRTree]
_ Int
0 Fix SRTree
t = Fix SRTree
t
rewriteUntilNoChange [Fix SRTree -> Fix SRTree]
rs Int
n Fix SRTree
t
  | Fix SRTree
t forall a. Eq a => a -> a -> Bool
== Fix SRTree
t'   = Fix SRTree
t'
  | Bool
otherwise = [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange (forall a. [a] -> [a]
tail [Fix SRTree -> Fix SRTree]
rs forall a. Semigroup a => a -> a -> a
<> [forall a. [a] -> a
head [Fix SRTree -> Fix SRTree]
rs]) (Int
nforall a. Num a => a -> a -> a
-Int
1) Fix SRTree
t'
  where t' :: Fix SRTree
t' = forall a. [a] -> a
head [Fix SRTree -> Fix SRTree]
rs Fix SRTree
t

simplifyEqSat :: R.Fix SRTree -> R.Fix SRTree
simplifyEqSat :: Fix SRTree -> Fix SRTree
simplifyEqSat = Fix SRTree -> Fix SRTree
relabelParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
fromEqFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange [Fix SRTree -> Fix SRTree
rewriteAll] Int
2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
rewriteConst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
toEqFix

fromEqFix :: Fix SRTree -> R.Fix SRTree
fromEqFix :: Fix SRTree -> Fix SRTree
fromEqFix = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree) -> Fix SRTree
alg
  where
    alg :: SRTree (Fix SRTree) -> Fix SRTree
alg (Const Double
x) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Double -> SRTree val
Const Double
x)
    alg (Var Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Int -> SRTree val
Var Int
ix)
    alg (Param Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Int -> SRTree val
Param Int
ix)
    alg (Bin Op
op Fix SRTree
l Fix SRTree
r) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
l Fix SRTree
r)
    alg (Uni Function
f Fix SRTree
t) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)

toEqFix :: R.Fix SRTree -> Fix SRTree
toEqFix :: Fix SRTree -> Fix SRTree
toEqFix = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
R.cata SRTree (Fix SRTree) -> Fix SRTree
alg
  where
    alg :: SRTree (Fix SRTree) -> Fix SRTree
alg (Const Double
x) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
x)
    alg (Var Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Var Int
ix)
    alg (Param Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Param Int
ix)
    alg (Bin Op
op Fix SRTree
l Fix SRTree
r) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
l Fix SRTree
r)
    alg (Uni Function
f Fix SRTree
t) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)