{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform a restricted form of loop tiling within SegMaps.  We only
-- tile primitive types, to avoid excessive local memory use.
module Futhark.Optimise.TileLoops (tileLoops) where

import Control.Monad.Reader
import Control.Monad.State
import qualified Data.Map.Strict as M
import Data.Maybe (mapMaybe)
import qualified Data.Sequence as Seq
import Futhark.IR.Kernels
import Futhark.MonadFreshNames
import Futhark.Optimise.BlkRegTiling
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)

-- | The pass definition.
tileLoops :: Pass Kernels Kernels
tileLoops :: Pass Kernels Kernels
tileLoops =
  String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"tile loops" String
"Tile stream loops inside kernels" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
    (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope Kernels -> Stms Kernels -> m (Stms Kernels)
onStms
  where
    onStms :: Scope Kernels -> Stms Kernels -> m (Stms Kernels)
onStms Scope Kernels
scope Stms Kernels
stms =
      (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels))
-> (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
        State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms Kernels)
 -> VNameSource -> (Stms Kernels, VNameSource))
-> State VNameSource (Stms Kernels)
-> VNameSource
-> (Stms Kernels, VNameSource)
forall a b. (a -> b) -> a -> b
$
          ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms) Scope Kernels
scope

optimiseBody :: Body Kernels -> TileM (Body Kernels)
optimiseBody :: BodyT Kernels -> TileM (BodyT Kernels)
optimiseBody (Body () Stms Kernels
stms Result
res) =
  BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (Stms Kernels -> Result -> BodyT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Result -> BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms ReaderT
  (Scope Kernels) (State VNameSource) (Result -> BodyT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) Result
-> TileM (BodyT Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope Kernels) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseStms :: Stms Kernels -> TileM (Stms Kernels)
optimiseStms :: Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms =
  Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
    [Stms Kernels] -> Stms Kernels
forall a. Monoid a => [a] -> a
mconcat ([Stms Kernels] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)

optimiseStm :: Stm Kernels -> TileM (Stms Kernels)
optimiseStm :: Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm stm :: Stm Kernels
stm@(Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp (SegMap lvl :: SegLevel
lvl@SegThread {} SegSpace
space [Type]
ts KernelBody Kernels
kbody)))) = do
  Maybe (Stms Kernels, Stm Kernels)
res3dtiling <- Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
doRegTiling3D Stm Kernels
stm
  case Maybe (Stms Kernels, Stm Kernels)
