{-# LANGUAGE TypeFamilies #-}

-- | Tries to turn a generalized reduction kernel into
--     a more specialized construct, for example:
--       (a) a map nest with a sequential redomap ripe for tiling
--       (b) a SegRed kernel followed by a smallish accumulation kernel.
--       (c) a histogram (for this we need to track the withAccs)
--   The idea is to identify the first accumulation and
--     to separate the initial kernels into two:
--     1. the code up to and including the accumulation,
--        which is optimized to turn the accumulation either
--        into a map-reduce composition or a histogram, and
--     2. the remaining code, which is recursively optimized.
--   Since this is mostly prototyping, when the accumulation
--     can be rewritten as a map-reduce, we sequentialize the
--     map-reduce, as to potentially enable tiling oportunities.
module Futhark.Optimise.GenRedOpt (optimiseGenRed) where

import Control.Monad.Reader
import Control.Monad.State
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Builder
import Futhark.IR.GPU
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename

type GenRedM = ReaderT (Scope GPU) (State VNameSource)

-- | The pass definition.
optimiseGenRed :: Pass GPU GPU
optimiseGenRed :: Pass GPU GPU
optimiseGenRed =
  forall {k} {k1} (fromrep :: k) (torep :: k1).
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"optimise generalized reductions" [Char]
"Specializes generalized reductions into map-reductions or histograms" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k).
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
      forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$
        forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$
          forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Env -> Stms GPU -> GenRedM (Stms GPU)
optimiseStms (forall k a. Map k a
M.empty, forall k a. Map k a
M.empty) Stms GPU
stms) Scope GPU
scope

optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env (Body () Stms GPU
stms Result
res) =
  forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> Stms GPU -> GenRedM (Stms GPU)
optimiseStms Env
env Stms GPU
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseStms :: Env -> Stms GPU -> GenRedM (Stms GPU)
optimiseStms :: Env -> Stms GPU -> GenRedM (Stms GPU)
optimiseStms Env
env Stms GPU
stms =
  forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) forall a b. (a -> b) -> a -> b
$ do
    (Env
_, Stms GPU
stms') <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
env, forall a. Monoid a => a
mempty) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms'
  where
    foldfun :: (Env, Stms GPU) -> Stm GPU -> GenRedM (Env, Stms GPU)
    foldfun :: (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
e, Stms GPU
ss) Stm GPU
s = do
      (Env
e', Stms GPU
s') <- Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
e Stm GPU
s
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
e', Stms GPU
ss forall a. Semigroup a => a -> a -> a
<> Stms GPU
s')

optimiseStm :: Env -> Stm GPU -> GenRedM (Env, Stms GPU)
optimiseStm :: Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
env stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (SegOp (SegMap SegThread {} SegSpace
_ [Type]
_ KernelBody GPU
_)))) = do
  Maybe (Stms GPU)
res_genred_opt <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
stm
  let stms' :: Stms GPU
stms' =
        case Maybe (Stms GPU)
res_genred_opt of
          Just Stms GPU
stms -> Stms GPU
stms
          Maybe (Stms GPU)
Nothing -> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
stm
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env, Stms GPU
stms')
optimiseStm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Env
env' <- Env -> VName -> Exp GPU -> TileM Env
changeEnv Env
env (forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat) Exp GPU
e
  Exp GPU
e' <- forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env') Exp GPU
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env', forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e')
  where
    optimise :: Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env' = forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper {mapOnBody :: Scope GPU -> Body GPU -> GenRedM (Body GPU)
mapOnBody = \Scope GPU
scope -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env'}

------------------------

genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker = do
  Maybe (Stms GPU, Stm GPU)
res_tile <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env Stm GPU
ker
  case Maybe (Stms GPU, Stm GPU)
res_tile of
    Maybe (Stms GPU, Stm GPU)
Nothing -> do
      Maybe (Stms GPU, Stm GPU)
res_sgrd <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
env Stm GPU
ker
      Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_sgrd
    Maybe (Stms GPU, Stm GPU)
_ -> Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_tile
  where
    helperGenRed :: Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
Nothing = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    helperGenRed (Just (Stms GPU
stms_before, Stm GPU
ker_snd)) = do
      Maybe (Stms GPU)
