{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.IR.SOACS.Simplify
( simplifySOACS,
simplifyLambda,
simplifyFun,
simplifyStms,
simplifyConsts,
simpleSOACS,
simplifySOAC,
soacRules,
HasSOAC (..),
simplifyKnownIterationSOAC,
removeReplicateMapping,
liftIdentityMapping,
SOACS,
)
where
import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Either
import Data.Foldable
import Data.List (partition, transpose, unzip6, zip6)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.Analysis.DataDependencies
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import qualified Futhark.IR as AST
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify as Simplify
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util
simpleSOACS :: Simplify.SimpleOps SOACS
simpleSOACS :: SimpleOps SOACS
simpleSOACS = SimplifyOp SOACS (Op SOACS) -> SimpleOps SOACS
forall lore.
(SimplifiableLore lore, Bindable lore) =>
SimplifyOp lore (Op lore) -> SimpleOps lore
Simplify.bindableSimpleOps SimplifyOp SOACS (Op SOACS)
forall lore. SimplifiableLore lore => SimplifyOp lore (SOAC lore)
simplifySOAC
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Prog SOACS
-> PassM (Prog SOACS)
forall lore.
SimplifiableLore lore =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Prog lore
-> PassM (Prog lore)
Simplify.simplifyProg SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers
simplifyFun ::
MonadFreshNames m =>
ST.SymbolTable (Wise SOACS) ->
FunDef SOACS ->
m (FunDef SOACS)
simplifyFun :: forall (m :: * -> *).
MonadFreshNames m =>
SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> SymbolTable (Wise SOACS)
-> FunDef SOACS
-> m (FunDef SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> SymbolTable (Wise lore)
-> FunDef lore
-> m (FunDef lore)
Simplify.simplifyFun SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers
simplifyLambda ::
(HasScope SOACS m, MonadFreshNames m) =>
Lambda ->
m Lambda
simplifyLambda :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda -> m Lambda
simplifyLambda =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Lambda
-> m Lambda
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Lambda lore
-> m (Lambda lore)
Simplify.simplifyLambda SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers
simplifyStms ::
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS ->
m (ST.SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms Stms SOACS
stms = do
Scope SOACS
scope <- m (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
Simplify.simplifyStms
SimpleOps SOACS
simpleSOACS
RuleBook (Wise SOACS)
soacRules
HoistBlockers SOACS
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers
Scope SOACS
scope
Stms SOACS
stms
simplifyConsts ::
MonadFreshNames m =>
Stms SOACS ->
m (ST.SymbolTable (Wise SOACS), Stms SOACS)
simplifyConsts :: forall (m :: * -> *).
MonadFreshNames m =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyConsts =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers Scope SOACS
forall a. Monoid a => a
mempty
simplifySOAC ::
Simplify.SimplifiableLore lore =>
Simplify.SimplifyOp lore (SOAC lore)
simplifySOAC :: forall lore. SimplifiableLore lore => SimplifyOp lore (SOAC lore)
simplifySOAC (Stream SubExp
outerdim [VName]
arr StreamForm lore
form [SubExp]
nes Lambda lore
lam) = do
SubExp
outerdim' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
outerdim
(StreamForm (Wise lore)
form', Stms (Wise lore)
form_hoisted) <- StreamForm lore
-> SimpleM lore (StreamForm (Wise lore), Stms (Wise lore))
forall {lore}.
(ASTLore lore, Simplifiable (LetDec lore),
Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
Simplifiable (RetType lore), Simplifiable (BranchType lore),
CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
BinderOps (Wise lore)) =>
StreamForm lore
-> SimpleM lore (StreamForm (Wise lore), Stms (Wise lore))
simplifyStreamForm StreamForm lore
form
[SubExp]
nes' <- (SubExp -> SimpleM lore SubExp)
-> [SubExp] -> SimpleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
[VName]
arr' <- (VName -> SimpleM lore VName) -> [VName] -> SimpleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
arr
(Lambda (Wise lore)
lam', Stms (Wise lore)
lam_hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
(SOAC (Wise lore), Stms (Wise lore))
-> SimpleM lore (SOAC (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
( SubExp
-> [VName]
-> StreamForm (Wise lore)
-> [SubExp]
-> Lambda (Wise lore)
-> SOAC (Wise lore)
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Stream SubExp
outerdim' [VName]
arr' StreamForm (Wise lore)
form' [SubExp]
nes' Lambda (Wise lore)
lam',
Stms (Wise lore)
form_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
lam_hoisted
)
where
simplifyStreamForm :: StreamForm lore
-> SimpleM lore (StreamForm (Wise lore), Stms (Wise lore))
simplifyStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda lore
lam0) = do
(Lambda (Wise lore)
lam0', Stms (Wise lore)
hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam0
(StreamForm (Wise lore), Stms (Wise lore))
-> SimpleM lore (StreamForm (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamOrd
-> Commutativity -> Lambda (Wise lore) -> StreamForm (Wise lore)
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o Commutativity
comm Lambda (Wise lore)
lam0', Stms (Wise lore)
hoisted)
simplifyStreamForm StreamForm lore
Sequential =
(StreamForm (Wise lore), Stms (Wise lore))
-> SimpleM lore (StreamForm (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamForm (Wise lore)
forall lore. StreamForm lore
Sequential, Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifySOAC (Scatter SubExp
len Lambda lore
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
SubExp
len' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
len
(Lambda (Wise lore)
lam', Stms (Wise lore)
hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
[VName]
ivs' <- (VName -> SimpleM lore VName) -> [VName] -> SimpleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
ivs
[(Shape, Int, VName)]
as' <- ((Shape, Int, VName) -> SimpleM lore (Shape, Int, VName))
-> [(Shape, Int, VName)] -> SimpleM lore [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape, Int, VName) -> SimpleM lore (Shape, Int, VName)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(Shape, Int, VName)]
as
(SOAC (Wise lore), Stms (Wise lore))
-> SimpleM lore (SOAC (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> Lambda (Wise lore)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise lore)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
len' Lambda (Wise lore)
lam' [VName]
ivs' [(Shape, Int, VName)]
as', Stms (Wise lore)
hoisted)
simplifySOAC (Hist SubExp
w [HistOp lore]
ops Lambda lore
bfun [VName]
imgs) = do
SubExp
w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
([HistOp (Wise lore)]
ops', [Stms (Wise lore)]
hoisted) <- ([(HistOp (Wise lore), Stms (Wise lore))]
-> ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise lore), Stms (Wise lore))]
-> ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
[HistOp lore]
-> (HistOp lore
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp lore]
ops ((HistOp lore
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))])
-> (HistOp lore
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dests_w SubExp
rf [VName]
dests [SubExp]
nes Lambda lore
op) -> do
SubExp
dests_w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
dests_w
SubExp
rf' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
rf
[VName]
dests' <- [VName] -> SimpleM lore [VName]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
dests
[SubExp]
nes' <- (SubExp -> SimpleM lore SubExp)
-> [SubExp] -> SimpleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
(Lambda (Wise lore)
op', Stms (Wise lore)
hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
op
(HistOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Lambda (Wise lore)
-> HistOp (Wise lore)
forall lore.
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda lore -> HistOp lore
HistOp SubExp
dests_w' SubExp
rf' [VName]
dests' [SubExp]
nes' Lambda (Wise lore)
op', Stms (Wise lore)
hoisted)
[VName]
imgs' <- (VName -> SimpleM lore VName) -> [VName] -> SimpleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
imgs
(Lambda (Wise lore)
bfun', Stms (Wise lore)
bfun_hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
bfun
(SOAC (Wise lore), Stms (Wise lore))
-> SimpleM lore (SOAC (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> [HistOp (Wise lore)]
-> Lambda (Wise lore)
-> [VName]
-> SOAC (Wise lore)
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Hist SubExp
w' [HistOp (Wise lore)]
ops' Lambda (Wise lore)
bfun' [VName]
imgs', [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
bfun_hoisted)
simplifySOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam)) = do
([Scan (Wise lore)]
scans', [Stms (Wise lore)]
scans_hoisted) <- ([(Scan (Wise lore), Stms (Wise lore))]
-> ([Scan (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(Scan (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([Scan (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Scan (Wise lore), Stms (Wise lore))]
-> ([Scan (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(Scan (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([Scan (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(Scan (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([Scan (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
[Scan lore]
-> (Scan lore -> SimpleM lore (Scan (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(Scan (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan lore]
scans ((Scan lore -> SimpleM lore (Scan (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(Scan (Wise lore), Stms (Wise lore))])
-> (Scan lore -> SimpleM lore (Scan (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(Scan (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda lore
lam [SubExp]
nes) -> do
(Lambda (Wise lore)
lam', Stms (Wise lore)
hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
[SubExp]
nes' <- [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
(Scan (Wise lore), Stms (Wise lore))
-> SimpleM lore (Scan (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Wise lore) -> [SubExp] -> Scan (Wise lore)
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan Lambda (Wise lore)
lam' [SubExp]
nes', Stms (Wise lore)
hoisted)
([Reduce (Wise lore)]
reds', [Stms (Wise lore)]
reds_hoisted) <- ([(Reduce (Wise lore), Stms (Wise lore))]
-> ([Reduce (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(Reduce (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([Reduce (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Reduce (Wise lore), Stms (Wise lore))]
-> ([Reduce (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(Reduce (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([Reduce (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(Reduce (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([Reduce (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
[Reduce lore]
-> (Reduce lore
-> SimpleM lore (Reduce (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(Reduce (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce lore]
reds ((Reduce lore
-> SimpleM lore (Reduce (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(Reduce (Wise lore), Stms (Wise lore))])
-> (Reduce lore
-> SimpleM lore (Reduce (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(Reduce (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda lore
lam [SubExp]
nes) -> do
(Lambda (Wise lore)
lam', Stms (Wise lore)
hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
[SubExp]
nes' <- [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
(Reduce (Wise lore), Stms (Wise lore))
-> SimpleM lore (Reduce (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Commutativity
-> Lambda (Wise lore) -> [SubExp] -> Reduce (Wise lore)
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm Lambda (Wise lore)
lam' [SubExp]
nes', Stms (Wise lore)
hoisted)
(Lambda (Wise lore)
map_lam', Stms (Wise lore)
map_lam_hoisted) <- Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
map_lam
(,)
(SOAC (Wise lore)
-> Stms (Wise lore) -> (SOAC (Wise lore), Stms (Wise lore)))
-> SimpleM lore (SOAC (Wise lore))
-> SimpleM
lore (Stms (Wise lore) -> (SOAC (Wise lore), Stms (Wise lore)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SubExp -> [VName] -> ScremaForm (Wise lore) -> SOAC (Wise lore)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma (SubExp -> [VName] -> ScremaForm (Wise lore) -> SOAC (Wise lore))
-> SimpleM lore SubExp
-> SimpleM
lore ([VName] -> ScremaForm (Wise lore) -> SOAC (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
SimpleM
lore ([VName] -> ScremaForm (Wise lore) -> SOAC (Wise lore))
-> SimpleM lore [VName]
-> SimpleM lore (ScremaForm (Wise lore) -> SOAC (Wise lore))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> SimpleM lore [VName]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
arrs
SimpleM lore (ScremaForm (Wise lore) -> SOAC (Wise lore))
-> SimpleM lore (ScremaForm (Wise lore))
-> SimpleM lore (SOAC (Wise lore))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScremaForm (Wise lore) -> SimpleM lore (ScremaForm (Wise lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Scan (Wise lore)]
-> [Reduce (Wise lore)]
-> Lambda (Wise lore)
-> ScremaForm (Wise lore)
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan (Wise lore)]
scans' [Reduce (Wise lore)]
reds' Lambda (Wise lore)
map_lam')
)
SimpleM
lore (Stms (Wise lore) -> (SOAC (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore))
-> SimpleM lore (SOAC (Wise lore), Stms (Wise lore))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise lore) -> SimpleM lore (Stms (Wise lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
scans_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
reds_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
map_lam_hoisted)
instance BinderOps (Wise SOACS)
fixLambdaParams ::
(MonadBinder m, Bindable (Lore m), BinderOps (Lore m)) =>
AST.Lambda (Lore m) ->
[Maybe SubExp] ->
m (AST.Lambda (Lore m))
fixLambdaParams :: forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), BinderOps (Lore m)) =>
Lambda (Lore m) -> [Maybe SubExp] -> m (Lambda (Lore m))
fixLambdaParams Lambda (Lore m)
lam [Maybe SubExp]
fixes = do
Body (Lore m)
body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope (Lore m)
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param Type] -> Scope (Lore m)) -> [Param Type] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
lam) (Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
(Param Type
-> Maybe SubExp -> BinderT (Lore m) (State VNameSource) ())
-> [Param Type]
-> [Maybe SubExp]
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param Type
-> Maybe SubExp -> BinderT (Lore m) (State VNameSource) ()
forall {m :: * -> *} {dec}.
MonadBinder m =>
Param dec -> Maybe SubExp -> m ()
maybeFix (Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
lam) [Maybe SubExp]
fixes'
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
lam
Lambda (Lore m) -> m (Lambda (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return
Lambda (Lore m)
lam
{ lambdaBody :: Body (Lore m)
lambdaBody = Body (Lore m)
body,
lambdaParams :: [LParam (Lore m)]
lambdaParams =
((LParam (Lore m), Maybe SubExp) -> LParam (Lore m))
-> [(LParam (Lore m), Maybe SubExp)] -> [LParam (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam (Lore m), Maybe SubExp) -> LParam (Lore m)
forall a b. (a, b) -> a
fst ([(LParam (Lore m), Maybe SubExp)] -> [LParam (Lore m)])
-> [(LParam (Lore m), Maybe SubExp)] -> [LParam (Lore m)]
forall a b. (a -> b) -> a -> b
$
((LParam (Lore m), Maybe SubExp) -> Bool)
-> [(LParam (Lore m), Maybe SubExp)]
-> [(LParam (Lore m), Maybe SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe SubExp -> Bool)
-> ((LParam (Lore m), Maybe SubExp) -> Maybe SubExp)
-> (LParam (Lore m), Maybe SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LParam (Lore m), Maybe SubExp) -> Maybe SubExp
forall a b. (a, b) -> b
snd) ([(LParam (Lore m), Maybe SubExp)]
-> [(LParam (Lore m), Maybe SubExp)])
-> [(LParam (Lore m), Maybe SubExp)]
-> [(LParam (Lore m), Maybe SubExp)]
forall a b. (a -> b) -> a -> b
$
[LParam (Lore m)]
-> [Maybe SubExp] -> [(LParam (Lore m), Maybe SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
lam) [Maybe SubExp]
fixes'
}
where
fixes' :: [Maybe SubExp]
fixes' = [Maybe SubExp]
fixes [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ Maybe SubExp -> [Maybe SubExp]
forall a. a -> [a]
repeat Maybe SubExp
forall a. Maybe a
Nothing
maybeFix :: Param dec -> Maybe SubExp -> m ()
maybeFix Param dec
p (Just SubExp
x) = [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
maybeFix Param dec
_ Maybe SubExp
Nothing = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
removeLambdaResults :: [Bool] -> AST.Lambda lore -> AST.Lambda lore
removeLambdaResults :: forall lore. [Bool] -> Lambda lore -> Lambda lore
removeLambdaResults [Bool]
keep Lambda lore
lam =
Lambda lore
lam
{ lambdaBody :: BodyT lore
lambdaBody = BodyT lore
lam_body',
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ret
}
where
keep' :: [a] -> [a]
keep' :: forall a. [a] -> [a]
keep' = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool]
keep [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ Bool -> [Bool]
forall a. a -> [a]
repeat Bool
True)
lam_body :: BodyT lore
lam_body = Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
lam_body' :: BodyT lore
lam_body' = BodyT lore
lam_body {bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
keep' ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
lam_body}
ret :: [Type]
ret = [Type] -> [Type]
forall a. [a] -> [a]
keep' ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
soacRules :: RuleBook (Wise SOACS)
soacRules :: RuleBook (Wise SOACS)
soacRules = RuleBook (Wise SOACS)
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
standardRules RuleBook (Wise SOACS)
-> RuleBook (Wise SOACS) -> RuleBook (Wise SOACS)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise SOACS)]
-> [BottomUpRule (Wise SOACS)] -> RuleBook (Wise SOACS)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule (Wise SOACS)]
topDownRules [BottomUpRule (Wise SOACS)]
bottomUpRules
class HasSOAC lore where
asSOAC :: Op lore -> Maybe (SOAC lore)
soacOp :: SOAC lore -> Op lore
instance HasSOAC (Wise SOACS) where
asSOAC :: Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
asSOAC = Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
forall a. a -> Maybe a
Just
soacOp :: SOAC (Wise SOACS) -> Op (Wise SOACS)
soacOp = SOAC (Wise SOACS) -> Op (Wise SOACS)
forall a. a -> a
id
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules =
[ RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCertificates,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
removeReplicateMapping,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeReplicateWrite,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeUnusedSOACInput,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
simplifyKnownIterationSOAC,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
liftIdentityMapping,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyMapIota,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput
]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules =
[ RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadWrite,
RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
forall lore.
(BinderOps lore, Aliased lore) =>
BottomUpRuleBasicOp lore
removeUnnecessaryCopy,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp
]
hoistCertificates :: TopDownRuleOp (Wise SOACS)
hoistCertificates :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCertificates SymbolTable (Wise SOACS)
vtable Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux Op (Wise SOACS)
soac
| (SOAC (Wise SOACS)
soac', Certificates
hoisted) <- State Certificates (SOAC (Wise SOACS))
-> Certificates -> (SOAC (Wise SOACS), Certificates)
forall s a. State s a -> s -> (a, s)
runState (SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certificates Identity)
-> SOAC (Wise SOACS) -> State Certificates (SOAC (Wise SOACS))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certificates Identity)
mapper Op (Wise SOACS)
SOAC (Wise SOACS)
soac) Certificates
forall a. Monoid a => a
mempty,
Certificates
hoisted Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
/= Certificates
forall a. Monoid a => a
mempty =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
hoisted (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op Op (Wise SOACS)
SOAC (Wise SOACS)
soac'
where
mapper :: SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certificates Identity)
mapper = SOACMapper Any Any (StateT Certificates Identity)
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS)
-> StateT Certificates Identity (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS)
-> StateT Certificates Identity (Lambda (Wise SOACS))
onLambda}
onLambda :: Lambda (Wise SOACS)
-> StateT Certificates Identity (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
Stms (Wise SOACS)
stms' <- (Stm (Wise SOACS)
-> StateT Certificates Identity (Stm (Wise SOACS)))
-> Stms (Wise SOACS)
-> StateT Certificates Identity (Stms (Wise SOACS))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm (Wise SOACS) -> StateT Certificates Identity (Stm (Wise SOACS))
onStm (Stms (Wise SOACS)
-> StateT Certificates Identity (Stms (Wise SOACS)))
-> Stms (Wise SOACS)
-> StateT Certificates Identity (Stms (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT (Wise SOACS) -> Stms (Wise SOACS))
-> BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
lam
Lambda (Wise SOACS)
-> StateT Certificates Identity (Lambda (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return
Lambda (Wise SOACS)
lam
{ lambdaBody :: BodyT (Wise SOACS)
lambdaBody =
Stms (Wise SOACS) -> [SubExp] -> BodyT (Wise SOACS)
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms (Wise SOACS)
stms' ([SubExp] -> BodyT (Wise SOACS)) -> [SubExp] -> BodyT (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise SOACS) -> [SubExp]) -> BodyT (Wise SOACS) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
lam
}
onStm :: Stm (Wise SOACS) -> StateT Certificates Identity (Stm (Wise SOACS))
onStm (Let Pattern (Wise SOACS)
se_pat StmAux (ExpDec (Wise SOACS))
se_aux (BasicOp (SubExp SubExp
se))) = do
let ([VName]
invariant, [VName]
variant) =
(VName -> Bool) -> [VName] -> ([VName], [VName])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (VName -> SymbolTable (Wise SOACS) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
Certificates -> [VName]
unCertificates (Certificates -> [VName]) -> Certificates -> [VName]
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
se_aux
se_aux' :: StmAux (ExpWisdom, ())
se_aux' = StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
se_aux {stmAuxCerts :: Certificates
stmAuxCerts = [VName] -> Certificates
Certificates [VName]
variant}
(Certificates -> Certificates) -> StateT Certificates Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([VName] -> Certificates
Certificates [VName]
invariant Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>)
Stm (Wise SOACS) -> StateT Certificates Identity (Stm (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Wise SOACS)
-> StateT Certificates Identity (Stm (Wise SOACS)))
-> Stm (Wise SOACS)
-> StateT Certificates Identity (Stm (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ Pattern (Wise SOACS)
-> StmAux (ExpDec (Wise SOACS))
-> Exp (Wise SOACS)
-> Stm (Wise SOACS)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern (Wise SOACS)
se_pat StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
se_aux' (Exp (Wise SOACS) -> Stm (Wise SOACS))
-> Exp (Wise SOACS) -> Stm (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
onStm Stm (Wise SOACS)
stm = Stm (Wise SOACS) -> StateT Certificates Identity (Stm (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return Stm (Wise SOACS)
stm
hoistCertificates SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
Rule (Wise SOACS)
forall lore. Rule lore
Skip
liftIdentityMapping ::
forall lore.
(Bindable lore, Simplify.SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
liftIdentityMapping :: forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
liftIdentityMapping TopDown (Wise lore)
_ Pattern (Wise lore)
pat StmAux (ExpDec (Wise lore))
aux Op (Wise lore)
op
| Just (Screma SubExp
w [VName]
arrs ScremaForm (Wise lore)
form :: SOAC (Wise lore)) <- Op (Wise lore) -> Maybe (SOAC (Wise lore))
forall lore. HasSOAC lore => Op lore -> Maybe (SOAC lore)
asSOAC Op (Wise lore)
op,
Just Lambda (Wise lore)
fun <- ScremaForm (Wise lore) -> Maybe (Lambda (Wise lore))
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm (Wise lore)
form = do
let inputMap :: Map VName VName
inputMap = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise lore) -> [LParam (Wise lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise lore)
fun) [VName]
arrs
free :: Names
free = BodyT (Wise lore) -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT (Wise lore) -> Names) -> BodyT (Wise lore) -> Names
forall a b. (a -> b) -> a -> b
$ Lambda (Wise lore) -> BodyT (Wise lore)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise lore)
fun
rettype :: [Type]
rettype = Lambda (Wise lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Wise lore)
fun
ses :: [SubExp]
ses = BodyT (Wise lore) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise lore) -> [SubExp]) -> BodyT (Wise lore) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise lore) -> BodyT (Wise lore)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise lore)
fun
freeOrConst :: SubExp -> Bool
freeOrConst (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
free
freeOrConst Constant {} = Bool
True
checkInvariance :: (PatElemT (VarWisdom, LetDec lore), SubExp, Type)
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
checkInvariance (PatElemT (VarWisdom, LetDec lore)
outId, Var VName
v, Type
_) ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
invariant, [(PatElemT (VarWisdom, LetDec lore), SubExp)]
mapresult, [Type]
rettype')
| Just VName
inp <- VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
inputMap =
( ([PatElemT (VarWisdom, LetDec lore)]
-> [PatElemT (VarWisdom, LetDec lore)]
-> PatternT (VarWisdom, LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, LetDec lore)
outId], BasicOp -> ExpT (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (VName -> BasicOp
Copy VName
inp)) (PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))
-> [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
-> [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
forall a. a -> [a] -> [a]
: [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
invariant,
[(PatElemT (VarWisdom, LetDec lore), SubExp)]
mapresult,
[Type]
rettype'
)
checkInvariance (PatElemT (VarWisdom, LetDec lore)
outId, SubExp
e, Type
t) ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
invariant, [(PatElemT (VarWisdom, LetDec lore), SubExp)]
mapresult, [Type]
rettype')
| SubExp -> Bool
freeOrConst SubExp
e =
( ([PatElemT (VarWisdom, LetDec lore)]
-> [PatElemT (VarWisdom, LetDec lore)]
-> PatternT (VarWisdom, LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, LetDec lore)
outId], BasicOp -> ExpT (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Wise lore)) -> BasicOp -> ExpT (Wise lore)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
e) (PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))
-> [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
-> [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
forall a. a -> [a] -> [a]
: [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
invariant,
[(PatElemT (VarWisdom, LetDec lore), SubExp)]
mapresult,
[Type]
rettype'
)
| Bool
otherwise =
( [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
invariant,
(PatElemT (VarWisdom, LetDec lore)
outId, SubExp
e) (PatElemT (VarWisdom, LetDec lore), SubExp)
-> [(PatElemT (VarWisdom, LetDec lore), SubExp)]
-> [(PatElemT (VarWisdom, LetDec lore), SubExp)]
forall a. a -> [a] -> [a]
: [(PatElemT (VarWisdom, LetDec lore), SubExp)]
mapresult,
Type
t Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
rettype'
)
case ((PatElemT (VarWisdom, LetDec lore), SubExp, Type)
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type]))
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
-> [(PatElemT (VarWisdom, LetDec lore), SubExp, Type)]
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (PatElemT (VarWisdom, LetDec lore), SubExp, Type)
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
checkInvariance ([], [], []) ([(PatElemT (VarWisdom, LetDec lore), SubExp, Type)]
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type]))
-> [(PatElemT (VarWisdom, LetDec lore), SubExp, Type)]
-> ([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))],
[(PatElemT (VarWisdom, LetDec lore), SubExp)], [Type])
forall a b. (a -> b) -> a -> b
$
[PatElemT (VarWisdom, LetDec lore)]
-> [SubExp]
-> [Type]
-> [(PatElemT (VarWisdom, LetDec lore), SubExp, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (PatternT (VarWisdom, LetDec lore)
-> [PatElemT (VarWisdom, LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, LetDec lore)
Pattern (Wise lore)
pat) [SubExp]
ses [Type]
rettype of
([], [(PatElemT (VarWisdom, LetDec lore), SubExp)]
_, [Type]
_) -> Rule (Wise lore)
forall lore. Rule lore
Skip
([(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
invariant, [(PatElemT (VarWisdom, LetDec lore), SubExp)]
mapresult, [Type]
rettype') -> RuleM (Wise lore) () -> Rule (Wise lore)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise lore) () -> Rule (Wise lore))
-> RuleM (Wise lore) () -> Rule (Wise lore)
forall a b. (a -> b) -> a -> b
$ do
let ([PatElemT (VarWisdom, LetDec lore)]
pat', [SubExp]
ses') = [(PatElemT (VarWisdom, LetDec lore), SubExp)]
-> ([PatElemT (VarWisdom, LetDec lore)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (VarWisdom, LetDec lore), SubExp)]
mapresult
fun' :: Lambda (Wise lore)
fun' =
Lambda (Wise lore)
fun
{ lambdaBody :: BodyT (Wise lore)
lambdaBody = (Lambda (Wise lore) -> BodyT (Wise lore)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise lore)
fun) {bodyResult :: [SubExp]
bodyResult = [SubExp]
ses'},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
rettype'
}
((PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))
-> RuleM (Wise lore) ())
-> [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
-> RuleM (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((PatternT (VarWisdom, LetDec lore)
-> ExpT (Wise lore) -> RuleM (Wise lore) ())
-> (PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))
-> RuleM (Wise lore) ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry PatternT (VarWisdom, LetDec lore)
-> ExpT (Wise lore) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind) [(PatternT (VarWisdom, LetDec lore), ExpT (Wise lore))]
invariant
StmAux (ExpWisdom, ExpDec lore)
-> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ExpDec lore)
StmAux (ExpDec (Wise lore))
aux (RuleM (Wise lore) () -> RuleM (Wise lore) ())
-> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames ((PatElemT (VarWisdom, LetDec lore) -> VName)
-> [PatElemT (VarWisdom, LetDec lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (VarWisdom, LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT (VarWisdom, LetDec lore)]
pat') (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise lore) -> ExpT (Wise lore)
forall lore. Op lore -> ExpT lore
Op (Op (Wise lore) -> ExpT (Wise lore))
-> Op (Wise lore) -> ExpT (Wise lore)
forall a b. (a -> b) -> a -> b
$ SOAC (Wise lore) -> Op (Wise lore)
forall lore. HasSOAC lore => SOAC lore -> Op lore
soacOp (SOAC (Wise lore) -> Op (Wise lore))
-> SOAC (Wise lore) -> Op (Wise lore)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise lore) -> SOAC (Wise lore)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (Lambda (Wise lore) -> ScremaForm (Wise lore)
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda (Wise lore)
fun')
liftIdentityMapping TopDown (Wise lore)
_ Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ Op (Wise lore)
_ = Rule (Wise lore)
forall lore. Rule lore
Skip
liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming BottomUp (Wise SOACS)
_ (Pattern [] [PatElemT (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Stream SubExp
w [VName]
arrs StreamForm (Wise SOACS)
form [SubExp]
nes Lambda (Wise SOACS)
lam)
| ([(Type, PatElemT (VarWisdom, Type), SubExp)]
variant_map, [(PatElemT (VarWisdom, Type), VName)]
invariant_map) <-
[Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)]
-> ([(Type, PatElemT (VarWisdom, Type), SubExp)],
[(PatElemT (VarWisdom, Type), VName)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)]
-> ([(Type, PatElemT (VarWisdom, Type), SubExp)],
[(PatElemT (VarWisdom, Type), VName)]))
-> [Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)]
-> ([(Type, PatElemT (VarWisdom, Type), SubExp)],
[(PatElemT (VarWisdom, Type), VName)])
forall a b. (a -> b) -> a -> b
$ ((Type, PatElemT (VarWisdom, Type), SubExp)
-> Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName))
-> [(Type, PatElemT (VarWisdom, Type), SubExp)]
-> [Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type, PatElemT (VarWisdom, Type), SubExp)
-> Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)
isInvariantRes ([(Type, PatElemT (VarWisdom, Type), SubExp)]
-> [Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)])
-> [(Type, PatElemT (VarWisdom, Type), SubExp)]
-> [Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)]
forall a b. (a -> b) -> a -> b
$ [Type]
-> [PatElemT (VarWisdom, Type)]
-> [SubExp]
-> [(Type, PatElemT (VarWisdom, Type), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
map_ts [PatElemT (VarWisdom, Type)]
map_pes [SubExp]
map_res,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(PatElemT (VarWisdom, Type), VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElemT (VarWisdom, Type), VName)]
invariant_map = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
[(PatElemT (VarWisdom, Type), VName)]
-> ((PatElemT (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElemT (VarWisdom, Type), VName)]
invariant_map (((PatElemT (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> ((PatElemT (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (VarWisdom, Type)
pe, VName
arr) ->
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, Type)
pe]) (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
let ([Type]
variant_map_ts, [PatElemT (VarWisdom, Type)]
variant_map_pes, [SubExp]
variant_map_res) = [(Type, PatElemT (VarWisdom, Type), SubExp)]
-> ([Type], [PatElemT (VarWisdom, Type)], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, PatElemT (VarWisdom, Type), SubExp)]
variant_map
lam' :: Lambda (Wise SOACS)
lam' =
Lambda (Wise SOACS)
lam
{ lambdaBody :: BodyT (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: [SubExp]
bodyResult = [SubExp]
fold_res [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
variant_map_res},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
fold_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
variant_map_ts
}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type))
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall a b. (a -> b) -> a -> b
$ [PatElemT (VarWisdom, Type)]
fold_pes [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> [PatElemT (VarWisdom, Type)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (VarWisdom, Type)]
variant_map_pes) (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName]
-> StreamForm (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> SOAC (Wise SOACS)
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Stream SubExp
w [VName]
arrs StreamForm (Wise SOACS)
form [SubExp]
nes Lambda (Wise SOACS)
lam'
where
num_folds :: Int
num_folds = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes
([PatElemT (VarWisdom, Type)]
fold_pes, [PatElemT (VarWisdom, Type)]
map_pes) = Int
-> [PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [PatElemT (VarWisdom, Type)]
[PatElemT (LetDec (Wise SOACS))]
pes
([Type]
fold_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
lam_res :: [SubExp]
lam_res = BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise SOACS) -> [SubExp]) -> BodyT (Wise SOACS) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
lam
([SubExp]
fold_res, [SubExp]
map_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [SubExp]
lam_res
params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
num_folds) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
lam) [VName]
arrs
isInvariantRes :: (Type, PatElemT (VarWisdom, Type), SubExp)
-> Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)
isInvariantRes (Type
_, PatElemT (VarWisdom, Type)
pe, Var VName
v)
| Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
(PatElemT (VarWisdom, Type), VName)
-> Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)
forall a b. b -> Either a b
Right (PatElemT (VarWisdom, Type)
pe, VName
arr)
isInvariantRes (Type, PatElemT (VarWisdom, Type), SubExp)
x =
(Type, PatElemT (VarWisdom, Type), SubExp)
-> Either
(Type, PatElemT (VarWisdom, Type), SubExp)
(PatElemT (VarWisdom, Type), VName)
forall a b. a -> Either a b
Left (Type, PatElemT (VarWisdom, Type), SubExp)
x
liftIdentityStreaming BottomUp (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
removeReplicateMapping ::
(Bindable lore, Simplify.SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
removeReplicateMapping :: forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
removeReplicateMapping TopDown (Wise lore)
vtable Pattern (Wise lore)
pat StmAux (ExpDec (Wise lore))
aux Op (Wise lore)
op
| Just (Screma SubExp
w [VName]
arrs ScremaForm (Wise lore)
form) <- Op (Wise lore) -> Maybe (SOAC (Wise lore))
forall lore. HasSOAC lore => Op lore -> Maybe (SOAC lore)
asSOAC Op (Wise lore)
op,
Just Lambda (Wise lore)
fun <- ScremaForm (Wise lore) -> Maybe (Lambda (Wise lore))
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm (Wise lore)
form,
Just ([([VName], Certificates, Exp (Wise lore))]
bnds, Lambda (Wise lore)
fun', [VName]
arrs') <- TopDown (Wise lore)
-> Lambda (Wise lore)
-> [VName]
-> Maybe
([([VName], Certificates, Exp (Wise lore))], Lambda (Wise lore),
[VName])
forall lore.
Aliased lore =>
SymbolTable lore
-> Lambda lore
-> [VName]
-> Maybe
([([VName], Certificates, Exp lore)], Lambda lore, [VName])
removeReplicateInput TopDown (Wise lore)
vtable Lambda (Wise lore)
fun [VName]
arrs = RuleM (Wise lore) () -> Rule (Wise lore)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise lore) () -> Rule (Wise lore))
-> RuleM (Wise lore) () -> Rule (Wise lore)
forall a b. (a -> b) -> a -> b
$ do
[([VName], Certificates, Exp (Wise lore))]
-> (([VName], Certificates, Exp (Wise lore))
-> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certificates, Exp (Wise lore))]
bnds ((([VName], Certificates, Exp (Wise lore)) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ())
-> (([VName], Certificates, Exp (Wise lore))
-> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certificates
cs, Exp (Wise lore)
e) -> Certificates -> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM (Wise lore) () -> RuleM (Wise lore) ())
-> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
vs Exp (Lore (RuleM (Wise lore)))
Exp (Wise lore)
e
StmAux (ExpWisdom, ExpDec lore)
-> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ExpDec lore)
StmAux (ExpDec (Wise lore))
aux (RuleM (Wise lore) () -> RuleM (Wise lore) ())
-> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise lore)))
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise lore)))
Pattern (Wise lore)
pat (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise lore) -> Exp (Wise lore)
forall lore. Op lore -> ExpT lore
Op (Op (Wise lore) -> Exp (Wise lore))
-> Op (Wise lore) -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ SOAC (Wise lore) -> Op (Wise lore)
forall lore. HasSOAC lore => SOAC lore -> Op lore
soacOp (SOAC (Wise lore) -> Op (Wise lore))
-> SOAC (Wise lore) -> Op (Wise lore)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise lore) -> SOAC (Wise lore)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs' (ScremaForm (Wise lore) -> SOAC (Wise lore))
-> ScremaForm (Wise lore) -> SOAC (Wise lore)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise lore) -> ScremaForm (Wise lore)
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda (Wise lore)
fun'
removeReplicateMapping TopDown (Wise lore)
_ Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ Op (Wise lore)
_ = Rule (Wise lore)
forall lore. Rule lore
Skip
removeReplicateWrite :: TopDownRuleOp (Wise SOACS)
removeReplicateWrite :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeReplicateWrite SymbolTable (Wise SOACS)
vtable Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
len Lambda (Wise SOACS)
lam [VName]
ivs [(Shape, Int, VName)]
as)
| Just ([([VName], Certificates, Exp (Wise SOACS))]
bnds, Lambda (Wise SOACS)
lam', [VName]
ivs') <- SymbolTable (Wise SOACS)
-> Lambda (Wise SOACS)
-> [VName]
-> Maybe
([([VName], Certificates, Exp (Wise SOACS))], Lambda (Wise SOACS),
[VName])
forall lore.
Aliased lore =>
SymbolTable lore
-> Lambda lore
-> [VName]
-> Maybe
([([VName], Certificates, Exp lore)], Lambda lore, [VName])
removeReplicateInput SymbolTable (Wise SOACS)
vtable Lambda (Wise SOACS)
lam [VName]
ivs = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
[([VName], Certificates, Exp (Wise SOACS))]
-> (([VName], Certificates, Exp (Wise SOACS))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certificates, Exp (Wise SOACS))]
bnds ((([VName], Certificates, Exp (Wise SOACS))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> (([VName], Certificates, Exp (Wise SOACS))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certificates
cs, Exp (Wise SOACS)
e) -> Certificates -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
vs Exp (Lore (RuleM (Wise SOACS)))
Exp (Wise SOACS)
e
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda (Wise SOACS)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
len Lambda (Wise SOACS)
lam' [VName]
ivs' [(Shape, Int, VName)]
as
removeReplicateWrite SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
removeReplicateInput ::
Aliased lore =>
ST.SymbolTable lore ->
AST.Lambda lore ->
[VName] ->
Maybe
( [([VName], Certificates, AST.Exp lore)],
AST.Lambda lore,
[VName]
)
removeReplicateInput :: forall lore.
Aliased lore =>
SymbolTable lore
-> Lambda lore
-> [VName]
-> Maybe
([([VName], Certificates, Exp lore)], Lambda lore, [VName])
removeReplicateInput SymbolTable lore
vtable Lambda lore
fun [VName]
arrs
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [([VName], Certificates, ExpT lore)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([VName], Certificates, ExpT lore)]
parameterBnds = do
let ([Param (LParamInfo lore)]
arr_params', [VName]
arrs') = [(Param (LParamInfo lore), VName)]
-> ([Param (LParamInfo lore)], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo lore), VName)]
params_and_arrs
fun' :: Lambda lore
fun' = Lambda lore
fun {lambdaParams :: [Param (LParamInfo lore)]
lambdaParams = [Param (LParamInfo lore)]
acc_params [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Semigroup a => a -> a -> a
<> [Param (LParamInfo lore)]
arr_params'}
([([VName], Certificates, ExpT lore)], Lambda lore, [VName])
-> Maybe
([([VName], Certificates, ExpT lore)], Lambda lore, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return ([([VName], Certificates, ExpT lore)]
parameterBnds, Lambda lore
fun', [VName]
arrs')
| Bool
otherwise = Maybe ([([VName], Certificates, ExpT lore)], Lambda lore, [VName])
forall a. Maybe a
Nothing
where
params :: [Param (LParamInfo lore)]
params = Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
fun
([Param (LParamInfo lore)]
acc_params, [Param (LParamInfo lore)]
arr_params) =
Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (LParamInfo lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (LParamInfo lore)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) [Param (LParamInfo lore)]
params
([(Param (LParamInfo lore), VName)]
params_and_arrs, [([VName], Certificates, ExpT lore)]
parameterBnds) =
[Either
(Param (LParamInfo lore), VName)
([VName], Certificates, ExpT lore)]
-> ([(Param (LParamInfo lore), VName)],
[([VName], Certificates, ExpT lore)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either
(Param (LParamInfo lore), VName)
([VName], Certificates, ExpT lore)]
-> ([(Param (LParamInfo lore), VName)],
[([VName], Certificates, ExpT lore)]))
-> [Either
(Param (LParamInfo lore), VName)
([VName], Certificates, ExpT lore)]
-> ([(Param (LParamInfo lore), VName)],
[([VName], Certificates, ExpT lore)])
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo lore)
-> VName
-> Either
(Param (LParamInfo lore), VName)
([VName], Certificates, ExpT lore))
-> [Param (LParamInfo lore)]
-> [VName]
-> [Either
(Param (LParamInfo lore), VName)
([VName], Certificates, ExpT lore)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo lore)
-> VName
-> Either
(Param (LParamInfo lore), VName) ([VName], Certificates, ExpT lore)
isReplicateAndNotConsumed [Param (LParamInfo lore)]
arr_params [VName]
arrs
isReplicateAndNotConsumed :: Param (LParamInfo lore)
-> VName
-> Either
(Param (LParamInfo lore), VName) ([VName], Certificates, ExpT lore)
isReplicateAndNotConsumed Param (LParamInfo lore)
p VName
v
| Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
ds)) SubExp
e), Certificates
v_cs) <-
VName -> SymbolTable lore -> Maybe (ExpT lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable lore
vtable,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
p VName -> Names -> Bool
`nameIn` Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
fun =
([VName], Certificates, ExpT lore)
-> Either
(Param (LParamInfo lore), VName) ([VName], Certificates, ExpT lore)
forall a b. b -> Either a b
Right
( [Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
p],
Certificates
v_cs,
case [SubExp]
ds of
[] -> BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
[SubExp]
_ -> BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
e
)
| Bool
otherwise =
(Param (LParamInfo lore), VName)
-> Either
(Param (LParamInfo lore), VName) ([VName], Certificates, ExpT lore)
forall a b. a -> Either a b
Left (Param (LParamInfo lore)
p, VName
v)
removeUnusedSOACInput :: TopDownRuleOp (Wise SOACS)
removeUnusedSOACInput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeUnusedSOACInput SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam))
| ([(Param Type, VName)]
used, [(Param Type, VName)]
unused) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)]
-> ([(Param Type, VName)], [(Param Type, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Param Type, VName) -> Bool
usedInput [(Param Type, VName)]
params_and_arrs,
Bool -> Bool
not ([(Param Type, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param Type, VName)]
unused) = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let ([Param Type]
used_params, [VName]
used_arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
used
map_lam' :: Lambda (Wise SOACS)
map_lam' = Lambda (Wise SOACS)
map_lam {lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = [Param Type]
[LParam (Wise SOACS)]
used_params}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
used_arrs ([Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
where
params_and_arrs :: [(Param Type, VName)]
params_and_arrs = [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
map_lam) [VName]
arrs
used_in_body :: Names
used_in_body = BodyT (Wise SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT (Wise SOACS) -> Names) -> BodyT (Wise SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam
usedInput :: (Param Type, VName) -> Bool
usedInput (Param Type
param, VName
_) = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param VName -> Names -> Bool
`nameIn` Names
used_in_body
removeUnusedSOACInput SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping (SymbolTable (Wise SOACS)
_, UsageTable
used) Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
| Just Lambda (Wise SOACS)
fun <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm (Wise SOACS)
form =
let ses :: [SubExp]
ses = BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise SOACS) -> [SubExp]) -> BodyT (Wise SOACS) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
fun
isUsed :: (PatElemT (VarWisdom, Type), SubExp, Type) -> Bool
isUsed (PatElemT (VarWisdom, Type)
bindee, SubExp
_, Type
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
bindee
([PatElemT (VarWisdom, Type)]
pat', [SubExp]
ses', [Type]
ts') =
[(PatElemT (VarWisdom, Type), SubExp, Type)]
-> ([PatElemT (VarWisdom, Type)], [SubExp], [Type])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (VarWisdom, Type), SubExp, Type)]
-> ([PatElemT (VarWisdom, Type)], [SubExp], [Type]))
-> [(PatElemT (VarWisdom, Type), SubExp, Type)]
-> ([PatElemT (VarWisdom, Type)], [SubExp], [Type])
forall a b. (a -> b) -> a -> b
$
((PatElemT (VarWisdom, Type), SubExp, Type) -> Bool)
-> [(PatElemT (VarWisdom, Type), SubExp, Type)]
-> [(PatElemT (VarWisdom, Type), SubExp, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (VarWisdom, Type), SubExp, Type) -> Bool
isUsed ([(PatElemT (VarWisdom, Type), SubExp, Type)]
-> [(PatElemT (VarWisdom, Type), SubExp, Type)])
-> [(PatElemT (VarWisdom, Type), SubExp, Type)]
-> [(PatElemT (VarWisdom, Type), SubExp, Type)]
forall a b. (a -> b) -> a -> b
$
[PatElemT (VarWisdom, Type)]
-> [SubExp]
-> [Type]
-> [(PatElemT (VarWisdom, Type), SubExp, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (PatternT (VarWisdom, Type) -> [PatElemT (VarWisdom, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat) [SubExp]
ses ([Type] -> [(PatElemT (VarWisdom, Type), SubExp, Type)])
-> [Type] -> [(PatElemT (VarWisdom, Type), SubExp, Type)]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
fun' :: Lambda (Wise SOACS)
fun' =
Lambda (Wise SOACS)
fun
{ lambdaBody :: BodyT (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: [SubExp]
bodyResult = [SubExp]
ses'},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts'
}
in if PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat PatternT (VarWisdom, Type) -> PatternT (VarWisdom, Type) -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, Type)]
pat'
then
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, Type)]
pat') (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda (Wise SOACS)
fun'
else Rule (Wise SOACS)
forall lore. Rule lore
Skip
removeDeadMapping BottomUp (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
| Just Lambda (Wise SOACS)
fun <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm (Wise SOACS)
form =
let ses :: [SubExp]
ses = BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise SOACS) -> [SubExp]) -> BodyT (Wise SOACS) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
fun
ts :: [Type]
ts = Lambda (Wise SOACS) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
pes :: [PatElemT (VarWisdom, Type)]
pes = PatternT (VarWisdom, Type) -> [PatElemT (VarWisdom, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat
ses_ts_pes :: [(SubExp, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes = [SubExp]
-> [Type]
-> [PatElemT (VarWisdom, Type)]
-> [(SubExp, Type, PatElemT (VarWisdom, Type))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExp]
ses [Type]
ts [PatElemT (VarWisdom, Type)]
pes
([(SubExp, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes', [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
copies) =
(([(SubExp, Type, PatElemT (VarWisdom, Type))],
[(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
-> (SubExp, Type, PatElemT (VarWisdom, Type))
-> ([(SubExp, Type, PatElemT (VarWisdom, Type))],
[(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]))
-> ([(SubExp, Type, PatElemT (VarWisdom, Type))],
[(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
-> [(SubExp, Type, PatElemT (VarWisdom, Type))]
-> ([(SubExp, Type, PatElemT (VarWisdom, Type))],
[(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([(SubExp, Type, PatElemT (VarWisdom, Type))],
[(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
-> (SubExp, Type, PatElemT (VarWisdom, Type))
-> ([(SubExp, Type, PatElemT (VarWisdom, Type))],
[(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
forall {a} {b} {a}.
Eq a =>
([(a, b, a)], [(a, a)]) -> (a, b, a) -> ([(a, b, a)], [(a, a)])
checkForDuplicates ([(SubExp, Type, PatElemT (VarWisdom, Type))]
forall a. Monoid a => a
mempty, [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
forall a. Monoid a => a
mempty) [(SubExp, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes
in if [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
copies
then Rule (Wise SOACS)
forall lore. Rule lore
Skip
else RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
ses', [Type]
ts', [PatElemT (VarWisdom, Type)]
pes') = [(SubExp, Type, PatElemT (VarWisdom, Type))]
-> ([SubExp], [Type], [PatElemT (VarWisdom, Type)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes'
pat' :: PatternT (VarWisdom, Type)
pat' = [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, Type)]
pes'
fun' :: Lambda (Wise SOACS)
fun' =
Lambda (Wise SOACS)
fun
{ lambdaBody :: BodyT (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: [SubExp]
bodyResult = [SubExp]
ses'},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts'
}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT (VarWisdom, Type)
Pattern (Lore (RuleM (Wise SOACS)))
pat' (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda (Wise SOACS)
fun'
[(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
-> ((PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
copies (((PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> ((PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (VarWisdom, Type)
from, PatElemT (VarWisdom, Type)
to) ->
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, Type)
to]) (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
from
where
checkForDuplicates :: ([(a, b, a)], [(a, a)]) -> (a, b, a) -> ([(a, b, a)], [(a, a)])
checkForDuplicates ([(a, b, a)]
ses_ts_pes', [(a, a)]
copies) (a
se, b
t, a
pe)
| Just (a
_, b
_, a
pe') <- ((a, b, a) -> Bool) -> [(a, b, a)] -> Maybe (a, b, a)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(a
x, b
_, a
_) -> a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
se) [(a, b, a)]
ses_ts_pes' =
([(a, b, a)]
ses_ts_pes', (a
pe', a
pe) (a, a) -> [(a, a)] -> [(a, a)]
forall a. a -> [a] -> [a]
: [(a, a)]
copies)
| Bool
otherwise = ([(a, b, a)]
ses_ts_pes' [(a, b, a)] -> [(a, b, a)] -> [(a, b, a)]
forall a. [a] -> [a] -> [a]
++ [(a
se, b
t, a
pe)], [(a, a)]
copies)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp (SymbolTable (Wise SOACS)
_, UsageTable
used) Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux1 Op (Wise SOACS)
e
| Just (PatElemT (VarWisdom, Type)
map_pe, Certificates
cs, SubExp
w, BasicOp (Reshape ShapeChange SubExp
newshape VName
reshape_arr), [Param Type
p], [VName
arr]) <-
PatternT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT (VarWisdom, Type), Certificates, SubExp,
Exp (Wise SOACS), [Param Type], [VName])
forall dec.
PatternT dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
isMapWithOp PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
reshape_arr,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
map_pe) UsageTable
used = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let redim :: DimChange SubExp
redim
| Maybe [SubExp] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [SubExp] -> Bool) -> Maybe [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape = SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion SubExp
w
| Bool
otherwise = SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew SubExp
w
Certificates -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape (DimChange SubExp
redim DimChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. a -> [a] -> [a]
: ShapeChange SubExp
newshape) VName
arr
| Just
( PatElemT (VarWisdom, Type)
_,
Certificates
cs,
SubExp
_,
BasicOp (Concat Int
d VName
arr [VName]
arrs SubExp
dw),
[Param Type]
ps,
VName
outer_arr : [VName]
outer_arrs
) <-
PatternT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT (VarWisdom, Type), Certificates, SubExp,
Exp (Wise SOACS), [Param Type], [VName])
forall dec.
PatternT dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
isMapWithOp PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
(VName
arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ps =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) VName
outer_arr [VName]
outer_arrs SubExp
dw
| Just
( PatElemT (VarWisdom, Type)
map_pe,
Certificates
cs,
SubExp
_,
BasicOp (Rearrange [Int]
perm VName
rearrange_arr),
[Param Type
p],
[VName
arr]
) <-
PatternT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT (VarWisdom, Type), Certificates, SubExp,
Exp (Wise SOACS), [Param Type], [VName])
forall dec.
PatternT dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
isMapWithOp PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
rearrange_arr,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
map_pe) UsageTable
used =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) [Int]
perm) VName
arr
| Just (PatElemT (VarWisdom, Type)
map_pe, Certificates
cs, SubExp
_, BasicOp (Rotate [SubExp]
rots VName
rotate_arr), [Param Type
p], [VName
arr]) <-
PatternT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT (VarWisdom, Type), Certificates, SubExp,
Exp (Wise SOACS), [Param Type], [VName])
forall dec.
PatternT dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
isMapWithOp PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
rotate_arr,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
map_pe) UsageTable
used =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
mapOpToOp BottomUp (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
isMapWithOp ::
PatternT dec ->
SOAC (Wise SOACS) ->
Maybe
( PatElemT dec,
Certificates,
SubExp,
AST.Exp (Wise SOACS),
[Param Type],
[VName]
)
isMapWithOp :: forall dec.
PatternT dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
isMapWithOp PatternT dec
pat SOAC (Wise SOACS)
e
| Pattern [] [PatElemT dec
map_pe] <- PatternT dec
pat,
Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form <- SOAC (Wise SOACS)
e,
Just Lambda (Wise SOACS)
map_lam <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm (Wise SOACS)
form,
[Let (Pattern [] [PatElemT (LetDec (Wise SOACS))
pe]) StmAux (ExpDec (Wise SOACS))
aux2 Exp (Wise SOACS)
e'] <-
Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT (Wise SOACS) -> Stms (Wise SOACS))
-> BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam,
[Var VName
r] <- BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise SOACS) -> [SubExp]) -> BodyT (Wise SOACS) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam,
VName
r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
PatElemT (LetDec (Wise SOACS))
pe =
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
-> Maybe
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
forall a. a -> Maybe a
Just (PatElemT dec
map_pe, StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux2, SubExp
w, Exp (Wise SOACS)
e', Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
map_lam, [VName]
arrs)
| Bool
otherwise = Maybe
(PatElemT dec, Certificates, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
forall a. Maybe a
Nothing
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction (SymbolTable (Wise SOACS)
_, UsageTable
used) Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
| Just ([Reduce Commutativity
comm Lambda (Wise SOACS)
redlam [SubExp]
nes], Lambda (Wise SOACS)
maplam) <- ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat,
let ([PatElemT (VarWisdom, Type)]
red_pes, [PatElemT (VarWisdom, Type)]
map_pes) = Int
-> [PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)]))
-> [PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)])
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, Type) -> [PatElemT (VarWisdom, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat,
let redlam_deps :: Dependencies
redlam_deps = BodyT (Wise SOACS) -> Dependencies
forall lore. ASTLore lore => Body lore -> Dependencies
dataDependencies (BodyT (Wise SOACS) -> Dependencies)
-> BodyT (Wise SOACS) -> Dependencies
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
redlam,
let redlam_res :: [SubExp]
redlam_res = BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise SOACS) -> [SubExp]) -> BodyT (Wise SOACS) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
redlam,
let redlam_params :: [LParam (Wise SOACS)]
redlam_params = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
redlam,
let used_after :: [Param Type]
used_after =
((PatElemT (VarWisdom, Type), Param Type) -> Param Type)
-> [(PatElemT (VarWisdom, Type), Param Type)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (VarWisdom, Type), Param Type) -> Param Type
forall a b. (a, b) -> b
snd ([(PatElemT (VarWisdom, Type), Param Type)] -> [Param Type])
-> [(PatElemT (VarWisdom, Type), Param Type)] -> [Param Type]
forall a b. (a -> b) -> a -> b
$
((PatElemT (VarWisdom, Type), Param Type) -> Bool)
-> [(PatElemT (VarWisdom, Type), Param Type)]
-> [(PatElemT (VarWisdom, Type), Param Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool)
-> ((PatElemT (VarWisdom, Type), Param Type) -> VName)
-> (PatElemT (VarWisdom, Type), Param Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT (VarWisdom, Type) -> VName)
-> ((PatElemT (VarWisdom, Type), Param Type)
-> PatElemT (VarWisdom, Type))
-> (PatElemT (VarWisdom, Type), Param Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (VarWisdom, Type), Param Type)
-> PatElemT (VarWisdom, Type)
forall a b. (a, b) -> a
fst) ([(PatElemT (VarWisdom, Type), Param Type)]
-> [(PatElemT (VarWisdom, Type), Param Type)])
-> [(PatElemT (VarWisdom, Type), Param Type)]
-> [(PatElemT (VarWisdom, Type), Param Type)]
forall a b. (a -> b) -> a -> b
$
[PatElemT (VarWisdom, Type)]
-> [Param Type] -> [(PatElemT (VarWisdom, Type), Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (VarWisdom, Type)]
red_pes [Param Type]
[LParam (Wise SOACS)]
redlam_params,
let necessary :: Names
necessary =
(Param Type -> Bool)
-> [(Param Type, SubExp)] -> Dependencies -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned
(Param Type -> [Param Type] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Param Type]
used_after)
([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Wise SOACS)]
redlam_params ([SubExp] -> [(Param Type, SubExp)])
-> [SubExp] -> [(Param Type, SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
redlam_res [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
redlam_res)
Dependencies
redlam_deps,
let alive_mask :: [Bool]
alive_mask = (Param Type -> Bool) -> [Param Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Names -> Bool
`nameIn` Names
necessary) (VName -> Bool) -> (Param Type -> VName) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
[LParam (Wise SOACS)]
redlam_params,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool) -> [Bool] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
True) [Bool]
alive_mask = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let fixDeadToNeutral :: Bool -> a -> Maybe a
fixDeadToNeutral Bool
lives a
ne = if Bool
lives then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just a
ne
dead_fix :: [Maybe SubExp]
dead_fix = (Bool -> SubExp -> Maybe SubExp)
-> [Bool] -> [SubExp] -> [Maybe SubExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Bool -> SubExp -> Maybe SubExp
forall {a}. Bool -> a -> Maybe a
fixDeadToNeutral [Bool]
alive_mask [SubExp]
nes
([PatElemT (VarWisdom, Type)]
used_red_pes, [Param Type]
_, [SubExp]
used_nes) =
[(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElemT (VarWisdom, Type)], [Param Type], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElemT (VarWisdom, Type)], [Param Type], [SubExp]))
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElemT (VarWisdom, Type)], [Param Type], [SubExp])
forall a b. (a -> b) -> a -> b
$
((PatElemT (VarWisdom, Type), Param Type, SubExp) -> Bool)
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElemT (VarWisdom, Type)
_, Param Type
x, SubExp
_) -> Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x VName -> Names -> Bool
`nameIn` Names
necessary) ([(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)])
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
forall a b. (a -> b) -> a -> b
$
[PatElemT (VarWisdom, Type)]
-> [Param Type]
-> [SubExp]
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (VarWisdom, Type)]
red_pes [Param Type]
[LParam (Wise SOACS)]
redlam_params [SubExp]
nes
let maplam' :: Lambda (Wise SOACS)
maplam' = [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall lore. [Bool] -> Lambda lore -> Lambda lore
removeLambdaResults (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) Lambda (Wise SOACS)
maplam
Lambda (Wise SOACS)
redlam' <- [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall lore. [Bool] -> Lambda lore -> Lambda lore
removeLambdaResults (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) (Lambda (Wise SOACS) -> Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Lore (RuleM (Wise SOACS)))
-> [Maybe SubExp]
-> RuleM (Wise SOACS) (Lambda (Lore (RuleM (Wise SOACS))))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), BinderOps (Lore m)) =>
Lambda (Lore m) -> [Maybe SubExp] -> m (Lambda (Lore m))
fixLambdaParams Lambda (Lore (RuleM (Wise SOACS)))
Lambda (Wise SOACS)
redlam ([Maybe SubExp]
dead_fix [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ [Maybe SubExp]
dead_fix)
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type))
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall a b. (a -> b) -> a -> b
$ [PatElemT (VarWisdom, Type)]
used_red_pes [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> [PatElemT (VarWisdom, Type)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (VarWisdom, Type)]
map_pes) (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity
-> Lambda (Wise SOACS) -> [SubExp] -> Reduce (Wise SOACS)
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm Lambda (Wise SOACS)
redlam' [SubExp]
used_nes] Lambda (Wise SOACS)
maplam'
removeDeadReduction BottomUp (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
removeDeadWrite :: BottomUpRuleOp (Wise SOACS)
removeDeadWrite :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadWrite (SymbolTable (Wise SOACS)
_, UsageTable
used) Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
w Lambda (Wise SOACS)
fun [VName]
arrs [(Shape, Int, VName)]
dests) =
let ([[SubExp]]
i_ses, [SubExp]
v_ses) = [([SubExp], SubExp)] -> ([[SubExp]], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], SubExp)] -> ([[SubExp]], [SubExp]))
-> [([SubExp], SubExp)] -> ([[SubExp]], [SubExp])
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)] -> [SubExp] -> [([SubExp], SubExp)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests ([SubExp] -> [([SubExp], SubExp)])
-> [SubExp] -> [([SubExp], SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Wise SOACS) -> [SubExp]) -> BodyT (Wise SOACS) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
fun
([[Type]]
i_ts, [Type]
v_ts) = [([Type], Type)] -> ([[Type]], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Type], Type)] -> ([[Type]], [Type]))
-> [([Type], Type)] -> ([[Type]], [Type])
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)] -> [Type] -> [([Type], Type)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests ([Type] -> [([Type], Type)]) -> [Type] -> [([Type], Type)]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
isUsed :: (PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))
-> Bool
isUsed (PatElemT (VarWisdom, Type)
bindee, [SubExp]
_, SubExp
_, [Type]
_, Type
_, (Shape, Int, VName)
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
bindee
([PatElemT (VarWisdom, Type)]
pat', [[SubExp]]
i_ses', [SubExp]
v_ses', [[Type]]
i_ts', [Type]
v_ts', [(Shape, Int, VName)]
dests') =
[(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
-> ([PatElemT (VarWisdom, Type)], [[SubExp]], [SubExp], [[Type]],
[Type], [(Shape, Int, VName)])
forall a b c d e f.
[(a, b, c, d, e, f)] -> ([a], [b], [c], [d], [e], [f])
unzip6 ([(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
-> ([PatElemT (VarWisdom, Type)], [[SubExp]], [SubExp], [[Type]],
[Type], [(Shape, Int, VName)]))
-> [(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
-> ([PatElemT (VarWisdom, Type)], [[SubExp]], [SubExp], [[Type]],
[Type], [(Shape, Int, VName)])
forall a b. (a -> b) -> a -> b
$
((PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))
-> Bool)
-> [(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
-> [(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))
-> Bool
isUsed ([(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
-> [(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))])
-> [(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
-> [(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
forall a b. (a -> b) -> a -> b
$
[PatElemT (VarWisdom, Type)]
-> [[SubExp]]
-> [SubExp]
-> [[Type]]
-> [Type]
-> [(Shape, Int, VName)]
-> [(PatElemT (VarWisdom, Type), [SubExp], SubExp, [Type], Type,
(Shape, Int, VName))]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 (PatternT (VarWisdom, Type) -> [PatElemT (VarWisdom, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat) [[SubExp]]
i_ses [SubExp]
v_ses [[Type]]
i_ts [Type]
v_ts [(Shape, Int, VName)]
dests
fun' :: Lambda (Wise SOACS)
fun' =
Lambda (Wise SOACS)
fun
{ lambdaBody :: BodyT (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: [SubExp]
bodyResult = [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
i_ses' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
v_ses'},
lambdaReturnType :: [Type]
lambdaReturnType = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
i_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
v_ts'
}
in if PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat PatternT (VarWisdom, Type) -> PatternT (VarWisdom, Type) -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, Type)]
pat'
then
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarWisdom, Type)]
pat') (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda (Wise SOACS)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
w Lambda (Wise SOACS)
fun' [VName]
arrs [(Shape, Int, VName)]
dests'
else Rule (Wise SOACS)
forall lore. Rule lore
Skip
removeDeadWrite BottomUp (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter SymbolTable (Wise SOACS)
vtable Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
_ (Scatter SubExp
_ Lambda (Wise SOACS)
fun [VName]
arrs [(Shape, Int, VName)]
dests)
| Just (ws :: [SubExp]
ws@(SubExp
w' : [SubExp]
_), [[VName]]
xss, [Certificates]
css) <- [(SubExp, [VName], Certificates)]
-> ([SubExp], [[VName]], [Certificates])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(SubExp, [VName], Certificates)]
-> ([SubExp], [[VName]], [Certificates]))
-> Maybe [(SubExp, [VName], Certificates)]
-> Maybe ([SubExp], [[VName]], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (SubExp, [VName], Certificates))
-> [VName] -> Maybe [(SubExp, [VName], Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe (SubExp, [VName], Certificates)
isConcat [VName]
arrs,
[[VName]]
xivs <- [[VName]] -> [[VName]]
forall a. [[a]] -> [[a]]
transpose [[VName]]
xss,
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
w' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
==) [SubExp]
ws = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let r :: Int
r = [[VName]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[VName]]
xivs
[Lambda (Wise SOACS)]
fun2s <- (Int -> RuleM (Wise SOACS) (Lambda (Wise SOACS)))
-> [Int] -> RuleM (Wise SOACS) [Lambda (Wise SOACS)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Int
_ -> Lambda (Wise SOACS) -> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda (Wise SOACS)
fun) [Int
1 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
let ([[SubExp]]
fun_is, [[SubExp]]
fun_vs) =
[([SubExp], [SubExp])] -> ([[SubExp]], [[SubExp]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], [SubExp])] -> ([[SubExp]], [[SubExp]]))
-> [([SubExp], [SubExp])] -> ([[SubExp]], [[SubExp]])
forall a b. (a -> b) -> a -> b
$
(Lambda (Wise SOACS) -> ([SubExp], [SubExp]))
-> [Lambda (Wise SOACS)] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> [a] -> [b]
map
( [(Shape, Int, VName)] -> [SubExp] -> ([SubExp], [SubExp])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests
([SubExp] -> ([SubExp], [SubExp]))
-> (Lambda (Wise SOACS) -> [SubExp])
-> Lambda (Wise SOACS)
-> ([SubExp], [SubExp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT (Wise SOACS) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult
(BodyT (Wise SOACS) -> [SubExp])
-> (Lambda (Wise SOACS) -> BodyT (Wise SOACS))
-> Lambda (Wise SOACS)
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody
)
(Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s)
([[Type]]
its, [[Type]]
vts) =
[([Type], [Type])] -> ([[Type]], [[Type]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Type], [Type])] -> ([[Type]], [[Type]]))
-> [([Type], [Type])] -> ([[Type]], [[Type]])
forall a b. (a -> b) -> a -> b
$
Int -> ([Type], [Type]) -> [([Type], [Type])]
forall a. Int -> a -> [a]
replicate Int
r (([Type], [Type]) -> [([Type], [Type])])
-> ([Type], [Type]) -> [([Type], [Type])]
forall a b. (a -> b) -> a -> b
$
[(Shape, Int, VName)] -> [Type] -> ([Type], [Type])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
new_stmts :: Stms (Wise SOACS)
new_stmts = [Stms (Wise SOACS)] -> Stms (Wise SOACS)
forall a. Monoid a => [a] -> a
mconcat ([Stms (Wise SOACS)] -> Stms (Wise SOACS))
-> [Stms (Wise SOACS)] -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ (Lambda (Wise SOACS) -> Stms (Wise SOACS))
-> [Lambda (Wise SOACS)] -> [Stms (Wise SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT (Wise SOACS) -> Stms (Wise SOACS))
-> (Lambda (Wise SOACS) -> BodyT (Wise SOACS))
-> Lambda (Wise SOACS)
-> Stms (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody) (Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s)
let fun' :: Lambda (Wise SOACS)
fun' =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = [[Param Type]] -> [Param Type]
forall a. Monoid a => [a] -> a
mconcat ([[Param Type]] -> [Param Type]) -> [[Param Type]] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ (Lambda (Wise SOACS) -> [Param Type])
-> [Lambda (Wise SOACS)] -> [[Param Type]]
forall a b. (a -> b) -> [a] -> [b]
map Lambda (Wise SOACS) -> [Param Type]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s),
lambdaBody :: BodyT (Wise SOACS)
lambdaBody =
Stms (Wise SOACS) -> [SubExp] -> BodyT (Wise SOACS)
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms (Wise SOACS)
new_stmts ([SubExp] -> BodyT (Wise SOACS)) -> [SubExp] -> BodyT (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
[[SubExp]] -> [SubExp]
forall {a}. [[a]] -> [a]
mix [[SubExp]]
fun_is [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [[SubExp]] -> [SubExp]
forall {a}. [[a]] -> [a]
mix [[SubExp]]
fun_vs,
lambdaReturnType :: [Type]
lambdaReturnType = [[Type]] -> [Type]
forall {a}. [[a]] -> [a]
mix [[Type]]
its [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
<> [[Type]] -> [Type]
forall {a}. [[a]] -> [a]
mix [[Type]]
vts
}
Certificates -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying ([Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
css) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda (Wise SOACS)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
w' Lambda (Wise SOACS)
fun' ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xivs) ([(Shape, Int, VName)] -> SOAC (Wise SOACS))
-> [(Shape, Int, VName)] -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> (Shape, Int, VName))
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> (Shape, Int, VName) -> (Shape, Int, VName)
forall {b} {a} {c}. Num b => b -> (a, b, c) -> (a, b, c)
incWrites Int
r) [(Shape, Int, VName)]
dests
where
sizeOf :: VName -> Maybe SubExp
sizeOf :: VName -> Maybe SubExp
sizeOf VName
x = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp)
-> (Entry (Wise SOACS) -> Type) -> Entry (Wise SOACS) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry (Wise SOACS) -> Type
forall t. Typed t => t -> Type
typeOf (Entry (Wise SOACS) -> SubExp)
-> Maybe (Entry (Wise SOACS)) -> Maybe SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable (Wise SOACS) -> Maybe (Entry (Wise SOACS))
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
x SymbolTable (Wise SOACS)
vtable
mix :: [[a]] -> [a]
mix = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a]) -> ([[a]] -> [[a]]) -> [[a]] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose
incWrites :: b -> (a, b, c) -> (a, b, c)
incWrites b
r (a
w, b
n, c
a) = (a
w, b
n b -> b -> b
forall a. Num a => a -> a -> a
* b
r, c
a)
isConcat :: VName -> Maybe (SubExp, [VName], Certificates)
isConcat VName
v = case VName
-> SymbolTable (Wise SOACS)
-> Maybe (Exp (Wise SOACS), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable (Wise SOACS)
vtable of
Just (BasicOp (Concat Int
0 VName
x [VName]
ys SubExp
_), Certificates
cs) -> do
SubExp
x_w <- VName -> Maybe SubExp
sizeOf VName
x
[SubExp]
y_ws <- (VName -> Maybe SubExp) -> [VName] -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe SubExp
sizeOf [VName]
ys
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
x_w SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
==) [SubExp]
y_ws
(SubExp, [VName], Certificates)
-> Maybe (SubExp, [VName], Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
x_w, VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certificates
cs)
Just (BasicOp (Reshape ShapeChange SubExp
reshape VName
arr), Certificates
cs) -> do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Maybe [SubExp] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [SubExp] -> Bool) -> Maybe [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
reshape
(SubExp
a, [VName]
b, Certificates
cs') <- VName -> Maybe (SubExp, [VName], Certificates)
isConcat VName
arr
(SubExp, [VName], Certificates)
-> Maybe (SubExp, [VName], Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
a, [VName]
b, Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs')
Maybe (Exp (Wise SOACS), Certificates)
_ -> Maybe (SubExp, [VName], Certificates)
forall a. Maybe a
Nothing
fuseConcatScatter SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
_ (Screma (Constant PrimValue
w) [VName]
_ ScremaForm (Wise SOACS)
form)
| Just [SubExp]
nes <- (Reduce (Wise SOACS) -> [SubExp])
-> [Reduce (Wise SOACS)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce (Wise SOACS) -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral ([Reduce (Wise SOACS)] -> [SubExp])
-> (([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [Reduce (Wise SOACS)])
-> ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [Reduce (Wise SOACS)]
forall a b. (a, b) -> a
fst (([Reduce (Wise SOACS)], Lambda (Wise SOACS)) -> [SubExp])
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
PrimValue -> Bool
zeroIsh PrimValue
w =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT (VarWisdom, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat) [SubExp]
nes) (((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
ne) ->
[VName] -> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
simplifyClosedFormReduce SymbolTable (Wise SOACS)
vtable Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
_ (Screma SubExp
_ [VName]
arrs ScremaForm (Wise SOACS)
form)
| Just [Reduce Commutativity
_ Lambda (Wise SOACS)
red_fun [SubExp]
nes] <- ScremaForm (Wise SOACS) -> Maybe [Reduce (Wise SOACS)]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm (Wise SOACS)
form =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VarLookup (Wise SOACS)
-> Pattern (Wise SOACS)
-> Lambda (Wise SOACS)
-> [SubExp]
-> [VName]
-> RuleM (Wise SOACS) ()
forall lore.
(ASTLore lore, BinderOps lore) =>
VarLookup lore
-> Pattern lore
-> Lambda lore
-> [SubExp]
-> [VName]
-> RuleM lore ()
foldClosedForm (VName
-> SymbolTable (Wise SOACS)
-> Maybe (Exp (Wise SOACS), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
`ST.lookupExp` SymbolTable (Wise SOACS)
vtable) Pattern (Wise SOACS)
pat Lambda (Wise SOACS)
red_fun [SubExp]
nes [VName]
arrs
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
simplifyKnownIterationSOAC ::
(Bindable lore, Simplify.SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
simplifyKnownIterationSOAC :: forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
simplifyKnownIterationSOAC TopDown (Wise lore)
_ Pattern (Wise lore)
pat StmAux (ExpDec (Wise lore))
_ Op (Wise lore)
op
| Just (Screma (Constant PrimValue
k) [VName]
arrs (ScremaForm [Scan (Wise lore)]
scans [Reduce (Wise lore)]
reds Lambda (Wise lore)
map_lam)) <- Op (Wise lore) -> Maybe (SOAC (Wise lore))
forall lore. HasSOAC lore => Op lore -> Maybe (SOAC lore)
asSOAC Op (Wise lore)
op,
PrimValue -> Bool
oneIsh PrimValue
k = RuleM (Wise lore) () -> Rule (Wise lore)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise lore) () -> Rule (Wise lore))
-> RuleM (Wise lore) () -> Rule (Wise lore)
forall a b. (a -> b) -> a -> b
$ do
let (Reduce Commutativity
_ Lambda (Wise lore)
red_lam [SubExp]
red_nes) = [Reduce (Wise lore)] -> Reduce (Wise lore)
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce (Wise lore)]
reds
(Scan Lambda (Wise lore)
scan_lam [SubExp]
scan_nes) = [Scan (Wise lore)] -> Scan (Wise lore)
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan (Wise lore)]
scans
([PatElemT (VarWisdom, LetDec lore)]
scan_pes, [PatElemT (VarWisdom, LetDec lore)]
red_pes, [PatElemT (VarWisdom, LetDec lore)]
map_pes) =
Int
-> Int
-> [PatElemT (VarWisdom, LetDec lore)]
-> ([PatElemT (VarWisdom, LetDec lore)],
[PatElemT (VarWisdom, LetDec lore)],
[PatElemT (VarWisdom, LetDec lore)])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([PatElemT (VarWisdom, LetDec lore)]
-> ([PatElemT (VarWisdom, LetDec lore)],
[PatElemT (VarWisdom, LetDec lore)],
[PatElemT (VarWisdom, LetDec lore)]))
-> [PatElemT (VarWisdom, LetDec lore)]
-> ([PatElemT (VarWisdom, LetDec lore)],
[PatElemT (VarWisdom, LetDec lore)],
[PatElemT (VarWisdom, LetDec lore)])
forall a b. (a -> b) -> a -> b
$
PatternT (VarWisdom, LetDec lore)
-> [PatElemT (VarWisdom, LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, LetDec lore)
Pattern (Wise lore)
pat
bindMapParam :: Param dec -> VName -> m ()
bindMapParam Param dec
p VName
a = do
Type
a_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
a
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
a (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
a_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)]
bindArrayResult :: PatElemT dec -> SubExp -> m ()
bindArrayResult PatElemT dec
pe SubExp
se =
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe
bindResult :: PatElemT dec -> SubExp -> m ()
bindResult PatElemT dec
pe SubExp
se =
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
(Param Type -> VName -> RuleM (Wise lore) ())
-> [Param Type] -> [VName] -> RuleM (Wise lore) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param Type -> VName -> RuleM (Wise lore) ()
forall {m :: * -> *} {dec}.
MonadBinder m =>
Param dec -> VName -> m ()
bindMapParam (Lambda (Wise lore) -> [LParam (Wise lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise lore)
map_lam) [VName]
arrs
([SubExp]
to_scan, [SubExp]
to_red, [SubExp]
map_res) <-
Int -> Int -> [SubExp] -> ([SubExp], [SubExp], [SubExp])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes)
([SubExp] -> ([SubExp], [SubExp], [SubExp]))
-> RuleM (Wise lore) [SubExp]
-> RuleM (Wise lore) ([SubExp], [SubExp], [SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Lambda (Wise lore) -> BodyT (Wise lore)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise lore)
map_lam)
[SubExp]
scan_res <- Lambda (Lore (RuleM (Wise lore)))
-> [RuleM (Wise lore) (Exp (Lore (RuleM (Wise lore))))]
-> RuleM (Wise lore) [SubExp]
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m [SubExp]
eLambda Lambda (Lore (RuleM (Wise lore)))
Lambda (Wise lore)
scan_lam ([RuleM (Wise lore) (Exp (Lore (RuleM (Wise lore))))]
-> RuleM (Wise lore) [SubExp])
-> [RuleM (Wise lore) (Exp (Lore (RuleM (Wise lore))))]
-> RuleM (Wise lore) [SubExp]
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM (Wise lore) (Exp (Wise lore)))
-> [SubExp] -> [RuleM (Wise lore) (Exp (Wise lore))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM (Wise lore) (Exp (Wise lore))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp ([SubExp] -> [RuleM (Wise lore) (Exp (Wise lore))])
-> [SubExp] -> [RuleM (Wise lore) (Exp (Wise lore))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
scan_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
to_scan
[SubExp]
red_res <- Lambda (Lore (RuleM (Wise lore)))
-> [RuleM (Wise lore) (Exp (Lore (RuleM (Wise lore))))]
-> RuleM (Wise lore) [SubExp]
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m [SubExp]
eLambda Lambda (Lore (RuleM (Wise lore)))
Lambda (Wise lore)
red_lam ([RuleM (Wise lore) (Exp (Lore (RuleM (Wise lore))))]
-> RuleM (Wise lore) [SubExp])
-> [RuleM (Wise lore) (Exp (Lore (RuleM (Wise lore))))]
-> RuleM (Wise lore) [SubExp]
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM (Wise lore) (Exp (Wise lore)))
-> [SubExp] -> [RuleM (Wise lore) (Exp (Wise lore))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM (Wise lore) (Exp (Wise lore))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp ([SubExp] -> [RuleM (Wise lore) (Exp (Wise lore))])
-> [SubExp] -> [RuleM (Wise lore) (Exp (Wise lore))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
red_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
to_red
(PatElemT (VarWisdom, LetDec lore)
-> SubExp -> RuleM (Wise lore) ())
-> [PatElemT (VarWisdom, LetDec lore)]
-> [SubExp]
-> RuleM (Wise lore) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (VarWisdom, LetDec lore) -> SubExp -> RuleM (Wise lore) ()
forall {m :: * -> *} {dec}.
(MonadBinder m, Typed dec) =>
PatElemT dec -> SubExp -> m ()
bindArrayResult [PatElemT (VarWisdom, LetDec lore)]
scan_pes [SubExp]
scan_res
(PatElemT (VarWisdom, LetDec lore)
-> SubExp -> RuleM (Wise lore) ())
-> [PatElemT (VarWisdom, LetDec lore)]
-> [SubExp]
-> RuleM (Wise lore) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (VarWisdom, LetDec lore) -> SubExp -> RuleM (Wise lore) ()
forall {m :: * -> *} {dec}.
MonadBinder m =>
PatElemT dec -> SubExp -> m ()
bindResult [PatElemT (VarWisdom, LetDec lore)]
red_pes [SubExp]
red_res
(PatElemT (VarWisdom, LetDec lore)
-> SubExp -> RuleM (Wise lore) ())
-> [PatElemT (VarWisdom, LetDec lore)]
-> [SubExp]
-> RuleM (Wise lore) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (VarWisdom, LetDec lore) -> SubExp -> RuleM (Wise lore) ()
forall {m :: * -> *} {dec}.
(MonadBinder m, Typed dec) =>
PatElemT dec -> SubExp -> m ()
bindArrayResult [PatElemT (VarWisdom, LetDec lore)]
map_pes [SubExp]
map_res
simplifyKnownIterationSOAC TopDown (Wise lore)
_ Pattern (Wise lore)
pat StmAux (ExpDec (Wise lore))
_ Op (Wise lore)
op
| Just (Stream (Constant PrimValue
k) [VName]
arrs StreamForm (Wise lore)
_ [SubExp]
nes Lambda (Wise lore)
fold_lam) <- Op (Wise lore) -> Maybe (SOAC (Wise lore))
forall lore. HasSOAC lore => Op lore -> Maybe (SOAC lore)
asSOAC Op (Wise lore)
op,
PrimValue -> Bool
oneIsh PrimValue
k = RuleM (Wise lore) () -> Rule (Wise lore)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise lore) () -> Rule (Wise lore))
-> RuleM (Wise lore) () -> Rule (Wise lore)
forall a b. (a -> b) -> a -> b
$ do
let (Param Type
chunk_param, [Param Type]
acc_params, [Param Type]
slice_params) =
Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda (Wise lore) -> [LParam (Wise lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise lore)
fold_lam)
[VName] -> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param] (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
[(Param Type, SubExp)]
-> ((Param Type, SubExp) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
acc_params [SubExp]
nes) (((Param Type, SubExp) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ())
-> ((Param Type, SubExp) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, SubExp
ne) ->
[VName] -> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
[(Param Type, VName)]
-> ((Param Type, VName) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
slice_params [VName]
arrs) (((Param Type, VName) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ())
-> ((Param Type, VName) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
[VName] -> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
[SubExp]
res <- Body (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Body (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) [SubExp])
-> Body (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise lore) -> BodyT (Wise lore)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise lore)
fold_lam
[(VName, SubExp)]
-> ((VName, SubExp) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT (VarWisdom, LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarWisdom, LetDec lore)
Pattern (Wise lore)
pat) [SubExp]
res) (((VName, SubExp) -> RuleM (Wise lore) ()) -> RuleM (Wise lore) ())
-> ((VName, SubExp) -> RuleM (Wise lore) ())
-> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
[VName] -> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationSOAC TopDown (Wise lore)
_ Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ Op (Wise lore)
_ = Rule (Wise lore)
forall lore. Rule lore
Skip
data ArrayOp
= ArrayIndexing Certificates VName (Slice SubExp)
| ArrayRearrange Certificates VName [Int]
| ArrayRotate Certificates VName [SubExp]
|
ArrayVar Certificates VName
deriving (ArrayOp -> ArrayOp -> Bool
(ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool) -> Eq ArrayOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayOp -> ArrayOp -> Bool
$c/= :: ArrayOp -> ArrayOp -> Bool
== :: ArrayOp -> ArrayOp -> Bool
$c== :: ArrayOp -> ArrayOp -> Bool
Eq, Eq ArrayOp
Eq ArrayOp
-> (ArrayOp -> ArrayOp -> Ordering)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> Ord ArrayOp
ArrayOp -> ArrayOp -> Bool
ArrayOp -> ArrayOp -> Ordering
ArrayOp -> ArrayOp -> ArrayOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayOp -> ArrayOp -> ArrayOp
$cmin :: ArrayOp -> ArrayOp -> ArrayOp
max :: ArrayOp -> ArrayOp -> ArrayOp
$cmax :: ArrayOp -> ArrayOp -> ArrayOp
>= :: ArrayOp -> ArrayOp -> Bool
$c>= :: ArrayOp -> ArrayOp -> Bool
> :: ArrayOp -> ArrayOp -> Bool
$c> :: ArrayOp -> ArrayOp -> Bool
<= :: ArrayOp -> ArrayOp -> Bool
$c<= :: ArrayOp -> ArrayOp -> Bool
< :: ArrayOp -> ArrayOp -> Bool
$c< :: ArrayOp -> ArrayOp -> Bool
compare :: ArrayOp -> ArrayOp -> Ordering
$ccompare :: ArrayOp -> ArrayOp -> Ordering
Ord, Int -> ArrayOp -> ShowS
[ArrayOp] -> ShowS
ArrayOp -> String
(Int -> ArrayOp -> ShowS)
-> (ArrayOp -> String) -> ([ArrayOp] -> ShowS) -> Show ArrayOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayOp] -> ShowS
$cshowList :: [ArrayOp] -> ShowS
show :: ArrayOp -> String
$cshow :: ArrayOp -> String
showsPrec :: Int -> ArrayOp -> ShowS
$cshowsPrec :: Int -> ArrayOp -> ShowS
Show)
arrayOpArr :: ArrayOp -> VName
arrayOpArr :: ArrayOp -> VName
arrayOpArr (ArrayIndexing Certificates
_ VName
arr Slice SubExp
_) = VName
arr
arrayOpArr (ArrayRearrange Certificates
_ VName
arr [Int]
_) = VName
arr
arrayOpArr (ArrayRotate Certificates
_ VName
arr [SubExp]
_) = VName
arr
arrayOpArr (ArrayVar Certificates
_ VName
arr) = VName
arr
arrayOpCerts :: ArrayOp -> Certificates
arrayOpCerts :: ArrayOp -> Certificates
arrayOpCerts (ArrayIndexing Certificates
cs VName
_ Slice SubExp
_) = Certificates
cs
arrayOpCerts (ArrayRearrange Certificates
cs VName
_ [Int]
_) = Certificates
cs
arrayOpCerts (ArrayRotate Certificates
cs VName
_ [SubExp]
_) = Certificates
cs
arrayOpCerts (ArrayVar Certificates
cs VName
_) = Certificates
cs
isArrayOp :: Certificates -> AST.Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp :: Certificates -> Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp Certificates
cs (BasicOp (Index VName
arr Slice SubExp
slice)) =
ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certificates
cs VName
arr Slice SubExp
slice
isArrayOp Certificates
cs (BasicOp (Rearrange [Int]
perm VName
arr)) =
ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> [Int] -> ArrayOp
ArrayRearrange Certificates
cs VName
arr [Int]
perm
isArrayOp Certificates
cs (BasicOp (Rotate [SubExp]
rots VName
arr)) =
ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> [SubExp] -> ArrayOp
ArrayRotate Certificates
cs VName
arr [SubExp]
rots
isArrayOp Certificates
_ Exp (Wise SOACS)
_ =
Maybe ArrayOp
forall a. Maybe a
Nothing
fromArrayOp :: ArrayOp -> (Certificates, AST.Exp (Wise SOACS))
fromArrayOp :: ArrayOp -> (Certificates, Exp (Wise SOACS))
fromArrayOp (ArrayIndexing Certificates
cs VName
arr Slice SubExp
slice) = (Certificates
cs, BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice)
fromArrayOp (ArrayRearrange Certificates
cs VName
arr [Int]
perm) = (Certificates
cs, BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr)
fromArrayOp (ArrayRotate Certificates
cs VName
arr [SubExp]
rots) = (Certificates
cs, BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots VName
arr)
fromArrayOp (ArrayVar Certificates
cs VName
arr) = (Certificates
cs, BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
arrayOps :: AST.Body (Wise SOACS) -> S.Set (AST.Pattern (Wise SOACS), ArrayOp)
arrayOps :: BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
arrayOps = [Set (PatternT (VarWisdom, Type), ArrayOp)]
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall a. Monoid a => [a] -> a
mconcat ([Set (PatternT (VarWisdom, Type), ArrayOp)]
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> (BodyT (Wise SOACS)
-> [Set (PatternT (VarWisdom, Type), ArrayOp)])
-> BodyT (Wise SOACS)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm (Wise SOACS) -> Set (PatternT (VarWisdom, Type), ArrayOp))
-> [Stm (Wise SOACS)]
-> [Set (PatternT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> [a] -> [b]
map Stm (Wise SOACS) -> Set (PatternT (VarWisdom, Type), ArrayOp)
onStm ([Stm (Wise SOACS)] -> [Set (PatternT (VarWisdom, Type), ArrayOp)])
-> (BodyT (Wise SOACS) -> [Stm (Wise SOACS)])
-> BodyT (Wise SOACS)
-> [Set (PatternT (VarWisdom, Type), ArrayOp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> (BodyT (Wise SOACS) -> Stms (Wise SOACS))
-> BodyT (Wise SOACS)
-> [Stm (Wise SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall lore. BodyT lore -> Stms lore
bodyStms
where
onStm :: Stm (Wise SOACS) -> Set (PatternT (VarWisdom, Type), ArrayOp)
onStm (Let Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux Exp (Wise SOACS)
e) =
case Certificates -> Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp (StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux) Exp (Wise SOACS)
e of
Just ArrayOp
op -> (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall a. a -> Set a
S.singleton (PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat, ArrayOp
op)
Maybe ArrayOp
Nothing -> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall s a. State s a -> s -> s
execState (Walker
(Wise SOACS)
(StateT (Set (PatternT (VarWisdom, Type), ArrayOp)) Identity)
-> Exp (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
forall (m :: * -> *) lore.
Monad m =>
Walker lore m -> Exp lore -> m ()
walkExpM Walker
(Wise SOACS)
(StateT (Set (PatternT (VarWisdom, Type), ArrayOp)) Identity)
walker Exp (Wise SOACS)
e) Set (PatternT (VarWisdom, Type), ArrayOp)
forall a. Monoid a => a
mempty
onOp :: SOAC (Wise SOACS) -> Set (PatternT (VarWisdom, Type), ArrayOp)
onOp = Writer
(Set (PatternT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS))
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall w a. Writer w a -> w
execWriter (Writer
(Set (PatternT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS))
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> (SOAC (Wise SOACS)
-> Writer
(Set (PatternT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS)))
-> SOAC (Wise SOACS)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper
(Wise SOACS)
(Wise SOACS)
(WriterT (Set (PatternT (VarWisdom, Type), ArrayOp)) Identity)
-> SOAC (Wise SOACS)
-> Writer
(Set (PatternT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper
Any
Any
(WriterT (Set (PatternT (VarWisdom, Type), ArrayOp)) Identity)
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS)
-> WriterT
(Set (PatternT (VarWisdom, Type), ArrayOp))
Identity
(Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS)
-> WriterT
(Set (PatternT (VarWisdom, Type), ArrayOp))
Identity
(Lambda (Wise SOACS))
forall {m :: * -> *}.
MonadWriter (Set (PatternT (VarWisdom, Type), ArrayOp)) m =>
Lambda (Wise SOACS) -> m (Lambda (Wise SOACS))
onLambda}
onLambda :: Lambda (Wise SOACS) -> m (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
Set (PatternT (VarWisdom, Type), ArrayOp) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Set (PatternT (VarWisdom, Type), ArrayOp) -> m ())
-> Set (PatternT (VarWisdom, Type), ArrayOp) -> m ()
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
arrayOps (BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp))
-> BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
lam
Lambda (Wise SOACS) -> m (Lambda (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda (Wise SOACS)
lam
walker :: Walker
(Wise SOACS)
(StateT (Set (PatternT (VarWisdom, Type), ArrayOp)) Identity)
walker =
Walker
(Wise SOACS)
(StateT (Set (PatternT (VarWisdom, Type), ArrayOp)) Identity)
forall (m :: * -> *) lore. Monad m => Walker lore m
identityWalker
{ walkOnBody :: Scope (Wise SOACS)
-> BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
walkOnBody = (BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ())
-> Scope (Wise SOACS)
-> BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
forall a b. a -> b -> a
const ((BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ())
-> Scope (Wise SOACS)
-> BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ())
-> (BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ())
-> Scope (Wise SOACS)
-> BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
forall a b. (a -> b) -> a -> b
$ (Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ())
-> (BodyT (Wise SOACS)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> BodyT (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall a. Semigroup a => a -> a -> a
(<>) (Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> (BodyT (Wise SOACS)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> BodyT (Wise SOACS)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT (Wise SOACS) -> Set (PatternT (VarWisdom, Type), ArrayOp)
BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
arrayOps,
walkOnOp :: Op (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
walkOnOp = (Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ())
-> (SOAC (Wise SOACS)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> SOAC (Wise SOACS)
-> State (Set (PatternT (VarWisdom, Type), ArrayOp)) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall a. Semigroup a => a -> a -> a
(<>) (Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp))
-> (SOAC (Wise SOACS) -> Set (PatternT (VarWisdom, Type), ArrayOp))
-> SOAC (Wise SOACS)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> Set (PatternT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC (Wise SOACS) -> Set (PatternT (VarWisdom, Type), ArrayOp)
onOp
}
replaceArrayOps ::
M.Map ArrayOp ArrayOp ->
AST.Body (Wise SOACS) ->
AST.Body (Wise SOACS)
replaceArrayOps :: Map ArrayOp ArrayOp -> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (Body BodyDec (Wise SOACS)
_ Stms (Wise SOACS)
stms [SubExp]
res) =
Stms (Wise SOACS) -> [SubExp] -> BodyT (Wise SOACS)
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody ((Stm (Wise SOACS) -> Stm (Wise SOACS))
-> Stms (Wise SOACS) -> Stms (Wise SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Wise SOACS) -> Stm (Wise SOACS)
onStm Stms (Wise SOACS)
stms) [SubExp]
res
where
onStm :: Stm (Wise SOACS) -> Stm (Wise SOACS)
onStm (Let Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux Exp (Wise SOACS)
e) =
let (Certificates
cs', Exp (Wise SOACS)
e') = Certificates
-> Exp (Wise SOACS) -> (Certificates, Exp (Wise SOACS))
onExp (StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux) Exp (Wise SOACS)
e
in Certificates -> Stm (Wise SOACS) -> Stm (Wise SOACS)
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs' (Stm (Wise SOACS) -> Stm (Wise SOACS))
-> Stm (Wise SOACS) -> Stm (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
[Ident]
-> [Ident]
-> StmAux (ExpWisdom, ())
-> Exp (Wise SOACS)
-> Stm (Wise SOACS)
forall lore a.
Bindable lore =>
[Ident] -> [Ident] -> StmAux a -> Exp lore -> Stm lore
mkLet' (PatternT (VarWisdom, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternContextIdents PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat) (PatternT (VarWisdom, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT (VarWisdom, Type)
Pattern (Wise SOACS)
pat) StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux Exp (Wise SOACS)
e'
onExp :: Certificates
-> Exp (Wise SOACS) -> (Certificates, Exp (Wise SOACS))
onExp Certificates
cs Exp (Wise SOACS)
e
| Just ArrayOp
op <- Certificates -> Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp Certificates
cs Exp (Wise SOACS)
e,
Just ArrayOp
op' <- ArrayOp -> Map ArrayOp ArrayOp -> Maybe ArrayOp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ArrayOp
op Map ArrayOp ArrayOp
substs =
ArrayOp -> (Certificates, Exp (Wise SOACS))
fromArrayOp ArrayOp
op'
onExp Certificates
cs Exp (Wise SOACS)
e = (Certificates
cs, Mapper (Wise SOACS) (Wise SOACS) Identity
-> Exp (Wise SOACS) -> Exp (Wise SOACS)
forall flore tlore.
Mapper flore tlore Identity -> Exp flore -> Exp tlore
mapExp Mapper (Wise SOACS) (Wise SOACS) Identity
mapper Exp (Wise SOACS)
e)
mapper :: Mapper (Wise SOACS) (Wise SOACS) Identity
mapper =
Mapper (Wise SOACS) (Wise SOACS) Identity
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
{ mapOnBody :: Scope (Wise SOACS)
-> BodyT (Wise SOACS) -> Identity (BodyT (Wise SOACS))
mapOnBody = (BodyT (Wise SOACS) -> Identity (BodyT (Wise SOACS)))
-> Scope (Wise SOACS)
-> BodyT (Wise SOACS)
-> Identity (BodyT (Wise SOACS))
forall a b. a -> b -> a
const ((BodyT (Wise SOACS) -> Identity (BodyT (Wise SOACS)))
-> Scope (Wise SOACS)
-> BodyT (Wise SOACS)
-> Identity (BodyT (Wise SOACS)))
-> (BodyT (Wise SOACS) -> Identity (BodyT (Wise SOACS)))
-> Scope (Wise SOACS)
-> BodyT (Wise SOACS)
-> Identity (BodyT (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> Identity (BodyT (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT (Wise SOACS) -> Identity (BodyT (Wise SOACS)))
-> (BodyT (Wise SOACS) -> BodyT (Wise SOACS))
-> BodyT (Wise SOACS)
-> Identity (BodyT (Wise SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map ArrayOp ArrayOp -> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs,
mapOnOp :: Op (Wise SOACS) -> Identity (Op (Wise SOACS))
mapOnOp = SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS)))
-> (SOAC (Wise SOACS) -> SOAC (Wise SOACS))
-> SOAC (Wise SOACS)
-> Identity (SOAC (Wise SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC (Wise SOACS) -> SOAC (Wise SOACS)
onOp
}
onOp :: SOAC (Wise SOACS) -> SOAC (Wise SOACS)
onOp = Identity (SOAC (Wise SOACS)) -> SOAC (Wise SOACS)
forall a. Identity a -> a
runIdentity (Identity (SOAC (Wise SOACS)) -> SOAC (Wise SOACS))
-> (SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS)))
-> SOAC (Wise SOACS)
-> SOAC (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper (Wise SOACS) (Wise SOACS) Identity
-> SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any Identity
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS) -> Identity (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS) -> Identity (Lambda (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Wise SOACS) -> Identity (Lambda (Wise SOACS)))
-> (Lambda (Wise SOACS) -> Lambda (Wise SOACS))
-> Lambda (Wise SOACS)
-> Identity (Lambda (Wise SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> Lambda (Wise SOACS)
onLambda}
onLambda :: Lambda (Wise SOACS) -> Lambda (Wise SOACS)
onLambda Lambda (Wise SOACS)
lam = Lambda (Wise SOACS)
lam {lambdaBody :: BodyT (Wise SOACS)
lambdaBody = Map ArrayOp ArrayOp -> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (BodyT (Wise SOACS) -> BodyT (Wise SOACS))
-> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
lam}
simplifyMapIota :: TopDownRuleOp (Wise SOACS)
simplifyMapIota :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyMapIota SymbolTable (Wise SOACS)
vtable Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam))
| Just (Param Type
p, VName
_) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Param Type, VName) -> Bool
isIota ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
map_lam) [VName]
arrs),
[ArrayOp]
indexings <-
(ArrayOp -> Bool) -> [ArrayOp] -> [ArrayOp]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> ArrayOp -> Bool
indexesWith (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p)) ([ArrayOp] -> [ArrayOp]) -> [ArrayOp] -> [ArrayOp]
forall a b. (a -> b) -> a -> b
$
((PatternT (VarWisdom, Type), ArrayOp) -> ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> [a] -> [b]
map (PatternT (VarWisdom, Type), ArrayOp) -> ArrayOp
forall a b. (a, b) -> b
snd ([(PatternT (VarWisdom, Type), ArrayOp)] -> [ArrayOp])
-> [(PatternT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> a -> b
$
Set (PatternT (VarWisdom, Type), ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (PatternT (VarWisdom, Type), ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)])
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$
BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
arrayOps (BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp))
-> BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ArrayOp]
indexings = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
([VName]
more_arrs, [Param Type]
more_params, [ArrayOp]
replacements) <-
[(VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp]))
-> ([Maybe (VName, Param Type, ArrayOp)]
-> [(VName, Param Type, ArrayOp)])
-> [Maybe (VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, Param Type, ArrayOp)]
-> [(VName, Param Type, ArrayOp)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp]))
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
-> RuleM (Wise SOACS) ([VName], [Param Type], [ArrayOp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ArrayOp
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> [ArrayOp]
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr [ArrayOp]
indexings
let substs :: Map ArrayOp ArrayOp
substs = [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp)
-> [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> [ArrayOp] -> [(ArrayOp, ArrayOp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ArrayOp]
indexings [ArrayOp]
replacements
map_lam' :: Lambda (Wise SOACS)
map_lam' =
Lambda (Wise SOACS)
map_lam
{ lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
map_lam [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
lambdaBody :: BodyT (Wise SOACS)
lambdaBody =
Map ArrayOp ArrayOp -> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (BodyT (Wise SOACS) -> BodyT (Wise SOACS))
-> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam
}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) ([Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
where
isIota :: (Param Type, VName) -> Bool
isIota (Param Type
_, VName
arr) = case VName -> SymbolTable (Wise SOACS) -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
arr SymbolTable (Wise SOACS)
vtable of
Just (Iota SubExp
_ (Constant PrimValue
o) (Constant PrimValue
s) IntType
_, Certificates
_) ->
PrimValue -> Bool
zeroIsh PrimValue
o Bool -> Bool -> Bool
&& PrimValue -> Bool
oneIsh PrimValue
s
Maybe (BasicOp, Certificates)
_ -> Bool
False
indexesWith :: VName -> ArrayOp -> Bool
indexesWith VName
v (ArrayIndexing Certificates
cs VName
arr (DimFix (Var VName
i) : Slice SubExp
_))
| VName
arr VName -> SymbolTable (Wise SOACS) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Certificates -> [VName]
unCertificates Certificates
cs =
VName
i VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v
indexesWith VName
_ ArrayOp
_ = Bool
False
mapOverArr :: ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr (ArrayIndexing Certificates
cs VName
arr Slice SubExp
slice) = do
VName
arr_elem <- String -> RuleM (Wise SOACS) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise SOACS) VName)
-> String -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_elem"
Type
arr_t <- VName -> RuleM (Wise SOACS) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
VName
arr' <-
if Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
w
then VName -> RuleM (Wise SOACS) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
else
Certificates
-> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName)
-> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
String
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_prefix") (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName)
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall a b. (a -> b) -> a -> b
$
(VName, Param Type, ArrayOp) -> Maybe (VName, Param Type, ArrayOp)
forall a. a -> Maybe a
Just
( VName
arr',
VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
arr_elem (Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
arr_t),
Certificates -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certificates
cs VName
arr_elem (Int -> Slice SubExp -> Slice SubExp
forall a. Int -> [a] -> [a]
drop Int
1 Slice SubExp
slice)
)
mapOverArr ArrayOp
_ = Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Param Type, ArrayOp)
forall a. Maybe a
Nothing
simplifyMapIota SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall lore. Rule lore
Skip
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput SymbolTable (Wise SOACS)
vtable Pattern (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam))
| [ArrayOp]
ops <- ((PatternT (VarWisdom, Type), ArrayOp) -> ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> [a] -> [b]
map (PatternT (VarWisdom, Type), ArrayOp) -> ArrayOp
forall a b. (a, b) -> b
snd ([(PatternT (VarWisdom, Type), ArrayOp)] -> [ArrayOp])
-> [(PatternT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> a -> b
$ ((PatternT (VarWisdom, Type), ArrayOp) -> Bool)
-> [(PatternT (VarWisdom, Type), ArrayOp)]
-> [(PatternT (VarWisdom, Type), ArrayOp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatternT (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam ([(PatternT (VarWisdom, Type), ArrayOp)]
-> [(PatternT (VarWisdom, Type), ArrayOp)])
-> [(PatternT (VarWisdom, Type), ArrayOp)]
-> [(PatternT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ Set (PatternT (VarWisdom, Type), ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (PatternT (VarWisdom, Type), ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)])
-> Set (PatternT (VarWisdom, Type), ArrayOp)
-> [(PatternT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
arrayOps (BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp))
-> BodyT (Wise SOACS) -> Set (Pattern (Wise SOACS), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ArrayOp]
ops = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
([VName]
more_arrs, [Param Type]
more_params, [ArrayOp]
replacements) <-
[(VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp]))
-> ([Maybe (VName, Param Type, ArrayOp)]
-> [(VName, Param Type, ArrayOp)])
-> [Maybe (VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, Param Type, ArrayOp)]
-> [(VName, Param Type, ArrayOp)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp]))
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
-> RuleM (Wise SOACS) ([VName], [Param Type], [ArrayOp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ArrayOp
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> [ArrayOp]
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr [ArrayOp]
ops
Bool -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
more_arrs) RuleM (Wise SOACS) ()
forall lore a. RuleM lore a
cannotSimplify
let substs :: Map ArrayOp ArrayOp
substs = [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp)
-> [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> [ArrayOp] -> [(ArrayOp, ArrayOp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ArrayOp]
ops [ArrayOp]
replacements
map_lam' :: Lambda (Wise SOACS)
map_lam' =
Lambda (Wise SOACS)
map_lam
{ lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
map_lam [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
lambdaBody :: BodyT (Wise SOACS)
lambdaBody =
Map ArrayOp ArrayOp -> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (BodyT (Wise SOACS) -> BodyT (Wise SOACS))
-> BodyT (Wise SOACS) -> BodyT (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam
}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM (Wise SOACS)))
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise SOACS)))
Pattern (Wise SOACS)
pat (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall lore. Op lore -> ExpT lore
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) ([Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
where
map_param_names :: [VName]
map_param_names = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Wise SOACS)
map_lam)
topLevelPattern :: PatternT (VarWisdom, Type) -> Bool
topLevelPattern = (PatternT (VarWisdom, Type)
-> Seq (PatternT (VarWisdom, Type)) -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Stm (Wise SOACS) -> PatternT (VarWisdom, Type))
-> Stms (Wise SOACS) -> Seq (PatternT (VarWisdom, Type))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Wise SOACS) -> PatternT (VarWisdom, Type)
forall lore. Stm lore -> Pattern lore
stmPattern (BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam)))
onlyUsedOnce :: VName -> Bool
onlyUsedOnce VName
arr =
case (Stm (Wise SOACS) -> Bool)
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName
arr VName -> Names -> Bool
`nameIn`) (Names -> Bool)
-> (Stm (Wise SOACS) -> Names) -> Stm (Wise SOACS) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn) ([Stm (Wise SOACS)] -> [Stm (Wise SOACS)])
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT (Wise SOACS) -> Stms (Wise SOACS))
-> BodyT (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> BodyT (Wise SOACS)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Wise SOACS)
map_lam of
Stm (Wise SOACS)
_ : Stm (Wise SOACS)
_ : [Stm (Wise SOACS)]
_ -> Bool
False
[Stm (Wise SOACS)]
_ -> Bool
True
arrayIsMapParam :: (PatternT (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam (PatternT (VarWisdom, Type)
pat', ArrayIndexing Certificates
cs VName
arr Slice SubExp
slice) =
VName
arr VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certificates -> Names
forall a. FreeIn a => a -> Names
freeIn Certificates
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice)
Bool -> Bool -> Bool
&& Bool -> Bool
not (Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
slice)
Bool -> Bool -> Bool
&& (Bool -> Bool
not ([SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Bool -> Bool -> Bool
|| (PatternT (VarWisdom, Type) -> Bool
topLevelPattern PatternT (VarWisdom, Type)
pat' Bool -> Bool -> Bool
&& VName -> Bool
onlyUsedOnce VName
arr))
arrayIsMapParam (PatternT (VarWisdom, Type)
_, ArrayRearrange Certificates
cs VName
arr [Int]
perm) =
VName
arr VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certificates -> Names
forall a. FreeIn a => a -> Names
freeIn Certificates
cs)
Bool -> Bool -> Bool
&& Bool -> Bool
not ([Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm)
arrayIsMapParam (PatternT (VarWisdom, Type)
_, ArrayRotate Certificates
cs VName
arr [SubExp]
rots) =
VName
arr VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certificates -> Names
forall a. FreeIn a => a -> Names
freeIn Certificates
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
rots)
arrayIsMapParam (PatternT (VarWisdom, Type)
_, ArrayVar {}) =
Bool
False
mapOverArr :: ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr ArrayOp
op
| Just (VName
_, VName
arr) <- ((VName, VName) -> Bool)
-> [(VName, VName)] -> Maybe (VName, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayOp -> VName
arrayOpArr ArrayOp
op) (VName -> Bool)
-> ((VName, VName) -> VName) -> (VName, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> a
fst) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
map_param_names [VName]
arrs) = do
Type
arr_t <- VName -> RuleM (Wise SOACS) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let whole_dim :: DimIndex SubExp
whole_dim = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
VName
arr_transformed <- Certificates
-> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (ArrayOp -> Certificates
arrayOpCerts ArrayOp
op) (RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName)
-> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
String
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_transformed") (Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName)
-> Exp (Lore (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
case ArrayOp
op of
ArrayIndexing Certificates
_ VName
_ Slice SubExp
slice ->
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ DimIndex SubExp
whole_dim DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice
ArrayRearrange Certificates
_ VName
_ [Int]
perm ->
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm) VName
arr
ArrayRotate Certificates
_ VName
_ [SubExp]
rots ->
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
ArrayVar {} ->
BasicOp -> Exp (Wise SOACS)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
Type
arr_transformed_t <- VName -> RuleM (Wise SOACS) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr_transformed
VName
arr_transformed_row <- String -> RuleM (Wise SOACS) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise SOACS) VName)
-> String -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_transformed_row"
Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall a b. (a -> b) -> a -> b
$
(VName, Param Type, ArrayOp) -> Maybe (VName, Param Type, ArrayOp)
forall a. a -> Maybe a
Just
( VName
arr_transformed,
VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
arr_transformed_row (Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
arr_transformed_t),
Certificates -> VName -> ArrayOp
ArrayVar Certificates
forall a. Monoid a => a
mempty VName
arr_transformed_row
)
mapOverArr ArrayOp
_ = Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Param Type, ArrayOp)
forall a. Maybe a
Nothing
moveTransformToInput SymbolTable (Wise SOACS)
_ Pattern (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
Rule (Wise SOACS)
forall lore. Rule lore
Skip