{-# LANGUAGE TypeFamilies #-}

-- | Tries to turn a generalized reduction kernel into
--     a more specialized construct, for example:
--       (a) a map nest with a sequential redomap ripe for tiling
--       (b) a SegRed kernel followed by a smallish accumulation kernel.
--       (c) a histogram (for this we need to track the withAccs)
--   The idea is to identify the first accumulation and
--     to separate the initial kernels into two:
--     1. the code up to and including the accumulation,
--        which is optimized to turn the accumulation either
--        into a map-reduce composition or a histogram, and
--     2. the remaining code, which is recursively optimized.
--   Since this is mostly prototyping, when the accumulation
--     can be rewritten as a map-reduce, we sequentialize the
--     map-reduce, as to potentially enable tiling oportunities.
module Futhark.Optimise.GenRedOpt (optimiseGenRed) where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Builder
import Futhark.IR.GPU
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename

type GenRedM = ReaderT (Scope GPU) (State VNameSource)

-- | The pass definition.
optimiseGenRed :: Pass GPU GPU
optimiseGenRed :: Pass GPU GPU
optimiseGenRed =
  [Char] -> [Char] -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"optimise generalized reductions" [Char]
"Specializes generalized reductions into map-reductions or histograms" ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$
    (Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
      (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU))
-> (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms GPU)
 -> VNameSource -> (Stms GPU, VNameSource))
-> State VNameSource (Stms GPU)
-> VNameSource
-> (Stms GPU, VNameSource)
forall a b. (a -> b) -> a -> b
$
          ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms (WithEnv
forall k a. Map k a
M.empty, Map VName IxFun
forall k a. Map k a
M.empty) Stms GPU
stms) Scope GPU
scope

optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env (Body () Stms GPU
stms Result
res) =
  BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU -> Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) Result
-> GenRedM (Body GPU)
forall a b.
ReaderT (Scope GPU) (State VNameSource) (a -> b)
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope GPU) (State VNameSource) Result
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseStms :: Env -> Stms GPU -> GenRedM (Stms GPU)
optimiseStms :: Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms =
  Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    (Env
_, Stms GPU
stms') <- ((Env, Stms GPU)
 -> Stm GPU
 -> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> (Env, Stms GPU)
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
env, Stms GPU
forall a. Monoid a => a
mempty) ([Stm GPU]
 -> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
    Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms'
  where
    foldfun :: (Env, Stms GPU) -> Stm GPU -> GenRedM (Env, Stms GPU)
    foldfun :: (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
e, Stms GPU
ss) Stm GPU
s = do
      (Env
e', Stms GPU
s') <- Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
e Stm GPU
s
      (Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
e', Stms GPU
ss Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s')

optimiseStm :: Env -> Stm GPU -> GenRedM (Env, Stms GPU)
optimiseStm :: Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
env stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (SegOp (SegMap SegThread {} SegSpace
_ [Type]
_ KernelBody GPU
_)))) = do
  Maybe (Stms GPU)
res_genred_opt <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
stm
  let stms' :: Stms GPU
stms' =
        case Maybe (Stms GPU)
res_genred_opt of
          Just Stms GPU
stms -> Stms GPU
stms
          Maybe (Stms GPU)
Nothing -> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stm
  (Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env, Stms GPU
stms')
optimiseStm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Env
env' <- Env -> VName -> Exp GPU -> TileM Env
changeEnv Env
env ([VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat) Exp GPU
e
  Exp GPU
e' <- Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env') Exp GPU
e
  (Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env', Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> Stm GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e')
  where
    optimise :: Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env' = Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper {mapOnBody :: Scope GPU -> Body GPU -> GenRedM (Body GPU)
mapOnBody = \Scope GPU
scope -> Scope GPU -> GenRedM (Body GPU) -> GenRedM (Body GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (GenRedM (Body GPU) -> GenRedM (Body GPU))
-> (Body GPU -> GenRedM (Body GPU))
-> Body GPU
-> GenRedM (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env'}

------------------------

genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker = do
  Maybe (Stms GPU, Stm GPU)
res_tile <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env Stm GPU
ker
  case Maybe (Stms GPU, Stm GPU)
res_tile of
    Maybe (Stms GPU, Stm GPU)
Nothing -> do
      Maybe (Stms GPU, Stm GPU)
res_sgrd <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
env Stm GPU
ker
      Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_sgrd
    Maybe (Stms GPU, Stm GPU)
_ -> Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_tile
  where
    helperGenRed :: Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
Nothing = Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU)
forall a. Maybe a
Nothing
    helperGenRed (Just (Stms GPU
stms_before, Stm GPU
ker_snd)) = do
      Maybe (Stms GPU)
mb_stms_after <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker_snd
      case Maybe (Stms GPU)
mb_stms_after of
        Just Stms GPU
stms_after -> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU)))
-> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Maybe (Stms GPU)
forall a. a -> Maybe a
Just (Stms GPU -> Maybe (Stms GPU)) -> Stms GPU -> Maybe (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms_after
        Maybe (Stms GPU)
Nothing -> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU)))
-> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Maybe (Stms GPU)
forall a. a -> Maybe a
Just (Stms GPU -> Maybe (Stms GPU)) -> Stms GPU -> Maybe (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
ker_snd

genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env kerstm :: Stm GPU
kerstm@(Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegLevel
seg_thd SegSpace
seg_space [Type]
kres_tps KernelBody GPU
old_kbody))))
  | SegThread SegVirt