mb_stms_after <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker_snd
      case Maybe (Stms GPU)
mb_stms_after of
        Just Stms GPU
stms_after -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms_after
        Maybe (Stms GPU)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
ker_snd

se1 :: SubExp
se1 :: SubExp
se1 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env kerstm :: Stm GPU
kerstm@(Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegLevel
seg_thd SegSpace
seg_space [Type]
kres_tps KernelBody GPU
old_kbody))))
  | (SegThread Count NumGroups SubExp
_ Count GroupSize SubExp
seg_group_size SegVirt
_novirt) <- SegLevel
seg_thd,
    -- novirt == SegNoVirtFull || novirt == SegNoVirt,
    KernelBody () Stms GPU
kstms [KernelResult]
kres <- KernelBody GPU
old_kbody,
    Just ([VName]
css, [SubExp]
r_ses) <- [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres,
    forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
css,
    -- build the variance table, that records, for
    -- each variable name, the variables it depends on
    VarianceTable
initial_variance <- forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space,
    VarianceTable
variance <- VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
initial_variance Stms GPU
kstms,
    -- check that the code fits the pattern having:
    -- some `code1`, followed by one accumulation, followed by some `code2`
    -- UpdateAcc VName [SubExp] [SubExp]
    (Stms GPU
code1, Just Stm GPU
accum_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms,
    Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
_aux_acc (BasicOp (UpdateAcc VName
acc_nm [SubExp]
acc_inds [SubExp]
acc_vals)) <- Stm GPU
accum_stmt,
    [VName
pat_acc_nm] <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat_accum,
    -- check that the `acc_inds` are invariant to at least one
    -- parallel kernel dimensions, and return the innermost such one:
    Just (VName
invar_gid, Int
gid_ind) <- Names
-> SegSpace -> VarianceTable -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim forall a. Monoid a => a
mempty SegSpace
seg_space VarianceTable
variance [SubExp]
acc_inds,
    [(VName, SubExp)]
gid_dims_new_0 <- forall a. (a -> Bool) -> [a] -> [a]
filter (\(VName, SubExp)
x -> VName
invar_gid forall a. Eq a => a -> a -> Bool
/= forall a b. (a, b) -> a
fst (VName, SubExp)
x) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space),
    -- reorder the variant dimensions such that inner(most) accum-indices
    -- correspond to inner(most) parallel dimensions, so that the babysitter
    -- does not introduce transpositions
    -- gid_dims_new <- gid_dims_new_0,
    [(VName, SubExp)]
gid_dims_new <- forall {b}.
VarianceTable -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims VarianceTable
variance [SubExp]
acc_inds [(VName, SubExp)]
gid_dims_new_0,
    -- check that all global-memory accesses in `code1` on which
    --   `accum_stmt` depends on are invariant to at least one of
    --   the remaining parallel dimensions (i.e., excluding `invar_gid`)
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName
-> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
isTileable VName
invar_gid [(VName, SubExp)]
gid_dims_new VarianceTable
variance VName
pat_acc_nm) (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPU
code1),
    -- need to establish a cost model for the stms that would now
    --   be redundantly executed by the two kernels. If any recurence
    --   is redundant than it is a no go. Otherwise we need to look at
    --   memory accesses: if more than two are re-executed, then we
    --   should abort.
    Cost
cost <- VarianceTable -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution VarianceTable
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms,
    Cost -> Cost -> Cost
maxCost Cost
cost (Int -> Cost
Small Int
2) forall a. Eq a => a -> a -> Bool
== Int -> Cost
Small Int
2 = do
      -- 1. create the first kernel
      Type
acc_tp <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
acc_nm
      let inv_dim_len :: SubExp
inv_dim_len = SegSpace -> [SubExp]
segSpaceDims SegSpace
seg_space forall a. [a] -> Int -> a
!! Int
gid_ind
          -- 1.1. get the accumulation operator
          ((Lambda GPU
redop0, [SubExp]
neutral), [Type]
el_tps) = Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp
      Lambda GPU
redop <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
redop0
      let red :: Reduce GPU