res3dtiling of
    Just (Stms Kernels
extra_bnds, Stm Kernels
stmt') -> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
extra_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stmt')
    Maybe (Stms Kernels, Stm Kernels)
Nothing -> do
      Maybe (Stms Kernels, Stm Kernels)
blkRegTiling_res <- Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
mmBlkRegTiling Stm Kernels
stm
      case Maybe (Stms Kernels, Stm Kernels)
blkRegTiling_res of
        Just (Stms Kernels
extra_bnds, Stm Kernels
stmt') -> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
extra_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stmt')
        Maybe (Stms Kernels, Stm Kernels)
Nothing -> do
          (Stms Kernels
host_stms, (SegLevel
lvl', SegSpace
space', KernelBody Kernels
kbody')) <- Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
forall a. Monoid a => a
mempty VarianceTable
initial_variance SegLevel
lvl SegSpace
space [Type]
ts KernelBody Kernels
kbody
          Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels
host_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl' SegSpace
space' [Type]
ts KernelBody Kernels
kbody')
  where
    initial_variance :: VarianceTable
initial_variance = (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> VarianceTable)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
optimiseStm (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) =
  Stm Kernels -> Stms Kernels
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> ExpT Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise ExpT Kernels
e)
  where
    optimise :: Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper {mapOnBody :: Scope Kernels -> BodyT Kernels -> TileM (BodyT Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels -> TileM (BodyT Kernels) -> TileM (BodyT Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (TileM (BodyT Kernels) -> TileM (BodyT Kernels))
-> (BodyT Kernels -> TileM (BodyT Kernels))
-> BodyT Kernels
-> TileM (BodyT Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT Kernels -> TileM (BodyT Kernels)
optimiseBody}

tileInKernelBody ::
  Names ->
  VarianceTable ->
  SegLevel ->
  SegSpace ->
  [Type] ->
  KernelBody Kernels ->
  TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody :: Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
branch_variant VarianceTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts KernelBody Kernels
kbody
  | Just Result
kbody_res <- (KernelResult -> Maybe SubExp) -> [KernelResult] -> Maybe Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> Maybe SubExp
isSimpleResult ([KernelResult] -> Maybe Result) -> [KernelResult] -> Maybe Result
forall a b. (a -> b) -> a -> b
$ KernelBody Kernels -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody Kernels
kbody = do
    Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
      Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant VarianceTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts (BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
        BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (KernelBody Kernels -> Stms Kernels
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody Kernels
kbody) Result
kbody_res
    case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
      Just (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) -> do
        ([KernelResult]
res', Stms Kernels
stms') <-
          Binder Kernels [KernelResult]
-> ReaderT
     (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
 -> ReaderT
      (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> ReaderT
     (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> [VName] -> Binder Kernels [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns Tiling
tiling) ([VName] -> Binder Kernels [KernelResult])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [KernelResult]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TiledBody
tiledBody Names
forall a. Monoid a => a
mempty PrivStms
forall a. Monoid a => a
mempty
        (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Stms Kernels
host_stms,
            ( Tiling -> SegLevel
tilingLevel Tiling
tiling,
              Tiling -> SegSpace
tilingSpace Tiling
tiling,
              BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' [KernelResult]
res'
            )
          )
      Maybe (Stms Kernels, Tiling, TiledBody)
Nothing ->
        (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
  | Bool
otherwise =
    (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
  where
    isSimpleResult :: KernelResult -> Maybe SubExp
isSimpleResult (Returns ResultManifest
_ SubExp
se) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se
    isSimpleResult KernelResult
_ = Maybe SubExp
forall a. Maybe a
Nothing

tileInBody ::
  Names ->
  VarianceTable ->
  SegLevel ->
  SegSpace ->
  [Type] ->
  Body Kernels ->
  TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody :: Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant VarianceTable
initial_variance SegLevel
initial_lvl SegSpace
initial_space [Type]
res_ts (Body () Stms Kernels
initial_kstms Result
stms_res) =
  Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
forall a. Monoid a => a
mempty ([Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
initial_kstms
  where
    variance :: VarianceTable
variance = VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
initial_variance Stms Kernels
initial_kstms

    descend :: Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
_ [] =
      Maybe (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms Kernels, Tiling, TiledBody)
forall a. Maybe a
Nothing
    descend Stms Kernels
prestms (Stm Kernels
stm_to_tile : [Stm Kernels]
poststms)
      -- 2D tiling of redomap.
      | ([VName]
gtids, Result
kdims) <- [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
        Just [InputArray]
inputs <-
          (VName -> Maybe InputArray) -> [VName] -> Maybe [InputArray]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Names -> VarianceTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant VarianceTable
variance [VName]
gtids) [VName]
arrs,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
        VName
gtid_y : VName
gtid_x : [VName]
top_gtids_rev <- [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
gtids,
        SubExp
kdim_y : SubExp
kdim_x : Result
top_kdims_rev <- Result -> Result
forall a. [a] -> [a]
reverse Result
kdims,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
        Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
        (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
    -> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms' Names
used
          ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling (VName, VName) (SubExp, SubExp)
-> SegLevel
-> [Type]
-> Pattern Kernels
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric
            ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp))
-> [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
top_gtids_rev Result
top_kdims_rev)
            SegLevel
initial_lvl
            [Type]
res_ts
            (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
            (VName
gtid_x, VName
gtid_y)
            (SubExp
kdim_x, SubExp
kdim_y)
            SubExp
w
            (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form
            [InputArray]
inputs
            Stms Kernels
poststms'
            Result
stms_res
      -- 1D tiling of redomap.
      | (VName
gtid, SubExp
kdim) : [(VName, SubExp)]
top_space_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
        [InputArray]
inputs <- (VName -> InputArray) -> [VName] -> [InputArray]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> VarianceTable -> VName -> InputArray
is1DTileable VName
gtid VarianceTable
variance) [VName]
arrs,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
gtid VName -> Names -> Bool
`nameIn` Names
branch_variant,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
        Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
        (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
    -> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms' Names
used
          ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling VName SubExp
-> SegLevel
-> [Type]
-> Pattern Kernels
-> VName
-> SubExp
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric
            ([(VName, SubExp)] -> DoTiling VName SubExp
tiling1d ([(VName, SubExp)] -> DoTiling VName SubExp)
-> [(VName, SubExp)] -> DoTiling VName SubExp
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
top_space_rev)
            SegLevel
initial_lvl
            [Type]
res_ts
            (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
            VName
gtid
            SubExp
kdim
            SubExp
w
            (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form
            [InputArray]
inputs
            Stms Kernels
poststms'
            Result
stms_res
      -- Tiling inside for-loop.
      | DoLoop [] [(FParam Kernels, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
bound []) BodyT Kernels
loopbody <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm_to_tile,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms) = do
        let branch_variant' :: Names
branch_variant' =
              Names
branch_variant
                Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
                  ( (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map
                      ((VName -> VarianceTable -> Names)
-> VarianceTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) VarianceTable
variance)
                      (Names -> [VName]
namesToList (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
bound))
                  )
            merge_params :: [Param (TypeBase Shape Uniqueness)]
merge_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge

        Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
          Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (VName -> NameInfo Kernels -> Scope Kernels -> Scope Kernels
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexName IntType
it) (Scope Kernels -> Scope Kernels) -> Scope Kernels -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
merge_params) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
 -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
            Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody
              Names
branch_variant'
              VarianceTable
variance
              SegLevel
initial_lvl
              SegSpace
initial_space
              ((Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
merge_params)
              (BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms BodyT Kernels
loopbody) (BodyT Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT Kernels
loopbody)

        case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
          Maybe (Stms Kernels, Tiling, TiledBody)
Nothing -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next
          Just (Stms Kernels, Tiling, TiledBody)
tiled ->
            (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just
              ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpDec Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileDoLoop
                SegSpace
initial_space
                VarianceTable
variance
                Stms Kernels
prestms'
                (BodyT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT Kernels
loopbody Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param (TypeBase Shape Uniqueness), SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge)
                (Stms Kernels, Tiling, TiledBody)
tiled
                [Type]
res_ts
                (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
                (Stm Kernels -> StmAux (ExpDec Kernels)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm Kernels
stm_to_tile)
                [(FParam Kernels, SubExp)]
merge
                VName
i
                IntType
it
                SubExp
bound
                Stms Kernels
poststms'
                Result
stms_res
      | Bool
otherwise = TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next
      where
        next :: TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next =
          Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stm Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm Kernels
stm_to_tile) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
 -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
            Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend (Stms Kernels
prestms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stm_to_tile) [Stm Kernels]
poststms

-- | Move statements from prelude to postlude if they are not used in
-- the tiled statement anyway.
preludeToPostlude ::
  VarianceTable ->
  Stms Kernels ->
  Stm Kernels ->
  Stms Kernels ->
  (Stms Kernels, Stms Kernels)
preludeToPostlude :: VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prelude Stm Kernels
stm_to_tile Stms Kernels
postlude =
  (Stms Kernels
prelude_used, Stms Kernels
prelude_not_used Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
postlude)
  where
    used_in_tiled :: Names
used_in_tiled = Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile

    used_in_stm_variant :: Names
used_in_stm_variant =
      (Names
used_in_tiled Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
        [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
          (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> VarianceTable -> Names)
-> VarianceTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$
            Names -> [VName]
namesToList Names
used_in_tiled

    used :: Stm Kernels -> Bool
used Stm Kernels
stm =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_in_stm_variant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
        PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm

    (Stms Kernels
prelude_used, Stms Kernels
prelude_not_used) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
used Stms Kernels
prelude

-- | Partition prelude statements preceding a tiled loop (or something
-- containing a tiled loop) into three categories:
--
-- 1) Group-level statements that are invariant to the threads in the group.
--
-- 2) Thread-variant statements that should be computed once with a segmap_thread_scalar.
--
-- 3) Thread-variant statements that should be recomputed whenever
-- they are needed.
--
-- The third category duplicates computation, so we only want to do it
-- when absolutely necessary.  Currently, this is necessary for
-- results that are views of an array (slicing, rotate, etc) and which
-- results are used after the prelude, because these cannot be
-- efficiently represented by a scalar segmap (they'll be manifested
-- in memory).
partitionPrelude ::
  VarianceTable ->
  Stms Kernels ->
  Names ->
  Names ->
  (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude :: VarianceTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
private Names
used_after =
  (Stms Kernels
invariant_prestms, Stms Kernels
precomputed_variant_prestms, Stms Kernels
recomputed_variant_prestms)
  where
    invariantTo :: Names -> Stm Kernels -> Bool
invariantTo Names
names Stm Kernels
stm =
      case PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm) of
        [] -> Bool
True -- Does not matter.
        VName
v : [VName]
_ ->
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
            (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
names) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
              Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
                Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance
    (Stms Kernels
invariant_prestms, Stms Kernels
variant_prestms) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition (Names -> Stm Kernels -> Bool
invariantTo Names
private) Stms Kernels
prestms

    mustBeInlinedExp :: ExpT lore -> Bool
mustBeInlinedExp (BasicOp (Index VName
_ Slice SubExp
slice)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
    mustBeInlinedExp (BasicOp Rotate {}) = Bool
True
    mustBeInlinedExp (BasicOp Rearrange {}) = Bool
True
    mustBeInlinedExp (BasicOp Reshape {}) = Bool
True
    mustBeInlinedExp ExpT lore
_ = Bool
False
    mustBeInlined :: Stm Kernels -> Bool
mustBeInlined Stm Kernels
stm =
      ExpT Kernels -> Bool
forall {lore}. ExpT lore -> Bool
mustBeInlinedExp (Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm)
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_after) (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm))

    must_be_inlined :: Names
must_be_inlined =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        (Stm Kernels -> [VName]) -> [Stm Kernels] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm Kernels] -> [VName]) -> [Stm Kernels] -> [VName]
forall a b. (a -> b) -> a -> b
$
          Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> Bool) -> Stms Kernels -> Stms Kernels
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm Kernels -> Bool
mustBeInlined Stms Kernels
variant_prestms
    recompute :: Stm Kernels -> Bool
recompute Stm Kernels
stm =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
must_be_inlined) (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm))
        Bool -> Bool -> Bool
|| Bool -> Bool
not (Names -> Stm Kernels -> Bool
invariantTo Names
must_be_inlined Stm Kernels
stm)
    (Stms Kernels
recomputed_variant_prestms, Stms Kernels
precomputed_variant_prestms) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
recompute Stms Kernels
variant_prestms

-- Anything that is variant to the "private" names should be
-- considered thread-local.
injectPrelude ::
  SegSpace ->
  VarianceTable ->
  Stms Kernels ->
  Names ->
  (Stms Kernels, Tiling, TiledBody) ->
  (Stms Kernels, Tiling, TiledBody)
injectPrelude :: SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms Names
used (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) =
  (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')
  where
    tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = do
      let nontiled :: (VName, SubExp) -> Bool
nontiled = ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling))
          private' :: Names
private' =
            Names
private
              Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst (((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, SubExp) -> Bool
nontiled ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space))
          ( Stms Kernels
invariant_prestms,
            Stms Kernels
precomputed_variant_prestms,
            Stms Kernels
recomputed_variant_prestms
            ) =
              VarianceTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
private' Names
used

      Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms

      let live_set :: [VName]
live_set =
            Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
              Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
                Names
used Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
recomputed_variant_prestms
      [VName]
prelude_arrs <-
        Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
precomputed_variant_prestms [VName]
live_set

      let prelude_privstms :: PrivStms
prelude_privstms =
            Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
              [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

      TiledBody
tiledBody Names
private' (PrivStms
prelude_privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms)

tileDoLoop ::
  SegSpace ->
  VarianceTable ->
  Stms Kernels ->
  Names ->
  (Stms Kernels, Tiling, TiledBody) ->
  [Type] ->
  Pattern Kernels ->
  StmAux (ExpDec Kernels) ->
  [(FParam Kernels, SubExp)] ->
  VName ->
  IntType ->
  SubExp ->
  Stms Kernels ->
  Result ->
  TileM (Stms Kernels, Tiling, TiledBody)
tileDoLoop :: SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpDec Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileDoLoop SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms Names
used_in_body (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) [Type]
res_ts Pattern Kernels
pat StmAux (ExpDec Kernels)
aux [(FParam Kernels, SubExp)]
merge VName
i IntType
it SubExp
bound Stms Kernels
poststms Result
poststms_res = do
  let prestms_used :: Names
prestms_used = Names
used_in_body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
poststms_res
      ( Stms Kernels
invariant_prestms,
        Stms Kernels
precomputed_variant_prestms,
        Stms Kernels
recomputed_variant_prestms
        ) =
          VarianceTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
tiled_kdims Names
prestms_used

  let ([Param (TypeBase Shape Uniqueness)]
mergeparams, Result
mergeinits) = [(Param (TypeBase Shape Uniqueness), SubExp)]
-> ([Param (TypeBase Shape Uniqueness)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge

      -- Expand the loop merge parameters to be arrays.
      tileDim :: TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
t = TypeBase Shape Uniqueness
-> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t (Tiling -> Shape
tilingTileShape Tiling
tiling) (Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> TypeBase Shape Uniqueness
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness TypeBase Shape Uniqueness
t

      merge_scope :: Scope Kernels
merge_scope = VName -> NameInfo Kernels -> Scope Kernels -> Scope Kernels
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexName IntType
it) (Scope Kernels -> Scope Kernels) -> Scope Kernels -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams

      tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = Scope Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
host_stms Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope Kernels
merge_scope) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
        Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms

        let live_set :: [VName]
live_set =
              Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
                Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
                  Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
recomputed_variant_prestms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
prestms_used

        [VName]
prelude_arrs <-
          Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
precomputed_variant_prestms [VName]
live_set

        [Param (TypeBase Shape Uniqueness)]
mergeparams' <- [Param (TypeBase Shape Uniqueness)]
-> (Param (TypeBase Shape Uniqueness)
    -> BinderT
         Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BinderT
     Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
mergeparams ((Param (TypeBase Shape Uniqueness)
  -> BinderT
       Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
 -> BinderT
      Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)])
-> (Param (TypeBase Shape Uniqueness)
    -> BinderT
         Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BinderT
     Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ \(Param VName
pname TypeBase Shape Uniqueness
pt) ->
          VName
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness)
forall dec. VName -> dec -> Param dec
Param (VName
 -> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) VName
-> BinderT
     Kernels
     (State VNameSource)
     (TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
pname String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group") BinderT
  Kernels
  (State VNameSource)
  (TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) (TypeBase Shape Uniqueness)
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase Shape Uniqueness
-> BinderT Kernels (State VNameSource) (TypeBase Shape Uniqueness)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
pt)

        let merge_ts :: [Type]
merge_ts = (Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
mergeparams

        let inloop_privstms :: PrivStms
inloop_privstms =
              Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
                [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

        Result
mergeinit' <-
          ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
            Certificates
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec Kernels)
aux) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
              Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"tiled_loopinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
                \PrimExp VName
in_bounds Slice SubExp
slice ->
                  ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
                    String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"loopinit" PrimExp VName
in_bounds [Type]
merge_ts (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
                      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
inloop_privstms
                      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
                      Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
mergeinits

        let merge' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
merge' = [Param (TypeBase Shape Uniqueness)]
-> Result -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams' Result
mergeinit'

        let indexMergeParams :: ReadPrelude
indexMergeParams Slice SubExp
slice =
              Scope Kernels
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams') (BinderT Kernels (State VNameSource) ()
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                [(Param (TypeBase Shape Uniqueness),
  Param (TypeBase Shape Uniqueness))]
-> ((Param (TypeBase Shape Uniqueness),
     Param (TypeBase Shape Uniqueness))
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [(Param (TypeBase Shape Uniqueness),
     Param (TypeBase Shape Uniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams [Param (TypeBase Shape Uniqueness)]
mergeparams') (((Param (TypeBase Shape Uniqueness),
   Param (TypeBase Shape Uniqueness))
  -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((Param (TypeBase Shape Uniqueness),
     Param (TypeBase Shape Uniqueness))
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
to, Param (TypeBase Shape Uniqueness)
from) ->
                  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
to] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
                      VName -> Slice SubExp -> BasicOp
Index (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
from) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                        Type -> Slice SubExp -> Slice SubExp
fullSlice (Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (TypeBase Shape Uniqueness)
from) Slice SubExp
slice

            private' :: Names
private' =
              Names
private Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams')

            privstms' :: PrivStms
privstms' =
              Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
forall a. Monoid a => a
mempty ReadPrelude
indexMergeParams PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms

        BodyT Kernels
loopbody' <-
          Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
            Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
              ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TiledBody
tiledBody Names
private' PrivStms
privstms'
        [VName]
accs' <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"tiled_inside_loop" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> BodyT Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge' (VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound []) BodyT Kernels
loopbody'

        Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling (PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms) Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts

  (Stms Kernels, Tiling, TiledBody)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')
  where
    tiled_kdims :: Names
tiled_kdims =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$
          ((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling)) ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
            SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space

doPrelude :: Tiling -> PrivStms -> Stms Kernels -> [VName] -> Binder Kernels [VName]
doPrelude :: Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
prestms [VName]
prestms_live =
  -- Create a SegMap that takes care of the prelude for every thread.
  Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"prelude" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
    \PrimExp VName
in_bounds Slice SubExp
slice -> do
      [Type]
ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
prestms_live
      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre"
          (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
            (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds)
            ( do
                Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
                Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
prestms
                Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
prestms_live
            )
            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)

liveSet :: FreeIn a => Stms Kernels -> a -> Names
liveSet :: forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
stms a
after =
  [VName] -> Names
namesFromList ((Stm Kernels -> [VName]) -> Stms Kernels -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) Stms Kernels
stms)
    Names -> Names -> Names
`namesIntersection` a -> Names
forall a. FreeIn a => a -> Names
freeIn a
after

tileable ::
  Stm Kernels ->
  Maybe
    ( SubExp,
      [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels)
    )
tileable :: Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm
  | Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm Kernels
form)) <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm,
    Just ([Reduce Kernels]
reds, Lambda Kernels
map_lam) <- ScremaForm Kernels -> Maybe ([Reduce Kernels], Lambda Kernels)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm Kernels
form,
    Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
red_nes <- [Reduce Kernels] -> Reduce Kernels
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce Kernels]
reds,
    Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam, -- No mapout arrays.
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam =
    (SubExp, [VName],
 (Commutativity, Lambda Kernels, Result, Lambda Kernels))
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam))
  | Bool
otherwise =
    Maybe
  (SubExp, [VName],
   (Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. Maybe a
Nothing

-- | We classify the inputs to the tiled loop as whether they are
-- tileable (and with what permutation of the kernel indexes) or not.
-- In practice, we should have at least one tileable array per loop,
-- but this is not enforced in our representation.
data InputArray
  = InputTile [Int] VName
  | InputDontTile VName

tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs = (InputArray -> Maybe (VName, [Int]))
-> [InputArray] -> [(VName, [Int])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputArray -> Maybe (VName, [Int])
f
  where
    f :: InputArray -> Maybe (VName, [Int])
f (InputTile [Int]
perm VName
arr) = (VName, [Int]) -> Maybe (VName, [Int])
forall a. a -> Maybe a
Just (VName
arr, [Int]
perm)
    f InputDontTile {} = Maybe (VName, [Int])
forall a. Maybe a
Nothing

-- | A tile (or an original untiled array).
data InputTile
  = InputTiled [Int] VName
  | InputUntiled VName

-- First VNames are the tiles, second are the untiled.
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles (InputTile [Int]
perm VName
_ : [InputArray]
inputs) (VName
tile : [VName]
tiles) =
  [Int] -> VName -> InputTile
InputTiled [Int]
perm VName
tile InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles (InputDontTile VName
arr : [InputArray]
inputs) [VName]
tiles =
  VName -> InputTile
InputUntiled VName
arr InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles [InputArray]
_ [VName]
_ = []

-- The atual tile size may be smaller for the last tile, so we have to
-- be careful now.
sliceUntiled ::
  MonadBinder m =>
  VName ->
  SubExp ->
  SubExp ->
  SubExp ->
  m VName
sliceUntiled :: forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
full_tile_size SubExp
this_tile_size = do
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  SubExp
slice_offset <-
    String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
full_tile_size)
  let slice :: DimIndex SubExp
slice = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
slice_offset SubExp
this_tile_size (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"untiled_slice" (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [DimIndex SubExp
slice]

-- | Statements that we insert directly into every thread-private
-- SegMaps.  This is for things that cannot efficiently be computed
-- once in advance in the prelude SegMap, primarily (exclusively?)
-- array slicing operations.
data PrivStms = PrivStms (Stms Kernels) ReadPrelude

privStms :: Stms Kernels -> PrivStms
privStms :: Stms Kernels -> PrivStms
privStms Stms Kernels
stms = Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$ BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. a -> b -> a
const (BinderT Kernels (State VNameSource) () -> ReadPrelude)
-> BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. (a -> b) -> a -> b
$ () -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

addPrivStms :: Slice SubExp -> PrivStms -> Binder Kernels ()
addPrivStms :: Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
local_slice (PrivStms Stms Kernels
stms ReadPrelude
readPrelude) = do
  ReadPrelude
readPrelude Slice SubExp
local_slice
  Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
stms

instance Semigroup PrivStms where
  PrivStms Stms Kernels
stms_x ReadPrelude
readPrelude_x <> :: PrivStms -> PrivStms -> PrivStms
<> PrivStms Stms Kernels
stms_y ReadPrelude
readPrelude_y =
    Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms_z ReadPrelude
readPrelude_z
    where
      stms_z :: Stms Kernels
stms_z = Stms Kernels
stms_x Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
stms_y
      readPrelude_z :: ReadPrelude
readPrelude_z Slice SubExp
slice = ReadPrelude
readPrelude_x Slice SubExp
slice BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ReadPrelude
readPrelude_y Slice SubExp
slice

instance Monoid PrivStms where
  mempty :: PrivStms
mempty = Stms Kernels -> PrivStms
privStms Stms Kernels
forall a. Monoid a => a
mempty

type ReadPrelude = Slice SubExp -> Binder Kernels ()

data ProcessTileArgs = ProcessTileArgs
  { ProcessTileArgs -> PrivStms
processPrivStms :: PrivStms,
    ProcessTileArgs -> Commutativity
processComm :: Commutativity,
    ProcessTileArgs -> Lambda Kernels
processRedLam :: Lambda Kernels,
    ProcessTileArgs -> Lambda Kernels
processMapLam :: Lambda Kernels,
    ProcessTileArgs -> [InputTile]
processTiles :: [InputTile],
    ProcessTileArgs -> [VName]
processAcc :: [VName],
    ProcessTileArgs -> SubExp
processTileId :: SubExp
  }

data ResidualTileArgs = ResidualTileArgs
  { ResidualTileArgs -> PrivStms
residualPrivStms :: PrivStms,
    ResidualTileArgs -> Commutativity
residualComm :: Commutativity,
    ResidualTileArgs -> Lambda Kernels
residualRedLam :: Lambda Kernels,
    ResidualTileArgs -> Lambda Kernels
residualMapLam :: Lambda Kernels,
    ResidualTileArgs -> [InputArray]
residualInput :: [InputArray],
    ResidualTileArgs -> [VName]
residualAcc :: [VName],
    ResidualTileArgs -> SubExp
residualInputSize :: SubExp,
    ResidualTileArgs -> SubExp
residualNumWholeTiles :: SubExp
  }

-- | Information about a loop that has been tiled inside a kernel, as
-- well as the kinds of changes that we would then like to perform on
-- the kernel.
data Tiling = Tiling
  { Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap ::
      String ->
      SegLevel ->
      ResultManifest ->
      (PrimExp VName -> Slice SubExp -> Binder Kernels [SubExp]) ->
      Binder Kernels [VName],
    -- The boolean PrimExp indicates whether they are in-bounds.

    Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
tilingReadTile ::
      TileKind ->
      PrivStms ->
      SubExp ->
      [InputArray] ->
      Binder Kernels [InputTile],
    Tiling
-> ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile ::
      ProcessTileArgs ->
      Binder Kernels [VName],
    Tiling
-> ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile ::
      ResidualTileArgs ->
      Binder Kernels [VName],
    Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns :: VName -> Binder Kernels KernelResult,
    Tiling -> SegSpace
tilingSpace :: SegSpace,
    Tiling -> Shape
tilingTileShape :: Shape,
    Tiling -> SegLevel
tilingLevel :: SegLevel,
    Tiling -> Binder Kernels SubExp
tilingNumWholeTiles :: Binder Kernels SubExp
  }

type DoTiling gtids kdims =
  SegLevel -> gtids -> kdims -> SubExp -> Binder Kernels Tiling

scalarLevel :: Tiling -> SegLevel
scalarLevel :: Tiling -> SegLevel
scalarLevel Tiling
tiling =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt
  where
    lvl :: SegLevel
lvl = Tiling -> SegLevel
tilingLevel Tiling
tiling

protectOutOfBounds ::
  String ->
  PrimExp VName ->
  [Type] ->
  Binder Kernels [SubExp] ->
  Binder Kernels [VName]
protectOutOfBounds :: String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
desc PrimExp VName
in_bounds [Type]
ts BinderT Kernels (State VNameSource) Result
m =
  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds) (Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT Kernels (State VNameSource) Result
m) ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)

postludeGeneric ::
  Tiling ->
  PrivStms ->
  Pattern Kernels ->
  [VName] ->
  Stms Kernels ->
  Result ->
  [Type] ->
  Binder Kernels [VName]
postludeGeneric :: Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts =
  Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"thread_res" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice -> do
    -- Read our per-thread result from the tiled loop.
    [(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat) [VName]
accs') (((VName, VName) -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(VName
us, VName
everyone) -> do
      Type
everyone_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
everyone
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
us] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
everyone (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
everyone_t Slice SubExp
slice

    if Stms Kernels
poststms Stms Kernels -> Stms Kernels -> Bool
forall a. Eq a => a -> a -> Bool
== Stms Kernels
forall a. Monoid a => a
mempty
      then do
        -- The privstms may still be necessary for the result.
        Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
        Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res
      else ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"postlude" PrimExp VName
in_bounds [Type]
res_ts (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
          Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
          Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
poststms
          Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res

type TiledBody = Names -> PrivStms -> Binder Kernels [VName]

tileGeneric ::
  DoTiling gtids kdims ->
  SegLevel ->
  [Type] ->
  Pattern Kernels ->
  gtids ->
  kdims ->
  SubExp ->
  (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels) ->
  [InputArray] ->
  Stms Kernels ->
  Result ->
  TileM (Stms Kernels, Tiling, TiledBody)
tileGeneric :: forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric DoTiling gtids kdims
doTiling SegLevel
initial_lvl [Type]
res_ts Pattern Kernels
pat gtids
gtids kdims
kdims SubExp
w (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form [InputArray]
inputs Stms Kernels
poststms Result
poststms_res = do
  (Tiling
tiling, Stms Kernels
tiling_stms) <- Binder Kernels Tiling
-> ReaderT
     (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels Tiling
 -> ReaderT
      (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels))
-> Binder Kernels Tiling
-> ReaderT
     (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ DoTiling gtids kdims
doTiling SegLevel
initial_lvl gtids
gtids kdims
kdims SubExp
w

  (Stms Kernels, Tiling, TiledBody)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
tiling_stms, Tiling
tiling, Tiling -> TiledBody
tiledBody Tiling
tiling)
  where
    (Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam) = (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form

    tiledBody :: Tiling -> Names -> PrivStms -> Binder Kernels [VName]
    tiledBody :: Tiling -> TiledBody
tiledBody Tiling
tiling Names
_private PrivStms
privstms = do
      let tile_shape :: Shape
tile_shape = Tiling -> Shape
tilingTileShape Tiling
tiling

      SubExp
num_whole_tiles <- Tiling -> Binder Kernels SubExp
tilingNumWholeTiles Tiling
tiling

      -- We don't use a Replicate here, because we want to enforce a
      -- scalar memory space.
      [VName]
mergeinits <- Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"mergeinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice ->
        -- Constant neutral elements (a common case) do not need protection from OOB.
        if Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
red_nes Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
          then Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes
          else ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
            String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"neutral" PrimExp VName
in_bounds (Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam) (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
              Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
              Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes

      [(Param (TypeBase Shape Uniqueness), SubExp)]
merge <- [(Param Type, VName)]
-> ((Param Type, VName)
    -> BinderT
         Kernels
         (State VNameSource)
         (Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
     Kernels
     (State VNameSource)
     [(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam) [VName]
mergeinits) (((Param Type, VName)
  -> BinderT
       Kernels
       (State VNameSource)
       (Param (TypeBase Shape Uniqueness), SubExp))
 -> BinderT
      Kernels
      (State VNameSource)
      [(Param (TypeBase Shape Uniqueness), SubExp)])
-> ((Param Type, VName)
    -> BinderT
         Kernels
         (State VNameSource)
         (Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
     Kernels
     (State VNameSource)
     [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
mergeinit) ->
        (,)
          (Param (TypeBase Shape Uniqueness)
 -> SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
-> BinderT
     Kernels
     (State VNameSource)
     (SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> TypeBase Shape Uniqueness
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam
            (VName -> String
baseString (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_merge")
            (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> Shape -> Type
`arrayOfShape` Shape
tile_shape Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique)
          BinderT
  Kernels
  (State VNameSource)
  (SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> Binder Kernels SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Param (TypeBase Shape Uniqueness), SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> Binder Kernels SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
mergeinit)

      VName
tile_id <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_id"
      let loopform :: LoopForm Kernels
loopform = VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
tile_id IntType
Int64 SubExp
num_whole_tiles []
      BodyT Kernels
loopbody <- BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> (Binder Kernels (BodyT Kernels)
    -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
        LoopForm Kernels
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm Kernels
loopform (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
          Scope Kernels
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels)
-> [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
merge) (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
            -- Collectively read a tile.
            [InputTile]
tile <- Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
tilingReadTile Tiling
tiling TileKind
TilePartial PrivStms
privstms (VName -> SubExp
Var VName
tile_id) [InputArray]
inputs

            -- Now each thread performs a traversal of the tile and
            -- updates its accumulator.
            let accs :: [VName]
accs =
                  ((Param (TypeBase Shape Uniqueness), SubExp) -> VName)
-> [(Param (TypeBase Shape Uniqueness), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> ((Param (TypeBase Shape Uniqueness), SubExp)
    -> Param (TypeBase Shape Uniqueness))
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
merge
                tile_args :: ProcessTileArgs
tile_args =
                  PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tile [VName]
accs (VName -> SubExp
Var VName
tile_id)
            Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tiling
-> ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile Tiling
tiling ProcessTileArgs
tile_args

      [VName]
accs <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"accs" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> BodyT Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge LoopForm Kernels
loopform BodyT Kernels
loopbody

      -- We possibly have to traverse a residual tile.
      Lambda Kernels
red_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
red_lam
      Lambda Kernels
map_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
map_lam
      let residual_args :: ResidualTileArgs
residual_args =
            PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputArray]
-> [VName]
-> SubExp
-> SubExp
-> ResidualTileArgs
ResidualTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam' Lambda Kernels
map_lam' [InputArray]
inputs [VName]
accs SubExp
w SubExp
num_whole_tiles
      [VName]
accs' <- Tiling
-> ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile Tiling
tiling ResidualTileArgs
residual_args

      -- Create a SegMap that takes care of the postlude for every thread.
      Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts

data TileKind = TilePartial | TileFull

mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prestms_live_arrs [VName]
prestms_live Slice SubExp
slice =
  ([()] -> ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [()] -> ()
forall a. Monoid a => [a] -> a
mconcat (BinderT Kernels (State VNameSource) [()]
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
    [(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prestms_live_arrs [VName]
prestms_live) (((VName, VName) -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) [()])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, VName
v) -> do
      Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t Slice SubExp
slice

tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Binder Kernels KernelResult
tileReturns :: [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp, SubExp)]
dims VName
arr = do
  let unit_dims :: Result
unit_dims = Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
dims_on_top) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  VName
arr' <-
    if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
      then VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
      else do
        Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
        let new_shape :: Result
new_shape = Result
unit_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr) (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew Result
new_shape) VName
arr
  let tile_dims :: [(SubExp, SubExp)]
tile_dims = Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top) Result
unit_dims [(SubExp, SubExp)] -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp, SubExp)]
dims
  KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> BinderT Kernels (State VNameSource) KernelResult)
-> KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns [(SubExp, SubExp)]
tile_dims VName
arr'

is1DTileable :: VName -> M.Map VName Names -> VName -> InputArray
is1DTileable :: VName -> VarianceTable -> VName -> InputArray
is1DTileable VName
gtid VarianceTable
variance VName
arr
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
gtid (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr VarianceTable
variance =
    [Int] -> VName -> InputArray
InputTile [Int
0] VName
arr
  | Bool
otherwise =
    VName -> InputArray
InputDontTile VName
arr

segMap1D ::
  String ->
  SegLevel ->
  ResultManifest ->
  (VName -> Binder Kernels [SubExp]) ->
  Binder Kernels [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> BinderT Kernels (State VNameSource) Result
f = do
  VName
ltid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
  VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]

  (([Type]
ts, Result
res), Stms Kernels
stms) <- Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], Result)
 -> BinderT
      Kernels (State VNameSource) (([Type], Result), Stms Kernels))
-> Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- VName -> BinderT Kernels (State VNameSource) Result
f VName
ltid
    [Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> Result -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType Result
res
    ([Type], Result) -> Binder Kernels ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)
  Body BodyDec Kernels
_ Stms Kernels
stms' Result
res' <- BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
res

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
    Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
      SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) Result
res'

reconstructGtids1D ::
  Count GroupSize SubExp ->
  VName ->
  VName ->
  VName ->
  Binder Kernels ()
reconstructGtids1D :: Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid =
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid]
    (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)

readTile1D ::
  SubExp ->
  VName ->
  VName ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  TileKind ->
  PrivStms ->
  SubExp ->
  [InputArray] ->
  Binder Kernels [InputTile]
readTile1D :: SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
  ([VName] -> [InputTile])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
    (BinderT Kernels (State VNameSource) [VName]
 -> Binder Kernels [InputTile])
-> ((VName -> BinderT Kernels (State VNameSource) Result)
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"full_tile" SegLevel
lvl ResultManifest
ResultNoSimplify
    ((VName -> BinderT Kernels (State VNameSource) Result)
 -> Binder Kernels [InputTile])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
      SubExp
j <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)

      Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

      let arrs :: [VName]
arrs = ((VName, [Int]) -> VName) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> VName
forall a b. (a, b) -> a
fst ([(VName, [Int])] -> [VName]) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs
      [Type]
arr_ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrs
      let tile_ts :: [Type]
tile_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
arr_ts
          w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
arr_ts

      let readTileElem :: VName -> BinderT Kernels (State VNameSource) VName
readTileElem VName
arr =
            -- No need for fullSlice because we are tiling only prims.
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile_elem" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j])
      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        case TileKind
kind of
          TileKind
TilePartial ->
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre"
              (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
j TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w)
                (Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Binder Kernels SubExp)
-> [VName] -> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> SubExp)
-> BinderT Kernels (State VNameSource) VName
-> Binder Kernels SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (BinderT Kernels (State VNameSource) VName
 -> Binder Kernels SubExp)
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> VName
-> Binder Kernels SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BinderT Kernels (State VNameSource) VName
readTileElem) [VName]
arrs)
                ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
tile_ts)
          TileKind
TileFull ->
            (VName -> BinderT Kernels (State VNameSource) VName)
-> [VName] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) VName
readTileElem [VName]
arrs
  where
    lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt

processTile1D ::
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ProcessTileArgs ->
  Binder Kernels [VName]
processTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
  let red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
      privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
      map_lam :: Lambda Kernels
map_lam = ProcessTileArgs -> Lambda Kernels
processMapLam ProcessTileArgs
tile_args
      red_lam :: Lambda Kernels
red_lam = ProcessTileArgs -> Lambda Kernels
processRedLam ProcessTileArgs
tile_args
      tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
      tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args
      accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args

  String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"acc" SegLevel
lvl ResultManifest
ResultPrivate ((VName -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
    Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
    Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

    -- We replace the neutral elements with the accumulators (this is
    -- OK because the parallel semantics are not used after this
    -- point).
    Result
thread_accs <- [VName]
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Binder Kernels SubExp)
 -> BinderT Kernels (State VNameSource) Result)
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid]
    let sliceTile :: InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile (InputTiled [Int]
_ VName
arr) =
          VName -> BinderT Kernels (State VNameSource) VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
        sliceTile (InputUntiled VName
arr) =
          VName
-> SubExp
-> SubExp
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
tile_size

    [VName]
tiles' <- (InputTile -> BinderT Kernels (State VNameSource) VName)
-> [InputTile] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile [InputTile]
tiles

    let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam
    ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc"
        (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
          (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim)
          ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm Kernels -> SOAC Kernels
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
tile_size [VName]
tiles' ScremaForm Kernels
form'])
          (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)
  where
    lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt

processResidualTile1D ::
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ResidualTileArgs ->
  Binder Kernels [VName]
processResidualTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ResidualTileArgs
args = do
  -- The number of residual elements that are not covered by
  -- the whole tiles.
  SubExp
residual_input <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual"
    (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
      (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
      (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
      (SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input)
  where
    red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
    map_lam :: Lambda Kernels
map_lam = ResidualTileArgs -> Lambda Kernels
residualMapLam ResidualTileArgs
args
    red_lam :: Lambda Kernels
red_lam = ResidualTileArgs -> Lambda Kernels
residualRedLam ResidualTileArgs
args
    privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
    inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
    accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
    num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
    w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args

    nonemptyTile :: SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input = Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
      -- Collectively construct a tile.  Threads that are out-of-bounds
      -- provide a blank dummy value.
      [InputTile]
full_tiles <-
        SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D
          SubExp
tile_size
          VName
gid
          VName
gtid
          Count NumGroups SubExp
num_groups
          Count GroupSize SubExp
group_size
          TileKind
TilePartial
          PrivStms
privstms
          SubExp
num_whole_tiles
          [InputArray]
inputs

      let sliceTile :: InputTile -> BinderT Kernels (State VNameSource) InputTile
sliceTile (InputUntiled VName
arr) =
            InputTile -> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> InputTile -> BinderT Kernels (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr
          sliceTile (InputTiled [Int]
perm VName
tile) = do
            let slice :: DimIndex SubExp
slice =
                  SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
            [Int] -> VName -> InputTile
InputTiled [Int]
perm
              (VName -> InputTile)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile [DimIndex SubExp
slice])

      [InputTile]
tiles <- (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> [InputTile] -> Binder Kernels [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) InputTile
sliceTile [InputTile]
full_tiles

      -- Now each thread performs a traversal of the tile and
      -- updates its accumulator.
      let tile_args :: ProcessTileArgs
tile_args =
            PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles
      Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
        ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
residual_input Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args

tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d [(VName, SubExp)]
dims_on_top SegLevel
initial_lvl VName
gtid SubExp
kdim SubExp
w = do
  VName
gid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid"
  VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"

  (SegLevel
lvl, SegSpace
space) <-
    if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
      then
        (SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl) (SegVirt -> SegLevel) -> SegVirt -> SegLevel
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegVirt
segVirt SegLevel
initial_lvl,
            VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat [(VName
gid, Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl)]
          )
      else do
        SubExp
group_size <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl)) SubExp
kdim

        -- How many groups we need to exhaust the innermost dimension.
        SubExp
ldim <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"ldim" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim SubExp
group_size

        SubExp
num_groups <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_num_groups"
            (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
ldim (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

        (SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt,
            VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid, SubExp
ldim)]
          )
  let tile_size :: SubExp
tile_size = Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

  Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return
    Tiling :: (String
 -> SegLevel
 -> ResultManifest
 -> (PrimExp VName
     -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
    -> PrivStms
    -> SubExp
    -> [InputArray]
    -> Binder Kernels [InputTile])
-> (ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName])
-> (ResidualTileArgs
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> Binder Kernels SubExp
-> Tiling
Tiling
      { tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f -> String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl' ResultManifest
manifest ((VName -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
          [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid]
            (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
          PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f (TPrimExp Bool VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid],
        tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Binder Kernels [InputTile]
tilingReadTile =
          SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessTile :: ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile =
          VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessResidualTile :: ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile =
          VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim, SubExp
tile_size)],
        tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size],
        tilingNumWholeTiles :: Binder Kernels SubExp
tilingNumWholeTiles =
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
        tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
        tilingSpace :: SegSpace
tilingSpace = SegSpace
space
      }

invariantToOneOfTwoInnerDims ::
  Names ->
  M.Map VName Names ->
  [VName] ->
  VName ->
  Maybe InputArray
invariantToOneOfTwoInnerDims :: Names -> VarianceTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant VarianceTable
variance [VName]
dims VName
arr = do
  VName
j : VName
i : [VName]
_ <- [VName] -> Maybe [VName]
forall a. a -> Maybe a
Just ([VName] -> Maybe [VName]) -> [VName] -> Maybe [VName]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
dims
  let variant_to :: Names
variant_to = Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr VarianceTable
variance
      branch_invariant :: Bool
branch_invariant = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
  if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
j VName -> Names -> Bool
`nameIn` Names
variant_to)
    then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
0, Int
1] VName
arr
    else
      if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
i VName -> Names -> Bool
`nameIn` Names
variant_to)
        then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
1, Int
0] VName
arr
        else InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ VName -> InputArray
InputDontTile VName
arr

-- Reconstruct the original gtids from group and local IDs.
reconstructGtids2D ::
  SubExp ->
  (VName, VName) ->
  (VName, VName) ->
  (VName, VName) ->
  Binder Kernels ()
reconstructGtids2D :: SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y) = do
  -- Reconstruct the original gtids from gid_x/gid_y and ltid_x/ltid_y.
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_x]
    (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_y]
    (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)

readTile2D ::
  (SubExp, SubExp) ->
  (VName, VName) ->
  (VName, VName) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  TileKind ->
  PrivStms ->
  SubExp ->
  [InputArray] ->
  Binder Kernels [InputTile]
readTile2D :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
  ([VName] -> [InputTile])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
    (BinderT Kernels (State VNameSource) [VName]
 -> Binder Kernels [InputTile])
-> (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
    -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D
      String
"full_tile"
      (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
      ResultManifest
ResultNoSimplify
      (SubExp
tile_size, SubExp
tile_size)
    (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> Binder Kernels [InputTile])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
      SubExp
i <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
      SubExp
j <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)

      SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

      let arrs_and_perms :: [(VName, [Int])]
arrs_and_perms = [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs

          readTileElem :: (VName, [Int]) -> BinderT Kernels (State VNameSource) VName
readTileElem (VName
arr, [Int]
perm) =
            -- No need for fullSlice because we are tiling only prims.
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp
              String
"tile_elem"
              ( BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index
                    VName
arr
                    [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]]
              )

          readTileElemIfInBounds :: (VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
readTileElemIfInBounds (VName
arr, [Int]
perm) = do
            Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
            let tile_t :: Type
tile_t = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
arr_t
                w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
                idx :: SubExp
idx = Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]
                othercheck :: TPrimExp Bool VName
othercheck =
                  [TPrimExp Bool VName] -> TPrimExp Bool VName
forall a. [a] -> a
last ([TPrimExp Bool VName] -> TPrimExp Bool VName)
-> [TPrimExp Bool VName] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$
                    [Int] -> [TPrimExp Bool VName] -> [TPrimExp Bool VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape
                      [Int]
perm
                      [ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y,
                        VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
                      ]
            BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
              (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool VName
othercheck)
              ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]])
              ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank Type
tile_t])

      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        case TileKind
kind of
          TileKind
TilePartial ->
            ((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> [(VName, [Int])] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"pre" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> ((VName, [Int])
    -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> (VName, [Int])
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
readTileElemIfInBounds) [(VName, [Int])]
arrs_and_perms
          TileKind
TileFull ->
            ((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> [(VName, [Int])] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, [Int]) -> BinderT Kernels (State VNameSource) VName
readTileElem [(VName, [Int])]
arrs_and_perms

findTileSize :: HasScope lore m => [InputTile] -> m SubExp
findTileSize :: forall lore (m :: * -> *).
HasScope lore m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles =
  case (InputTile -> Maybe VName) -> [InputTile] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputTile -> Maybe VName
isTiled [InputTile]
tiles of
    VName
v : [VName]
_ -> Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp) -> m Type -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
    [] -> SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
  where
    isTiled :: InputTile -> Maybe VName
isTiled InputUntiled {} = Maybe VName
forall a. Maybe a
Nothing
    isTiled (InputTiled [Int]
_ VName
tile) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
tile

processTile2D ::
  (VName, VName) ->
  (VName, VName) ->
  (SubExp, SubExp) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ProcessTileArgs ->
  Binder Kernels [VName]
processTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
  let privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
      red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
      red_lam :: Lambda Kernels
red_lam = ProcessTileArgs -> Lambda Kernels
processRedLam ProcessTileArgs
tile_args
      map_lam :: Lambda Kernels
map_lam = ProcessTileArgs -> Lambda Kernels
processMapLam ProcessTileArgs
tile_args
      tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
      accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args
      tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args

  -- Might be truncated in case of a partial tile.
  SubExp
actual_tile_size <- [InputTile] -> Binder Kernels SubExp
forall lore (m :: * -> *).
HasScope lore m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles

  String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D
    String
"acc"
    (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
    ResultManifest
ResultPrivate
    (SubExp
tile_size, SubExp
tile_size)
    (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
      SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)

      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

      -- We replace the neutral elements with the accumulators (this is
      -- OK because the parallel semantics are not used after this
      -- point).
      Result
thread_accs <- [VName]
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Binder Kernels SubExp)
 -> BinderT Kernels (State VNameSource) Result)
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y]
      let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam

          sliceTile :: InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile (InputUntiled VName
arr) =
            VName
-> SubExp
-> SubExp
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
actual_tile_size
          sliceTile (InputTiled [Int]
perm VName
tile) = do
            Type
tile_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
tile
            let idx :: DimIndex SubExp
idx = SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [Int] -> [VName] -> [VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [VName
ltid_x, VName
ltid_y]
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
              BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Int -> Slice SubExp -> Slice SubExp
sliceAt Type
tile_t ([Int] -> Int
forall a. [a] -> a
head [Int]
perm) [DimIndex SubExp
idx]

      [VName]
tiles' <- (InputTile -> BinderT Kernels (State VNameSource) VName)
-> [InputTile] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile [InputTile]
tiles

      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc"
          (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
            ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
            )
            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm Kernels -> SOAC Kernels
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
actual_tile_size [VName]
tiles' ScremaForm Kernels
form'])
            (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)

processResidualTile2D ::
  (VName, VName) ->
  (VName, VName) ->
  (SubExp, SubExp) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ResidualTileArgs ->
  Binder Kernels [VName]
processResidualTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D
  (VName, VName)
gids
  (VName, VName)
gtids
  (SubExp, SubExp)
kdims
  SubExp
tile_size
  Count NumGroups SubExp
num_groups
  Count GroupSize SubExp
group_size
  ResidualTileArgs
args = do
    -- The number of residual elements that are not covered by
    -- the whole tiles.
    SubExp
residual_input <-
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size

    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual"
      (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
        (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
        (SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input)
    where
      privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
      red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
      red_lam :: Lambda Kernels
red_lam = ResidualTileArgs -> Lambda Kernels
residualRedLam ResidualTileArgs
args
      map_lam :: Lambda Kernels
map_lam = ResidualTileArgs -> Lambda Kernels
residualMapLam ResidualTileArgs
args
      accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
      inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
      num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
      w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args

      nonemptyTile :: SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input = BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> (Binder Kernels (BodyT Kernels)
    -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
        -- Collectively construct a tile.  Threads that are out-of-bounds
        -- provide a blank dummy value.
        [InputTile]
full_tile <-
          (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D
            (SubExp, SubExp)
kdims
            (VName, VName)
gtids
            (VName, VName)
gids
            SubExp
tile_size
            Count NumGroups SubExp
num_groups
            Count GroupSize SubExp
group_size
            TileKind
TilePartial
            PrivStms
privstms
            SubExp
num_whole_tiles
            [InputArray]
inputs

        let slice :: DimIndex SubExp
slice =
              SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
        [InputTile]
tiles <- [InputTile]
-> (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> Binder Kernels [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [InputTile]
full_tile ((InputTile -> BinderT Kernels (State VNameSource) InputTile)
 -> Binder Kernels [InputTile])
-> (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \case
          InputTiled [Int]
perm VName
tile' ->
            [Int] -> VName -> InputTile
InputTiled [Int]
perm
              (VName -> InputTile)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile' [DimIndex SubExp
slice, DimIndex SubExp
slice])
          InputUntiled VName
arr ->
            InputTile -> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> InputTile -> BinderT Kernels (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr

        let tile_args :: ProcessTileArgs
tile_args =
              PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles

        -- Now each thread performs a traversal of the tile and
        -- updates its accumulator.
        Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
          ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D
            (VName, VName)
gids
            (VName, VName)
gtids
            (SubExp, SubExp)
kdims
            SubExp
tile_size
            Count NumGroups SubExp
num_groups
            Count GroupSize SubExp
group_size
            ProcessTileArgs
tile_args

tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d [(VName, SubExp)]
dims_on_top SegLevel
_initial_lvl (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
w = do
  VName
gid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_x"
  VName
gid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_y"

  Name
tile_size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_size"
  SubExp
tile_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"tile_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tile_size_key SizeClass
SizeTile
  SubExp
group_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
tile_size SubExp
tile_size

  SubExp
num_groups_x <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_x" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_x SubExp
tile_size
  SubExp
num_groups_y <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_y" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_y SubExp
tile_size

  SubExp
num_groups <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_top"
      (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        SubExp
num_groups_x
        (SubExp
num_groups_y SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

  VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"
  let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirtFull
      space :: SegSpace
space =
        VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_x, SubExp
num_groups_x), (VName
gid_y, SubExp
num_groups_y)]

  Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return
    Tiling :: (String
 -> SegLevel
 -> ResultManifest
 -> (PrimExp VName
     -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
    -> PrivStms
    -> SubExp
    -> [InputArray]
    -> Binder Kernels [InputTile])
-> (ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName])
-> (ResidualTileArgs
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> Binder Kernels SubExp
-> Tiling
Tiling
      { tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f ->
          String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D String
desc SegLevel
lvl' ResultManifest
manifest (SubExp
tile_size, SubExp
tile_size) (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
            SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
            PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f
              ( TPrimExp Bool VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                  VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
                    TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
              )
              [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y],
        tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Binder Kernels [InputTile]
tilingReadTile = (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessTile :: ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessResidualTile :: ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim_x, SubExp
tile_size), (SubExp
kdim_y, SubExp
tile_size)],
        tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size, SubExp
tile_size],
        tilingNumWholeTiles :: Binder Kernels SubExp
tilingNumWholeTiles =
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
        tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
        tilingSpace :: SegSpace
tilingSpace = SegSpace
space
      }