_novirt Maybe KernelGrid
_ <- SegLevel
seg_thd,
    -- novirt == SegNoVirtFull || novirt == SegNoVirt,
    KernelBody () Stms GPU
kstms [KernelResult]
kres <- KernelBody GPU
old_kbody,
    Just ([VName]
css, [SubExp]
r_ses) <- [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres,
    [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
css,
    -- build the variance table, that records, for
    -- each variable name, the variables it depends on
    VarianceTable
initial_variance <- (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> VarianceTable)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space,
    VarianceTable
variance <- VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
initial_variance Stms GPU
kstms,
    -- check that the code fits the pattern having:
    -- some `code1`, followed by one accumulation, followed by some `code2`
    -- UpdateAcc VName [SubExp] [SubExp]
    (Stms GPU
code1, Just Stm GPU
accum_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms,
    Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
_aux_acc (BasicOp (UpdateAcc VName
acc_nm [SubExp]
acc_inds [SubExp]
acc_vals)) <- Stm GPU
accum_stmt,
    [VName
pat_acc_nm] <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat_accum,
    -- check that the `acc_inds` are invariant to at least one
    -- parallel kernel dimensions, and return the innermost such one:
    Just (VName
invar_gid, Int
gid_ind) <- Names
-> SegSpace -> VarianceTable -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim Names
forall a. Monoid a => a
mempty SegSpace
seg_space VarianceTable
variance [SubExp]
acc_inds,
    [(VName, SubExp)]
gid_dims_new_0 <- ((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(VName, SubExp)
x -> VName
invar_gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst (VName, SubExp)
x) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space),
    -- reorder the variant dimensions such that inner(most) accum-indices
    -- correspond to inner(most) parallel dimensions, so that the babysitter
    -- does not introduce transpositions
    -- gid_dims_new <- gid_dims_new_0,
    [(VName, SubExp)]
gid_dims_new <- VarianceTable -> [SubExp] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall {b}.
VarianceTable -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims VarianceTable
variance [SubExp]
acc_inds [(VName, SubExp)]
gid_dims_new_0,
    -- check that all global-memory accesses in `code1` on which
    --   `accum_stmt` depends on are invariant to at least one of
    --   the remaining parallel dimensions (i.e., excluding `invar_gid`)
    (Stm GPU -> Bool) -> [Stm GPU] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName
-> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
isTileable VName
invar_gid [(VName, SubExp)]
gid_dims_new VarianceTable
variance VName
pat_acc_nm) (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
code1),
    -- need to establish a cost model for the stms that would now
    --   be redundantly executed by the two kernels. If any recurence
    --   is redundant than it is a no go. Otherwise we need to look at
    --   memory accesses: if more than two are re-executed, then we
    --   should abort.
    Cost
cost <- VarianceTable -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution VarianceTable
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms,
    Cost -> Cost -> Cost
maxCost Cost
cost (Int -> Cost
Small Int
2) Cost -> Cost -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Cost
Small Int
2 = do
      -- 1. create the first kernel
      Type
acc_tp <- VName -> ReaderT (Scope GPU) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc_nm
      let inv_dim_len :: SubExp
inv_dim_len = SegSpace -> [SubExp]
segSpaceDims SegSpace
seg_space [SubExp] -> Int -> SubExp
forall a. HasCallStack => [a] -> Int -> a
!! Int
gid_ind
          -- 1.1. get the accumulation operator
          ((Lambda GPU
redop0, [SubExp]
neutral), [Type]
el_tps) = Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp
      Lambda GPU
redop <- Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
redop0
      let red :: Reduce GPU
red =
            Reduce
              { redComm :: Commutativity
redComm = Commutativity
Commutative,
                redLambda :: Lambda GPU
redLambda = Lambda GPU
redop,
                redNeutral :: [SubExp]
redNeutral = [SubExp]
neutral
              }
          -- 1.2. build the sequential map-reduce screma
          code1' :: Stms GPU
code1' =
            [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU) -> [Stm GPU] -> Stms GPU
forall a b. (a -> b) -> a -> b
$
              (Stm GPU -> Bool) -> [Stm GPU] -> [Stm GPU]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> VarianceTable -> Stm GPU -> Bool
forall {k} {rep}. Ord k => k -> Map k Names -> Stm rep -> Bool
dependsOnAcc VName
pat_acc_nm VarianceTable
variance) ([Stm GPU] -> [Stm GPU]) -> [Stm GPU] -> [Stm GPU]
forall a b. (a -> b) -> a -> b
$
                Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
code1
      (Stms GPU
code1'', Stms GPU
code1_tr_host) <- Names
-> VarianceTable
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs (Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
kerstm) VarianceTable
variance VName
invar_gid Stms GPU
code1'
      let map_lam_body :: Body GPU
map_lam_body = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
code1'' (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExpRes) -> [SubExp] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes ([VName] -> Certs
Certs [])) [SubExp]
acc_vals
          map_lam0 :: Lambda GPU