red =
            Reduce
              { redComm :: Commutativity
redComm = Commutativity
Commutative,
                redLambda :: Lambda GPU
redLambda = Lambda GPU
redop,
                redNeutral :: [SubExp]
redNeutral = [SubExp]
neutral
              }
          -- 1.2. build the sequential map-reduce screma
          code1' :: Stms GPU
code1' =
            forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall a b. (a -> b) -> a -> b
$
              forall a. (a -> Bool) -> [a] -> [a]
filter (forall {k} {k} {rep :: k}.
Ord k =>
k -> Map k Names -> Stm rep -> Bool
dependsOnAcc VName
pat_acc_nm VarianceTable
variance) forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPU
code1
      (Stms GPU
code1'', Stms GPU
code1_tr_host) <- Names
-> VarianceTable
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs (forall a. FreeIn a => a -> Names
freeIn Stm GPU
kerstm) VarianceTable
variance VName
invar_gid Stms GPU
code1'
      let map_lam_body :: Body GPU
map_lam_body = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
code1'' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes ([VName] -> Certs
Certs [])) [SubExp]
acc_vals
          map_lam0 :: Lambda GPU
map_lam0 = forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
invar_gid (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)] Body GPU
map_lam_body [Type]
el_tps
      Lambda GPU
map_lam <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam0
      (SubExp
k1_res, Stms GPU
ker1_stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' forall a b. (a -> b) -> a -> b
$ do
        VName
iota <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
inv_dim_len (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
        let op_exp :: Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
op_exp = forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
inv_dim_len [VName
iota] (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [Reduce GPU
red] Lambda GPU
map_lam)))
        [VName]
res_redmap <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"res_mapred" Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
op_exp
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
pat_acc_nm forall a. [a] -> [a] -> [a]
++ [Char]
"_big_update") forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc_nm [SubExp]
acc_inds forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
res_redmap)

      -- 1.3. build the kernel expression and rename it!
      VName
gid_flat_1 <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
      let space1 :: SegSpace
space1 = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat_1 [(VName, SubExp)]
gid_dims_new

      (SubExp
grid_size, Stms GPU
host_stms1) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
        let grid_pexp :: TPrimExp Int64 VName
grid_pexp = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
x SubExp
d -> TPrimExp Int64 VName
x forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
d) (SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
gid_dims_new
        SubExp
dim_prod <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"dim_prod" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
grid_pexp
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"grid_size" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} {f :: * -> *} {rep :: k}.
Applicative f =>
SubExp -> SubExp -> f (Exp rep)
ceilDiv SubExp
dim_prod (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
seg_group_size)
      let level1 :: SegLevel
level1 = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (forall {k} (u :: k) e. e -> Count u e
Count SubExp
grid_size) Count GroupSize SubExp
seg_group_size (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])) -- novirt ?
          kbody1 :: KernelBody GPU
kbody1 = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
ker1_stms [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify ([VName] -> Certs
Certs []) SubExp
k1_res]

      -- is it OK here to use the "aux" from the parrent kernel?
      Exp GPU
ker_exp <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp (forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
level1 SegSpace
space1 [Type
acc_tp] KernelBody GPU
kbody1))
      let ker1 :: Stm GPU
ker1 = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
aux Exp GPU
ker_exp

      -- 2 build the second kernel
      let ker2_body :: KernelBody GPU
ker2_body = KernelBody GPU
old_kbody {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
code1 forall a. Semigroup a => a -> a -> a
<> Stms GPU
code2}
      Exp GPU
ker2_exp <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp (forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
seg_thd SegSpace
seg_space [Type]
kres_tps KernelBody GPU
ker2_body))
      let ker2 :: Stm GPU
ker2 = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux Exp GPU
ker2_exp
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. a -> Maybe a
Just (Stms GPU
code1_tr_host forall a. Semigroup a => a -> a -> a
<> Stms GPU
host_stms1 forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
ker1, Stm GPU
ker2)
  where
    isIndVarToParDim :: VarianceTable -> SubExp -> (VName, b) -> Bool
isIndVarToParDim VarianceTable
_ (Constant PrimValue
_) (VName, b)
_ = Bool
False
    isIndVarToParDim VarianceTable
variance (Var VName
acc_ind) (VName, b)
par_dim =
      VName
