{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.Optimise.Simplify.Engine
(
SimpleM,
runSimpleM,
SimpleOps (..),
SimplifyOp,
bindableSimpleOps,
Env (envHoistBlockers, envRules),
emptyEnv,
HoistBlockers (..),
neverBlocks,
noExtraHoistBlockers,
neverHoist,
BlockPred,
orIf,
hasFree,
isConsumed,
isFalse,
isOp,
isNotSafe,
asksEngineEnv,
askVtable,
localVtable,
SimplifiableLore,
Simplifiable (..),
simplifyStms,
simplifyFun,
simplifyLambda,
simplifyLambdaNoHoisting,
bindLParams,
simplifyBody,
SimplifiedBody,
ST.SymbolTable,
hoistStms,
blockIf,
module Futhark.Optimise.Simplify.Lore,
)
where
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Either
import Data.List (find, foldl', mapAccumL, nub)
import Data.Maybe
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Lore
import Futhark.Optimise.Simplify.Rule
import Futhark.Util (splitFromEnd)
data HoistBlockers lore = HoistBlockers
{
HoistBlockers lore -> BlockPred (Wise lore)
blockHoistPar :: BlockPred (Wise lore),
HoistBlockers lore -> BlockPred (Wise lore)
blockHoistSeq :: BlockPred (Wise lore),
HoistBlockers lore -> BlockPred (Wise lore)
blockHoistBranch :: BlockPred (Wise lore),
HoistBlockers lore -> Stm (Wise lore) -> Bool
isAllocation :: Stm (Wise lore) -> Bool
}
noExtraHoistBlockers :: HoistBlockers lore
=
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
forall lore.
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
HoistBlockers BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks (Bool -> Stm (Wise lore) -> Bool
forall a b. a -> b -> a
const Bool
False)
neverHoist :: HoistBlockers lore
neverHoist :: HoistBlockers lore
neverHoist =
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
forall lore.
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
HoistBlockers BlockPred (Wise lore)
forall lore. BlockPred lore
alwaysBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
alwaysBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
alwaysBlocks (Bool -> Stm (Wise lore) -> Bool
forall a b. a -> b -> a
const Bool
False)
data Env lore = Env
{ Env lore -> RuleBook (Wise lore)
envRules :: RuleBook (Wise lore),
Env lore -> HoistBlockers lore
envHoistBlockers :: HoistBlockers lore,
Env lore -> SymbolTable (Wise lore)
envVtable :: ST.SymbolTable (Wise lore)
}
emptyEnv :: RuleBook (Wise lore) -> HoistBlockers lore -> Env lore
emptyEnv :: RuleBook (Wise lore) -> HoistBlockers lore -> Env lore
emptyEnv RuleBook (Wise lore)
rules HoistBlockers lore
blockers =
Env :: forall lore.
RuleBook (Wise lore)
-> HoistBlockers lore -> SymbolTable (Wise lore) -> Env lore
Env
{ envRules :: RuleBook (Wise lore)
envRules = RuleBook (Wise lore)
rules,
envHoistBlockers :: HoistBlockers lore
envHoistBlockers = HoistBlockers lore
blockers,
envVtable :: SymbolTable (Wise lore)
envVtable = SymbolTable (Wise lore)
forall a. Monoid a => a
mempty
}
type Protect m = SubExp -> Pattern (Lore m) -> Op (Lore m) -> Maybe (m ())
data SimpleOps lore = SimpleOps
{ SimpleOps lore
-> SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpDec (Wise lore))
mkExpDecS ::
ST.SymbolTable (Wise lore) ->
Pattern (Wise lore) ->
Exp (Wise lore) ->
SimpleM lore (ExpDec (Wise lore)),
SimpleOps lore
-> SymbolTable (Wise lore)
-> Stms (Wise lore)
-> Result
-> SimpleM lore (Body (Wise lore))
mkBodyS ::
ST.SymbolTable (Wise lore) ->
Stms (Wise lore) ->
Result ->
SimpleM lore (Body (Wise lore)),
SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS :: Protect (Binder (Wise lore)),
SimpleOps lore -> Op (Wise lore) -> UsageTable
opUsageS :: Op (Wise lore) -> UT.UsageTable,
SimpleOps lore -> SimplifyOp lore (Op lore)
simplifyOpS :: SimplifyOp lore (Op lore)
}
type SimplifyOp lore op = op -> SimpleM lore (OpWithWisdom op, Stms (Wise lore))
bindableSimpleOps ::
(SimplifiableLore lore, Bindable lore) =>
SimplifyOp lore (Op lore) ->
SimpleOps lore
bindableSimpleOps :: SimplifyOp lore (Op lore) -> SimpleOps lore
bindableSimpleOps =
(SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpDec (Wise lore)))
-> (SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> Protect (Binder (Wise lore))
-> (Op (Wise lore) -> UsageTable)
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
forall lore.
(SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpDec (Wise lore)))
-> (SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> Protect (Binder (Wise lore))
-> (Op (Wise lore) -> UsageTable)
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
SimpleOps SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpDec (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Bindable lore) =>
p -> PatternT (LetDec lore) -> Exp lore -> m (ExpDec lore)
mkExpDecS' SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Bindable lore) =>
p -> Stms lore -> Result -> m (Body lore)
mkBodyS' Protect (Binder (Wise lore))
forall p p p a. p -> p -> p -> Maybe a
protectHoistedOpS' (UsageTable -> OpWithWisdom (Op lore) -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty)
where
mkExpDecS' :: p -> PatternT (LetDec lore) -> Exp lore -> m (ExpDec lore)
mkExpDecS' p
_ PatternT (LetDec lore)
pat Exp lore
e = ExpDec lore -> m (ExpDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpDec lore -> m (ExpDec lore)) -> ExpDec lore -> m (ExpDec lore)
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec lore) -> Exp lore -> ExpDec lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpDec lore
mkExpDec PatternT (LetDec lore)
pat Exp lore
e
mkBodyS' :: p -> Stms lore -> Result -> m (Body lore)
mkBodyS' p
_ Stms lore
bnds Result
res = Body lore -> m (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> m (Body lore)) -> Body lore -> m (Body lore)
forall a b. (a -> b) -> a -> b
$ Stms lore -> Result -> Body lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms lore
bnds Result
res
protectHoistedOpS' :: p -> p -> p -> Maybe a
protectHoistedOpS' p
_ p
_ p
_ = Maybe a
forall a. Maybe a
Nothing
newtype SimpleM lore a
= SimpleM
( ReaderT
(SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates))
a
)
deriving
( Functor (SimpleM lore)
a -> SimpleM lore a
Functor (SimpleM lore)
-> (forall a. a -> SimpleM lore a)
-> (forall a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b)
-> (forall a b c.
(a -> b -> c)
-> SimpleM lore a -> SimpleM lore b -> SimpleM lore c)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a)
-> Applicative (SimpleM lore)
SimpleM lore a -> SimpleM lore b -> SimpleM lore b
SimpleM lore a -> SimpleM lore b -> SimpleM lore a
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall lore. Functor (SimpleM lore)
forall a. a -> SimpleM lore a
forall lore a. a -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
forall a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall lore a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SimpleM lore a -> SimpleM lore b -> SimpleM lore a
$c<* :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
*> :: SimpleM lore a -> SimpleM lore b -> SimpleM lore b
$c*> :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
liftA2 :: (a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
<*> :: SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
$c<*> :: forall lore a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
pure :: a -> SimpleM lore a
$cpure :: forall lore a. a -> SimpleM lore a
$cp1Applicative :: forall lore. Functor (SimpleM lore)
Applicative,
a -> SimpleM lore b -> SimpleM lore a
(a -> b) -> SimpleM lore a -> SimpleM lore b
(forall a b. (a -> b) -> SimpleM lore a -> SimpleM lore b)
-> (forall a b. a -> SimpleM lore b -> SimpleM lore a)
-> Functor (SimpleM lore)
forall a b. a -> SimpleM lore b -> SimpleM lore a
forall a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
forall lore a b. a -> SimpleM lore b -> SimpleM lore a
forall lore a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SimpleM lore b -> SimpleM lore a
$c<$ :: forall lore a b. a -> SimpleM lore b -> SimpleM lore a
fmap :: (a -> b) -> SimpleM lore a -> SimpleM lore b
$cfmap :: forall lore a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
Functor,
Applicative (SimpleM lore)
a -> SimpleM lore a
Applicative (SimpleM lore)
-> (forall a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b)
-> (forall a. a -> SimpleM lore a)
-> Monad (SimpleM lore)
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore. Applicative (SimpleM lore)
forall a. a -> SimpleM lore a
forall lore a. a -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SimpleM lore a
$creturn :: forall lore a. a -> SimpleM lore a
>> :: SimpleM lore a -> SimpleM lore b -> SimpleM lore b
$c>> :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
>>= :: SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
$c>>= :: forall lore a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
$cp1Monad :: forall lore. Applicative (SimpleM lore)
Monad,
MonadReader (SimpleOps lore, Env lore),
MonadState (VNameSource, Bool, Certificates)
)
instance MonadFreshNames (SimpleM lore) where
putNameSource :: VNameSource -> SimpleM lore ()
putNameSource VNameSource
src = ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
_, Bool
b, Certificates
c) -> (VNameSource
src, Bool
b, Certificates
c)
getNameSource :: SimpleM lore VNameSource
getNameSource = ((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource)
-> ((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
_, Certificates
_) -> VNameSource
a
instance SimplifiableLore lore => HasScope (Wise lore) (SimpleM lore) where
askScope :: SimpleM lore (Scope (Wise lore))
askScope = SymbolTable (Wise lore) -> Scope (Wise lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope (SymbolTable (Wise lore) -> Scope (Wise lore))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Scope (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
lookupType :: VName -> SimpleM lore Type
lookupType VName
name = do
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case VName -> SymbolTable (Wise lore) -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
name SymbolTable (Wise lore)
vtable of
Just Type
t -> Type -> SimpleM lore Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
Maybe Type
Nothing ->
[Char] -> SimpleM lore Type
forall a. HasCallStack => [Char] -> a
error ([Char] -> SimpleM lore Type) -> [Char] -> SimpleM lore Type
forall a b. (a -> b) -> a -> b
$
[Char]
"SimpleM.lookupType: cannot find variable "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" in symbol table."
instance
SimplifiableLore lore =>
LocalScope (Wise lore) (SimpleM lore)
where
localScope :: Scope (Wise lore) -> SimpleM lore a -> SimpleM lore a
localScope Scope (Wise lore)
types = (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
types)
runSimpleM ::
SimpleM lore a ->
SimpleOps lore ->
Env lore ->
VNameSource ->
((a, Bool), VNameSource)
runSimpleM :: SimpleM lore a
-> SimpleOps lore
-> Env lore
-> VNameSource
-> ((a, Bool), VNameSource)
runSimpleM (SimpleM ReaderT
(SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates))
a
m) SimpleOps lore
simpl Env lore
env VNameSource
src =
let (a
x, (VNameSource
src', Bool
b, Certificates
_)) = State (VNameSource, Bool, Certificates) a
-> (VNameSource, Bool, Certificates)
-> (a, (VNameSource, Bool, Certificates))
forall s a. State s a -> s -> (a, s)
runState (ReaderT
(SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates))
a
-> (SimpleOps lore, Env lore)
-> State (VNameSource, Bool, Certificates) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
(SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates))
a
m (SimpleOps lore
simpl, Env lore
env)) (VNameSource
src, Bool
False, Certificates
forall a. Monoid a => a
mempty)
in ((a
x, Bool
b), VNameSource
src')
askEngineEnv :: SimpleM lore (Env lore)
askEngineEnv :: SimpleM lore (Env lore)
askEngineEnv = ((SimpleOps lore, Env lore) -> Env lore) -> SimpleM lore (Env lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (SimpleOps lore, Env lore) -> Env lore
forall a b. (a, b) -> b
snd
asksEngineEnv :: (Env lore -> a) -> SimpleM lore a
asksEngineEnv :: (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> a
f = Env lore -> a
f (Env lore -> a) -> SimpleM lore (Env lore) -> SimpleM lore a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (Env lore)
forall lore. SimpleM lore (Env lore)
askEngineEnv
askVtable :: SimpleM lore (ST.SymbolTable (Wise lore))
askVtable :: SimpleM lore (SymbolTable (Wise lore))
askVtable = (Env lore -> SymbolTable (Wise lore))
-> SimpleM lore (SymbolTable (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> SymbolTable (Wise lore)
forall lore. Env lore -> SymbolTable (Wise lore)
envVtable
localVtable ::
(ST.SymbolTable (Wise lore) -> ST.SymbolTable (Wise lore)) ->
SimpleM lore a ->
SimpleM lore a
localVtable :: (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable SymbolTable (Wise lore) -> SymbolTable (Wise lore)
f = ((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a -> SimpleM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a -> SimpleM lore a)
-> ((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \(SimpleOps lore
ops, Env lore
env) -> (SimpleOps lore
ops, Env lore
env {envVtable :: SymbolTable (Wise lore)
envVtable = SymbolTable (Wise lore) -> SymbolTable (Wise lore)
f (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ Env lore -> SymbolTable (Wise lore)
forall lore. Env lore -> SymbolTable (Wise lore)
envVtable Env lore
env})
collectCerts :: SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts :: SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts SimpleM lore a
m = do
a
x <- SimpleM lore a
m
(VNameSource
a, Bool
b, Certificates
cs) <- SimpleM lore (VNameSource, Bool, Certificates)
forall s (m :: * -> *). MonadState s m => m s
get
(VNameSource, Bool, Certificates) -> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (VNameSource
a, Bool
b, Certificates
forall a. Monoid a => a
mempty)
(a, Certificates) -> SimpleM lore (a, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Certificates
cs)
changed :: SimpleM lore ()
changed :: SimpleM lore ()
changed = ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
src, Bool
_, Certificates
cs) -> (VNameSource
src, Bool
True, Certificates
cs)
usedCerts :: Certificates -> SimpleM lore ()
usedCerts :: Certificates -> SimpleM lore ()
usedCerts Certificates
cs = ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
b, Certificates
c) -> (VNameSource
a, Bool
b, Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
c)
enterLoop :: SimpleM lore a -> SimpleM lore a
enterLoop :: SimpleM lore a -> SimpleM lore a
enterLoop = (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. SymbolTable lore -> SymbolTable lore
ST.deepen
bindFParams :: SimplifiableLore lore => [FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams :: [FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams [FParam (Wise lore)]
params =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ [FParam (Wise lore)]
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
[FParam lore] -> SymbolTable lore -> SymbolTable lore
ST.insertFParams [FParam (Wise lore)]
params
bindLParams :: SimplifiableLore lore => [LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams :: [LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams [LParam (Wise lore)]
params =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise lore)
vtable -> (Param (LParamInfo lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Param (LParamInfo lore)]
-> SymbolTable (Wise lore)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Param (LParamInfo lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
LParam lore -> SymbolTable lore -> SymbolTable lore
ST.insertLParam SymbolTable (Wise lore)
vtable [Param (LParamInfo lore)]
[LParam (Wise lore)]
params
bindArrayLParams ::
SimplifiableLore lore =>
[LParam (Wise lore)] ->
SimpleM lore a ->
SimpleM lore a
bindArrayLParams :: [LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindArrayLParams [LParam (Wise lore)]
params =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise lore)
vtable -> (SymbolTable (Wise lore)
-> Param (LParamInfo lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Param (LParamInfo lore)]
-> SymbolTable (Wise lore)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Param (LParamInfo lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> Param (LParamInfo lore)
-> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Param (LParamInfo lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
LParam lore -> SymbolTable lore -> SymbolTable lore
ST.insertLParam) SymbolTable (Wise lore)
vtable [Param (LParamInfo lore)]
[LParam (Wise lore)]
params
bindMerge ::
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp, SubExp)] ->
SimpleM lore a ->
SimpleM lore a
bindMerge :: [(FParam (Wise lore), SubExp, SubExp)]
-> SimpleM lore a -> SimpleM lore a
bindMerge = (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> ([(Param (FParamInfo lore), SubExp, SubExp)]
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> [(Param (FParamInfo lore), SubExp, SubExp)]
-> SimpleM lore a
-> SimpleM lore a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Param (FParamInfo lore), SubExp, SubExp)]
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
[(FParam lore, SubExp, SubExp)]
-> SymbolTable lore -> SymbolTable lore
ST.insertLoopMerge
bindLoopVar :: SimplifiableLore lore => VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar :: VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar VName
var IntType
it SubExp
bound =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ VName
-> IntType
-> SubExp
-> SymbolTable (Wise lore)
-> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.insertLoopVar VName
var IntType
it SubExp
bound
protectIfHoisted ::
SimplifiableLore lore =>
SubExp ->
Bool ->
SimpleM lore (a, Stms (Wise lore)) ->
SimpleM lore (a, Stms (Wise lore))
protectIfHoisted :: SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
side SimpleM lore (a, Stms (Wise lore))
m = do
(a
x, Stms (Wise lore)
stms) <- SimpleM lore (a, Stms (Wise lore))
m
SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
ops <- ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())))
-> ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall lore. SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS (SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore)))
-> Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise lore) -> Bool) -> Stms (Wise lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp (Wise lore) -> Bool)
-> (Stm (Wise lore) -> Exp (Wise lore)) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Wise lore)
stms
then do
SubExp
cond' <-
if Bool
side
then SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond
else [Char]
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"cond_neg" (Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
(Stm (Wise lore) -> Binder (Wise lore) ())
-> Stms (Wise lore) -> Binder (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BinderT (Wise lore) (State VNameSource))
-> (Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
Protect (BinderT (Wise lore) (State VNameSource))
ops Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool
forall lore. ASTLore lore => Exp lore -> Bool
unsafeOrCostly SubExp
cond') Stms (Wise lore)
stms
else Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
a -> Binder (Wise lore) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
where
unsafeOrCostly :: Exp lore -> Bool
unsafeOrCostly Exp lore
e = Bool -> Bool
not (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp Exp lore
e) Bool -> Bool -> Bool
|| Bool -> Bool
not (Exp lore -> Bool
forall lore. ASTLore lore => Exp lore -> Bool
cheapExp Exp lore
e)
protectLoopHoisted ::
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp)] ->
[(FParam (Wise lore), SubExp)] ->
LoopForm (Wise lore) ->
SimpleM lore (a, Stms (Wise lore)) ->
SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted :: [(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam (Wise lore), SubExp)]
ctx [(FParam (Wise lore), SubExp)]
val LoopForm (Wise lore)
form SimpleM lore (a, Stms (Wise lore))
m = do
(a
x, Stms (Wise lore)
stms) <- SimpleM lore (a, Stms (Wise lore))
m
SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
ops <- ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())))
-> ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall lore. SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS (SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore)))
-> Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise lore) -> Bool) -> Stms (Wise lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp (Wise lore) -> Bool)
-> (Stm (Wise lore) -> Exp (Wise lore)) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Wise lore)
stms
then do
SubExp
is_nonempty <- BinderT (Wise lore) (State VNameSource) SubExp
checkIfNonEmpty
(Stm (Wise lore) -> Binder (Wise lore) ())
-> Stms (Wise lore) -> Binder (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BinderT (Wise lore) (State VNameSource))
-> (Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
Protect (BinderT (Wise lore) (State VNameSource))
ops (Bool -> Bool
not (Bool -> Bool)
-> (Exp (Wise lore) -> Bool) -> Exp (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp) SubExp
is_nonempty) Stms (Wise lore)
stms
else Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
a -> Binder (Wise lore) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
where
checkIfNonEmpty :: BinderT (Wise lore) (State VNameSource) SubExp
checkIfNonEmpty =
case LoopForm (Wise lore)
form of
WhileLoop VName
cond
| Just (Param (FParamInfo lore)
_, SubExp
cond_init) <-
((Param (FParamInfo lore), SubExp) -> Bool)
-> [(Param (FParamInfo lore), SubExp)]
-> Maybe (Param (FParamInfo lore), SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
cond) (VName -> Bool)
-> ((Param (FParamInfo lore), SubExp) -> VName)
-> (Param (FParamInfo lore), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo lore) -> VName)
-> ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> (Param (FParamInfo lore), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst) ([(Param (FParamInfo lore), SubExp)]
-> Maybe (Param (FParamInfo lore), SubExp))
-> [(Param (FParamInfo lore), SubExp)]
-> Maybe (Param (FParamInfo lore), SubExp)
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
[(FParam (Wise lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
[(FParam (Wise lore), SubExp)]
val ->
SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond_init
| Bool
otherwise -> SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> BinderT (Wise lore) (State VNameSource) SubExp)
-> SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
ForLoop VName
_ IntType
it SubExp
bound [(LParam (Wise lore), VName)]
_ ->
[Char]
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"loop_nonempty" (Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) (IntType -> Integer -> SubExp
intConst IntType
it Integer
0) SubExp
bound
protectIf ::
MonadBinder m =>
Protect m ->
(Exp (Lore m) -> Bool) ->
SubExp ->
Stm (Lore m) ->
m ()
protectIf :: Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf
Protect m
_
Exp (Lore m) -> Bool
_
SubExp
taken
( Let
Pattern (Lore m)
pat
StmAux (ExpDec (Lore m))
aux
(If SubExp
cond BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfDec [BranchType (Lore m)]
if_ts IfSort
IfFallback))
) = do
SubExp
cond' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"protect_cond_conj" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
taken SubExp
cond
StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond' BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType (Lore m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux (BasicOp (Assert SubExp
cond ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc))) = do
SubExp
not_taken <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"loop_not_taken" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
taken
SubExp
cond' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"protect_assert_disj" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
not_taken SubExp
cond
StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp
Assert SubExp
cond' ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc
protectIf Protect m
protect Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux (Op Op (Lore m)
op))
| Just m ()
m <- Protect m
protect SubExp
taken Pattern (Lore m)
pat Op (Lore m)
op =
StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux m ()
m
protectIf Protect m
_ Exp (Lore m) -> Bool
f SubExp
taken (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux Exp (Lore m)
e)
| Exp (Lore m) -> Bool
f Exp (Lore m)
e =
case Exp (Lore m) -> Maybe (Exp (Lore m))
forall lore. Exp lore -> Maybe (Exp lore)
makeSafe Exp (Lore m)
e of
Just Exp (Lore m)
e' ->
StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat Exp (Lore m)
e'
Maybe (Exp (Lore m))
Nothing -> do
BodyT (Lore m)
taken_body <- [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
e]
BodyT (Lore m)
untaken_body <-
[m (Exp (Lore m))] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([m (Exp (Lore m))] -> m (BodyT (Lore m)))
-> [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$
(Type -> m (Exp (Lore m))) -> [Type] -> [m (Exp (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map
([VName] -> Type -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Type -> m (Exp (Lore m))
emptyOfType ([VName] -> Type -> m (Exp (Lore m)))
-> [VName] -> Type -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern (Lore m)
pat)
(Pattern (Lore m) -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternValueTypes Pattern (Lore m)
pat)
[BranchType (Lore m)]
if_ts <- Pattern (Lore m) -> m [BranchType (Lore m)]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern Pattern (Lore m)
pat
StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
taken BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType (Lore m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
_ Stm (Lore m)
stm =
Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore m)
stm
makeSafe :: Exp lore -> Maybe (Exp lore)
makeSafe :: Exp lore -> Maybe (Exp lore)
makeSafe (BasicOp (BinOp (SDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SQuot IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SMod IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SRem IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UMod IntType
t Safety
_) SubExp
x SubExp
y)) =
Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe Exp lore
_ =
Maybe (Exp lore)
forall a. Maybe a
Nothing
emptyOfType :: MonadBinder m => [VName] -> Type -> m (Exp (Lore m))
emptyOfType :: [VName] -> Type -> m (Exp (Lore m))
emptyOfType [VName]
_ Mem {} =
[Char] -> m (Exp (Lore m))
forall a. HasCallStack => [Char] -> a
error [Char]
"emptyOfType: Cannot hoist non-existential memory."
emptyOfType [VName]
_ (Prim PrimType
pt) =
Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
emptyOfType [VName]
ctx_names (Array PrimType
pt Shape
shape NoUniqueness
_) = do
let dims :: Result
dims = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
zeroIfContext (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch PrimType
pt Result
dims
where
zeroIfContext :: SubExp -> SubExp
zeroIfContext (Var VName
v) | VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ctx_names = IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
zeroIfContext SubExp
se = SubExp
se
notWorthHoisting :: ASTLore lore => BlockPred lore
notWorthHoisting :: BlockPred lore
notWorthHoisting SymbolTable lore
_ UsageTable
_ (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
Bool -> Bool
not (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp Exp lore
e) Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Int -> Bool) -> (Type -> Int) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) (Pattern lore -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes Pattern lore
pat)
hoistStms ::
SimplifiableLore lore =>
RuleBook (Wise lore) ->
BlockPred (Wise lore) ->
ST.SymbolTable (Wise lore) ->
UT.UsageTable ->
Stms (Wise lore) ->
SimpleM
lore
( Stms (Wise lore),
Stms (Wise lore)
)
hoistStms :: RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
uses Stms (Wise lore)
orig_stms = do
([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted) <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
simplifyStmsBottomUp SymbolTable (Wise lore)
vtable UsageTable
uses Stms (Wise lore)
orig_stms
Bool -> SimpleM lore () -> SimpleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Stm (Wise lore)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Stm (Wise lore)]
hoisted) SimpleM lore ()
forall lore. SimpleM lore ()
changed
(Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise lore)] -> Stms (Wise lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Wise lore)]
blocked, [Stm (Wise lore)] -> Stms (Wise lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Wise lore)]
hoisted)
where
simplifyStmsBottomUp :: SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
simplifyStmsBottomUp SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms = do
(UsageTable
_, [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms') <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms
let ([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted) = [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)]))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)])
forall a b. (a -> b) -> a -> b
$ [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall lore.
ASTLore lore =>
[Either (Stm lore) (Stm lore)] -> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms'
([Stm (Wise lore)], [Stm (Wise lore)])
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted)
simplifyStmsBottomUp' :: SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms = do
OpWithWisdom (Op lore) -> UsageTable
opUsage <- ((SimpleOps lore, Env lore)
-> OpWithWisdom (Op lore) -> UsageTable)
-> SimpleM lore (OpWithWisdom (Op lore) -> UsageTable)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
-> OpWithWisdom (Op lore) -> UsageTable)
-> SimpleM lore (OpWithWisdom (Op lore) -> UsageTable))
-> ((SimpleOps lore, Env lore)
-> OpWithWisdom (Op lore) -> UsageTable)
-> SimpleM lore (OpWithWisdom (Op lore) -> UsageTable)
forall a b. (a -> b) -> a -> b
$ SimpleOps lore -> OpWithWisdom (Op lore) -> UsageTable
forall lore. SimpleOps lore -> Op (Wise lore) -> UsageTable
opUsageS (SimpleOps lore -> OpWithWisdom (Op lore) -> UsageTable)
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> OpWithWisdom (Op lore)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
let usageInStm :: Stm (Wise lore) -> UsageTable
usageInStm Stm (Wise lore)
stm =
Stm (Wise lore) -> UsageTable
forall lore. (ASTLore lore, Aliased lore) => Stm lore -> UsageTable
UT.usageInStm Stm (Wise lore)
stm
UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> case Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp Stm (Wise lore)
stm of
Op Op (Wise lore)
op -> OpWithWisdom (Op lore) -> UsageTable
opUsage Op (Wise lore)
OpWithWisdom (Op lore)
op
Exp (Wise lore)
_ -> UsageTable
forall a. Monoid a => a
mempty
((UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))]))
-> (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((Stm (Wise lore) -> UsageTable)
-> (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
hoistable Stm (Wise lore) -> UsageTable
usageInStm) (UsageTable
uses', []) ([(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))]))
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall a b. (a -> b) -> a -> b
$ [(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a. [a] -> [a]
reverse ([(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))])
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a b. (a -> b) -> a -> b
$ [Stm (Wise lore)]
-> [SymbolTable (Wise lore)]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Stms (Wise lore) -> [Stm (Wise lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Wise lore)
stms) [SymbolTable (Wise lore)]
vtables
where
vtables :: [SymbolTable (Wise lore)]
vtables = (SymbolTable (Wise lore)
-> Stm (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Stm (Wise lore)]
-> [SymbolTable (Wise lore)]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl ((Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> Stm (Wise lore)
-> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm) SymbolTable (Wise lore)
vtable' ([Stm (Wise lore)] -> [SymbolTable (Wise lore)])
-> [Stm (Wise lore)] -> [SymbolTable (Wise lore)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise lore) -> [Stm (Wise lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Wise lore)
stms
hoistable :: (Stm (Wise lore) -> UsageTable)
-> (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
hoistable Stm (Wise lore) -> UsageTable
usageInStm (UsageTable
uses', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms) (Stm (Wise lore)
stm, SymbolTable (Wise lore)
vtable')
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
uses') ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Stm (Wise lore) -> [VName]
forall lore. Stm lore -> [VName]
provides Stm (Wise lore)
stm
=
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
| Bool
otherwise = do
Maybe (Stms (Wise lore))
res <-
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. a -> b -> a
const SymbolTable (Wise lore)
vtable') (SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore))))
-> SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall a b. (a -> b) -> a -> b
$
RuleBook (Wise lore)
-> (SymbolTable (Wise lore), UsageTable)
-> Stm (Wise lore)
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> (SymbolTable lore, UsageTable)
-> Stm lore
-> m (Maybe (Stms lore))
bottomUpSimplifyStm RuleBook (Wise lore)
rules (SymbolTable (Wise lore)
vtable', UsageTable
uses') Stm (Wise lore)
stm
case Maybe (Stms (Wise lore))
res of
Maybe (Stms (Wise lore))
Nothing
| BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm ->
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return
( (Stm (Wise lore) -> UsageTable)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stm (Wise lore)
-> UsageTable
forall lore.
(ASTLore lore, Aliased lore) =>
(Stm lore -> UsageTable)
-> SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage Stm (Wise lore) -> UsageTable
usageInStm SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm
UsageTable -> [VName] -> UsageTable
`UT.without` Stm (Wise lore) -> [VName]
forall lore. Stm lore -> [VName]
provides Stm (Wise lore)
stm,
Stm (Wise lore) -> Either (Stm (Wise lore)) (Stm (Wise lore))
forall a b. a -> Either a b
Left Stm (Wise lore)
stm Either (Stm (Wise lore)) (Stm (Wise lore))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms
)
| Bool
otherwise ->
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return
( (Stm (Wise lore) -> UsageTable)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stm (Wise lore)
-> UsageTable
forall lore.
(ASTLore lore, Aliased lore) =>
(Stm lore -> UsageTable)
-> SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage Stm (Wise lore) -> UsageTable
usageInStm SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm,
Stm (Wise lore) -> Either (Stm (Wise lore)) (Stm (Wise lore))
forall a b. b -> Either a b
Right Stm (Wise lore)
stm Either (Stm (Wise lore)) (Stm (Wise lore))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms
)
Just Stms (Wise lore)
optimstms -> do
SimpleM lore ()
forall lore. SimpleM lore ()
changed
(UsageTable
uses'', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms') <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
optimstms
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses'', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms' [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. [a] -> [a] -> [a]
++ [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
blockUnhoistedDeps ::
ASTLore lore =>
[Either (Stm lore) (Stm lore)] ->
[Either (Stm lore) (Stm lore)]
blockUnhoistedDeps :: [Either (Stm lore) (Stm lore)] -> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps = (Names, [Either (Stm lore) (Stm lore)])
-> [Either (Stm lore) (Stm lore)]
forall a b. (a, b) -> b
snd ((Names, [Either (Stm lore) (Stm lore)])
-> [Either (Stm lore) (Stm lore)])
-> ([Either (Stm lore) (Stm lore)]
-> (Names, [Either (Stm lore) (Stm lore)]))
-> [Either (Stm lore) (Stm lore)]
-> [Either (Stm lore) (Stm lore)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore)))
-> Names
-> [Either (Stm lore) (Stm lore)]
-> (Names, [Either (Stm lore) (Stm lore)])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
forall lore.
(FreeDec (ExpDec lore), FreeDec (BodyDec lore),
FreeIn (FParamInfo lore), FreeIn (LParamInfo lore),
FreeIn (LetDec lore), FreeIn (Op lore)) =>
Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
block Names
forall a. Monoid a => a
mempty
where
block :: Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
block Names
blocked (Left Stm lore
need) =
(Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm lore -> [VName]
forall lore. Stm lore -> [VName]
provides Stm lore
need), Stm lore -> Either (Stm lore) (Stm lore)
forall a b. a -> Either a b
Left Stm lore
need)
block Names
blocked (Right Stm lore
need)
| Names
blocked Names -> Names -> Bool
`namesIntersect` Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm lore
need =
(Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm lore -> [VName]
forall lore. Stm lore -> [VName]
provides Stm lore
need), Stm lore -> Either (Stm lore) (Stm lore)
forall a b. a -> Either a b
Left Stm lore
need)
| Bool
otherwise =
(Names
blocked, Stm lore -> Either (Stm lore) (Stm lore)
forall a b. b -> Either a b
Right Stm lore
need)
provides :: Stm lore -> [VName]
provides :: Stm lore -> [VName]
provides = PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern
expandUsage ::
(ASTLore lore, Aliased lore) =>
(Stm lore -> UT.UsageTable) ->
ST.SymbolTable lore ->
UT.UsageTable ->
Stm lore ->
UT.UsageTable
expandUsage :: (Stm lore -> UsageTable)
-> SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage Stm lore -> UsageTable
usageInStm SymbolTable lore
vtable UsageTable
utable stm :: Stm lore
stm@(Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
(VName -> Names) -> UsageTable -> UsageTable
UT.expand (VName -> SymbolTable lore -> Names
forall lore. VName -> SymbolTable lore -> Names
`ST.lookupAliases` SymbolTable lore
vtable) (Stm lore -> UsageTable
usageInStm Stm lore
stm UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
usageThroughAliases)
UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> ( if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
utable) (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat)
then Names -> UsageTable
UT.sizeUsages (Exp lore -> Names
forall a. FreeIn a => a -> Names
freeIn Exp lore
e)
else UsageTable
forall a. Monoid a => a
mempty
)
UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
utable
where
usageThroughAliases :: UsageTable
usageThroughAliases =
[UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$
((VName, Names) -> Maybe UsageTable)
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases ([(VName, Names)] -> [UsageTable])
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$
[VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Pattern lore -> [Names]
forall dec. AliasesOf dec => PatternT dec -> [Names]
patternAliases Pattern lore
pat)
usageThroughBindeeAliases :: (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases (VName
name, Names
aliases) = do
Usages
uses <- VName -> UsageTable -> Maybe Usages
UT.lookup VName
name UsageTable
utable
UsageTable -> Maybe UsageTable
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable -> Maybe UsageTable) -> UsageTable -> Maybe UsageTable
forall a b. (a -> b) -> a -> b
$ [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Usages -> UsageTable
`UT.usage` Usages
uses) ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases
type BlockPred lore = ST.SymbolTable lore -> UT.UsageTable -> Stm lore -> Bool
neverBlocks :: BlockPred lore
neverBlocks :: BlockPred lore
neverBlocks SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False
alwaysBlocks :: BlockPred lore
alwaysBlocks :: BlockPred lore
alwaysBlocks SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
True
isFalse :: Bool -> BlockPred lore
isFalse :: Bool -> BlockPred lore
isFalse Bool
b SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool -> Bool
not Bool
b
orIf :: BlockPred lore -> BlockPred lore -> BlockPred lore
orIf :: BlockPred lore -> BlockPred lore -> BlockPred lore
orIf BlockPred lore
p1 BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need = BlockPred lore
p1 SymbolTable lore
body UsageTable
vtable Stm lore
need Bool -> Bool -> Bool
|| BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need
andAlso :: BlockPred lore -> BlockPred lore -> BlockPred lore
andAlso :: BlockPred lore -> BlockPred lore -> BlockPred lore
andAlso BlockPred lore
p1 BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need = BlockPred lore
p1 SymbolTable lore
body UsageTable
vtable Stm lore
need Bool -> Bool -> Bool
&& BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need
isConsumed :: BlockPred lore
isConsumed :: BlockPred lore
isConsumed SymbolTable lore
_ UsageTable
utable = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
utable) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern
isOp :: BlockPred lore
isOp :: BlockPred lore
isOp SymbolTable lore
_ UsageTable
_ (Let Pattern lore
_ StmAux (ExpDec lore)
_ Op {}) = Bool
True
isOp SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False
constructBody ::
SimplifiableLore lore =>
Stms (Wise lore) ->
Result ->
SimpleM lore (Body (Wise lore))
constructBody :: Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
stms Result
res =
((Body (Wise lore), Stms (Wise lore)) -> Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body (Wise lore), Stms (Wise lore)) -> Body (Wise lore)
forall a b. (a, b) -> a
fst (SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$
Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore)))
-> Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource)))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
Result
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
res
type SimplifiedBody lore a = ((a, UT.UsageTable), Stms (Wise lore))
blockIf ::
SimplifiableLore lore =>
BlockPred (Wise lore) ->
SimpleM lore (SimplifiedBody lore a) ->
SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf :: BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf BlockPred (Wise lore)
block SimpleM lore (SimplifiedBody lore a)
m = do
((a
x, UsageTable
usages), Stms (Wise lore)
stms) <- SimpleM lore (SimplifiedBody lore a)
m
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
(Stms (Wise lore)
blocked, Stms (Wise lore)
hoisted) <- RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages Stms (Wise lore)
stms
((Stms (Wise lore), a), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Stms (Wise lore)
blocked, a
x), Stms (Wise lore)
hoisted)
hasFree :: ASTLore lore => Names -> BlockPred lore
hasFree :: Names -> BlockPred lore
hasFree Names
ks SymbolTable lore
_ UsageTable
_ Stm lore
need = Names
ks Names -> Names -> Bool
`namesIntersect` Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm lore
need
isNotSafe :: ASTLore lore => BlockPred lore
isNotSafe :: BlockPred lore
isNotSafe SymbolTable lore
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm lore -> Bool) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp
isInPlaceBound :: BlockPred m
isInPlaceBound :: BlockPred m
isInPlaceBound SymbolTable m
_ UsageTable
_ = ExpT m -> Bool
forall lore. ExpT lore -> Bool
isUpdate (ExpT m -> Bool) -> (Stm m -> ExpT m) -> Stm m -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm m -> ExpT m
forall lore. Stm lore -> Exp lore
stmExp
where
isUpdate :: ExpT lore -> Bool
isUpdate (BasicOp Update {}) = Bool
True
isUpdate ExpT lore
_ = Bool
False
isNotCheap :: ASTLore lore => BlockPred lore
isNotCheap :: BlockPred lore
isNotCheap SymbolTable lore
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm lore -> Bool) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Bool
forall lore. ASTLore lore => Stm lore -> Bool
cheapStm
cheapStm :: ASTLore lore => Stm lore -> Bool
cheapStm :: Stm lore -> Bool
cheapStm = Exp lore -> Bool
forall lore. ASTLore lore => Exp lore -> Bool
cheapExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp
cheapExp :: ASTLore lore => Exp lore -> Bool
cheapExp :: Exp lore -> Bool
cheapExp (BasicOp BinOp {}) = Bool
True
cheapExp (BasicOp SubExp {}) = Bool
True
cheapExp (BasicOp UnOp {}) = Bool
True
cheapExp (BasicOp CmpOp {}) = Bool
True
cheapExp (BasicOp ConvOp {}) = Bool
True
cheapExp (BasicOp Copy {}) = Bool
False
cheapExp (BasicOp Manifest {}) = Bool
False
cheapExp DoLoop {} = Bool
False
cheapExp (If SubExp
_ BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
(Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm lore -> Bool
forall lore. ASTLore lore => Stm lore -> Bool
cheapStm (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch)
Bool -> Bool -> Bool
&& (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm lore -> Bool
forall lore. ASTLore lore => Stm lore -> Bool
cheapStm (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fbranch)
cheapExp (Op Op lore
op) = Op lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp Op lore
op
cheapExp Exp lore
_ = Bool
True
stmIs :: (Stm lore -> Bool) -> BlockPred lore
stmIs :: (Stm lore -> Bool) -> BlockPred lore
stmIs Stm lore -> Bool
f SymbolTable lore
_ UsageTable
_ = Stm lore -> Bool
f
loopInvariantStm :: ASTLore lore => ST.SymbolTable lore -> Stm lore -> Bool
loopInvariantStm :: SymbolTable lore -> Stm lore -> Bool
loopInvariantStm SymbolTable lore
vtable =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable lore -> Names
forall lore. SymbolTable lore -> Names
ST.availableAtClosestLoop SymbolTable lore
vtable) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (Stm lore -> Names) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn
hoistCommon ::
SimplifiableLore lore =>
SubExp ->
IfSort ->
SimplifiedBody lore Result ->
SimplifiedBody lore Result ->
SimpleM
lore
( Body (Wise lore),
Body (Wise lore),
Stms (Wise lore)
)
hoistCommon :: SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
hoistCommon SubExp
cond IfSort
ifsort ((Result
res1, UsageTable
usages1), Stms (Wise lore)
stms1) ((Result
res2, UsageTable
usages2), Stms (Wise lore)
stms2) = do
Stm (Wise lore) -> Bool
is_alloc_fun <- (Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool)
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool))
-> (Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool)
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> Stm (Wise lore) -> Bool
forall lore. HoistBlockers lore -> Stm (Wise lore) -> Bool
isAllocation (HoistBlockers lore -> Stm (Wise lore) -> Bool)
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> Stm (Wise lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
BlockPred (Wise lore)
branch_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistBranch (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
let
cond_loop_invariant :: Bool
cond_loop_invariant =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable (Wise lore) -> Names
forall lore. SymbolTable lore -> Names
ST.availableAtClosestLoop SymbolTable (Wise lore)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond
desirableToHoist :: Stm (Wise lore) -> Bool
desirableToHoist Stm (Wise lore)
stm =
Stm (Wise lore) -> Bool
is_alloc_fun Stm (Wise lore)
stm
Bool -> Bool -> Bool
|| ( SymbolTable (Wise lore) -> Int
forall lore. SymbolTable lore -> Int
ST.loopDepth SymbolTable (Wise lore)
vtable Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
Bool -> Bool -> Bool
&& Bool
cond_loop_invariant
Bool -> Bool -> Bool
&& IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback
Bool -> Bool -> Bool
&& SymbolTable (Wise lore) -> Stm (Wise lore) -> Bool
forall lore. ASTLore lore => SymbolTable lore -> Stm lore -> Bool
loopInvariantStm SymbolTable (Wise lore)
vtable Stm (Wise lore)
stm
)
isNotHoistableBnd :: BlockPred (Wise lore)
isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
_ (Let Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ (BasicOp ArrayLit {})) = Bool
False
isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
_ (Let Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ (BasicOp SubExp {})) = Bool
False
isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
usages (Let Pattern (Wise lore)
pat StmAux (ExpDec (Wise lore))
_ ExpT (Wise lore)
_)
| (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
usages) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarWisdom, LetDec lore)
Pattern (Wise lore)
pat =
Bool
False
isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
_ Stm (Wise lore)
stm
| Stm (Wise lore) -> Bool
is_alloc_fun Stm (Wise lore)
stm = Bool
False
isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
_ Stm (Wise lore)
_ =
IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfEquiv
block :: BlockPred (Wise lore)
block =
BlockPred (Wise lore)
branch_blocker
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` ((BlockPred (Wise lore)
forall lore. ASTLore lore => BlockPred lore
isNotSafe BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. ASTLore lore => BlockPred lore
isNotCheap) BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`andAlso` (Stm (Wise lore) -> Bool) -> BlockPred (Wise lore)
forall lore. (Stm lore -> Bool) -> BlockPred lore
stmIs (Bool -> Bool
not (Bool -> Bool)
-> (Stm (Wise lore) -> Bool) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Bool
desirableToHoist))
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isInPlaceBound
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
isNotHoistableBnd
RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
(Stms (Wise lore)
body1_bnds', Stms (Wise lore)
safe1) <-
SubExp
-> Bool
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
True (SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages1 Stms (Wise lore)
stms1
(Stms (Wise lore)
body2_bnds', Stms (Wise lore)
safe2) <-
SubExp
-> Bool
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
False (SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages2 Stms (Wise lore)
stms2
let hoistable :: Stms (Wise lore)
hoistable = Stms (Wise lore)
safe1 Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
safe2
Body (Wise lore)
body1' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
body1_bnds' Result
res1
Body (Wise lore)
body2' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
body2_bnds' Result
res2
(Body (Wise lore), Body (Wise lore), Stms (Wise lore))
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise lore)
body1', Body (Wise lore)
body2', Stms (Wise lore)
hoistable)
simplifyBody ::
SimplifiableLore lore =>
[Diet] ->
Body lore ->
SimpleM lore (SimplifiedBody lore Result)
simplifyBody :: [Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds (Body BodyDec lore
_ Stms lore
bnds Result
res) =
Stms lore
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
bnds (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall a b. (a -> b) -> a -> b
$ do
(Result, UsageTable)
res' <- [Diet] -> Result -> SimpleM lore (Result, UsageTable)
forall lore.
SimplifiableLore lore =>
[Diet] -> Result -> SimpleM lore (Result, UsageTable)
simplifyResult [Diet]
ds Result
res
SimplifiedBody lore Result
-> SimpleM lore (SimplifiedBody lore Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result, UsageTable)
res', Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyResult ::
SimplifiableLore lore =>
[Diet] ->
Result ->
SimpleM lore (Result, UT.UsageTable)
simplifyResult :: [Diet] -> Result -> SimpleM lore (Result, UsageTable)
simplifyResult [Diet]
ds Result
res = do
let (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([Diet] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Diet]
ds) Result
res
(Result
ctx_res', Certificates
_ctx_res_cs) <- SimpleM lore Result -> SimpleM lore (Result, Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore Result -> SimpleM lore (Result, Certificates))
-> SimpleM lore Result -> SimpleM lore (Result, Certificates)
forall a b. (a -> b) -> a -> b
$ (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ctx_res
Result
val_res' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall lore. SubExp -> SimpleM lore SubExp
simplify' Result
val_res
let consumption :: UsageTable
consumption = [(Diet, SubExp)] -> UsageTable
consumeResult ([(Diet, SubExp)] -> UsageTable) -> [(Diet, SubExp)] -> UsageTable
forall a b. (a -> b) -> a -> b
$ [Diet] -> Result -> [(Diet, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Diet]
ds Result
val_res'
res' :: Result
res' = Result
ctx_res' Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res'
(Result, UsageTable) -> SimpleM lore (Result, UsageTable)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
res', Names -> UsageTable
UT.usages (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res') UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
consumption)
where
simplify' :: SubExp -> SimpleM lore SubExp
simplify' (Var VName
name) = do
Maybe (SubExp, Certificates)
bnd <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
name (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
bnd of
Just (Constant PrimValue
v, Certificates
cs)
| Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
Just (Var VName
id', Certificates
cs)
| Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
Maybe (SubExp, Certificates)
_ -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
simplify' (Constant PrimValue
v) =
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
isDoLoopResult :: Result -> UT.UsageTable
isDoLoopResult :: Result -> UsageTable
isDoLoopResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> (Result -> [UsageTable]) -> Result -> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> UsageTable) -> Result -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> UsageTable
checkForVar
where
checkForVar :: SubExp -> UsageTable
checkForVar (Var VName
ident) = VName -> UsageTable
UT.inResultUsage VName
ident
checkForVar SubExp
_ = UsageTable
forall a. Monoid a => a
mempty
simplifyStms ::
SimplifiableLore lore =>
Stms lore ->
SimpleM lore (a, Stms (Wise lore)) ->
SimpleM lore (a, Stms (Wise lore))
simplifyStms :: Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
stms SimpleM lore (a, Stms (Wise lore))
m =
case Stms lore -> Maybe (Stm lore, Stms lore)
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms lore
stms of
Maybe (Stm lore, Stms lore)
Nothing -> Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
forall a. Monoid a => a
mempty SimpleM lore (a, Stms (Wise lore))
m
Just (Let Pattern lore
pat (StmAux Certificates
stm_cs Attrs
attrs ExpDec lore
dec) Exp lore
e, Stms lore
stms') -> do
Certificates
stm_cs' <- Certificates -> SimpleM lore Certificates
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Certificates
stm_cs
((Exp (Wise lore)
e', Stms (Wise lore)
e_stms), Certificates
e_cs) <- SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore ((Exp (Wise lore), Stms (Wise lore)), Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM
lore ((Exp (Wise lore), Stms (Wise lore)), Certificates))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore ((Exp (Wise lore), Stms (Wise lore)), Certificates)
forall a b. (a -> b) -> a -> b
$ Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp Exp lore
e
(Pattern lore
pat', Certificates
pat_cs) <- SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates))
-> SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> SimpleM lore (Pattern lore)
forall lore dec.
(SimplifiableLore lore, Simplifiable dec) =>
PatternT dec -> SimpleM lore (PatternT dec)
simplifyPattern Pattern lore
pat
let cs :: Certificates
cs = Certificates
stm_cs' Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
e_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
pat_cs
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
e_stms (SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStm (Pattern lore
-> StmAux (ExpDec lore) -> Exp (Wise lore) -> Stm (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern lore
-> StmAux (ExpDec lore) -> Exp (Wise lore) -> Stm (Wise lore)
mkWiseLetStm Pattern lore
pat' (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
dec) Exp (Wise lore)
e') (SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
stms' SimpleM lore (a, Stms (Wise lore))
m
inspectStm ::
SimplifiableLore lore =>
Stm (Wise lore) ->
SimpleM lore (a, Stms (Wise lore)) ->
SimpleM lore (a, Stms (Wise lore))
inspectStm :: Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStm = Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms (Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> (Stm (Wise lore) -> Stms (Wise lore))
-> Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Stms (Wise lore)
forall lore. Stm lore -> Stms lore
oneStm
inspectStms ::
SimplifiableLore lore =>
Stms (Wise lore) ->
SimpleM lore (a, Stms (Wise lore)) ->
SimpleM lore (a, Stms (Wise lore))
inspectStms :: Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
stms SimpleM lore (a, Stms (Wise lore))
m =
case Stms (Wise lore) -> Maybe (Stm (Wise lore), Stms (Wise lore))
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms (Wise lore)
stms of
Maybe (Stm (Wise lore), Stms (Wise lore))
Nothing -> SimpleM lore (a, Stms (Wise lore))
m
Just (Stm (Wise lore)
stm, Stms (Wise lore)
stms') -> do
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
Maybe (Stms (Wise lore))
simplified <- RuleBook (Wise lore)
-> SymbolTable (Wise lore)
-> Stm (Wise lore)
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> SymbolTable lore -> Stm lore -> m (Maybe (Stms lore))
topDownSimplifyStm RuleBook (Wise lore)
rules SymbolTable (Wise lore)
vtable Stm (Wise lore)
stm
case Maybe (Stms (Wise lore))
simplified of
Just Stms (Wise lore)
newbnds -> SimpleM lore ()
forall lore. SimpleM lore ()
changed SimpleM lore ()
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms (Stms (Wise lore)
newbnds Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
stms') SimpleM lore (a, Stms (Wise lore))
m
Maybe (Stms (Wise lore))
Nothing -> do
(a
x, Stms (Wise lore)
stms'') <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm Stm (Wise lore)
stm) (SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
stms' SimpleM lore (a, Stms (Wise lore))
m
(a, Stms (Wise lore)) -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stm (Wise lore) -> Stms (Wise lore)
forall lore. Stm lore -> Stms lore
oneStm Stm (Wise lore)
stm Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
stms'')
simplifyOp :: Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp :: Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp Op lore
op = do
Op lore -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
f <- ((SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
lore
(Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
lore
(Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))))
-> ((SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
lore
(Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
forall lore. SimpleOps lore -> SimplifyOp lore (Op lore)
simplifyOpS (SimpleOps lore
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
Op lore -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
f Op lore
op
simplifyExp ::
SimplifiableLore lore =>
Exp lore ->
SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp :: Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch (IfDec [BranchType lore]
ts IfSort
ifsort)) = do
SubExp
cond' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
cond
[BranchType lore]
ts' <- (BranchType lore -> SimpleM lore (BranchType lore))
-> [BranchType lore] -> SimpleM lore [BranchType lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BranchType lore -> SimpleM lore (BranchType lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [BranchType lore]
ts
let ds :: [Diet]
ds = (BranchType lore -> Diet) -> [BranchType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> BranchType lore -> Diet
forall a b. a -> b -> a
const Diet
Consume) [BranchType lore]
ts
SimplifiedBody lore Result
tbranch' <- [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
tbranch
SimplifiedBody lore Result
fbranch' <- [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
fbranch
(Body (Wise lore)
tbranch'', Body (Wise lore)
fbranch'', Stms (Wise lore)
hoisted) <- SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
hoistCommon SubExp
cond' IfSort
ifsort SimplifiedBody lore Result
tbranch' SimplifiedBody lore Result
fbranch'
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> Body (Wise lore)
-> Body (Wise lore)
-> IfDec (BranchType (Wise lore))
-> Exp (Wise lore)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond' Body (Wise lore)
tbranch'' Body (Wise lore)
fbranch'' (IfDec (BranchType (Wise lore)) -> Exp (Wise lore))
-> IfDec (BranchType (Wise lore)) -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ts' IfSort
ifsort, Stms (Wise lore)
hoisted)
simplifyExp (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
loopbody) = do
let ([FParam lore]
ctxparams, Result
ctxinit) = [(FParam lore, SubExp)] -> ([FParam lore], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
ctx
([FParam lore]
valparams, Result
valinit) = [(FParam lore, SubExp)] -> ([FParam lore], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
val
[FParam lore]
ctxparams' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo lore -> SimpleM lore (FParamInfo lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo lore -> SimpleM lore (FParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
ctxparams
Result
ctxinit' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ctxinit
[FParam lore]
valparams' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo lore -> SimpleM lore (FParamInfo lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo lore -> SimpleM lore (FParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
valparams
Result
valinit' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
valinit
let ctx' :: [(FParam lore, SubExp)]
ctx' = [FParam lore] -> Result -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam lore]
ctxparams' Result
ctxinit'
val' :: [(FParam lore, SubExp)]
val' = [FParam lore] -> Result -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam lore]
valparams' Result
valinit'
diets :: [Diet]
diets = (FParam lore -> Diet) -> [FParam lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase Shape Uniqueness -> Diet)
-> (FParam lore -> TypeBase Shape Uniqueness)
-> FParam lore
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType) [FParam lore]
valparams'
(LoopForm (Wise lore)
form', Names
boundnames, SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
wrapbody) <- case LoopForm lore
form of
ForLoop VName
loopvar IntType
it SubExp
boundexp [(LParam lore, VName)]
loopvars -> do
SubExp
boundexp' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
boundexp
let ([LParam lore]
loop_params, [VName]
loop_arrs) = [(LParam lore, VName)] -> ([LParam lore], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(LParam lore, VName)]
loopvars
[LParam lore]
loop_params' <- (LParam lore -> SimpleM lore (LParam lore))
-> [LParam lore] -> SimpleM lore [LParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamInfo lore -> SimpleM lore (LParamInfo lore))
-> LParam lore -> SimpleM lore (LParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LParamInfo lore -> SimpleM lore (LParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [LParam lore]
loop_params
[VName]
loop_arrs' <- (VName -> SimpleM lore VName) -> [VName] -> SimpleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [VName]
loop_arrs
let form' :: LoopForm (Wise lore)
form' = VName
-> IntType
-> SubExp
-> [(LParam (Wise lore), VName)]
-> LoopForm (Wise lore)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
loopvar IntType
it SubExp
boundexp' ([LParam lore] -> [VName] -> [(LParam lore, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam lore]
loop_params' [VName]
loop_arrs')
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM
lore
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
forall (m :: * -> *) a. Monad m => a -> m a
return
( LoopForm (Wise lore)
form',
[VName] -> Names
namesFromList (VName
loopvar VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: (LParam lore -> VName) -> [LParam lore] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam lore -> VName
forall dec. Param dec -> VName
paramName [LParam lore]
loop_params') Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
fparamnames,
VName
-> IntType
-> SubExp
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar VName
loopvar IntType
it SubExp
boundexp'
(SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' LoopForm (Wise lore)
form'
(SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LParam (Wise lore)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindArrayLParams [LParam lore]
[LParam (Wise lore)]
loop_params'
)
WhileLoop VName
cond -> do
VName
cond' <- VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify VName
cond
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM
lore
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
forall (m :: * -> *) a. Monad m => a -> m a
return
( VName -> LoopForm (Wise lore)
forall lore. VName -> LoopForm lore
WhileLoop VName
cond',
Names
fparamnames,
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' (VName -> LoopForm (Wise lore)
forall lore. VName -> LoopForm lore
WhileLoop VName
cond')
)
BlockPred (Wise lore)
seq_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistSeq (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
((Stms (Wise lore)
loopstms, Result
loopres), Stms (Wise lore)
hoisted) <-
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
enterLoop (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
consumeMerge (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
[(FParam (Wise lore), SubExp, SubExp)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp, SubExp)]
-> SimpleM lore a -> SimpleM lore a
bindMerge (((FParam lore, SubExp) -> SubExp -> (FParam lore, SubExp, SubExp))
-> [(FParam lore, SubExp)]
-> Result
-> [(FParam lore, SubExp, SubExp)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (FParam lore, SubExp) -> SubExp -> (FParam lore, SubExp, SubExp)
forall a b c. (a, b) -> c -> (a, b, c)
withRes ([(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val') (BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
loopbody)) (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
wrapbody (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf
( Names -> BlockPred (Wise lore)
forall lore. ASTLore lore => Names -> BlockPred lore
hasFree Names
boundnames BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isConsumed
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
seq_blocker
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. ASTLore lore => BlockPred lore
notWorthHoisting
)
(SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
((Result
res, UsageTable
uses), Stms (Wise lore)
stms) <- [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
diets BodyT lore
loopbody
SimplifiedBody lore Result
-> SimpleM lore (SimplifiedBody lore Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result
res, UsageTable
uses UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> Result -> UsageTable
isDoLoopResult Result
res), Stms (Wise lore)
stms)
Body (Wise lore)
loopbody' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
loopstms Result
loopres
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> Body (Wise lore)
-> Exp (Wise lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' LoopForm (Wise lore)
form' Body (Wise lore)
loopbody', Stms (Wise lore)
hoisted)
where
fparamnames :: Names
fparamnames =
[VName] -> Names
namesFromList (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val)
consumeMerge :: SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
consumeMerge =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ (SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. VName -> SymbolTable lore -> SymbolTable lore
ST.consume)) ([VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
consumed_by_merge
consumed_by_merge :: Names
consumed_by_merge =
Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(FParam lore, SubExp)] -> Result)
-> [(FParam lore, SubExp)] -> Result
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> ((FParam lore, SubExp) -> TypeBase Shape Uniqueness)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType (FParam lore -> TypeBase Shape Uniqueness)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> TypeBase Shape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val
withRes :: (a, b) -> c -> (a, b, c)
withRes (a
p, b
x) c
y = (a
p, b
x, c
y)
simplifyExp (Op Op lore
op) = do
(OpWithWisdom (Op lore)
op', Stms (Wise lore)
stms) <- Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
forall lore.
Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp Op lore
op
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Op (Wise lore) -> Exp (Wise lore)
forall lore. Op lore -> ExpT lore
Op Op (Wise lore)
OpWithWisdom (Op lore)
op', Stms (Wise lore)
stms)
simplifyExp (BasicOp (BinOp BinOp
op SubExp
x SubExp
y))
| BinOp -> Bool
commutativeBinOp BinOp
op = do
SubExp
x' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
x
SubExp
y' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
y
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
min SubExp
x' SubExp
y') (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
max SubExp
x' SubExp
y'), Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyExp Exp lore
e = do
Exp (Wise lore)
e' <- Exp lore -> SimpleM lore (Exp (Wise lore))
forall lore.
SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase Exp lore
e
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Wise lore)
e', Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyExpBase ::
SimplifiableLore lore =>
Exp lore ->
SimpleM lore (Exp (Wise lore))
simplifyExpBase :: Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase = Mapper lore (Wise lore) (SimpleM lore)
-> Exp lore -> SimpleM lore (Exp (Wise lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper lore (Wise lore) (SimpleM lore)
hoist
where
hoist :: Mapper lore (Wise lore) (SimpleM lore)
hoist =
Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper
{
mapOnBody :: Scope (Wise lore) -> Body lore -> SimpleM lore (Body (Wise lore))
mapOnBody =
[Char]
-> Scope (Wise lore)
-> Body lore
-> SimpleM lore (Body (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled body in simplification engine.",
mapOnSubExp :: SubExp -> SimpleM lore SubExp
mapOnSubExp = SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify,
mapOnVName :: VName -> SimpleM lore VName
mapOnVName = VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify,
mapOnRetType :: RetType lore -> SimpleM lore (RetType (Wise lore))
mapOnRetType = RetType lore -> SimpleM lore (RetType (Wise lore))
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify,
mapOnBranchType :: BranchType lore -> SimpleM lore (BranchType (Wise lore))
mapOnBranchType = BranchType lore -> SimpleM lore (BranchType (Wise lore))
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify,
mapOnFParam :: FParam lore -> SimpleM lore (FParam (Wise lore))
mapOnFParam =
[Char] -> FParam lore -> SimpleM lore (FParam (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled FParam in simplification engine.",
mapOnLParam :: LParam lore -> SimpleM lore (LParam (Wise lore))
mapOnLParam =
[Char] -> LParam lore -> SimpleM lore (LParam (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled LParam in simplification engine.",
mapOnOp :: Op lore -> SimpleM lore (Op (Wise lore))
mapOnOp =
[Char] -> Op lore -> SimpleM lore (Op (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in simplification engine."
}
type SimplifiableLore lore =
( ASTLore lore,
Simplifiable (LetDec lore),
Simplifiable (FParamInfo lore),
Simplifiable (LParamInfo lore),
Simplifiable (RetType lore),
Simplifiable (BranchType lore),
CanBeWise (Op lore),
ST.IndexOp (OpWithWisdom (Op lore)),
BinderOps (Wise lore),
IsOp (Op lore)
)
class Simplifiable e where
simplify :: SimplifiableLore lore => e -> SimpleM lore e
instance (Simplifiable a, Simplifiable b) => Simplifiable (a, b) where
simplify :: (a, b) -> SimpleM lore (a, b)
simplify (a
x, b
y) = (,) (a -> b -> (a, b)) -> SimpleM lore a -> SimpleM lore (b -> (a, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x SimpleM lore (b -> (a, b)) -> SimpleM lore b -> SimpleM lore (a, b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM lore b
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify b
y
instance
(Simplifiable a, Simplifiable b, Simplifiable c) =>
Simplifiable (a, b, c)
where
simplify :: (a, b, c) -> SimpleM lore (a, b, c)
simplify (a
x, b
y, c
z) = (,,) (a -> b -> c -> (a, b, c))
-> SimpleM lore a -> SimpleM lore (b -> c -> (a, b, c))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x SimpleM lore (b -> c -> (a, b, c))
-> SimpleM lore b -> SimpleM lore (c -> (a, b, c))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM lore b
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify b
y SimpleM lore (c -> (a, b, c))
-> SimpleM lore c -> SimpleM lore (a, b, c)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> c -> SimpleM lore c
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify c
z
instance Simplifiable Int where
simplify :: Int -> SimpleM lore Int
simplify = Int -> SimpleM lore Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance Simplifiable a => Simplifiable (Maybe a) where
simplify :: Maybe a -> SimpleM lore (Maybe a)
simplify Maybe a
Nothing = Maybe a -> SimpleM lore (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
simplify (Just a
x) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> SimpleM lore a -> SimpleM lore (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x
instance Simplifiable a => Simplifiable [a] where
simplify :: [a] -> SimpleM lore [a]
simplify = (a -> SimpleM lore a) -> [a] -> SimpleM lore [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
instance Simplifiable SubExp where
simplify :: SubExp -> SimpleM lore SubExp
simplify (Var VName
name) = do
Maybe (SubExp, Certificates)
bnd <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
name (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
bnd of
Just (Constant PrimValue
v, Certificates
cs) -> do
SimpleM lore ()
forall lore. SimpleM lore ()
changed
Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
Just (Var VName
id', Certificates
cs) -> do
SimpleM lore ()
forall lore. SimpleM lore ()
changed
Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
Maybe (SubExp, Certificates)
_ -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
simplify (Constant PrimValue
v) =
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
simplifyPattern ::
(SimplifiableLore lore, Simplifiable dec) =>
PatternT dec ->
SimpleM lore (PatternT dec)
simplifyPattern :: PatternT dec -> SimpleM lore (PatternT dec)
simplifyPattern PatternT dec
pat =
[PatElemT dec] -> [PatElemT dec] -> PatternT dec
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern
([PatElemT dec] -> [PatElemT dec] -> PatternT dec)
-> SimpleM lore [PatElemT dec]
-> SimpleM lore ([PatElemT dec] -> PatternT dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT dec -> SimpleM lore (PatElemT dec))
-> [PatElemT dec] -> SimpleM lore [PatElemT dec]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT dec -> SimpleM lore (PatElemT dec)
forall lore dec.
(ASTLore lore, Simplifiable dec, Simplifiable (LetDec lore),
Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
Simplifiable (RetType lore), Simplifiable (BranchType lore),
CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
BinderOps (Wise lore)) =>
PatElemT dec -> SimpleM lore (PatElemT dec)
inspect (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat)
SimpleM lore ([PatElemT dec] -> PatternT dec)
-> SimpleM lore [PatElemT dec] -> SimpleM lore (PatternT dec)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (PatElemT dec -> SimpleM lore (PatElemT dec))
-> [PatElemT dec] -> SimpleM lore [PatElemT dec]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT dec -> SimpleM lore (PatElemT dec)
forall lore dec.
(ASTLore lore, Simplifiable dec, Simplifiable (LetDec lore),
Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
Simplifiable (RetType lore), Simplifiable (BranchType lore),
CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
BinderOps (Wise lore)) =>
PatElemT dec -> SimpleM lore (PatElemT dec)
inspect (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
pat)
where
inspect :: PatElemT dec -> SimpleM lore (PatElemT dec)
inspect (PatElem VName
name dec
lore) = VName -> dec -> PatElemT dec
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (dec -> PatElemT dec)
-> SimpleM lore dec -> SimpleM lore (PatElemT dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> dec -> SimpleM lore dec
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify dec
lore
instance Simplifiable () where
simplify :: () -> SimpleM lore ()
simplify = () -> SimpleM lore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance Simplifiable VName where
simplify :: VName -> SimpleM lore VName
simplify VName
v = do
Maybe (SubExp, Certificates)
se <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
v (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
se of
Just (Var VName
v', Certificates
cs) -> do
SimpleM lore ()
forall lore. SimpleM lore ()
changed
Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
VName -> SimpleM lore VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v'
Maybe (SubExp, Certificates)
_ -> VName -> SimpleM lore VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
instance Simplifiable d => Simplifiable (ShapeBase d) where
simplify :: ShapeBase d -> SimpleM lore (ShapeBase d)
simplify = ([d] -> ShapeBase d)
-> SimpleM lore [d] -> SimpleM lore (ShapeBase d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape (SimpleM lore [d] -> SimpleM lore (ShapeBase d))
-> (ShapeBase d -> SimpleM lore [d])
-> ShapeBase d
-> SimpleM lore (ShapeBase d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [d] -> SimpleM lore [d]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify ([d] -> SimpleM lore [d])
-> (ShapeBase d -> [d]) -> ShapeBase d -> SimpleM lore [d]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims
instance Simplifiable ExtSize where
simplify :: ExtSize -> SimpleM lore ExtSize
simplify (Free SubExp
se) = SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize) -> SimpleM lore SubExp -> SimpleM lore ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
se
simplify (Ext Int
x) = ExtSize -> SimpleM lore ExtSize
forall (m :: * -> *) a. Monad m => a -> m a
return (ExtSize -> SimpleM lore ExtSize)
-> ExtSize -> SimpleM lore ExtSize
forall a b. (a -> b) -> a -> b
$ Int -> ExtSize
forall a. Int -> Ext a
Ext Int
x
instance Simplifiable Space where
simplify :: Space -> SimpleM lore Space
simplify (ScalarSpace Result
ds PrimType
t) = Result -> PrimType -> Space
ScalarSpace (Result -> PrimType -> Space)
-> SimpleM lore Result -> SimpleM lore (PrimType -> Space)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> SimpleM lore Result
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ds SimpleM lore (PrimType -> Space)
-> SimpleM lore PrimType -> SimpleM lore Space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM lore PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
simplify Space
s = Space -> SimpleM lore Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
s
instance Simplifiable shape => Simplifiable (TypeBase shape u) where
simplify :: TypeBase shape u -> SimpleM lore (TypeBase shape u)
simplify (Array PrimType
et shape
shape u
u) = do
shape
shape' <- shape -> SimpleM lore shape
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify shape
shape
TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase shape u -> SimpleM lore (TypeBase shape u))
-> TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape' u
u
simplify (Mem Space
space) =
Space -> TypeBase shape u
forall shape u. Space -> TypeBase shape u
Mem (Space -> TypeBase shape u)
-> SimpleM lore Space -> SimpleM lore (TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space -> SimpleM lore Space
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Space
space
simplify (Prim PrimType
bt) =
TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase shape u -> SimpleM lore (TypeBase shape u))
-> TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
instance Simplifiable d => Simplifiable (DimIndex d) where
simplify :: DimIndex d -> SimpleM lore (DimIndex d)
simplify (DimFix d
i) = d -> DimIndex d
forall d. d -> DimIndex d
DimFix (d -> DimIndex d) -> SimpleM lore d -> SimpleM lore (DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
i
simplify (DimSlice d
i d
n d
s) = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice (d -> d -> d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (d -> d -> DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
i SimpleM lore (d -> d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (d -> DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
n SimpleM lore (d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
s
simplifyLambda ::
SimplifiableLore lore =>
Lambda lore ->
SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambda :: Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambda Lambda lore
lam = do
BlockPred (Wise lore)
par_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistPar (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist BlockPred (Wise lore)
par_blocker Lambda lore
lam
simplifyLambdaNoHoisting ::
SimplifiableLore lore =>
Lambda lore ->
SimpleM lore (Lambda (Wise lore))
simplifyLambdaNoHoisting :: Lambda lore -> SimpleM lore (Lambda (Wise lore))
simplifyLambdaNoHoisting Lambda lore
lam =
(Lambda (Wise lore), Stms (Wise lore)) -> Lambda (Wise lore)
forall a b. (a, b) -> a
fst ((Lambda (Wise lore), Stms (Wise lore)) -> Lambda (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist (Bool -> BlockPred (Wise lore)
forall lore. Bool -> BlockPred lore
isFalse Bool
False) Lambda lore
lam
simplifyLambdaMaybeHoist ::
SimplifiableLore lore =>
BlockPred (Wise lore) ->
Lambda lore ->
SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist :: BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist BlockPred (Wise lore)
blocked lam :: Lambda lore
lam@(Lambda [LParam lore]
params BodyT lore
body [Type]
rettype) = do
[LParam lore]
params' <- (LParam lore -> SimpleM lore (LParam lore))
-> [LParam lore] -> SimpleM lore [LParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamInfo lore -> SimpleM lore (LParamInfo lore))
-> LParam lore -> SimpleM lore (LParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LParamInfo lore -> SimpleM lore (LParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [LParam lore]
params
let paramnames :: Names
paramnames = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [VName]
forall lore. Lambda lore -> [VName]
boundByLambda Lambda lore
lam
((Stms (Wise lore)
lamstms, Result
lamres), Stms (Wise lore)
hoisted) <-
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
enterLoop (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
[LParam (Wise lore)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams [LParam lore]
[LParam (Wise lore)]
params' (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf (BlockPred (Wise lore)
blocked BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` Names -> BlockPred (Wise lore)
forall lore. ASTLore lore => Names -> BlockPred lore
hasFree Names
paramnames BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isConsumed) (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
[Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) [Type]
rettype) BodyT lore
body
Body (Wise lore)
body' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
lamstms Result
lamres
[Type]
rettype' <- [Type] -> SimpleM lore [Type]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [Type]
rettype
(Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([LParam (Wise lore)]
-> Body (Wise lore) -> [Type] -> Lambda (Wise lore)
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
[LParam (Wise lore)]
params' Body (Wise lore)
body' [Type]
rettype', Stms (Wise lore)
hoisted)
consumeResult :: [(Diet, SubExp)] -> UT.UsageTable
consumeResult :: [(Diet, SubExp)] -> UsageTable
consumeResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> ([(Diet, SubExp)] -> [UsageTable])
-> [(Diet, SubExp)]
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Diet, SubExp) -> UsageTable) -> [(Diet, SubExp)] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (Diet, SubExp) -> UsageTable
inspect
where
inspect :: (Diet, SubExp) -> UsageTable
inspect (Diet
Consume, SubExp
se) =
[UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map VName -> UsageTable
UT.consumedUsage ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExp -> Names
subExpAliases SubExp
se
inspect (Diet, SubExp)
_ = UsageTable
forall a. Monoid a => a
mempty
instance Simplifiable Certificates where
simplify :: Certificates -> SimpleM lore Certificates
simplify (Certificates [VName]
ocs) = [VName] -> Certificates
Certificates ([VName] -> Certificates)
-> ([[VName]] -> [VName]) -> [[VName]] -> Certificates
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [VName]
forall a. Eq a => [a] -> [a]
nub ([VName] -> [VName])
-> ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> Certificates)
-> SimpleM lore [[VName]] -> SimpleM lore Certificates
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> SimpleM lore [VName])
-> [VName] -> SimpleM lore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore [VName]
forall lore. VName -> SimpleM lore [VName]
check [VName]
ocs
where
check :: VName -> SimpleM lore [VName]
check VName
idd = do
Maybe (SubExp, Certificates)
vv <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
idd (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
vv of
Just (Constant PrimValue
Checked, Certificates [VName]
cs) -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
cs
Just (Var VName
idd', Certificates
_) -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd']
Maybe (SubExp, Certificates)
_ -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd]
insertAllStms ::
SimplifiableLore lore =>
SimpleM lore (SimplifiedBody lore Result) ->
SimpleM lore (Body (Wise lore))
insertAllStms :: SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
insertAllStms = (Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> (Stms (Wise lore), Result) -> SimpleM lore (Body (Wise lore))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody ((Stms (Wise lore), Result) -> SimpleM lore (Body (Wise lore)))
-> (((Stms (Wise lore), Result), Stms (Wise lore))
-> (Stms (Wise lore), Result))
-> ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Stms (Wise lore), Result), Stms (Wise lore))
-> (Stms (Wise lore), Result)
forall a b. (a, b) -> a
fst (((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore)))
-> (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf (Bool -> BlockPred (Wise lore)
forall lore. Bool -> BlockPred lore
isFalse Bool
False)
simplifyFun ::
SimplifiableLore lore =>
FunDef lore ->
SimpleM lore (FunDef (Wise lore))
simplifyFun :: FunDef lore -> SimpleM lore (FunDef (Wise lore))
simplifyFun (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) = do
[RetType lore]
rettype' <- [RetType lore] -> SimpleM lore [RetType lore]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [RetType lore]
rettype
[FParam lore]
params' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo lore -> SimpleM lore (FParamInfo lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo lore -> SimpleM lore (FParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
params
let ds :: [Diet]
ds = (RetType lore -> Diet) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase ExtShape Uniqueness -> Diet)
-> (RetType lore -> TypeBase ExtShape Uniqueness)
-> RetType lore
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetType lore -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf) [RetType lore]
rettype'
Body (Wise lore)
body' <- [FParam (Wise lore)]
-> SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall lore a.
SimplifiableLore lore =>
[FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams [FParam lore]
[FParam (Wise lore)]
params (SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
insertAllStms (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
body
FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore)))
-> FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore))
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType (Wise lore)]
-> [FParam (Wise lore)]
-> Body (Wise lore)
-> FunDef (Wise lore)
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType lore]
[RetType (Wise lore)]
rettype' [FParam lore]
[FParam (Wise lore)]
params' Body (Wise lore)
body'