map_lam0 = [LParam GPU] -> [Type] -> Body GPU -> Lambda GPU
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
invar_gid (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)] [Type]
el_tps Body GPU
map_lam_body
      Lambda GPU
map_lam <- Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam0
      (SubExp
k1_res, Stms GPU
ker1_stms) <- BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
 -> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
        VName
iota <- [Char]
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> Exp
      (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)))))
-> BasicOp
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
inv_dim_len (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
        let op_exp :: Exp GPU
op_exp = OpC GPU GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SOAC GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (SubExp -> [VName] -> ScremaForm GPU -> SOAC GPU
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
inv_dim_len [VName
iota] ([Scan GPU] -> [Reduce GPU] -> Lambda GPU -> ScremaForm GPU
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [Reduce GPU
red] Lambda GPU
map_lam)))
        [VName]
res_redmap <- [Char]
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"res_mapred" Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
Exp GPU
op_exp
        [Char]
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
pat_acc_nm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_big_update") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc_nm [SubExp]
acc_inds ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
res_redmap)

      -- 1.3. build the kernel expression and rename it!
      VName
gid_flat_1 <- [Char] -> ReaderT (Scope GPU) (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
      let space1 :: SegSpace
space1 = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat_1 [(VName, SubExp)]
gid_dims_new

      let level1 :: SegLevel
level1 = SegVirt -> Maybe KernelGrid -> SegLevel
SegThread (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])) Maybe KernelGrid
forall a. Maybe a
Nothing -- novirt ?
          kbody1 :: KernelBody GPU
kbody1 = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
ker1_stms [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify ([VName] -> Certs
Certs []) SubExp
k1_res]

      -- is it OK here to use the "aux" from the parrent kernel?
      Exp GPU
ker_exp <- Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall a b. (a -> b) -> a -> b
$ OpC GPU GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
level1 SegSpace
space1 [Type
acc_tp] KernelBody GPU
kbody1))
      let ker1 :: Stm GPU
ker1 = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
aux Exp GPU
ker_exp

      -- 2 build the second kernel
      let ker2_body :: KernelBody GPU
ker2_body = KernelBody GPU
old_kbody {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
code1 Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
code2}
      Exp GPU
ker2_exp <- Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall a b. (a -> b) -> a -> b
$ OpC GPU GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
seg_thd SegSpace
seg_space [Type]
kres_tps KernelBody GPU
ker2_body))
      let ker2 :: Stm GPU
ker2 = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux Exp GPU
ker2_exp
      Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU)))
-> Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a b. (a -> b) -> a -> b
$
        (Stms GPU, Stm GPU) -> Maybe (Stms GPU, Stm GPU)