acc_ind forall a. Eq a => a -> a -> Bool
== forall a b. (a, b) -> a
fst (VName, b)
par_dim
        Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn (forall a b. (a, b) -> a
fst (VName, b)
par_dim) (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
acc_ind VarianceTable
variance)
    foldfunReorder :: VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder VarianceTable
variance ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims) SubExp
acc_ind =
      case forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (forall {b}. VarianceTable -> SubExp -> (VName, b) -> Bool
isIndVarToParDim VarianceTable
variance SubExp
acc_ind) [(VName, b)]
unused_dims of
        Maybe Int
Nothing -> ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims)
        Just Int
i ->
          ( forall a. Int -> [a] -> [a]
take Int
i [(VName, b)]
unused_dims forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop (Int
i forall a. Num a => a -> a -> a
+ Int
1) [(VName, b)]
unused_dims,
            ([(VName, b)]
unused_dims forall a. [a] -> Int -> a
!! Int
i) forall a. a -> [a] -> [a]
: [(VName, b)]
inner_dims
          )
    reorderParDims :: VarianceTable -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims VarianceTable
variance [SubExp]
acc_inds [(VName, b)]
gid_dims_new_0 =
      let ([(VName, b)]
invar_dims, [(VName, b)]
inner_dims) =
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              (forall {b}.
VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder VarianceTable
variance)
              ([(VName, b)]
gid_dims_new_0, [])
              (forall a. [a] -> [a]
reverse [SubExp]
acc_inds)
       in [(VName, b)]
invar_dims forall a. [a] -> [a] -> [a]
++ [(VName, b)]
inner_dims
    --
    ceilDiv :: SubExp -> SubExp -> f (Exp rep)
ceilDiv SubExp
x SubExp
y = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
x SubExp
y
    getAccLambda :: Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp =
      case Type
acc_tp of
        (Acc VName
tp_id ShapeBase SubExp
_shp [Type]
el_tps NoUniqueness
_) ->
          case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tp_id (forall a b. (a, b) -> a
fst Env
env) of
            Just (Lambda GPU, [SubExp])
lam -> ((Lambda GPU, [SubExp])
lam, [Type]
el_tps)
            Maybe (Lambda GPU, [SubExp])
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Lookup in environment failed! " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
tp_id forall a. [a] -> [a] -> [a]
++ [Char]
" env: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall a b. (a, b) -> a
fst Env
env)
        Type
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"Illegal accumulator type!"
    -- is a subexp invariant to a gid of a parallel dimension?
    isSeInvar2 :: VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid (Var VName
x) =
      let x_deps :: Names
x_deps = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
x VarianceTable
variance
       in VName
gid forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
&& VName
gid VName -> Names -> Bool
`notNameIn` Names
x_deps
    isSeInvar2 VarianceTable
_ VName
_ SubExp
_ = Bool
True
    -- is a DimIndex invar to a gid of a parallel dimension?
    isDimIdxInvar2 :: VarianceTable -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 VarianceTable
variance VName
gid (DimFix SubExp
d) =
      VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid SubExp
d
    isDimIdxInvar2 VarianceTable
variance VName
gid (DimSlice SubExp
d1 SubExp
d2 SubExp
d3) =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid) [SubExp
d1, SubExp
d2, SubExp
d3]
    -- is an entire slice invariant to at least one gid of a parallel dimension
    isSliceInvar2 :: VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\VName
gid -> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VarianceTable -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 VarianceTable
variance VName
gid) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc))
    -- are all statements that touch memory invariant to at least one parallel dimension?
    isTileable :: VName -> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
    isTileable :: VName
-> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
isTileable VName
seq_gid [(VName, SubExp)]
gid_dims VarianceTable
variance VName
acc_nm (Let (Pat [PatElem (LetDec GPU)
pel]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc)))
      | Names
acc_deps <- forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
acc_nm VarianceTable
variance,
        forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPU)
pel VName -> Names -> Bool
`nameIn` Names
acc_deps =
          let invar_par :: Bool
invar_par = forall {t :: * -> *}.
Foldable t =>
VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
gid_dims)
              invar_seq :: Bool
invar_seq = forall {t :: * -> *}.
Foldable t =>
VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc [VName
seq_gid]
           in Bool
