{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rule
(
RuleM
, cannotSimplify
, liftMaybe
, SimplificationRule(..)
, RuleGeneric
, RuleBasicOp
, RuleIf
, RuleDoLoop
, TopDown
, TopDownRule
, TopDownRuleGeneric
, TopDownRuleBasicOp
, TopDownRuleIf
, TopDownRuleDoLoop
, TopDownRuleOp
, BottomUp
, BottomUpRule
, BottomUpRuleGeneric
, BottomUpRuleBasicOp
, BottomUpRuleIf
, BottomUpRuleDoLoop
, BottomUpRuleOp
, RuleBook
, ruleBook
, topDownSimplifyStm
, bottomUpSimplifyStm
) where
import Control.Monad.State
import qualified Control.Monad.Fail as Fail
import Control.Monad.Except
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Representation.AST
import Futhark.Binder
data RuleError = CannotSimplify
| OtherError String
newtype RuleM lore a = RuleM (BinderT lore (StateT VNameSource (Except RuleError)) a)
deriving (Functor, Applicative, Monad,
MonadFreshNames, HasScope lore, LocalScope lore,
MonadError RuleError)
instance Fail.MonadFail (RuleM lore) where
fail = throwError . OtherError
instance (Attributes lore, BinderOps lore) => MonadBinder (RuleM lore) where
type Lore (RuleM lore) = lore
mkExpAttrM pat e = RuleM $ mkExpAttrM pat e
mkBodyM bnds res = RuleM $ mkBodyM bnds res
mkLetNamesM pat e = RuleM $ mkLetNamesM pat e
addStms = RuleM . addStms
collectStms (RuleM m) = RuleM $ collectStms m
certifying cs (RuleM m) = RuleM $ certifying cs m
simplify :: (MonadFreshNames m, HasScope lore m) =>
RuleM lore a
-> m (Maybe (a, Stms lore))
simplify (RuleM m) = do
scope <- askScope
modifyNameSource $ \src ->
case runExcept $ runStateT (runBinderT m scope) src of
Left CannotSimplify -> (Nothing, src)
Left (OtherError err) -> error $ "simplify: " ++ err
Right (x, src') -> (Just x, src')
cannotSimplify :: RuleM lore a
cannotSimplify = throwError CannotSimplify
liftMaybe :: Maybe a -> RuleM lore a
liftMaybe Nothing = cannotSimplify
liftMaybe (Just x) = return x
type RuleGeneric lore a = a -> Stm lore -> RuleM lore ()
type RuleBasicOp lore a = (a -> Pattern lore -> StmAux (ExpAttr lore) ->
BasicOp lore -> RuleM lore ())
type RuleIf lore a = a -> Pattern lore -> StmAux (ExpAttr lore) ->
(SubExp, BodyT lore, BodyT lore,
IfAttr (BranchType lore)) ->
RuleM lore ()
type RuleDoLoop lore a = a -> Pattern lore -> StmAux (ExpAttr lore) ->
([(FParam lore, SubExp)], [(FParam lore, SubExp)],
LoopForm lore, BodyT lore) ->
RuleM lore ()
type RuleOp lore a = a -> Pattern lore -> StmAux (ExpAttr lore) ->
Op lore -> RuleM lore ()
data SimplificationRule lore a = RuleGeneric (RuleGeneric lore a)
| RuleBasicOp (RuleBasicOp lore a)
| RuleIf (RuleIf lore a)
| RuleDoLoop (RuleDoLoop lore a)
| RuleOp (RuleOp lore a)
data Rules lore a = Rules { rulesAny :: [SimplificationRule lore a]
, rulesBasicOp :: [SimplificationRule lore a]
, rulesIf :: [SimplificationRule lore a]
, rulesDoLoop :: [SimplificationRule lore a]
, rulesOp :: [SimplificationRule lore a]
}
instance Semigroup (Rules lore a) where
Rules as1 bs1 cs1 ds1 es1 <> Rules as2 bs2 cs2 ds2 es2 =
Rules (as1<>as2) (bs1<>bs2) (cs1<>cs2) (ds1<>ds2) (es1<>es2)
instance Monoid (Rules lore a) where
mempty = Rules mempty mempty mempty mempty mempty
type TopDown lore = ST.SymbolTable lore
type TopDownRuleGeneric lore = RuleGeneric lore (TopDown lore)
type TopDownRuleBasicOp lore = RuleBasicOp lore (TopDown lore)
type TopDownRuleIf lore = RuleIf lore (TopDown lore)
type TopDownRuleDoLoop lore = RuleDoLoop lore (TopDown lore)
type TopDownRuleOp lore = RuleOp lore (TopDown lore)
type TopDownRule lore = SimplificationRule lore (TopDown lore)
type BottomUp lore = (ST.SymbolTable lore, UT.UsageTable)
type BottomUpRuleGeneric lore = RuleGeneric lore (BottomUp lore)
type BottomUpRuleBasicOp lore = RuleBasicOp lore (BottomUp lore)
type BottomUpRuleIf lore = RuleIf lore (BottomUp lore)
type BottomUpRuleDoLoop lore = RuleDoLoop lore (BottomUp lore)
type BottomUpRuleOp lore = RuleOp lore (BottomUp lore)
type BottomUpRule lore = SimplificationRule lore (BottomUp lore)
type TopDownRules lore = Rules lore (TopDown lore)
type BottomUpRules lore = Rules lore (BottomUp lore)
data RuleBook lore = RuleBook { bookTopDownRules :: TopDownRules lore
, bookBottomUpRules :: BottomUpRules lore
}
instance Semigroup (RuleBook lore) where
RuleBook ts1 bs1 <> RuleBook ts2 bs2 = RuleBook (ts1<>ts2) (bs1<>bs2)
instance Monoid (RuleBook lore) where
mempty = RuleBook mempty mempty
ruleBook :: [TopDownRule m]
-> [BottomUpRule m]
-> RuleBook m
ruleBook topdowns bottomups =
RuleBook (groupRules topdowns) (groupRules bottomups)
where groupRules :: [SimplificationRule m a] -> Rules m a
groupRules rs = Rules rs
(filter forBasicOp rs)
(filter forIf rs)
(filter forDoLoop rs)
(filter forOp rs)
forBasicOp RuleBasicOp{} = True
forBasicOp RuleGeneric{} = True
forBasicOp _ = False
forIf RuleIf{} = True
forIf RuleGeneric{} = True
forIf _ = False
forDoLoop RuleDoLoop{} = True
forDoLoop RuleGeneric{} = True
forDoLoop _ = False
forOp RuleOp{} = True
forOp RuleGeneric{} = True
forOp _ = False
topDownSimplifyStm :: (MonadFreshNames m, HasScope lore m, BinderOps lore) =>
RuleBook lore
-> ST.SymbolTable lore
-> Stm lore
-> m (Maybe (Stms lore))
topDownSimplifyStm = applyRules . bookTopDownRules
bottomUpSimplifyStm :: (MonadFreshNames m, HasScope lore m, BinderOps lore) =>
RuleBook lore
-> (ST.SymbolTable lore, UT.UsageTable)
-> Stm lore
-> m (Maybe (Stms lore))
bottomUpSimplifyStm = applyRules . bookBottomUpRules
rulesForStm :: Stm lore -> Rules lore a -> [SimplificationRule lore a]
rulesForStm stm = case stmExp stm of BasicOp{} -> rulesBasicOp
DoLoop{} -> rulesDoLoop
Op{} -> rulesOp
If{} -> rulesIf
_ -> rulesAny
applyRule :: SimplificationRule lore a -> a -> Stm lore -> RuleM lore ()
applyRule (RuleGeneric f) a stm = f a stm
applyRule (RuleBasicOp f) a (Let pat aux (BasicOp e)) = f a pat aux e
applyRule (RuleDoLoop f) a (Let pat aux (DoLoop ctx val form body)) =
f a pat aux (ctx, val, form, body)
applyRule (RuleIf f) a (Let pat aux (If cond tbody fbody ifsort)) =
f a pat aux (cond, tbody, fbody, ifsort)
applyRule (RuleOp f) a (Let pat aux (Op op)) =
f a pat aux op
applyRule _ _ _ =
cannotSimplify
applyRules :: (MonadFreshNames m, HasScope lore m, BinderOps lore) =>
Rules lore a -> a -> Stm lore
-> m (Maybe (Stms lore))
applyRules rules context stm = applyRules' (rulesForStm stm rules) context stm
applyRules' :: (MonadFreshNames m, HasScope lore m, BinderOps lore) =>
[SimplificationRule lore a] -> a -> Stm lore
-> m (Maybe (Stms lore))
applyRules' [] _ _ = return Nothing
applyRules' (rule:rules) context bnd = do
res <- simplify $ applyRule rule context bnd
case res of Just ((), bnds) -> return $ Just bnds
Nothing -> applyRules' rules context bnd