forall a. a -> Maybe a
Just (Stms GPU
code1_tr_host Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
ker1, Stm GPU
ker2)
  where
    isIndVarToParDim :: VarianceTable -> SubExp -> (VName, b) -> Bool
isIndVarToParDim VarianceTable
_ (Constant PrimValue
_) (VName, b)
_ = Bool
False
    isIndVarToParDim VarianceTable
variance (Var VName
acc_ind) (VName, b)
par_dim =
      VName
acc_ind VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== (VName, b) -> VName
forall a b. (a, b) -> a
fst (VName, b)
par_dim
        Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn ((VName, b) -> VName
forall a b. (a, b) -> a
fst (VName, b)
par_dim) (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
acc_ind VarianceTable
variance)
    foldfunReorder :: VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder VarianceTable
variance ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims) SubExp
acc_ind =
      case ((VName, b) -> Bool) -> [(VName, b)] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (VarianceTable -> SubExp -> (VName, b) -> Bool
forall {b}. VarianceTable -> SubExp -> (VName, b) -> Bool
isIndVarToParDim VarianceTable
variance SubExp
acc_ind) [(VName, b)]
unused_dims of
        Maybe Int
Nothing -> ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims)
        Just Int
i ->
          ( Int -> [(VName, b)] -> [(VName, b)]
forall a. Int -> [a] -> [a]
take Int
i [(VName, b)]
unused_dims [(VName, b)] -> [(VName, b)] -> [(VName, b)]
forall a. [a] -> [a] -> [a]
++ Int -> [(VName, b)] -> [(VName, b)]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [(VName, b)]
unused_dims,
            ([(VName, b)]
unused_dims [(VName, b)] -> Int -> (VName, b)
forall a. HasCallStack => [a] -> Int -> a
!! Int
i) (VName, b) -> [(VName, b)] -> [(VName, b)]
forall a. a -> [a] -> [a]
: [(VName, b)]
inner_dims
          )
    reorderParDims :: VarianceTable -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims VarianceTable
variance [SubExp]
acc_inds [(VName, b)]
gid_dims_new_0 =
      let ([(VName, b)]
invar_dims, [(VName, b)]
inner_dims) =
            (([(VName, b)], [(VName, b)])
 -> SubExp -> ([(VName, b)], [(VName, b)]))
-> ([(VName, b)], [(VName, b)])
-> [SubExp]
-> ([(VName, b)], [(VName, b)])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              (VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
forall {b}.
VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder VarianceTable
variance)
              ([(VName, b)]
gid_dims_new_0, [])
              ([SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse [SubExp]
acc_inds)
       in [(VName, b)]
invar_dims [(VName, b)] -> [(VName, b)] -> [(VName, b)]
forall a. [a] -> [a] -> [a]
++ [(VName, b)]
inner_dims
    --
    getAccLambda :: Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp =
      case Type
acc_tp of
        (Acc VName
tp_id ShapeBase SubExp
_shp [Type]
el_tps NoUniqueness
_) ->
          case VName -> WithEnv -> Maybe (Lambda GPU, [SubExp])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tp_id (Env -> WithEnv
forall a b. (a, b) -> a
fst Env
env) of
            Just (Lambda GPU, [SubExp])
lam -> ((Lambda GPU, [SubExp])
lam, [Type]
el_tps)
            Maybe (Lambda GPU, [SubExp])
_ -> [Char] -> ((Lambda GPU, [SubExp]), [Type])
forall a. HasCallStack => [Char] -> a
error ([Char] -> ((Lambda GPU, [SubExp]), [Type]))
-> [Char] -> ((Lambda GPU, [SubExp]), [Type])
forall a b. (a -> b) -> a -> b
$ [Char]
"Lookup in environment failed! " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
tp_id [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" env: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ WithEnv -> [Char]
forall a. Show a => a -> [Char]
show (Env -> WithEnv
forall a b. (a, b) -> a
fst Env
env)
        Type
_ -> [Char] -> ((Lambda GPU, [SubExp]), [Type])
forall a. HasCallStack => [Char] -> a
error [Char]
"Illegal accumulator type!"
    -- is a subexp invariant to a gid of a parallel dimension?
    isSeInvar2 :: VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid (Var VName
x) =
      let x_deps :: Names
x_deps = Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
x VarianceTable
variance
       in VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
&& VName
gid VName -> Names -> Bool
`notNameIn` Names
x_deps
    isSeInvar2 VarianceTable
_ VName
_ SubExp
_ = Bool
True
    -- is a DimIndex invar to a gid of a parallel dimension?
    isDimIdxInvar2 :: VarianceTable -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 VarianceTable
variance VName
gid (DimFix SubExp
d) =
      VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid SubExp
d
    isDimIdxInvar2 VarianceTable
variance VName
gid (DimSlice SubExp
d1 SubExp
d2 SubExp
d3) =
      (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid) [SubExp
d1, SubExp
d2, SubExp
d3]
    -- is an entire slice invariant to at least one gid of a parallel dimension
    isSliceInvar2 :: VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc =
      (VName -> Bool) -> t VName -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\VName
gid -> (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VarianceTable -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 VarianceTable
variance VName
gid) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc))
    -- are all statements that touch memory invariant to at least one parallel dimension?
    isTileable :: VName -> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
    isTileable :: VName
-> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
isTileable VName
seq_gid [(VName, SubExp)]
gid_dims VarianceTable
variance VName
acc_nm (Let (Pat [PatElem (LetDec GPU)
pel]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc)))
      | Names
acc_deps <- Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
acc_nm VarianceTable
variance,
        PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
PatElem (LetDec GPU)
pel VName -> Names -> Bool
`nameIn` Names
acc_deps =
          let invar_par :: Bool
invar_par = VarianceTable -> Slice SubExp -> [VName] -> Bool
forall {t :: * -> *}.
Foldable t =>
VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
gid_dims)
              invar_seq :: Bool
invar_seq = VarianceTable -> Slice SubExp -> [VName] -> Bool
forall {t :: * -> *}.
Foldable t =>
VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc [VName
seq_gid]
           in Bool
invar_par Bool -> Bool -> Bool
|| Bool
invar_seq
    -- this relies on the cost model, that currently accepts only
    -- global-memory reads, and for example rejects in-place updates
    -- or loops inside the code that is transformed in a redomap.
    isTileable VName
_ [(VName, SubExp)]
_ VarianceTable
_ VName
_ Stm GPU
_ = Bool
True
    -- does the to-be-reduced accumulator depends on this statement?
    dependsOnAcc :: k -> Map k Names -> Stm rep -> Bool
dependsOnAcc k
pat_acc_nm Map k Names
variance (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
_) =
      let acc_deps :: Names
acc_deps = Names -> k -> Map k Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty k
pat_acc_nm Map k Names
variance
       in (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
acc_deps) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
genRed2Tile2d Env
_ Stm GPU
_ =
  Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing

genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
_ Stm GPU
_ =
  Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing

transposeFVs ::
  Names ->
  VarianceTable ->
  VName ->
  Stms GPU ->
  GenRedM (Stms GPU, Stms GPU)
transposeFVs :: Names
-> VarianceTable
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs Names
fvs VarianceTable
variance VName
gid Stms GPU
stms = do
  (Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
stms') <- ((Map VName ([Int], VName, Stms GPU), Stms GPU)
 -> Stm GPU
 -> ReaderT
      (Scope GPU)
      (State VNameSource)
      (Map VName ([Int], VName, Stms GPU), Stms GPU))