invar_par Bool -> Bool -> Bool
|| Bool
invar_seq
    -- this relies on the cost model, that currently accepts only
    -- global-memory reads, and for example rejects in-place updates
    -- or loops inside the code that is transformed in a redomap.
    isTileable VName
_ [(VName, SubExp)]
_ VarianceTable
_ VName
_ Stm GPU
_ = Bool
True
    -- does the to-be-reduced accumulator depends on this statement?
    dependsOnAcc :: k -> Map k Names -> Stm rep -> Bool
dependsOnAcc k
pat_acc_nm Map k Names
variance (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
_) =
      let acc_deps :: Names
acc_deps = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty k
pat_acc_nm Map k Names
variance
       in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
acc_deps) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
genRed2Tile2d Env
_ Stm GPU
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
_ Stm GPU
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

transposeFVs ::
  Names ->
  VarianceTable ->
  VName ->
  Stms GPU ->
  GenRedM (Stms GPU, Stms GPU)
transposeFVs :: Names
-> VarianceTable
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs Names
fvs VarianceTable
variance VName
gid Stms GPU
stms = do
  (Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
stms') <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (forall k a. Map k a
M.empty, forall a. Monoid a => a
mempty) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
  let stms_host :: Stms GPU
stms_host = forall a b k. (a -> b -> b) -> b -> Map k a -> b
M.foldr (\([Int]
_, VName
_, Stms GPU
s) Stms GPU
ss -> Stms GPU
ss forall a. Semigroup a => a -> a -> a
<> Stms GPU
s) forall a. Monoid a => a
mempty Map VName ([Int], VName, Stms GPU)
tab
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Stms GPU
stms_host)
  where
    foldfun :: (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
all_stms) Stm GPU
stm = do
      (Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm') <- (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Stm GPU
stm)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stms GPU
all_stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
stm')
    -- ToDo: currently handles only 2-dim arrays, please generalize
    transposeFV :: (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (Index VName
arr Slice SubExp
slc)))
      | [DimIndex SubExp]
dims <- forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc,
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {d}. DimIndex d -> Bool
isFixDim [DimIndex SubExp]
dims,
        VName
arr VName -> Names -> Bool
`nameIn` Names
fvs,
        [Int]
iis <- forall a. (a -> Bool) -> [a] -> [Int]
L.findIndices DimIndex SubExp -> Bool
depOnGid [DimIndex SubExp]
dims,
        [Int
ii] <- [Int]
iis,
        -- generalize below: treat any rearange and add to tab if not there.
        Maybe ([Int], VName, Stms GPU)
Nothing <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Map VName ([Int], VName, Stms GPU)
tab,
        Int
ii forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims forall a. Num a => a -> a -> a
- Int
1,
        [Int]
perm <- [Int
0 .. Int
ii forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
ii forall a. Num a => a -> a -> a
+ Int
1 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
ii] = do
          (VName
arr_tr, Stms GPU
stms_tr) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' forall a b. (a -> b) -> a -> b
$ do
            VName
arr' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_trsp") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr -- Manifest [1,0] arr
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr' forall a. [a] -> [a] -> [a]
++ [Char]
"_opaque") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ OpaqueOp -> SubExp -> BasicOp
Opaque OpaqueOp
OpaqueNil forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr'
          let tab' :: Map VName ([Int], VName, Stms GPU)
tab' = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
arr ([Int]
perm, VName
arr_tr, Stms GPU
stms_tr) Map VName ([Int], VName, Stms GPU)
tab
              slc' :: Slice SubExp
slc' = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ([DimIndex SubExp]
dims !!) [Int]
perm
              stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_tr Slice SubExp
slc'
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm')
      where
        isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
        isFixDim DimIndex d
_ = Bool
False
        depOnGid :: DimIndex SubExp -> Bool
depOnGid (DimFix (Var VName
nm)) =
          VName
gid forall a. Eq a => a -> a -> Bool
== VName
nm Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
nm VarianceTable
variance)
        depOnGid DimIndex SubExp
_ = Bool
False
    transposeFV (Map VName ([Int], VName, Stms GPU), Stm GPU)
r = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU), Stm GPU)
r

