{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.TileLoops (tileLoops) where
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.Map.Strict as M
import Data.Maybe (mapMaybe)
import qualified Data.Sequence as Seq
import Futhark.IR.Kernels
import Futhark.MonadFreshNames
import Futhark.Optimise.BlkRegTiling
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)
tileLoops :: Pass Kernels Kernels
tileLoops :: Pass Kernels Kernels
tileLoops =
String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"tile loops" String
"Tile stream loops inside kernels" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
(Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope Kernels -> Stms Kernels -> m (Stms Kernels)
onStms
where
onStms :: Scope Kernels -> Stms Kernels -> m (Stms Kernels)
onStms Scope Kernels
scope Stms Kernels
stms =
(VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels))
-> (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource))
-> State VNameSource (Stms Kernels)
-> VNameSource
-> (Stms Kernels, VNameSource)
forall a b. (a -> b) -> a -> b
$
ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms) Scope Kernels
scope
optimiseBody :: Body Kernels -> TileM (Body Kernels)
optimiseBody :: BodyT Kernels -> TileM (BodyT Kernels)
optimiseBody (Body () Stms Kernels
stms Result
res) =
BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (Stms Kernels -> Result -> BodyT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
(Scope Kernels) (State VNameSource) (Result -> BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms ReaderT
(Scope Kernels) (State VNameSource) (Result -> BodyT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) Result
-> TileM (BodyT Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope Kernels) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
optimiseStms :: Stms Kernels -> TileM (Stms Kernels)
optimiseStms :: Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms =
Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
[Stms Kernels] -> Stms Kernels
forall a. Monoid a => [a] -> a
mconcat ([Stms Kernels] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)
optimiseStm :: Stm Kernels -> TileM (Stms Kernels)
optimiseStm :: Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm stm :: Stm Kernels
stm@(Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp (SegMap lvl :: SegLevel
lvl@SegThread {} SegSpace
space [Type]
ts KernelBody Kernels
kbody)))) = do
Maybe (Stms Kernels, Stm Kernels)
res3dtiling <- Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
doRegTiling3D Stm Kernels
stm
case Maybe (Stms Kernels, Stm Kernels)
res3dtiling of
Just (Stms Kernels
extra_bnds, Stm Kernels
stmt') -> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
extra_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stmt')
Maybe (Stms Kernels, Stm Kernels)
Nothing -> do
Maybe (Stms Kernels, Stm Kernels)
blkRegTiling_res <- Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
mmBlkRegTiling Stm Kernels
stm
case Maybe (Stms Kernels, Stm Kernels)
blkRegTiling_res of
Just (Stms Kernels
extra_bnds, Stm Kernels
stmt') -> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
extra_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stmt')
Maybe (Stms Kernels, Stm Kernels)
Nothing -> do
(Stms Kernels
host_stms, (SegLevel
lvl', SegSpace
space', KernelBody Kernels
kbody')) <- Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
forall a. Monoid a => a
mempty VarianceTable
initial_variance SegLevel
lvl SegSpace
space [Type]
ts KernelBody Kernels
kbody
Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels
host_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl' SegSpace
space' [Type]
ts KernelBody Kernels
kbody')
where
initial_variance :: VarianceTable
initial_variance = (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> VarianceTable)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
optimiseStm (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) =
Stm Kernels -> Stms Kernels
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> ExpT Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise ExpT Kernels
e)
where
optimise :: Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper {mapOnBody :: Scope Kernels -> BodyT Kernels -> TileM (BodyT Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels -> TileM (BodyT Kernels) -> TileM (BodyT Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (TileM (BodyT Kernels) -> TileM (BodyT Kernels))
-> (BodyT Kernels -> TileM (BodyT Kernels))
-> BodyT Kernels
-> TileM (BodyT Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT Kernels -> TileM (BodyT Kernels)
optimiseBody}
tileInKernelBody ::
Names ->
VarianceTable ->
SegLevel ->
SegSpace ->
[Type] ->
KernelBody Kernels ->
TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody :: Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
branch_variant VarianceTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts KernelBody Kernels
kbody
| Just Result
kbody_res <- (KernelResult -> Maybe SubExp) -> [KernelResult] -> Maybe Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> Maybe SubExp
isSimpleResult ([KernelResult] -> Maybe Result) -> [KernelResult] -> Maybe Result
forall a b. (a -> b) -> a -> b
$ KernelBody Kernels -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody Kernels
kbody = do
Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant VarianceTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts (BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (KernelBody Kernels -> Stms Kernels
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody Kernels
kbody) Result
kbody_res
case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
Just (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) -> do
([KernelResult]
res', Stms Kernels
stms') <-
Binder Kernels [KernelResult]
-> ReaderT
(Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
-> ReaderT
(Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> ReaderT
(Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> [VName] -> Binder Kernels [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns Tiling
tiling) ([VName] -> Binder Kernels [KernelResult])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [KernelResult]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TiledBody
tiledBody Names
forall a. Monoid a => a
mempty PrivStms
forall a. Monoid a => a
mempty
(Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return
( Stms Kernels
host_stms,
( Tiling -> SegLevel
tilingLevel Tiling
tiling,
Tiling -> SegSpace
tilingSpace Tiling
tiling,
BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' [KernelResult]
res'
)
)
Maybe (Stms Kernels, Tiling, TiledBody)
Nothing ->
(Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
| Bool
otherwise =
(Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
where
isSimpleResult :: KernelResult -> Maybe SubExp
isSimpleResult (Returns ResultManifest
_ SubExp
se) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se
isSimpleResult KernelResult
_ = Maybe SubExp
forall a. Maybe a
Nothing
tileInBody ::
Names ->
VarianceTable ->
SegLevel ->
SegSpace ->
[Type] ->
Body Kernels ->
TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody :: Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant VarianceTable
initial_variance SegLevel
initial_lvl SegSpace
initial_space [Type]
res_ts (Body () Stms Kernels
initial_kstms Result
stms_res) =
Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
forall a. Monoid a => a
mempty ([Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
initial_kstms
where
variance :: VarianceTable
variance = VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
initial_variance Stms Kernels
initial_kstms
descend :: Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
_ [] =
Maybe (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms Kernels, Tiling, TiledBody)
forall a. Maybe a
Nothing
descend Stms Kernels
prestms (Stm Kernels
stm_to_tile : [Stm Kernels]
poststms)
| ([VName]
gtids, Result
kdims) <- [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
Just [InputArray]
inputs <-
(VName -> Maybe InputArray) -> [VName] -> Maybe [InputArray]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Names -> VarianceTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant VarianceTable
variance [VName]
gtids) [VName]
arrs,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
VName
gtid_y : VName
gtid_x : [VName]
top_gtids_rev <- [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
gtids,
SubExp
kdim_y : SubExp
kdim_x : Result
top_kdims_rev <- Result -> Result
forall a. [a] -> [a]
reverse Result
kdims,
(Stms Kernels
prestms', Stms Kernels
poststms') <-
VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
(Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms' Names
used
((Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling (VName, VName) (SubExp, SubExp)
-> SegLevel
-> [Type]
-> Pattern Kernels
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
tileGeneric
([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp))
-> [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
top_gtids_rev Result
top_kdims_rev)
SegLevel
initial_lvl
[Type]
res_ts
(Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
(VName
gtid_x, VName
gtid_y)
(SubExp
kdim_x, SubExp
kdim_y)
SubExp
w
(Commutativity, Lambda Kernels, Result, Lambda Kernels)
form
[InputArray]
inputs
Stms Kernels
poststms'
Result
stms_res
| (VName
gtid, SubExp
kdim) : [(VName, SubExp)]
top_space_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
[InputArray]
inputs <- (VName -> InputArray) -> [VName] -> [InputArray]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> VarianceTable -> VName -> InputArray
is1DTileable VName
gtid VarianceTable
variance) [VName]
arrs,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
gtid VName -> Names -> Bool
`nameIn` Names
branch_variant,
(Stms Kernels
prestms', Stms Kernels
poststms') <-
VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
(Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms' Names
used
((Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling VName SubExp
-> SegLevel
-> [Type]
-> Pattern Kernels
-> VName
-> SubExp
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
tileGeneric
([(VName, SubExp)] -> DoTiling VName SubExp
tiling1d ([(VName, SubExp)] -> DoTiling VName SubExp)
-> [(VName, SubExp)] -> DoTiling VName SubExp
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
top_space_rev)
SegLevel
initial_lvl
[Type]
res_ts
(Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
VName
gtid
SubExp
kdim
SubExp
w
(Commutativity, Lambda Kernels, Result, Lambda Kernels)
form
[InputArray]
inputs
Stms Kernels
poststms'
Result
stms_res
| DoLoop [] [(FParam Kernels, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
bound []) BodyT Kernels
loopbody <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm_to_tile,
(Stms Kernels
prestms', Stms Kernels
poststms') <-
VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms) = do
let branch_variant' :: Names
branch_variant' =
Names
branch_variant
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
( (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map
((VName -> VarianceTable -> Names)
-> VarianceTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) VarianceTable
variance)
(Names -> [VName]
namesToList (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
bound))
)
merge_params :: [Param (TypeBase Shape Uniqueness)]
merge_params = ((Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge
Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (VName -> NameInfo Kernels -> Scope Kernels -> Scope Kernels
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexName IntType
it) (Scope Kernels -> Scope Kernels) -> Scope Kernels -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
merge_params) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody
Names
branch_variant'
VarianceTable
variance
SegLevel
initial_lvl
SegSpace
initial_space
((Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
merge_params)
(BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms BodyT Kernels
loopbody) (BodyT Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT Kernels
loopbody)
case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
Maybe (Stms Kernels, Tiling, TiledBody)
Nothing -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next
Just (Stms Kernels, Tiling, TiledBody)
tiled ->
(Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just
((Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpDec Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
tileDoLoop
SegSpace
initial_space
VarianceTable
variance
Stms Kernels
prestms'
(BodyT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT Kernels
loopbody Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param (TypeBase Shape Uniqueness), SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge)
(Stms Kernels, Tiling, TiledBody)
tiled
[Type]
res_ts
(Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
(Stm Kernels -> StmAux (ExpDec Kernels)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm Kernels
stm_to_tile)
[(FParam Kernels, SubExp)]
merge
VName
i
IntType
it
SubExp
bound
Stms Kernels
poststms'
Result
stms_res
| Bool
otherwise = TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next
where
next :: TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next =
Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stm Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm Kernels
stm_to_tile) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend (Stms Kernels
prestms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stm_to_tile) [Stm Kernels]
poststms
preludeToPostlude ::
VarianceTable ->
Stms Kernels ->
Stm Kernels ->
Stms Kernels ->
(Stms Kernels, Stms Kernels)
preludeToPostlude :: VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prelude Stm Kernels
stm_to_tile Stms Kernels
postlude =
(Stms Kernels
prelude_used, Stms Kernels
prelude_not_used Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
postlude)
where
used_in_tiled :: Names
used_in_tiled = Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile
used_in_stm_variant :: Names
used_in_stm_variant =
(Names
used_in_tiled Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
(VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> VarianceTable -> Names)
-> VarianceTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList Names
used_in_tiled
used :: Stm Kernels -> Bool
used Stm Kernels
stm =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_in_stm_variant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm
(Stms Kernels
prelude_used, Stms Kernels
prelude_not_used) =
(Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
used Stms Kernels
prelude
partitionPrelude ::
VarianceTable ->
Stms Kernels ->
Names ->
Names ->
(Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude :: VarianceTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
private Names
used_after =
(Stms Kernels
invariant_prestms, Stms Kernels
precomputed_variant_prestms, Stms Kernels
recomputed_variant_prestms)
where
invariantTo :: Names -> Stm Kernels -> Bool
invariantTo Names
names Stm Kernels
stm =
case PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm) of
[] -> Bool
True
VName
v : [VName]
_ ->
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
names) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance
(Stms Kernels
invariant_prestms, Stms Kernels
variant_prestms) =
(Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition (Names -> Stm Kernels -> Bool
invariantTo Names
private) Stms Kernels
prestms
mustBeInlinedExp :: ExpT lore -> Bool
mustBeInlinedExp (BasicOp (Index VName
_ Slice SubExp
slice)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
mustBeInlinedExp (BasicOp Rotate {}) = Bool
True
mustBeInlinedExp (BasicOp Rearrange {}) = Bool
True
mustBeInlinedExp (BasicOp Reshape {}) = Bool
True
mustBeInlinedExp ExpT lore
_ = Bool
False
mustBeInlined :: Stm Kernels -> Bool
mustBeInlined Stm Kernels
stm =
ExpT Kernels -> Bool
forall {lore}. ExpT lore -> Bool
mustBeInlinedExp (Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm)
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_after) (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm))
must_be_inlined :: Names
must_be_inlined =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
(Stm Kernels -> [VName]) -> [Stm Kernels] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm Kernels] -> [VName]) -> [Stm Kernels] -> [VName]
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> Bool) -> Stms Kernels -> Stms Kernels
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm Kernels -> Bool
mustBeInlined Stms Kernels
variant_prestms
recompute :: Stm Kernels -> Bool
recompute Stm Kernels
stm =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
must_be_inlined) (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm))
Bool -> Bool -> Bool
|| Bool -> Bool
not (Names -> Stm Kernels -> Bool
invariantTo Names
must_be_inlined Stm Kernels
stm)
(Stms Kernels
recomputed_variant_prestms, Stms Kernels
precomputed_variant_prestms) =
(Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
recompute Stms Kernels
variant_prestms
injectPrelude ::
SegSpace ->
VarianceTable ->
Stms Kernels ->
Names ->
(Stms Kernels, Tiling, TiledBody) ->
(Stms Kernels, Tiling, TiledBody)
injectPrelude :: SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms Names
used (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) =
(Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')
where
tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = do
let nontiled :: (VName, SubExp) -> Bool
nontiled = ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling))
private' :: Names
private' =
Names
private
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst (((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, SubExp) -> Bool
nontiled ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space))
( Stms Kernels
invariant_prestms,
Stms Kernels
precomputed_variant_prestms,
Stms Kernels
recomputed_variant_prestms
) =
VarianceTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
private' Names
used
Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms
let live_set :: [VName]
live_set =
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
Names
used Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
recomputed_variant_prestms
[VName]
prelude_arrs <-
Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
precomputed_variant_prestms [VName]
live_set
let prelude_privstms :: PrivStms
prelude_privstms =
Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set
TiledBody
tiledBody Names
private' (PrivStms
prelude_privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms)
tileDoLoop ::
SegSpace ->
VarianceTable ->
Stms Kernels ->
Names ->
(Stms Kernels, Tiling, TiledBody) ->
[Type] ->
Pattern Kernels ->
StmAux (ExpDec Kernels) ->
[(FParam Kernels, SubExp)] ->
VName ->
IntType ->
SubExp ->
Stms Kernels ->
Result ->
TileM (Stms Kernels, Tiling, TiledBody)
tileDoLoop :: SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpDec Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
tileDoLoop SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms Names
used_in_body (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) [Type]
res_ts Pattern Kernels
pat StmAux (ExpDec Kernels)
aux [(FParam Kernels, SubExp)]
merge VName
i IntType
it SubExp
bound Stms Kernels
poststms Result
poststms_res = do
let prestms_used :: Names
prestms_used = Names
used_in_body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
poststms_res
( Stms Kernels
invariant_prestms,
Stms Kernels
precomputed_variant_prestms,
Stms Kernels
recomputed_variant_prestms
) =
VarianceTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
tiled_kdims Names
prestms_used
let ([Param (TypeBase Shape Uniqueness)]
mergeparams, Result
mergeinits) = [(Param (TypeBase Shape Uniqueness), SubExp)]
-> ([Param (TypeBase Shape Uniqueness)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge
tileDim :: TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
t = TypeBase Shape Uniqueness
-> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t (Tiling -> Shape
tilingTileShape Tiling
tiling) (Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> TypeBase Shape Uniqueness
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness TypeBase Shape Uniqueness
t
merge_scope :: Scope Kernels
merge_scope = VName -> NameInfo Kernels -> Scope Kernels -> Scope Kernels
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexName IntType
it) (Scope Kernels -> Scope Kernels) -> Scope Kernels -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams
tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = Scope Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
host_stms Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope Kernels
merge_scope) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms
let live_set :: [VName]
live_set =
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
recomputed_variant_prestms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
prestms_used
[VName]
prelude_arrs <-
Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
precomputed_variant_prestms [VName]
live_set
[Param (TypeBase Shape Uniqueness)]
mergeparams' <- [Param (TypeBase Shape Uniqueness)]
-> (Param (TypeBase Shape Uniqueness)
-> BinderT
Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BinderT
Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
mergeparams ((Param (TypeBase Shape Uniqueness)
-> BinderT
Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BinderT
Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)])
-> (Param (TypeBase Shape Uniqueness)
-> BinderT
Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BinderT
Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ \(Param VName
pname TypeBase Shape Uniqueness
pt) ->
VName
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness)
forall dec. VName -> dec -> Param dec
Param (VName
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) VName
-> BinderT
Kernels
(State VNameSource)
(TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
pname String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group") BinderT
Kernels
(State VNameSource)
(TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) (TypeBase Shape Uniqueness)
-> BinderT
Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase Shape Uniqueness
-> BinderT Kernels (State VNameSource) (TypeBase Shape Uniqueness)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
pt)
let merge_ts :: [Type]
merge_ts = (Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
mergeparams
let inloop_privstms :: PrivStms
inloop_privstms =
Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set
Result
mergeinit' <-
([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
Certificates
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec Kernels)
aux) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"tiled_loopinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
\PrimExp VName
in_bounds Slice SubExp
slice ->
([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"loopinit" PrimExp VName
in_bounds [Type]
merge_ts (BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
inloop_privstms
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
mergeinits
let merge' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
merge' = [Param (TypeBase Shape Uniqueness)]
-> Result -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams' Result
mergeinit'
let indexMergeParams :: ReadPrelude
indexMergeParams Slice SubExp
slice =
Scope Kernels
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams') (BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
[(Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))]
-> ((Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [(Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams [Param (TypeBase Shape Uniqueness)]
mergeparams') (((Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ())
-> ((Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
to, Param (TypeBase Shape Uniqueness)
from) ->
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
to] (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
from) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice (Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (TypeBase Shape Uniqueness)
from) Slice SubExp
slice
private' :: Names
private' =
Names
private Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams')
privstms' :: PrivStms
privstms' =
Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
forall a. Monoid a => a
mempty ReadPrelude
indexMergeParams PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms
BodyT Kernels
loopbody' <-
Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TiledBody
tiledBody Names
private' PrivStms
privstms'
[VName]
accs' <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"tiled_inside_loop" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
[(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> BodyT Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge' (VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound []) BodyT Kernels
loopbody'
Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling (PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms) Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts
(Stms Kernels, Tiling, TiledBody)
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')
where
tiled_kdims :: Names
tiled_kdims =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling)) ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space
doPrelude :: Tiling -> PrivStms -> Stms Kernels -> [VName] -> Binder Kernels [VName]
doPrelude :: Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
prestms [VName]
prestms_live =
Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"prelude" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
\PrimExp VName
in_bounds Slice SubExp
slice -> do
[Type]
ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
prestms_live
([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre"
(ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
(PrimExp VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds)
( do
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
prestms
Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
prestms_live
)
([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)
liveSet :: FreeIn a => Stms Kernels -> a -> Names
liveSet :: forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
stms a
after =
[VName] -> Names
namesFromList ((Stm Kernels -> [VName]) -> Stms Kernels -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) Stms Kernels
stms)
Names -> Names -> Names
`namesIntersection` a -> Names
forall a. FreeIn a => a -> Names
freeIn a
after
tileable ::
Stm Kernels ->
Maybe
( SubExp,
[VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels)
)
tileable :: Stm Kernels
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm
| Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm Kernels
form)) <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm,
Just ([Reduce Kernels]
reds, Lambda Kernels
map_lam) <- ScremaForm Kernels -> Maybe ([Reduce Kernels], Lambda Kernels)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm Kernels
form,
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
red_nes <- [Reduce Kernels] -> Reduce Kernels
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce Kernels]
reds,
Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs,
(Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam,
(Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam =
(SubExp, [VName],
(Commutativity, Lambda Kernels, Result, Lambda Kernels))
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam))
| Bool
otherwise =
Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. Maybe a
Nothing
data InputArray
= InputTile [Int] VName
| InputDontTile VName
tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs = (InputArray -> Maybe (VName, [Int]))
-> [InputArray] -> [(VName, [Int])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputArray -> Maybe (VName, [Int])
f
where
f :: InputArray -> Maybe (VName, [Int])
f (InputTile [Int]
perm VName
arr) = (VName, [Int]) -> Maybe (VName, [Int])
forall a. a -> Maybe a
Just (VName
arr, [Int]
perm)
f InputDontTile {} = Maybe (VName, [Int])
forall a. Maybe a
Nothing
data InputTile
= InputTiled [Int] VName
| InputUntiled VName
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles (InputTile [Int]
perm VName
_ : [InputArray]
inputs) (VName
tile : [VName]
tiles) =
[Int] -> VName -> InputTile
InputTiled [Int]
perm VName
tile InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles (InputDontTile VName
arr : [InputArray]
inputs) [VName]
tiles =
VName -> InputTile
InputUntiled VName
arr InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles [InputArray]
_ [VName]
_ = []
sliceUntiled ::
MonadBinder m =>
VName ->
SubExp ->
SubExp ->
SubExp ->
m VName
sliceUntiled :: forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
full_tile_size SubExp
this_tile_size = do
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
SubExp
slice_offset <-
String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
full_tile_size)
let slice :: DimIndex SubExp
slice = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
slice_offset SubExp
this_tile_size (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"untiled_slice" (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [DimIndex SubExp
slice]
data PrivStms = PrivStms (Stms Kernels) ReadPrelude
privStms :: Stms Kernels -> PrivStms
privStms :: Stms Kernels -> PrivStms
privStms Stms Kernels
stms = Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$ BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. a -> b -> a
const (BinderT Kernels (State VNameSource) () -> ReadPrelude)
-> BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. (a -> b) -> a -> b
$ () -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
addPrivStms :: Slice SubExp -> PrivStms -> Binder Kernels ()
addPrivStms :: Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
local_slice (PrivStms Stms Kernels
stms ReadPrelude
readPrelude) = do
ReadPrelude
readPrelude Slice SubExp
local_slice
Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
stms
instance Semigroup PrivStms where
PrivStms Stms Kernels
stms_x ReadPrelude
readPrelude_x <> :: PrivStms -> PrivStms -> PrivStms
<> PrivStms Stms Kernels
stms_y ReadPrelude
readPrelude_y =
Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms_z ReadPrelude
readPrelude_z
where
stms_z :: Stms Kernels
stms_z = Stms Kernels
stms_x Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
stms_y
readPrelude_z :: ReadPrelude
readPrelude_z Slice SubExp
slice = ReadPrelude
readPrelude_x Slice SubExp
slice BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ReadPrelude
readPrelude_y Slice SubExp
slice
instance Monoid PrivStms where
mempty :: PrivStms
mempty = Stms Kernels -> PrivStms
privStms Stms Kernels
forall a. Monoid a => a
mempty
type ReadPrelude = Slice SubExp -> Binder Kernels ()
data ProcessTileArgs = ProcessTileArgs
{ ProcessTileArgs -> PrivStms
processPrivStms :: PrivStms,
ProcessTileArgs -> Commutativity
processComm :: Commutativity,
ProcessTileArgs -> Lambda Kernels
processRedLam :: Lambda Kernels,
ProcessTileArgs -> Lambda Kernels
processMapLam :: Lambda Kernels,
ProcessTileArgs -> [InputTile]
processTiles :: [InputTile],
ProcessTileArgs -> [VName]
processAcc :: [VName],
ProcessTileArgs -> SubExp
processTileId :: SubExp
}
data ResidualTileArgs = ResidualTileArgs
{ ResidualTileArgs -> PrivStms
residualPrivStms :: PrivStms,
ResidualTileArgs -> Commutativity
residualComm :: Commutativity,
ResidualTileArgs -> Lambda Kernels
residualRedLam :: Lambda Kernels,
ResidualTileArgs -> Lambda Kernels
residualMapLam :: Lambda Kernels,
ResidualTileArgs -> [InputArray]
residualInput :: [InputArray],
ResidualTileArgs -> [VName]
residualAcc :: [VName],
ResidualTileArgs -> SubExp
residualInputSize :: SubExp,
ResidualTileArgs -> SubExp
residualNumWholeTiles :: SubExp
}
data Tiling = Tiling
{ Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap ::
String ->
SegLevel ->
ResultManifest ->
(PrimExp VName -> Slice SubExp -> Binder Kernels [SubExp]) ->
Binder Kernels [VName],
Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
tilingReadTile ::
TileKind ->
PrivStms ->
SubExp ->
[InputArray] ->
Binder Kernels [InputTile],
Tiling
-> ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile ::
ProcessTileArgs ->
Binder Kernels [VName],
Tiling
-> ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile ::
ResidualTileArgs ->
Binder Kernels [VName],
Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns :: VName -> Binder Kernels KernelResult,
Tiling -> SegSpace
tilingSpace :: SegSpace,
Tiling -> Shape
tilingTileShape :: Shape,
Tiling -> SegLevel
tilingLevel :: SegLevel,
Tiling -> Binder Kernels SubExp
tilingNumWholeTiles :: Binder Kernels SubExp
}
type DoTiling gtids kdims =
SegLevel -> gtids -> kdims -> SubExp -> Binder Kernels Tiling
scalarLevel :: Tiling -> SegLevel
scalarLevel :: Tiling -> SegLevel
scalarLevel Tiling
tiling =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt
where
lvl :: SegLevel
lvl = Tiling -> SegLevel
tilingLevel Tiling
tiling
protectOutOfBounds ::
String ->
PrimExp VName ->
[Type] ->
Binder Kernels [SubExp] ->
Binder Kernels [VName]
protectOutOfBounds :: String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
desc PrimExp VName
in_bounds [Type]
ts BinderT Kernels (State VNameSource) Result
m =
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds) (Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT Kernels (State VNameSource) Result
m) ([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)
postludeGeneric ::
Tiling ->
PrivStms ->
Pattern Kernels ->
[VName] ->
Stms Kernels ->
Result ->
[Type] ->
Binder Kernels [VName]
postludeGeneric :: Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts =
Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"thread_res" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice -> do
[(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat) [VName]
accs') (((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ())
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(VName
us, VName
everyone) -> do
Type
everyone_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
everyone
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
us] (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
everyone (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
everyone_t Slice SubExp
slice
if Stms Kernels
poststms Stms Kernels -> Stms Kernels -> Bool
forall a. Eq a => a -> a -> Bool
== Stms Kernels
forall a. Monoid a => a
mempty
then do
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res
else ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"postlude" PrimExp VName
in_bounds [Type]
res_ts (BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
poststms
Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res
type TiledBody = Names -> PrivStms -> Binder Kernels [VName]
tileGeneric ::
DoTiling gtids kdims ->
SegLevel ->
[Type] ->
Pattern Kernels ->
gtids ->
kdims ->
SubExp ->
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels) ->
[InputArray] ->
Stms Kernels ->
Result ->
TileM (Stms Kernels, Tiling, TiledBody)
tileGeneric :: forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
tileGeneric DoTiling gtids kdims
doTiling SegLevel
initial_lvl [Type]
res_ts Pattern Kernels
pat gtids
gtids kdims
kdims SubExp
w (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form [InputArray]
inputs Stms Kernels
poststms Result
poststms_res = do
(Tiling
tiling, Stms Kernels
tiling_stms) <- Binder Kernels Tiling
-> ReaderT
(Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels Tiling
-> ReaderT
(Scope Kernels) (State VNameSource) (Tiling, Stms Kernels))
-> Binder Kernels Tiling
-> ReaderT
(Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ DoTiling gtids kdims
doTiling SegLevel
initial_lvl gtids
gtids kdims
kdims SubExp
w
(Stms Kernels, Tiling, TiledBody)
-> ReaderT
(Scope Kernels)
(State VNameSource)
(Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
tiling_stms, Tiling
tiling, Tiling -> TiledBody
tiledBody Tiling
tiling)
where
(Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam) = (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form
tiledBody :: Tiling -> Names -> PrivStms -> Binder Kernels [VName]
tiledBody :: Tiling -> TiledBody
tiledBody Tiling
tiling Names
_private PrivStms
privstms = do
let tile_shape :: Shape
tile_shape = Tiling -> Shape
tilingTileShape Tiling
tiling
SubExp
num_whole_tiles <- Tiling -> Binder Kernels SubExp
tilingNumWholeTiles Tiling
tiling
[VName]
mergeinits <- Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"mergeinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice ->
if Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
red_nes Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
then Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes
else ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"neutral" PrimExp VName
in_bounds (Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam) (BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes
[(Param (TypeBase Shape Uniqueness), SubExp)]
merge <- [(Param Type, VName)]
-> ((Param Type, VName)
-> BinderT
Kernels
(State VNameSource)
(Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
Kernels
(State VNameSource)
[(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam) [VName]
mergeinits) (((Param Type, VName)
-> BinderT
Kernels
(State VNameSource)
(Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
Kernels
(State VNameSource)
[(Param (TypeBase Shape Uniqueness), SubExp)])
-> ((Param Type, VName)
-> BinderT
Kernels
(State VNameSource)
(Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
Kernels
(State VNameSource)
[(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
mergeinit) ->
(,)
(Param (TypeBase Shape Uniqueness)
-> SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
-> BinderT
Kernels
(State VNameSource)
(SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> TypeBase Shape Uniqueness
-> BinderT
Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam
(VName -> String
baseString (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_merge")
(Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> Shape -> Type
`arrayOfShape` Shape
tile_shape Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique)
BinderT
Kernels
(State VNameSource)
(SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> Binder Kernels SubExp
-> BinderT
Kernels
(State VNameSource)
(Param (TypeBase Shape Uniqueness), SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> Binder Kernels SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
mergeinit)
VName
tile_id <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_id"
let loopform :: LoopForm Kernels
loopform = VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
tile_id IntType
Int64 SubExp
num_whole_tiles []
BodyT Kernels
loopbody <- BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> (Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
LoopForm Kernels
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm Kernels
loopform (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
Scope Kernels
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels)
-> [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
merge) (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
[InputTile]
tile <- Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
tilingReadTile Tiling
tiling TileKind
TilePartial PrivStms
privstms (VName -> SubExp
Var VName
tile_id) [InputArray]
inputs
let accs :: [VName]
accs =
((Param (TypeBase Shape Uniqueness), SubExp) -> VName)
-> [(Param (TypeBase Shape Uniqueness), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> ((Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness))
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
merge
tile_args :: ProcessTileArgs
tile_args =
PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tile [VName]
accs (VName -> SubExp
Var VName
tile_id)
Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tiling
-> ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile Tiling
tiling ProcessTileArgs
tile_args
[VName]
accs <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"accs" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> BodyT Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge LoopForm Kernels
loopform BodyT Kernels
loopbody
Lambda Kernels
red_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
red_lam
Lambda Kernels
map_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
map_lam
let residual_args :: ResidualTileArgs
residual_args =
PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputArray]
-> [VName]
-> SubExp
-> SubExp
-> ResidualTileArgs
ResidualTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam' Lambda Kernels
map_lam' [InputArray]
inputs [VName]
accs SubExp
w SubExp
num_whole_tiles
[VName]
accs' <- Tiling
-> ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile Tiling
tiling ResidualTileArgs
residual_args
Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts
data TileKind = TilePartial | TileFull
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prestms_live_arrs [VName]
prestms_live Slice SubExp
slice =
([()] -> ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [()] -> ()
forall a. Monoid a => [a] -> a
mconcat (BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
[(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prestms_live_arrs [VName]
prestms_live) (((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, VName
v) -> do
Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t Slice SubExp
slice
tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Binder Kernels KernelResult
tileReturns :: [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp, SubExp)]
dims VName
arr = do
let unit_dims :: Result
unit_dims = Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
dims_on_top) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
VName
arr' <-
if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
then VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
else do
Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let new_shape :: Result
new_shape = Result
unit_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr) (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew Result
new_shape) VName
arr
let tile_dims :: [(SubExp, SubExp)]
tile_dims = Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top) Result
unit_dims [(SubExp, SubExp)] -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp, SubExp)]
dims
KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> BinderT Kernels (State VNameSource) KernelResult)
-> KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns [(SubExp, SubExp)]
tile_dims VName
arr'
is1DTileable :: VName -> M.Map VName Names -> VName -> InputArray
is1DTileable :: VName -> VarianceTable -> VName -> InputArray
is1DTileable VName
gtid VarianceTable
variance VName
arr
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
gtid (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr VarianceTable
variance =
[Int] -> VName -> InputArray
InputTile [Int
0] VName
arr
| Bool
otherwise =
VName -> InputArray
InputDontTile VName
arr
segMap1D ::
String ->
SegLevel ->
ResultManifest ->
(VName -> Binder Kernels [SubExp]) ->
Binder Kernels [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> BinderT Kernels (State VNameSource) Result
f = do
VName
ltid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]
(([Type]
ts, Result
res), Stms Kernels
stms) <- Binder Kernels ([Type], Result)
-> BinderT
Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], Result)
-> BinderT
Kernels (State VNameSource) (([Type], Result), Stms Kernels))
-> Binder Kernels ([Type], Result)
-> BinderT
Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
Result
res <- VName -> BinderT Kernels (State VNameSource) Result
f VName
ltid
[Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> Result -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType Result
res
([Type], Result) -> Binder Kernels ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)
Body BodyDec Kernels
_ Stms Kernels
stms' Result
res' <- BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
res
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) Result
res'
reconstructGtids1D ::
Count GroupSize SubExp ->
VName ->
VName ->
VName ->
Binder Kernels ()
reconstructGtids1D :: Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid =
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid]
(ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
readTile1D ::
SubExp ->
VName ->
VName ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
TileKind ->
PrivStms ->
SubExp ->
[InputArray] ->
Binder Kernels [InputTile]
readTile1D :: SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
([VName] -> [InputTile])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
(BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile])
-> ((VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"full_tile" SegLevel
lvl ResultManifest
ResultNoSimplify
((VName -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
SubExp
j <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j"
(ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms
let arrs :: [VName]
arrs = ((VName, [Int]) -> VName) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> VName
forall a b. (a, b) -> a
fst ([(VName, [Int])] -> [VName]) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs
[Type]
arr_ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrs
let tile_ts :: [Type]
tile_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
arr_ts
w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
arr_ts
let readTileElem :: VName -> BinderT Kernels (State VNameSource) VName
readTileElem VName
arr =
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile_elem" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j])
([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
case TileKind
kind of
TileKind
TilePartial ->
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre"
(ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
(TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
j TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w)
(Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Binder Kernels SubExp)
-> [VName] -> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> SubExp)
-> BinderT Kernels (State VNameSource) VName
-> Binder Kernels SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (BinderT Kernels (State VNameSource) VName
-> Binder Kernels SubExp)
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> VName
-> Binder Kernels SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BinderT Kernels (State VNameSource) VName
readTileElem) [VName]
arrs)
([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
tile_ts)
TileKind
TileFull ->
(VName -> BinderT Kernels (State VNameSource) VName)
-> [VName] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) VName
readTileElem [VName]
arrs
where
lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt
processTile1D ::
VName ->
VName ->
SubExp ->
SubExp ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
ProcessTileArgs ->
Binder Kernels [VName]
processTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
let red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
map_lam :: Lambda Kernels
map_lam = ProcessTileArgs -> Lambda Kernels
processMapLam ProcessTileArgs
tile_args
red_lam :: Lambda Kernels
red_lam = ProcessTileArgs -> Lambda Kernels
processRedLam ProcessTileArgs
tile_args
tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args
accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args
String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"acc" SegLevel
lvl ResultManifest
ResultPrivate ((VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms
Result
thread_accs <- [VName]
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result)
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid]
let sliceTile :: InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile (InputTiled [Int]
_ VName
arr) =
VName -> BinderT Kernels (State VNameSource) VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
sliceTile (InputUntiled VName
arr) =
VName
-> SubExp
-> SubExp
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
tile_size
[VName]
tiles' <- (InputTile -> BinderT Kernels (State VNameSource) VName)
-> [InputTile] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile [InputTile]
tiles
let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam
([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc"
(ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
(TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim)
([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm Kernels -> SOAC Kernels
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
tile_size [VName]
tiles' ScremaForm Kernels
form'])
(Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)
where
lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt
processResidualTile1D ::
VName ->
VName ->
SubExp ->
SubExp ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
ResidualTileArgs ->
Binder Kernels [VName]
processResidualTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ResidualTileArgs
args = do
SubExp
residual_input <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual"
(ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
(TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
(Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
(SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input)
where
red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
map_lam :: Lambda Kernels
map_lam = ResidualTileArgs -> Lambda Kernels
residualMapLam ResidualTileArgs
args
red_lam :: Lambda Kernels
red_lam = ResidualTileArgs -> Lambda Kernels
residualRedLam ResidualTileArgs
args
privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args
nonemptyTile :: SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input = Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
[InputTile]
full_tiles <-
SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D
SubExp
tile_size
VName
gid
VName
gtid
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
TileKind
TilePartial
PrivStms
privstms
SubExp
num_whole_tiles
[InputArray]
inputs
let sliceTile :: InputTile -> BinderT Kernels (State VNameSource) InputTile
sliceTile (InputUntiled VName
arr) =
InputTile -> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> InputTile -> BinderT Kernels (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr
sliceTile (InputTiled [Int]
perm VName
tile) = do
let slice :: DimIndex SubExp
slice =
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
[Int] -> VName -> InputTile
InputTiled [Int]
perm
(VName -> InputTile)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile [DimIndex SubExp
slice])
[InputTile]
tiles <- (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> [InputTile] -> Binder Kernels [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) InputTile
sliceTile [InputTile]
full_tiles
let tile_args :: ProcessTileArgs
tile_args =
PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles
Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
residual_input Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d [(VName, SubExp)]
dims_on_top SegLevel
initial_lvl VName
gtid SubExp
kdim SubExp
w = do
VName
gid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid"
VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"
(SegLevel
lvl, SegSpace
space) <-
if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
then
(SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl) (SegVirt -> SegLevel) -> SegVirt -> SegLevel
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegVirt
segVirt SegLevel
initial_lvl,
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat [(VName
gid, Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl)]
)
else do
SubExp
group_size <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl)) SubExp
kdim
SubExp
ldim <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"ldim" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim SubExp
group_size
SubExp
num_groups <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_num_groups"
(ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
ldim (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)
(SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt,
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid, SubExp
ldim)]
)
let tile_size :: SubExp
tile_size = Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return
Tiling :: (String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile])
-> (ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName])
-> (ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> Binder Kernels SubExp
-> Tiling
Tiling
{ tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f -> String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl' ResultManifest
manifest ((VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid]
(ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f (TPrimExp Bool VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid],
tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Binder Kernels [InputTile]
tilingReadTile =
SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
tilingProcessTile :: ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile =
VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
tilingProcessResidualTile :: ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile =
VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim, SubExp
tile_size)],
tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size],
tilingNumWholeTiles :: Binder Kernels SubExp
tilingNumWholeTiles =
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
tilingSpace :: SegSpace
tilingSpace = SegSpace
space
}
invariantToOneOfTwoInnerDims ::
Names ->
M.Map VName Names ->
[VName] ->
VName ->
Maybe InputArray
invariantToOneOfTwoInnerDims :: Names -> VarianceTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant VarianceTable
variance [VName]
dims VName
arr = do
VName
j : VName
i : [VName]
_ <- [VName] -> Maybe [VName]
forall a. a -> Maybe a
Just ([VName] -> Maybe [VName]) -> [VName] -> Maybe [VName]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
dims
let variant_to :: Names
variant_to = Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr VarianceTable
variance
branch_invariant :: Bool
branch_invariant = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
j VName -> Names -> Bool
`nameIn` Names
variant_to)
then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
0, Int
1] VName
arr
else
if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
i VName -> Names -> Bool
`nameIn` Names
variant_to)
then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
1, Int
0] VName
arr
else InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ VName -> InputArray
InputDontTile VName
arr
reconstructGtids2D ::
SubExp ->
(VName, VName) ->
(VName, VName) ->
(VName, VName) ->
Binder Kernels ()
reconstructGtids2D :: SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y) = do
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_x]
(ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_y]
(ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
readTile2D ::
(SubExp, SubExp) ->
(VName, VName) ->
(VName, VName) ->
SubExp ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
TileKind ->
PrivStms ->
SubExp ->
[InputArray] ->
Binder Kernels [InputTile]
readTile2D :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
([VName] -> [InputTile])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
(BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile])
-> (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D
String
"full_tile"
(Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
ResultManifest
ResultNoSimplify
(SubExp
tile_size, SubExp
tile_size)
(((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
SubExp
i <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i"
(ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
SubExp
j <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j"
(ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms
let arrs_and_perms :: [(VName, [Int])]
arrs_and_perms = [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs
readTileElem :: (VName, [Int]) -> BinderT Kernels (State VNameSource) VName
readTileElem (VName
arr, [Int]
perm) =
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp
String
"tile_elem"
( BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index
VName
arr
[SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]]
)
readTileElemIfInBounds :: (VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
readTileElemIfInBounds (VName
arr, [Int]
perm) = do
Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let tile_t :: Type
tile_t = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
arr_t
w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
idx :: SubExp
idx = Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]
othercheck :: TPrimExp Bool VName
othercheck =
[TPrimExp Bool VName] -> TPrimExp Bool VName
forall a. [a] -> a
last ([TPrimExp Bool VName] -> TPrimExp Bool VName)
-> [TPrimExp Bool VName] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$
[Int] -> [TPrimExp Bool VName] -> [TPrimExp Bool VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape
[Int]
perm
[ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y,
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
]
BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
(TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool VName
othercheck)
([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]])
([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank Type
tile_t])
([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
case TileKind
kind of
TileKind
TilePartial ->
((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> [(VName, [Int])] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"pre" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> ((VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> (VName, [Int])
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
readTileElemIfInBounds) [(VName, [Int])]
arrs_and_perms
TileKind
TileFull ->
((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> [(VName, [Int])] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, [Int]) -> BinderT Kernels (State VNameSource) VName
readTileElem [(VName, [Int])]
arrs_and_perms
findTileSize :: HasScope lore m => [InputTile] -> m SubExp
findTileSize :: forall lore (m :: * -> *).
HasScope lore m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles =
case (InputTile -> Maybe VName) -> [InputTile] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputTile -> Maybe VName
isTiled [InputTile]
tiles of
VName
v : [VName]
_ -> Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp) -> m Type -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
[] -> SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
where
isTiled :: InputTile -> Maybe VName
isTiled InputUntiled {} = Maybe VName
forall a. Maybe a
Nothing
isTiled (InputTiled [Int]
_ VName
tile) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
tile
processTile2D ::
(VName, VName) ->
(VName, VName) ->
(SubExp, SubExp) ->
SubExp ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
ProcessTileArgs ->
Binder Kernels [VName]
processTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
let privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
red_lam :: Lambda Kernels
red_lam = ProcessTileArgs -> Lambda Kernels
processRedLam ProcessTileArgs
tile_args
map_lam :: Lambda Kernels
map_lam = ProcessTileArgs -> Lambda Kernels
processMapLam ProcessTileArgs
tile_args
tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args
tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args
SubExp
actual_tile_size <- [InputTile] -> Binder Kernels SubExp
forall lore (m :: * -> *).
HasScope lore m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles
String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D
String
"acc"
(Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
ResultManifest
ResultPrivate
(SubExp
tile_size, SubExp
tile_size)
(((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms
Result
thread_accs <- [VName]
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result)
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y]
let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam
sliceTile :: InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile (InputUntiled VName
arr) =
VName
-> SubExp
-> SubExp
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
actual_tile_size
sliceTile (InputTiled [Int]
perm VName
tile) = do
Type
tile_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
tile
let idx :: DimIndex SubExp
idx = SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [Int] -> [VName] -> [VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [VName
ltid_x, VName
ltid_y]
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Int -> Slice SubExp -> Slice SubExp
sliceAt Type
tile_t ([Int] -> Int
forall a. [a] -> a
head [Int]
perm) [DimIndex SubExp
idx]
[VName]
tiles' <- (InputTile -> BinderT Kernels (State VNameSource) VName)
-> [InputTile] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile [InputTile]
tiles
([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc"
(ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
( TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
)
([BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm Kernels -> SOAC Kernels
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
actual_tile_size [VName]
tiles' ScremaForm Kernels
form'])
(Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)
processResidualTile2D ::
(VName, VName) ->
(VName, VName) ->
(SubExp, SubExp) ->
SubExp ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
ResidualTileArgs ->
Binder Kernels [VName]
processResidualTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D
(VName, VName)
gids
(VName, VName)
gtids
(SubExp, SubExp)
kdims
SubExp
tile_size
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
ResidualTileArgs
args = do
SubExp
residual_input <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual"
(ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
(TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
(Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
Kernels
(State VNameSource)
(Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
(SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input)
where
privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
red_lam :: Lambda Kernels
red_lam = ResidualTileArgs -> Lambda Kernels
residualRedLam ResidualTileArgs
args
map_lam :: Lambda Kernels
map_lam = ResidualTileArgs -> Lambda Kernels
residualMapLam ResidualTileArgs
args
accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args
nonemptyTile :: SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input = BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> (Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
[InputTile]
full_tile <-
(SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D
(SubExp, SubExp)
kdims
(VName, VName)
gtids
(VName, VName)
gids
SubExp
tile_size
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
TileKind
TilePartial
PrivStms
privstms
SubExp
num_whole_tiles
[InputArray]
inputs
let slice :: DimIndex SubExp
slice =
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
[InputTile]
tiles <- [InputTile]
-> (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> Binder Kernels [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [InputTile]
full_tile ((InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> Binder Kernels [InputTile])
-> (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \case
InputTiled [Int]
perm VName
tile' ->
[Int] -> VName -> InputTile
InputTiled [Int]
perm
(VName -> InputTile)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile' [DimIndex SubExp
slice, DimIndex SubExp
slice])
InputUntiled VName
arr ->
InputTile -> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> InputTile -> BinderT Kernels (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr
let tile_args :: ProcessTileArgs
tile_args =
PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles
Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D
(VName, VName)
gids
(VName, VName)
gtids
(SubExp, SubExp)
kdims
SubExp
tile_size
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
ProcessTileArgs
tile_args
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d [(VName, SubExp)]
dims_on_top SegLevel
_initial_lvl (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
w = do
VName
gid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_x"
VName
gid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_y"
Name
tile_size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_size"
SubExp
tile_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"tile_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tile_size_key SizeClass
SizeTile
SubExp
group_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
tile_size SubExp
tile_size
SubExp
num_groups_x <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_x" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_x SubExp
tile_size
SubExp
num_groups_y <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_y" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_y SubExp
tile_size
SubExp
num_groups <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_top"
(ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
SubExp
num_groups_x
(SubExp
num_groups_y SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)
VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"
let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirtFull
space :: SegSpace
space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_x, SubExp
num_groups_x), (VName
gid_y, SubExp
num_groups_y)]
Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return
Tiling :: (String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile])
-> (ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName])
-> (ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> Binder Kernels SubExp
-> Tiling
Tiling
{ tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f ->
String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D String
desc SegLevel
lvl' ResultManifest
manifest (SubExp
tile_size, SubExp
tile_size) (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f
( TPrimExp Bool VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
)
[SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y],
tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Binder Kernels [InputTile]
tilingReadTile = (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
tilingProcessTile :: ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
tilingProcessResidualTile :: ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim_x, SubExp
tile_size), (SubExp
kdim_y, SubExp
tile_size)],
tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size, SubExp
tile_size],
tilingNumWholeTiles :: Binder Kernels SubExp
tilingNumWholeTiles =
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
tilingSpace :: SegSpace
tilingSpace = SegSpace
space
}