-> (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> [Stm GPU]
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (Map VName ([Int], VName, Stms GPU)
forall k a. Map k a
M.empty, Stms GPU
forall a. Monoid a => a
mempty) ([Stm GPU]
 -> ReaderT
      (Scope GPU)
      (State VNameSource)
      (Map VName ([Int], VName, Stms GPU), Stms GPU))
-> [Stm GPU]
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
  let stms_host :: Stms GPU
stms_host = (([Int], VName, Stms GPU) -> Stms GPU -> Stms GPU)
-> Stms GPU -> Map VName ([Int], VName, Stms GPU) -> Stms GPU
forall a b k. (a -> b -> b) -> b -> Map k a -> b
M.foldr (\([Int]
_, VName
_, Stms GPU
s) Stms GPU
ss -> Stms GPU
ss Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s) Stms GPU
forall a. Monoid a => a
mempty Map VName ([Int], VName, Stms GPU)
tab
  (Stms GPU, Stms GPU) -> GenRedM (Stms GPU, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Stms GPU
stms_host)
  where
    foldfun :: (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
all_stms) Stm GPU
stm = do
      (Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm') <- (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Stm GPU
stm)
      (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stms GPU
all_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stm')
    -- ToDo: currently handles only 2-dim arrays, please generalize
    transposeFV :: (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (Index VName
arr Slice SubExp
slc)))
      | [DimIndex SubExp]
dims <- Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc,
        (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimIndex SubExp -> Bool
forall {d}. DimIndex d -> Bool
isFixDim [DimIndex SubExp]
dims,
        VName
arr VName -> Names -> Bool
`nameIn` Names
fvs,
        [Int]
iis <- (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
L.findIndices DimIndex SubExp -> Bool
depOnGid [DimIndex SubExp]
dims,
        [Int
ii] <- [Int]
iis,
        -- generalize below: treat any rearange and add to tab if not there.
        Maybe ([Int], VName, Stms GPU)
Nothing <- VName
-> Map VName ([Int], VName, Stms GPU)
-> Maybe ([Int], VName, Stms GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Map VName ([Int], VName, Stms GPU)
tab,
        Int
ii Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1,
        [Int]
perm <- [Int
0 .. Int
ii Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
ii Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
ii] = do
          (VName
arr_tr, Stms GPU
stms_tr) <- BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
-> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
 -> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
-> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
            VName
arr' <- [Char]
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_trsp") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> Exp
      (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)))))
-> BasicOp
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr -- Manifest [1,0] arr
            [Char]
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr' [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_opaque") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> Exp
      (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)))))
-> BasicOp
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ OpaqueOp -> SubExp -> BasicOp
Opaque OpaqueOp
OpaqueNil (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr'
          let tab' :: Map VName ([Int], VName, Stms GPU)
tab' = VName
-> ([Int], VName, Stms GPU)
-> Map VName ([Int], VName, Stms GPU)
-> Map VName ([Int], VName, Stms GPU)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
arr ([Int]
perm, VName
arr_tr, Stms GPU
stms_tr) Map VName ([Int], VName, Stms GPU)
tab
              slc' :: Slice SubExp
slc' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (Int -> DimIndex SubExp) -> [Int] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map ([DimIndex SubExp]
dims !!) [Int]
perm
              stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_tr Slice SubExp
slc'
          (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm')
      where
        isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
        isFixDim DimIndex d
_ = Bool
False
        depOnGid :: DimIndex SubExp -> Bool
depOnGid (DimFix (Var VName
nm)) =
          VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
nm Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
nm VarianceTable
variance)
        depOnGid DimIndex SubExp
_ = Bool
False
    transposeFV (Map VName ([Int], VName, Stms GPU), Stm GPU)
r = (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU), Stm GPU)
r

-- | Tries to identify the following pattern:
--   code followed by some UpdateAcc-statement
--   followed by more code.
matchCodeAccumCode ::
  Stms GPU ->
  (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms =
  let ([Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU]
code2) =
        (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
 -> Stm GPU -> ([Stm GPU], Maybe (Stm GPU), [Stm GPU]))
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
-> [Stm GPU]
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc Stm GPU
stmt ->
              case (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc, Stm GPU
stmt) of
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) ->
                  ([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
stmt, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Just Stm GPU
strm, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
strm, [Stm GPU]
cd2 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt])
          )
          ([], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [])
          (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
   in ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code2)

-- | Checks that there exist a parallel dimension (among @kids@),
--     to which all the indices (@acc_inds@) are invariant to.
--   It returns the innermost such parallel dimension, as a tuple
--     of the pardim gid ('VName') and its index ('Int') in the
--     parallel space.
isInvarToParDim ::
  Names ->
  SegSpace ->
  VarianceTable ->
  [SubExp] ->
  Maybe (VName, Int)
isInvarToParDim :: Names
-> SegSpace -> VarianceTable -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim Names
branch_variant SegSpace
kspace VarianceTable
variance [SubExp]
acc_inds =
  let ker_gids :: [VName]
ker_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
      branch_invariant :: Bool
branch_invariant = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
branch_variant) [VName]
ker_gids
      allvar2 :: Names