-- | Tries to identify the following pattern:
--   code followed by some UpdateAcc-statement
--   followed by more code.
matchCodeAccumCode ::
  Stms GPU ->
  (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms =
  let ([Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU]
code2) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc Stm GPU
stmt ->
              case (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc, Stm GPU
stmt) of
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) ->
                  ([Stm GPU]
cd1, forall a. a -> Maybe a
Just Stm GPU
stmt, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1 forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt], forall a. Maybe a
Nothing, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Just Stm GPU
strm, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1, forall a. a -> Maybe a
Just Stm GPU
strm, [Stm GPU]
cd2 forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt])
          )
          ([], forall a. Maybe a
Nothing, [])
          (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
   in (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code1, Maybe (Stm GPU)
screma, forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code2)

-- | Checks that there exist a parallel dimension (among @kids@),
--     to which all the indices (@acc_inds@) are invariant to.
--   It returns the innermost such parallel dimension, as a tuple
--     of the pardim gid ('VName') and its index ('Int') in the
--     parallel space.
isInvarToParDim ::
  Names ->
  SegSpace ->
  VarianceTable ->
  [SubExp] ->
  Maybe (VName, Int)
isInvarToParDim :: Names
-> SegSpace -> VarianceTable -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim Names
branch_variant SegSpace
kspace VarianceTable
variance [SubExp]
acc_inds =
  let ker_gids :: [VName]
ker_gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
      branch_invariant :: Bool
branch_invariant = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
branch_variant) [VName]
ker_gids
      allvar2 :: Names
allvar2 = [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
acc_inds [VName]
ker_gids
      last_invar_dim :: Maybe (VName, Int)
last_invar_dim =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (forall {b}.
Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2) forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$
          forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ker_gids [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_gids forall a. Num a => a -> a -> a
- Int
1]
   in if Bool
branch_invariant
        then Maybe (VName, Int)
last_invar_dim
        else forall a. Maybe a
Nothing
  where
    variant2 :: SubExp -> [VName] -> [VName]
variant2 (Var VName
ind) [VName]
kids =
      let variant_to :: Names
variant_to =
            forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
ind VarianceTable
variance
              forall a. Semigroup a => a -> a -> a
<> (if VName
ind forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
kids then VName -> Names
oneName VName
ind else forall a. Monoid a => a
mempty)
       in forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
variant_to) [VName]
kids
    variant2 SubExp
_ [VName]
_ = []
    allvariant2 :: [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
ind_ses [VName]
kids =
      [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (SubExp -> [VName] -> [VName]
`variant2` [VName]
kids) [SubExp]
ind_ses
    lastNotIn :: Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2 Maybe (VName, b)
acc (VName
kid, b
k) =
      if VName
kid VName -> Names -> Bool
`nameIn` Names
allvar2 then Maybe (VName, b)
acc else forall a. a -> Maybe a
Just (VName
kid, b
k)

allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all KernelResult -> Bool
goodReturn [KernelResult]
kres = do
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([], []) [KernelResult]
kres
  where
    goodReturn :: KernelResult -> Bool
goodReturn (Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
_) = Bool
True
    goodReturn KernelResult
_ = Bool
False
    addCertAndRes :: ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([VName]
cs, [SubExp]
rs) (Returns ResultManifest
ResultMaySimplify Certs
c SubExp
r_se) =
      ([VName]
cs forall a. [a] -> [a] -> [a]
++ Certs -> [VName]
unCerts Certs
c, [SubExp]
rs forall a. [a] -> [a] -> [a]
++ [SubExp
r_se])
    addCertAndRes ([VName], [SubExp])
_ KernelResult
_ =
      forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible case reached in GenRedOpt.hs, function allGoodReturns!"
allGoodReturns [KernelResult]
_ = forall a. Maybe a
Nothing

--------------------------
--- Cost Model Helpers ---
--------------------------

costRedundantExecution ::
  VarianceTable ->
  VName ->
  [SubExp] ->
  Stms GPU ->
  Cost
costRedundantExecution :: VarianceTable -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution VarianceTable
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms =
  let acc_deps :: Names
acc_deps = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
pat_acc_nm VarianceTable
variance
      vartab_cut_acc :: VarianceTable
vartab_cut_acc = Names -> VarianceTable -> Stms GPU -> VarianceTable
varianceInStmsWithout (VName -> Names
oneName VName
pat_acc_nm) forall a. Monoid a => a
mempty Stms GPU
kstms
      res_deps :: Names
res_deps = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall {k} {a}. (Ord k, Monoid a) => Map k a -> k -> a
findDeps VarianceTable
vartab_cut_acc) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
se2nm [SubExp]
r_ses
      common_deps :: Names
common_deps = Names -> Names -> Names
namesIntersection Names
res_deps Names
acc_deps
   in forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps) (Int -> Cost
Small Int
0) Stms GPU
kstms
  where
    se2nm :: SubExp -> Maybe VName
