{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module defines the concept of a simplification rule for
-- bindings.  The intent is that you pass some context (such as symbol
-- table) and a binding, and is given back a sequence of bindings that
-- compute the same result, but are "better" in some sense.
--
-- These rewrite rules are "local", in that they do not maintain any
-- state or look at the program as a whole.  Compare this to the
-- fusion algorithm in @Futhark.Optimise.Fusion.Fusion@, which must be implemented
-- as its own pass.
module Futhark.Optimise.Simplify.Rule
  ( -- * The rule monad
    RuleM,
    cannotSimplify,
    liftMaybe,

    -- * Rule definition
    Rule (..),
    SimplificationRule (..),
    RuleGeneric,
    RuleBasicOp,
    RuleIf,
    RuleDoLoop,

    -- * Top-down rules
    TopDown,
    TopDownRule,
    TopDownRuleGeneric,
    TopDownRuleBasicOp,
    TopDownRuleIf,
    TopDownRuleDoLoop,
    TopDownRuleOp,

    -- * Bottom-up rules
    BottomUp,
    BottomUpRule,
    BottomUpRuleGeneric,
    BottomUpRuleBasicOp,
    BottomUpRuleIf,
    BottomUpRuleDoLoop,
    BottomUpRuleOp,

    -- * Assembling rules
    RuleBook,
    ruleBook,

    -- * Applying rules
    topDownSimplifyStm,
    bottomUpSimplifyStm,
  )
where

import Control.Monad.State
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Binder
import Futhark.IR

-- | The monad in which simplification rules are evaluated.
newtype RuleM lore a = RuleM (BinderT lore (StateT VNameSource Maybe) a)
  deriving
    ( (forall a b. (a -> b) -> RuleM lore a -> RuleM lore b)
-> (forall a b. a -> RuleM lore b -> RuleM lore a)
-> Functor (RuleM lore)
forall a b. a -> RuleM lore b -> RuleM lore a
forall a b. (a -> b) -> RuleM lore a -> RuleM lore b
forall lore a b. a -> RuleM lore b -> RuleM lore a
forall lore a b. (a -> b) -> RuleM lore a -> RuleM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> RuleM lore b -> RuleM lore a
$c<$ :: forall lore a b. a -> RuleM lore b -> RuleM lore a
fmap :: forall a b. (a -> b) -> RuleM lore a -> RuleM lore b
$cfmap :: forall lore a b. (a -> b) -> RuleM lore a -> RuleM lore b
Functor,
      Functor (RuleM lore)
Functor (RuleM lore)
-> (forall a. a -> RuleM lore a)
-> (forall a b.
    RuleM lore (a -> b) -> RuleM lore a -> RuleM lore b)
-> (forall a b c.
    (a -> b -> c) -> RuleM lore a -> RuleM lore b -> RuleM lore c)
-> (forall a b. RuleM lore a -> RuleM lore b -> RuleM lore b)
-> (forall a b. RuleM lore a -> RuleM lore b -> RuleM lore a)
-> Applicative (RuleM lore)
forall lore. Functor (RuleM lore)
forall a. a -> RuleM lore a
forall lore a. a -> RuleM lore a
forall a b. RuleM lore a -> RuleM lore b -> RuleM lore a
forall a b. RuleM lore a -> RuleM lore b -> RuleM lore b
forall a b. RuleM lore (a -> b) -> RuleM lore a -> RuleM lore b
forall lore a b. RuleM lore a -> RuleM lore b -> RuleM lore a
forall lore a b. RuleM lore a -> RuleM lore b -> RuleM lore b
forall lore a b.
RuleM lore (a -> b) -> RuleM lore a -> RuleM lore b
forall a b c.
(a -> b -> c) -> RuleM lore a -> RuleM lore b -> RuleM lore c
forall lore a b c.
(a -> b -> c) -> RuleM lore a -> RuleM lore b -> RuleM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. RuleM lore a -> RuleM lore b -> RuleM lore a
$c<* :: forall lore a b. RuleM lore a -> RuleM lore b -> RuleM lore a
*> :: forall a b. RuleM lore a -> RuleM lore b -> RuleM lore b
$c*> :: forall lore a b. RuleM lore a -> RuleM lore b -> RuleM lore b
liftA2 :: forall a b c.
(a -> b -> c) -> RuleM lore a -> RuleM lore b -> RuleM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c) -> RuleM lore a -> RuleM lore b -> RuleM lore c
<*> :: forall a b. RuleM lore (a -> b) -> RuleM lore a -> RuleM lore b
$c<*> :: forall lore a b.
RuleM lore (a -> b) -> RuleM lore a -> RuleM lore b
pure :: forall a. a -> RuleM lore a
$cpure :: forall lore a. a -> RuleM lore a
Applicative,
      Applicative (RuleM lore)
Applicative (RuleM lore)
-> (forall a b.
    RuleM lore a -> (a -> RuleM lore b) -> RuleM lore b)
-> (forall a b. RuleM lore a -> RuleM lore b -> RuleM lore b)
-> (forall a. a -> RuleM lore a)
-> Monad (RuleM lore)
forall lore. Applicative (RuleM lore)
forall a. a -> RuleM lore a
forall lore a. a -> RuleM lore a
forall a b. RuleM lore a -> RuleM lore b -> RuleM lore b
forall a b. RuleM lore a -> (a -> RuleM lore b) -> RuleM lore b
forall lore a b. RuleM lore a -> RuleM lore b -> RuleM lore b
forall lore a b.
RuleM lore a -> (a -> RuleM lore b) -> RuleM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> RuleM lore a
$creturn :: forall lore a. a -> RuleM lore a
>> :: forall a b. RuleM lore a -> RuleM lore b -> RuleM lore b
$c>> :: forall lore a b. RuleM lore a -> RuleM lore b -> RuleM lore b
>>= :: forall a b. RuleM lore a -> (a -> RuleM lore b) -> RuleM lore b
$c>>= :: forall lore a b.
RuleM lore a -> (a -> RuleM lore b) -> RuleM lore b
Monad,
      Monad (RuleM lore)
Applicative (RuleM lore)
RuleM lore VNameSource
Applicative (RuleM lore)
-> Monad (RuleM lore)
-> RuleM lore VNameSource
-> (VNameSource -> RuleM lore ())
-> MonadFreshNames (RuleM lore)
VNameSource -> RuleM lore ()
forall lore. Monad (RuleM lore)
forall lore. Applicative (RuleM lore)
forall lore. RuleM lore VNameSource
forall lore. VNameSource -> RuleM lore ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> RuleM lore ()
$cputNameSource :: forall lore. VNameSource -> RuleM lore ()
getNameSource :: RuleM lore VNameSource
$cgetNameSource :: forall lore. RuleM lore VNameSource
MonadFreshNames,
      HasScope lore,
      LocalScope lore
    )