allvar2 = [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
acc_inds [VName]
ker_gids
      last_invar_dim :: Maybe (VName, Int)
last_invar_dim =
        (Maybe (VName, Int) -> (VName, Int) -> Maybe (VName, Int))
-> Maybe (VName, Int) -> [(VName, Int)] -> Maybe (VName, Int)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names -> Maybe (VName, Int) -> (VName, Int) -> Maybe (VName, Int)
forall {b}.
Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2) Maybe (VName, Int)
forall a. Maybe a
Nothing ([(VName, Int)] -> Maybe (VName, Int))
-> [(VName, Int)] -> Maybe (VName, Int)
forall a b. (a -> b) -> a -> b
$
          [VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ker_gids [Int
0 .. [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_gids Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
   in if Bool
branch_invariant
        then Maybe (VName, Int)
last_invar_dim
        else Maybe (VName, Int)
forall a. Maybe a
Nothing
  where
    variant2 :: SubExp -> [VName] -> [VName]
variant2 (Var VName
ind) [VName]
kids =
      let variant_to :: Names
variant_to =
            Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
ind VarianceTable
variance
              Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (if VName
ind VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
kids then VName -> Names
oneName VName
ind else Names
forall a. Monoid a => a
mempty)
       in (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
variant_to) [VName]
kids
    variant2 SubExp
_ [VName]
_ = []
    allvariant2 :: [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
ind_ses [VName]
kids =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> [VName]) -> [SubExp] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (SubExp -> [VName] -> [VName]
`variant2` [VName]
kids) [SubExp]
ind_ses
    lastNotIn :: Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2 Maybe (VName, b)
acc (VName
kid, b
k) =
      if VName
kid VName -> Names -> Bool
`nameIn` Names
allvar2 then Maybe (VName, b)
acc else (VName, b) -> Maybe (VName, b)
forall a. a -> Maybe a
Just (VName
kid, b
k)

allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres
  | (KernelResult -> Bool) -> [KernelResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all KernelResult -> Bool
goodReturn [KernelResult]
kres = do
      ([VName], [SubExp]) -> Maybe ([VName], [SubExp])
forall a. a -> Maybe a
Just (([VName], [SubExp]) -> Maybe ([VName], [SubExp]))
-> ([VName], [SubExp]) -> Maybe ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ (([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp]))
-> ([VName], [SubExp]) -> [KernelResult] -> ([VName], [SubExp])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([], []) [KernelResult]
kres
  where
    goodReturn :: KernelResult -> Bool
goodReturn (Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
_) = Bool
True
    goodReturn KernelResult
_ = Bool
False
    addCertAndRes :: ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([VName]
cs, [SubExp]
rs) (Returns ResultManifest
ResultMaySimplify Certs
c SubExp
r_se) =
      ([VName]
cs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Certs -> [VName]
unCerts Certs
c, [SubExp]
rs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
r_se])
    addCertAndRes ([VName], [SubExp])
_ KernelResult
_ =
      [Char] -> ([VName], [SubExp])
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible case reached in GenRedOpt.hs, function allGoodReturns!"
allGoodReturns [KernelResult]
_ = Maybe ([VName], [SubExp])
forall a. Maybe a
Nothing

--------------------------
--- Cost Model Helpers ---
--------------------------

costRedundantExecution ::
  VarianceTable ->
  VName ->
  [SubExp] ->
  Stms GPU ->
  Cost
costRedundantExecution :: VarianceTable -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution VarianceTable
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms =
  let acc_deps :: Names
acc_deps = Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
pat_acc_nm VarianceTable
variance
      vartab_cut_acc :: VarianceTable
vartab_cut_acc = Names -> VarianceTable -> Stms GPU -> VarianceTable
varianceInStmsWithout (VName -> Names
oneName VName
pat_acc_nm) VarianceTable
forall a. Monoid a => a
mempty Stms GPU
kstms
      res_deps :: Names
res_deps = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
forall {k} {a}. (Ord k, Monoid a) => Map k a -> k -> a
findDeps VarianceTable
vartab_cut_acc) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
se2nm [SubExp]
r_ses
      common_deps :: Names
common_deps = Names -> Names -> Names
namesIntersection Names
res_deps Names
acc_deps
   in (Cost -> Stm GPU -> Cost) -> Cost -> Stms GPU -> Cost
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps) (Int -> Cost
Small Int
0) Stms GPU
kstms
  where
    se2nm :: SubExp -> Maybe VName
se2nm (Var VName
nm) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
nm
    se2nm SubExp
_ = Maybe VName
forall a. Maybe a
Nothing
    findDeps :: Map k a -> k -> a
findDeps Map k a
vartab k
nm = a -> k -> Map k a -> a
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault a
forall a. Monoid a => a
mempty k
nm Map k a
vartab
    addCostOfStmt :: Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps Cost
cur_cost Stm GPU
stm =
      let pat_nms :: [VName]
pat_nms = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
       in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
common_deps
            then Cost -> Cost -> Cost
addCosts Cost
cur_cost (Cost -> Cost) -> Cost -> Cost
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Cost
costRedundantStmt Stm GPU
stm
            else Cost
cur_cost
    varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable
    varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable
varianceInStmsWithout Names
nms = (VarianceTable -> Stm GPU -> VarianceTable)
-> VarianceTable -> Stms GPU -> VarianceTable
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (Names -> VarianceTable -> Stm GPU -> VarianceTable
forall {rep}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Names -> VarianceTable -> Stm rep -> VarianceTable
varianceInStmWithout Names
nms)
    varianceInStmWithout :: Names -> VarianceTable -> Stm rep -> VarianceTable
varianceInStmWithout Names
cuts VarianceTable
vartab Stm rep
stm =
      let pat_nms :: [VName]
pat_nms = Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm
       in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
cuts
            then VarianceTable
vartab
            else (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
vartab [VName]
pat_nms
      where
        add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
        look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
        binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
vartab) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm)