se2nm (Var VName
nm) = forall a. a -> Maybe a
Just VName
nm
    se2nm SubExp
_ = forall a. Maybe a
Nothing
    findDeps :: Map k a -> k -> a
findDeps Map k a
vartab k
nm = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty k
nm Map k a
vartab
    addCostOfStmt :: Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps Cost
cur_cost Stm GPU
stm =
      let pat_nms :: [VName]
pat_nms = forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
       in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
common_deps
            then Cost -> Cost -> Cost
addCosts Cost
cur_cost forall a b. (a -> b) -> a -> b
$ Stm GPU -> Cost
costRedundantStmt Stm GPU
stm
            else Cost
cur_cost
    varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable
    varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable
varianceInStmsWithout Names
nms = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (forall {k} {rep :: k}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep),
 FreeIn (Op rep)) =>
Names -> VarianceTable -> Stm rep -> VarianceTable
varianceInStmWithout Names
nms)
    varianceInStmWithout :: Names -> VarianceTable -> Stm rep -> VarianceTable
varianceInStmWithout Names
cuts VarianceTable
vartab Stm rep
stm =
      let pat_nms :: [VName]
pat_nms = forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm
       in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
cuts
            then VarianceTable
vartab
            else forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
vartab [VName]
pat_nms
      where
        add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
        look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v forall a. Semigroup a => a -> a -> a
<> forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
        binding_variance :: Names
binding_variance = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
vartab) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn Stm rep
stm)

data Cost = Small Int | Big | Break
  deriving (Cost -> Cost -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Cost -> Cost -> Bool
$c/= :: Cost -> Cost -> Bool
== :: Cost -> Cost -> Bool
$c== :: Cost -> Cost -> Bool
Eq)

addCosts :: Cost -> Cost -> Cost
addCosts :: Cost -> Cost -> Cost
addCosts Cost
Break Cost
_ = Cost
Break
addCosts Cost
_ Cost
Break = Cost
Break
addCosts Cost
Big Cost
_ = Cost
Big
addCosts Cost
_ Cost
Big = Cost
Big
addCosts (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (Int
c1 forall a. Num a => a -> a -> a
+ Int
c2)

maxCost :: Cost -> Cost -> Cost
maxCost :: Cost -> Cost -> Cost
maxCost (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (forall a. Ord a => a -> a -> a
max Int
c1 Int
c2)
maxCost Cost
c1 Cost
c2 = Cost -> Cost -> Cost
addCosts Cost
c1 Cost
c2

costBody :: Body GPU -> Cost
costBody :: Body GPU -> Cost
costBody Body GPU
bdy =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Cost -> Cost -> Cost
addCosts (Int -> Cost
Small Int
0) forall a b. (a -> b) -> a -> b
$
    forall a b. (a -> b) -> [a] -> [b]
map Stm GPU -> Cost
costRedundantStmt forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
bdy

costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op Op GPU
_)) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ DoLoop {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ Apply {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ WithAcc {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Match [SubExp]
_ [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_)) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Cost -> Cost -> Cost
maxCost (Body GPU -> Cost
costBody Body GPU
defbody) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Body GPU -> Cost
costBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Array {}))) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Type
_))) = Int -> Cost
Small Int
1
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc))) =
  if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {d}. DimIndex d -> Bool
isFixDim (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc) then Int -> Cost
Small Int
1 else Int -> Cost
Small Int
0
  where
    isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
    isFixDim DimIndex d
_ = Bool
False
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatIndex {})) = Int -> Cost
Small Int
0
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Update {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatUpdate {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Concat {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Copy {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Manifest {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Replicate {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) = Int -> Cost
Small Int
0