instance (ASTLore lore, BinderOps lore) => MonadBinder (RuleM lore) where
  type Lore (RuleM lore) = lore
  mkExpDecM :: Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore))
-> RuleM lore (ExpDec (Lore (RuleM lore)))
mkExpDecM Pattern (Lore (RuleM lore))
pat Exp (Lore (RuleM lore))
e = BinderT lore (StateT VNameSource Maybe) (ExpDec lore)
-> RuleM lore (ExpDec lore)
forall lore a.
BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
RuleM (BinderT lore (StateT VNameSource Maybe) (ExpDec lore)
 -> RuleM lore (ExpDec lore))
-> BinderT lore (StateT VNameSource Maybe) (ExpDec lore)
-> RuleM lore (ExpDec lore)
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT lore (StateT VNameSource Maybe)))
-> Exp (Lore (BinderT lore (StateT VNameSource Maybe)))
-> BinderT
     lore
     (StateT VNameSource Maybe)
     (ExpDec (Lore (BinderT lore (StateT VNameSource Maybe))))
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m (ExpDec (Lore m))
mkExpDecM Pattern (Lore (BinderT lore (StateT VNameSource Maybe)))
Pattern (Lore (RuleM lore))
pat Exp (Lore (BinderT lore (StateT VNameSource Maybe)))
Exp (Lore (RuleM lore))
e
  mkBodyM :: Stms (Lore (RuleM lore))