data Cost = Small Int | Big | Break
  deriving (Cost -> Cost -> Bool
(Cost -> Cost -> Bool) -> (Cost -> Cost -> Bool) -> Eq Cost
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Cost -> Cost -> Bool
== :: Cost -> Cost -> Bool
$c/= :: Cost -> Cost -> Bool
/= :: Cost -> Cost -> Bool
Eq)

addCosts :: Cost -> Cost -> Cost
addCosts :: Cost -> Cost -> Cost
addCosts Cost
Break Cost
_ = Cost
Break
addCosts Cost
_ Cost
Break = Cost
Break
addCosts Cost
Big Cost
_ = Cost
Big
addCosts Cost
_ Cost
Big = Cost
Big
addCosts (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)

maxCost :: Cost -> Cost -> Cost
maxCost :: Cost -> Cost -> Cost
maxCost (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
c1 Int
c2)
maxCost Cost
c1 Cost
c2 = Cost -> Cost -> Cost
addCosts Cost
c1 Cost
c2

costBody :: Body GPU -> Cost
costBody :: Body GPU -> Cost
costBody Body GPU
bdy =
  (Cost -> Cost -> Cost) -> Cost -> [Cost] -> Cost
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Cost -> Cost -> Cost
addCosts (Int -> Cost
Small Int
0) ([Cost] -> Cost) -> [Cost] -> Cost
forall a b. (a -> b) -> a -> b
$
    (Stm GPU -> Cost) -> [Stm GPU] -> [Cost]
forall a b. (a -> b) -> [a] -> [b]
map Stm GPU -> Cost
costRedundantStmt ([Stm GPU] -> [Cost]) -> [Stm GPU] -> [Cost]
forall a b. (a -> b) -> a -> b
$
      Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPU -> [Stm GPU]) -> Stms GPU -> [Stm GPU]
forall a b. (a -> b) -> a -> b
$
        Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
bdy

costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op OpC GPU GPU
_)) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ Loop {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ Apply {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ WithAcc {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Match [SubExp]
_ [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_)) =
  (Cost -> Cost -> Cost) -> Cost -> [Cost] -> Cost
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Cost -> Cost -> Cost
maxCost (Body GPU -> Cost
costBody Body GPU
defbody) ([Cost] -> Cost) -> [Cost] -> Cost
forall a b. (a -> b) -> a -> b
$ (Case (Body GPU) -> Cost) -> [Case (Body GPU)] -> [Cost]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPU -> Cost
costBody (Body GPU -> Cost)
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Array {}))) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Type
_))) = Int -> Cost
Small Int
1
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc))) =
  if (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimIndex SubExp -> Bool
forall {d}. DimIndex d -> Bool
isFixDim (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc) then Int -> Cost
Small Int
1 else Int -> Cost
Small Int
0
  where
    isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
    isFixDim DimIndex d
_ = Bool
False
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatIndex {})) = Int -> Cost
Small Int
0
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Update {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatUpdate {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Concat {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Manifest {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Replicate {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) = Int -> Cost
Small Int
0