-> Result -> RuleM lore (Body (Lore (RuleM lore)))
mkBodyM Stms (Lore (RuleM lore))
bnds Result
res = BinderT lore (StateT VNameSource Maybe) (Body lore)
-> RuleM lore (Body lore)
forall lore a.
BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
RuleM (BinderT lore (StateT VNameSource Maybe) (Body lore)
 -> RuleM lore (Body lore))
-> BinderT lore (StateT VNameSource Maybe) (Body lore)
-> RuleM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Stms (Lore (BinderT lore (StateT VNameSource Maybe)))
-> Result
-> BinderT
     lore
     (StateT VNameSource Maybe)
     (Body (Lore (BinderT lore (StateT VNameSource Maybe))))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> Result -> m (Body (Lore m))
mkBodyM Stms (Lore (BinderT lore (StateT VNameSource Maybe)))
Stms (Lore (RuleM lore))
bnds Result
res
  mkLetNamesM :: [VName]
-> Exp (Lore (RuleM lore)) -> RuleM lore (Stm (Lore (RuleM lore)))
mkLetNamesM [VName]
pat Exp (Lore (RuleM lore))
e = BinderT lore (StateT VNameSource Maybe) (Stm lore)
-> RuleM lore (Stm lore)
forall lore a.
BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
RuleM (BinderT lore (StateT VNameSource Maybe) (Stm lore)
 -> RuleM lore (Stm lore))
-> BinderT lore (StateT VNameSource Maybe) (Stm lore)
-> RuleM lore (Stm lore)
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Lore (BinderT lore (StateT VNameSource Maybe)))
-> BinderT
     lore
     (StateT VNameSource Maybe)
     (Stm (Lore (BinderT lore (StateT VNameSource Maybe))))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesM [VName]
pat Exp (Lore (BinderT lore (StateT VNameSource Maybe)))
Exp (Lore (RuleM lore))
e

  addStms :: Stms (Lore (RuleM lore)) -> RuleM lore ()
addStms = BinderT lore (StateT VNameSource Maybe) () -> RuleM lore ()
forall lore a.
BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
RuleM (BinderT lore (StateT VNameSource Maybe) () -> RuleM lore ())
-> (Stms lore -> BinderT lore (StateT VNameSource Maybe) ())
-> Stms lore
-> RuleM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms lore -> BinderT lore (StateT VNameSource Maybe) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms
  collectStms :: forall a. RuleM lore a -> RuleM lore (a, Stms (Lore (RuleM lore)))
collectStms (RuleM BinderT lore (StateT VNameSource Maybe) a
m) = BinderT lore (StateT VNameSource Maybe) (a, Stms lore)
-> RuleM lore (a, Stms lore)
forall lore a.
BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
RuleM (BinderT lore (StateT VNameSource Maybe) (a, Stms lore)
 -> RuleM lore (a, Stms lore))
-> BinderT lore (StateT VNameSource Maybe) (a, Stms lore)
-> RuleM lore (a, Stms lore)
forall a b. (a -> b) -> a -> b
$ BinderT lore (StateT VNameSource Maybe) a
-> BinderT
     lore
     (StateT VNameSource Maybe)
     (a, Stms (Lore (BinderT lore (StateT VNameSource Maybe))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms BinderT lore (StateT VNameSource Maybe) a
m

-- | Execute a 'RuleM' action.  If succesful, returns the result and a
-- list of new bindings.
simplify ::
  Scope lore ->
  VNameSource ->
  Rule lore ->
  Maybe (Stms lore, VNameSource)
simplify :: forall lore.
Scope lore
-> VNameSource -> Rule lore -> Maybe (Stms lore, VNameSource)
simplify Scope lore
_ VNameSource
_ Rule lore
Skip = Maybe (Stms lore, VNameSource)
forall a. Maybe a
Nothing
simplify Scope lore
scope VNameSource
src (Simplify (RuleM BinderT lore (StateT VNameSource Maybe) ()
m)) =
  StateT VNameSource Maybe (Stms lore)
-> VNameSource -> Maybe (Stms lore, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (BinderT lore (StateT VNameSource Maybe) ()
-> Scope lore -> StateT VNameSource Maybe (Stms lore)
forall (m :: * -> *) lore.
MonadFreshNames m =>
BinderT lore m () -> Scope lore -> m (Stms lore)
runBinderT_ BinderT lore (StateT VNameSource Maybe) ()
m Scope lore
scope) VNameSource
src

cannotSimplify :: RuleM lore a
cannotSimplify :: forall lore a. RuleM lore a
cannotSimplify = BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
forall lore a.
BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
RuleM (BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a)
-> BinderT lore (StateT VNameSource Maybe) a -> RuleM lore a
forall a b. (a -> b) -> a -> b
$ StateT VNameSource Maybe a
-> BinderT lore (StateT VNameSource Maybe) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT VNameSource Maybe a
 -> BinderT lore (StateT VNameSource Maybe) a)
-> StateT VNameSource Maybe a
-> BinderT lore (StateT VNameSource Maybe) a
forall a b. (a -> b) -> a -> b
$ Maybe a -> StateT VNameSource Maybe a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe a
forall a. Maybe a
Nothing

liftMaybe :: Maybe a -> RuleM lore a
liftMaybe :: forall a lore. Maybe a -> RuleM lore a
liftMaybe Maybe a
Nothing = RuleM lore a
forall lore a. RuleM lore a
cannotSimplify
liftMaybe (Just a
x) = a -> RuleM lore a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | An efficient way of encoding whether a simplification rule should even be attempted.
data Rule lore
  = -- | Give it a shot.
    Simplify (RuleM lore ())
  | -- | Don't bother.
    Skip

type RuleGeneric lore a = a -> Stm lore -> Rule lore

type RuleBasicOp lore a =
  ( a ->
    Pattern lore ->
    StmAux (ExpDec lore) ->
    BasicOp ->
    Rule lore
  )

type RuleIf lore a =
  a ->
  Pattern lore ->
  StmAux (ExpDec lore) ->
  ( SubExp,
    BodyT lore,
    BodyT lore,
    IfDec (BranchType lore)
  ) ->
  Rule lore

type RuleDoLoop lore a =
  a ->
  Pattern lore ->
  StmAux (ExpDec lore) ->
  ( [(FParam lore, SubExp)],
    [(FParam lore, SubExp)],
    LoopForm lore,
    BodyT lore
  ) ->
  Rule lore

type RuleOp lore a =
  a ->
  Pattern lore ->
  StmAux (ExpDec lore) ->
  Op lore ->
  Rule lore

-- | A simplification rule takes some argument and a statement, and
-- tries to simplify the statement.
data SimplificationRule lore a
  = RuleGeneric (RuleGeneric lore a)
  | RuleBasicOp (RuleBasicOp lore a)
  | RuleIf (RuleIf lore a)
  | RuleDoLoop (RuleDoLoop lore a)
  | RuleOp (RuleOp lore a)

-- | A collection of rules grouped by which forms of statements they
-- may apply to.
data Rules lore a = Rules
  { forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesAny :: [SimplificationRule lore a],
    forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesBasicOp :: [SimplificationRule lore a],
    forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesIf :: [SimplificationRule lore a],
    forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesDoLoop :: [SimplificationRule lore a],
    forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesOp :: [SimplificationRule lore a]
  }

instance Semigroup (Rules lore a) where
  Rules [SimplificationRule lore a]
as1 [SimplificationRule lore a]
bs1 [SimplificationRule lore a]
cs1 [SimplificationRule lore a]
ds1 [SimplificationRule lore a]
es1 <> :: Rules lore a -> Rules lore a -> Rules lore a
<> Rules [SimplificationRule lore a]
as2 [SimplificationRule lore a]
bs2 [SimplificationRule lore a]
cs2 [SimplificationRule lore a]
ds2 [SimplificationRule lore a]
es2 =
    [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> Rules lore a
forall lore a.
[SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> Rules lore a
Rules ([SimplificationRule lore a]
as1 [SimplificationRule lore a]
-> [SimplificationRule lore a] -> [SimplificationRule lore a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule lore a]
as2) ([SimplificationRule lore a]
bs1 [SimplificationRule lore a]
-> [SimplificationRule lore a] -> [SimplificationRule lore a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule lore a]
bs2) ([SimplificationRule lore a]
cs1 [SimplificationRule lore a]
-> [SimplificationRule lore a] -> [SimplificationRule lore a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule lore a]
cs2) ([SimplificationRule lore a]
ds1 [SimplificationRule lore a]
-> [SimplificationRule lore a] -> [SimplificationRule lore a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule lore a]
ds2) ([SimplificationRule lore a]
es1 [SimplificationRule lore a]
-> [SimplificationRule lore a] -> [SimplificationRule lore a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule lore a]
es2)

instance Monoid (Rules lore a) where
  mempty :: Rules lore a
mempty = [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> Rules lore a
forall lore a.
[SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> Rules lore a
Rules [SimplificationRule lore a]
forall a. Monoid a => a
mempty [SimplificationRule lore a]
forall a. Monoid a => a
mempty [SimplificationRule lore a]
forall a. Monoid a => a
mempty [SimplificationRule lore a]
forall a. Monoid a => a
mempty [SimplificationRule lore a]
forall a. Monoid a => a
mempty

-- | Context for a rule applied during top-down traversal of the
-- program.  Takes a symbol table as argument.
type TopDown lore = ST.SymbolTable lore

type TopDownRuleGeneric lore = RuleGeneric lore (TopDown lore)

type TopDownRuleBasicOp lore = RuleBasicOp lore (TopDown lore)

type TopDownRuleIf lore = RuleIf lore (TopDown lore)

type TopDownRuleDoLoop lore = RuleDoLoop lore (TopDown lore)

type TopDownRuleOp lore = RuleOp lore (TopDown lore)

type TopDownRule lore = SimplificationRule lore (TopDown lore)

-- | Context for a rule applied during bottom-up traversal of the
-- program.  Takes a symbol table and usage table as arguments.
type BottomUp lore = (ST.SymbolTable lore, UT.UsageTable)

type BottomUpRuleGeneric lore = RuleGeneric lore (BottomUp lore)

type BottomUpRuleBasicOp lore = RuleBasicOp lore (BottomUp lore)

type BottomUpRuleIf lore = RuleIf lore (BottomUp lore)

type BottomUpRuleDoLoop lore = RuleDoLoop lore (BottomUp lore)

type BottomUpRuleOp lore = RuleOp lore (BottomUp lore)

type BottomUpRule lore = SimplificationRule lore (BottomUp lore)

-- | A collection of top-down rules.
type TopDownRules lore = Rules lore (TopDown lore)

-- | A collection of bottom-up rules.
type BottomUpRules lore = Rules lore (BottomUp lore)

-- | A collection of both top-down and bottom-up rules.
data RuleBook lore = RuleBook
  { forall lore. RuleBook lore -> TopDownRules lore
bookTopDownRules :: TopDownRules lore,
    forall lore. RuleBook lore -> BottomUpRules lore
bookBottomUpRules :: BottomUpRules lore
  }

instance Semigroup (RuleBook lore) where
  RuleBook TopDownRules lore
ts1 BottomUpRules lore
bs1 <> :: RuleBook lore -> RuleBook lore -> RuleBook lore
<> RuleBook TopDownRules lore
ts2 BottomUpRules lore
bs2 = TopDownRules lore -> BottomUpRules lore -> RuleBook lore
forall lore.
TopDownRules lore -> BottomUpRules lore -> RuleBook lore
RuleBook (TopDownRules lore
ts1 TopDownRules lore -> TopDownRules lore -> TopDownRules lore
forall a. Semigroup a => a -> a -> a
<> TopDownRules lore
ts2) (BottomUpRules lore
bs1 BottomUpRules lore -> BottomUpRules lore -> BottomUpRules lore
forall a. Semigroup a => a -> a -> a
<> BottomUpRules lore
bs2)

instance Monoid (RuleBook lore) where
  mempty :: RuleBook lore
mempty = TopDownRules lore -> BottomUpRules lore -> RuleBook lore
forall lore.
TopDownRules lore -> BottomUpRules lore -> RuleBook lore
RuleBook TopDownRules lore
forall a. Monoid a => a
mempty BottomUpRules lore
forall a. Monoid a => a
mempty

-- | Construct a rule book from a collection of rules.
ruleBook ::
  [TopDownRule m] ->
  [BottomUpRule m] ->
  RuleBook m
ruleBook :: forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule m]
topdowns [BottomUpRule m]
bottomups =
  TopDownRules m -> BottomUpRules m -> RuleBook m
forall lore.
TopDownRules lore -> BottomUpRules lore -> RuleBook lore
RuleBook ([TopDownRule m] -> TopDownRules m
forall m a. [SimplificationRule m a] -> Rules m a
groupRules [TopDownRule m]
topdowns) ([BottomUpRule m] -> BottomUpRules m
forall m a. [SimplificationRule m a] -> Rules m a
groupRules [BottomUpRule m]
bottomups)
  where
    groupRules :: [SimplificationRule m a] -> Rules m a
    groupRules :: forall m a. [SimplificationRule m a] -> Rules m a
groupRules [SimplificationRule m a]
rs =
      [SimplificationRule m a]
-> [SimplificationRule m a]
-> [SimplificationRule m a]
-> [SimplificationRule m a]
-> [SimplificationRule m a]
-> Rules m a
forall lore a.
[SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> [SimplificationRule lore a]
-> Rules lore a
Rules
        [SimplificationRule m a]
rs
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall {lore} {a}. SimplificationRule lore a -> Bool
forBasicOp [SimplificationRule m a]
rs)
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall {lore} {a}. SimplificationRule lore a -> Bool
forIf [SimplificationRule m a]
rs)
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall {lore} {a}. SimplificationRule lore a -> Bool
forDoLoop [SimplificationRule m a]
rs)
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall {lore} {a}. SimplificationRule lore a -> Bool
forOp [SimplificationRule m a]
rs)

    forBasicOp :: SimplificationRule lore a -> Bool
forBasicOp RuleBasicOp {} = Bool
True
    forBasicOp RuleGeneric {} = Bool
True
    forBasicOp SimplificationRule lore a
_ = Bool
False

    forIf :: SimplificationRule lore a -> Bool
forIf RuleIf {} = Bool
True
    forIf RuleGeneric {} = Bool
True
    forIf SimplificationRule lore a
_ = Bool
False

    forDoLoop :: SimplificationRule lore a -> Bool
forDoLoop RuleDoLoop {} = Bool
True
    forDoLoop RuleGeneric {} = Bool
True
    forDoLoop SimplificationRule lore a
_ = Bool
False

    forOp :: SimplificationRule lore a -> Bool
forOp RuleOp {} = Bool
True
    forOp RuleGeneric {} = Bool
True
    forOp SimplificationRule lore a
_ = Bool
False

-- | @simplifyStm lookup bnd@ performs simplification of the
-- binding @bnd@.  If simplification is possible, a replacement list
-- of bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
topDownSimplifyStm ::
  (MonadFreshNames m, HasScope lore m) =>
  RuleBook lore ->
  ST.SymbolTable lore ->
  Stm lore ->
  m (Maybe (Stms lore))
topDownSimplifyStm :: forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> SymbolTable lore -> Stm lore -> m (Maybe (Stms lore))
topDownSimplifyStm = Rules lore (TopDown lore)
-> TopDown lore -> Stm lore -> m (Maybe (Stms lore))
forall (m :: * -> *) lore a.
(MonadFreshNames m, HasScope lore m) =>
Rules lore a -> a -> Stm lore -> m (Maybe (Stms lore))
applyRules (Rules lore (TopDown lore)
 -> TopDown lore -> Stm lore -> m (Maybe (Stms lore)))
-> (RuleBook lore -> Rules lore (TopDown lore))
-> RuleBook lore
-> TopDown lore
-> Stm lore
-> m (Maybe (Stms lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RuleBook lore -> Rules lore (TopDown lore)
forall lore. RuleBook lore -> TopDownRules lore
bookTopDownRules

-- | @simplifyStm uses bnd@ performs simplification of the binding
-- @bnd@.  If simplification is possible, a replacement list of
-- bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
-- The first argument is the set of names used after this binding.
bottomUpSimplifyStm ::
  (MonadFreshNames m, HasScope lore m) =>
  RuleBook lore ->
  (ST.SymbolTable lore, UT.UsageTable) ->
  Stm lore ->
  m (Maybe (Stms lore))
bottomUpSimplifyStm :: forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> (SymbolTable lore, UsageTable)
-> Stm lore
-> m (Maybe (Stms lore))
bottomUpSimplifyStm = Rules lore (BottomUp lore)
-> BottomUp lore -> Stm lore -> m (Maybe (Stms lore))
forall (m :: * -> *) lore a.
(MonadFreshNames m, HasScope lore m) =>
Rules lore a -> a -> Stm lore -> m (Maybe (Stms lore))
applyRules (Rules lore (BottomUp lore)
 -> BottomUp lore -> Stm lore -> m (Maybe (Stms lore)))
-> (RuleBook lore -> Rules lore (BottomUp lore))
-> RuleBook lore
-> BottomUp lore
-> Stm lore
-> m (Maybe (Stms lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RuleBook lore -> Rules lore (BottomUp lore)
forall lore. RuleBook lore -> BottomUpRules lore
bookBottomUpRules

rulesForStm :: Stm lore -> Rules lore a -> [SimplificationRule lore a]
rulesForStm :: forall lore a.
Stm lore -> Rules lore a -> [SimplificationRule lore a]
rulesForStm Stm lore
stm = case Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm of
  BasicOp {} -> Rules lore a -> [SimplificationRule lore a]
forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesBasicOp
  DoLoop {} -> Rules lore a -> [SimplificationRule lore a]
forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesDoLoop
  Op {} -> Rules lore a -> [SimplificationRule lore a]
forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesOp
  If {} -> Rules lore a -> [SimplificationRule lore a]
forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesIf
  Exp lore
_ -> Rules lore a -> [SimplificationRule lore a]
forall lore a. Rules lore a -> [SimplificationRule lore a]
rulesAny

applyRule :: SimplificationRule lore a -> a -> Stm lore -> Rule lore
applyRule :: forall lore a.
SimplificationRule lore a -> a -> Stm lore -> Rule lore
applyRule (RuleGeneric RuleGeneric lore a
f) a
a Stm lore
stm = RuleGeneric lore a
f a
a Stm lore
stm
applyRule (RuleBasicOp RuleBasicOp lore a
f) a
a (Let Pattern lore
pat StmAux (ExpDec lore)
aux (BasicOp BasicOp
e)) = RuleBasicOp lore a
f a
a Pattern lore
pat StmAux (ExpDec lore)
aux BasicOp
e
applyRule (RuleDoLoop RuleDoLoop lore a
f) a
a (Let Pattern lore
pat StmAux (ExpDec lore)
aux (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
body)) =
  RuleDoLoop lore a
f a
a Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
body)
applyRule (RuleIf RuleIf lore a
f) a
a (Let Pattern lore
pat StmAux (ExpDec lore)
aux (If SubExp
cond BodyT lore
tbody BodyT lore
fbody IfDec (BranchType lore)
ifsort)) =
  RuleIf lore a
f a
a Pattern lore
pat StmAux (ExpDec lore)
aux (SubExp
cond, BodyT lore
tbody, BodyT lore
fbody, IfDec (BranchType lore)
ifsort)
applyRule (RuleOp RuleOp lore a
f) a
a (Let Pattern lore
pat StmAux (ExpDec lore)
aux (Op Op lore
op)) =
  RuleOp lore a
f a
a Pattern lore
pat StmAux (ExpDec lore)
aux Op lore
op
applyRule SimplificationRule lore a
_ a
_ Stm lore
_ =
  Rule lore
forall lore. Rule lore
Skip

applyRules ::
  (MonadFreshNames m, HasScope lore m) =>
  Rules lore a ->
  a ->
  Stm lore ->
  m (Maybe (Stms lore))
applyRules :: forall (m :: * -> *) lore a.
(MonadFreshNames m, HasScope lore m) =>
Rules lore a -> a -> Stm lore -> m (Maybe (Stms lore))
applyRules Rules lore a
all_rules a
context Stm lore
stm = do
  Scope lore
scope <- m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope

  (VNameSource -> (Maybe (Stms lore), VNameSource))
-> m (Maybe (Stms lore))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe (Stms lore), VNameSource))
 -> m (Maybe (Stms lore)))
-> (VNameSource -> (Maybe (Stms lore), VNameSource))
-> m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let applyRules' :: [SimplificationRule lore a] -> Maybe (Stms lore, VNameSource)
applyRules' [] = Maybe (Stms lore, VNameSource)
forall a. Maybe a
Nothing
        applyRules' (SimplificationRule lore a
rule : [SimplificationRule lore a]
rules) =
          case Scope lore
-> VNameSource -> Rule lore -> Maybe (Stms lore, VNameSource)
forall lore.
Scope lore
-> VNameSource -> Rule lore -> Maybe (Stms lore, VNameSource)
simplify Scope lore
scope VNameSource
src (SimplificationRule lore a -> a -> Stm lore -> Rule lore
forall lore a.
SimplificationRule lore a -> a -> Stm lore -> Rule lore
applyRule SimplificationRule lore a
rule a
context Stm lore
stm) of
            Just (Stms lore, VNameSource)
x -> (Stms lore, VNameSource) -> Maybe (Stms lore, VNameSource)
forall a. a -> Maybe a
Just (Stms lore, VNameSource)
x
            Maybe (Stms lore, VNameSource)
Nothing -> [SimplificationRule lore a] -> Maybe (Stms lore, VNameSource)
applyRules' [SimplificationRule lore a]
rules
     in case [SimplificationRule lore a] -> Maybe (Stms lore, VNameSource)
applyRules' ([SimplificationRule lore a] -> Maybe (Stms lore, VNameSource))
-> [SimplificationRule lore a] -> Maybe (Stms lore, VNameSource)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Rules lore a -> [SimplificationRule lore a]
forall lore a.
Stm lore -> Rules lore a -> [SimplificationRule lore a]
rulesForStm Stm lore
stm Rules lore a
all_rules of
          Just (Stms lore
stms, VNameSource
src') -> (Stms lore -> Maybe (Stms lore)
forall a. a -> Maybe a
Just Stms lore
stms, VNameSource
src')
          Maybe (Stms lore, VNameSource)
Nothing -> (Maybe (Stms lore)
forall a. Maybe a
Nothing, VNameSource
src)