{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
-- | Extract limited nested parallelism for execution inside
-- individual kernel workgroups.
module Futhark.Pass.ExtractKernels.Intragroup
  (intraGroupParallelise)
where

import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S

import Prelude hiding (log)

import Futhark.Analysis.PrimExp.Convert
import Futhark.Representation.SOACS
import qualified Futhark.Representation.Kernels as Out
import Futhark.Representation.Kernels.Kernel hiding (HistOp)
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Util (chunks)
import Futhark.Util.Log

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding local memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intraGroupParallelise :: (MonadFreshNames m, LocalScope Out.Kernels m) =>
                         KernelNest -> Lambda
                      -> m (Maybe ((SubExp, SubExp), SubExp, Log,
                                   Out.Stms Out.Kernels, Out.Stms Out.Kernels))
intraGroupParallelise :: KernelNest
-> Lambda
-> m (Maybe
        ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
intraGroupParallelise KernelNest
knest Lambda
lam = MaybeT
  m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> m (Maybe
        ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT
   m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
 -> m (Maybe
         ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)))
-> MaybeT
     m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> m (Maybe
        ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
 -> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest

  (SubExp
num_groups, Stms Kernels
w_stms) <- m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels))
-> m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels SubExp -> m (SubExp, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels SubExp -> m (SubExp, Stms Kernels))
-> Binder Kernels SubExp -> m (SubExp, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"intra_num_groups" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    BinOp
-> SubExp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> BinOp
Mul IntType
Int32) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

  let body :: BodyT SOACS
body = Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam

  VName
group_size <- String -> MaybeT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_group_size"
  let intra_lvl :: SegLevel
intra_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt

  ([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody Kernels
kbody) <-
    m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
 -> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$ Scope Kernels
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) (m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
 -> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
intraGroupParalleliseBody SegLevel
intra_lvl BodyT SOACS
body

  [VName]
known_outside <- m [VName] -> MaybeT m [VName]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m [VName] -> MaybeT m [VName]) -> m [VName] -> MaybeT m [VName]
forall a b. (a -> b) -> a -> b
$ Scope Kernels -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope Kernels -> [VName]) -> m (Scope Kernels) -> m [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
known_outside) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> Names
forall a. FreeIn a => a -> Names
freeIn ([[SubExp]] -> Names) -> [[SubExp]] -> Names
forall a b. (a -> b) -> a -> b
$
          [[SubExp]]
wss_min [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail) (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
    String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"

  ((SubExp
intra_avail_par, SegSpace
kspace, Stms Kernels
read_input_stms), Stms Kernels
prelude_stms) <- m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
-> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
 -> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels))
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
-> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels (SubExp, SegSpace, Stms Kernels)
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, SegSpace, Stms Kernels)
 -> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels))
-> Binder Kernels (SubExp, SegSpace, Stms Kernels)
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' BinOp
_    []    = SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> m (Exp (Lore m))) -> SubExp -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
        foldBinOp' BinOp
bop (SubExp
x:[SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
    [SubExp]
ws_min <- ([SubExp] -> Binder Kernels SubExp)
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"one_intra_par_min" (ExpT Kernels -> Binder Kernels SubExp)
-> ([SubExp] -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [SubExp]
-> Binder Kernels SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
Mul IntType
Int32)) ([[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
              ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
    [SubExp]
ws_avail <- ([SubExp] -> Binder Kernels SubExp)
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"one_intra_par_avail" (ExpT Kernels -> Binder Kernels SubExp)
-> ([SubExp] -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [SubExp]
-> Binder Kernels SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
Mul IntType
Int32)) ([[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
                ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail

    -- The amount of parallelism available *in the worst case* is
    -- equal to the smallest parallel loop.
    SubExp
intra_avail_par <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"intra_avail_par" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int32) [SubExp]
ws_avail

    -- The group size is either the maximum of the minimum parallelism
    -- exploited, or the desired parallelism (bounded by the max group
    -- size) in case there is no minimum.
    [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
group_size] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      if [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
      then BinOp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> BinOp
SMin IntType
Int32)
           (SubExp -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> Binder Kernels SubExp
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"max_group_size" (Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup))
           (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
intra_avail_par)
      else BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int32) [SubExp]
ws_min

    let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT SOACS
body
        used_inps :: [KernelInput]
used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps

    Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
w_stms
    Stms Kernels
read_input_stms <- Binder Kernels [()]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels [()]
 -> BinderT Kernels (State VNameSource) (Stms Kernels))
-> Binder Kernels [()]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> Binder Kernels [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelInput -> m ()
readKernelInput [KernelInput]
used_inps
    SegSpace
space <- [(VName, SubExp)] -> BinderT Kernels (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
    (SubExp, SegSpace, Stms Kernels)
-> Binder Kernels (SubExp, SegSpace, Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_avail_par, SegSpace
space, Stms Kernels
read_input_stms)

  let kbody' :: KernelBody Kernels
kbody' = KernelBody Kernels
kbody { kernelBodyStms :: Stms Kernels
kernelBodyStms = Stms Kernels
read_input_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> KernelBody Kernels -> Stms Kernels
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody Kernels
kbody }

  let nested_pat :: Pattern Kernels
nested_pat = LoopNesting -> Pattern Kernels
loopNestingPattern LoopNesting
first_nest
      rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
`stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern Kernels
nested_pat
      lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt
      kstm :: Stm Kernels
kstm = Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
nested_pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody Kernels -> SegOp Kernels
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody Kernels
kbody'

  let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
  ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> MaybeT
     m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
intra_min_par, SubExp
intra_avail_par), VName -> SubExp
Var VName
group_size, Log
log,
           Stms Kernels
prelude_stms, Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
kstm)
  where first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
        cs :: Certificates
cs = LoopNesting -> Certificates
loopNestingCertificates LoopNesting
first_nest

data Acc = Acc { Acc -> Set [SubExp]
accMinPar :: S.Set [SubExp]
               , Acc -> Set [SubExp]
accAvailPar :: S.Set [SubExp]
               , Acc -> Log
accLog :: Log
               }

instance Semigroup Acc where
  Acc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: Acc -> Acc -> Acc
<> Acc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
    Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)

instance Monoid Acc where
  mempty :: Acc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

type IntraGroupM =
  BinderT Out.Kernels (RWS () Acc VNameSource)

instance MonadLogger IntraGroupM where
  addLog :: Log -> IntraGroupM ()
addLog Log
log = Acc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Acc
forall a. Monoid a => a
mempty { accLog :: Log
accLog = Log
log }

runIntraGroupM :: (MonadFreshNames m, HasScope Out.Kernels m) =>
                  IntraGroupM () -> m (Acc, Out.Stms Out.Kernels)
runIntraGroupM :: IntraGroupM () -> m (Acc, Stms Kernels)
runIntraGroupM IntraGroupM ()
m = do
  Scope Kernels
scope <- Scope Kernels -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope Kernels -> Scope Kernels)
-> m (Scope Kernels) -> m (Scope Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (VNameSource -> ((Acc, Stms Kernels), VNameSource))
-> m (Acc, Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Acc, Stms Kernels), VNameSource))
 -> m (Acc, Stms Kernels))
-> (VNameSource -> ((Acc, Stms Kernels), VNameSource))
-> m (Acc, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (((), Stms Kernels
kstms), VNameSource
src', Acc
acc) = RWS () Acc VNameSource ((), Stms Kernels)
-> () -> VNameSource -> (((), Stms Kernels), VNameSource, Acc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntraGroupM ()
-> Scope Kernels -> RWS () Acc VNameSource ((), Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT IntraGroupM ()
m Scope Kernels
scope) () VNameSource
src
    in ((Acc
acc, Stms Kernels
kstms), VNameSource
src')

parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws = Acc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Acc
forall a. Monoid a => a
mempty { accMinPar :: Set [SubExp]
accMinPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws
                             , accAvailPar :: Set [SubExp]
accAvailPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws
                             }

intraGroupBody :: SegLevel -> Body -> IntraGroupM (Out.Body Out.Kernels)
intraGroupBody :: SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
body = do
  Stms Kernels
stms <- IntraGroupM ()
-> BinderT
     Kernels
     (RWST () Acc VNameSource Identity)
     (Stms (Lore IntraGroupM))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (IntraGroupM ()
 -> BinderT
      Kernels
      (RWST () Acc VNameSource Identity)
      (Stms (Lore IntraGroupM)))
-> IntraGroupM ()
-> BinderT
     Kernels
     (RWST () Acc VNameSource Identity)
     (Stms (Lore IntraGroupM))
forall a b. (a -> b) -> a -> b
$ (Stm -> IntraGroupM ()) -> Seq Stm -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl) (Seq Stm -> IntraGroupM ()) -> Seq Stm -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Seq Stm
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
  Body Kernels -> IntraGroupM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> IntraGroupM (Body Kernels))
-> Body Kernels -> IntraGroupM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [SubExp] -> Body Kernels
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms Kernels
stms ([SubExp] -> Body Kernels) -> [SubExp] -> Body Kernels
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body

intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl stm :: Stm
stm@(Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
aux Exp SOACS
e) = do
  Scope Kernels
scope <- BinderT Kernels (RWST () Acc VNameSource Identity) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let lvl' :: SegLevel
lvl' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt

  case Exp SOACS
e of
    DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loopbody ->
      Scope Kernels -> IntraGroupM () -> IntraGroupM ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm Kernels
form') (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
      Scope Kernels -> IntraGroupM () -> IntraGroupM ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope Kernels
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope Kernels)
-> [Param DeclType] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
      Body Kernels
loopbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
loopbody
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
val LoopForm Kernels
form' Body Kernels
loopbody'
          where form' :: LoopForm Kernels
form' = case LoopForm SOACS
form of
                          ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps -> VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam Kernels, VName)]
inps
                          WhileLoop VName
cond          -> VName -> LoopForm Kernels
forall lore. VName -> LoopForm lore
WhileLoop VName
cond

    If SubExp
cond BodyT SOACS
tbody BodyT SOACS
fbody IfAttr (BranchType SOACS)
ifattr -> do
      Body Kernels
tbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
tbody
      Body Kernels
fbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
fbody
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body Kernels
-> Body Kernels
-> IfAttr (BranchType Kernels)
-> ExpT Kernels
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond Body Kernels
tbody' Body Kernels
fbody' IfAttr (BranchType SOACS)
IfAttr (BranchType Kernels)
ifattr

    Op (Screma w form arrs)
      | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form -> do
      let loopnest :: LoopNesting
loopnest = Pattern Kernels
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pattern SOACS
Pattern Kernels
pat (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
          env :: DistEnv IntraGroupM
env = DistEnv :: forall (m :: * -> *).
Nestings
-> Scope Kernels
-> (Seq Stm -> DistNestT m (Stms Kernels))
-> (MapLoop -> DistAcc -> DistNestT m DistAcc)
-> MkSegLevel m
-> DistEnv m
DistEnv { distNest :: Nestings
distNest =
                            Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest
                        , distScope :: Scope Kernels
distScope =
                            PatternT Type -> Scope Kernels
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern PatternT Type
Pattern SOACS
pat Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<>
                            Scope SOACS -> Scope Kernels
scopeForKernels (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam) Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope Kernels
scope
                        , distOnInnerMap :: MapLoop -> DistAcc -> DistNestT IntraGroupM DistAcc
distOnInnerMap =
                            MapLoop -> DistAcc -> DistNestT IntraGroupM DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
MapLoop -> DistAcc -> DistNestT m DistAcc
distributeMap
                        , distOnTopLevelStms :: Seq Stm -> DistNestT IntraGroupM (Stms Kernels)
distOnTopLevelStms =
                            BinderT Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> DistNestT IntraGroupM (Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (BinderT Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
 -> DistNestT IntraGroupM (Stms Kernels))
-> (Seq Stm
    -> BinderT
         Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> Seq Stm
-> DistNestT IntraGroupM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntraGroupM ()
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (IntraGroupM ()
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> (Seq Stm -> IntraGroupM ())
-> Seq Stm
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Seq Stm -> IntraGroupM ()
intraGroupStms SegLevel
lvl
                        , distSegLevel :: MkSegLevel IntraGroupM
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
                            IntraGroupM () -> BinderT Kernels IntraGroupM ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntraGroupM () -> BinderT Kernels IntraGroupM ())
-> IntraGroupM () -> BinderT Kernels IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
                            SegLevel -> BinderT Kernels IntraGroupM SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return SegLevel
lvl
                        }
          acc :: DistAcc
acc = DistAcc :: Targets -> Stms Kernels -> DistAcc
DistAcc { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pattern SOACS
Pattern Kernels
pat, BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT SOACS -> [SubExp]) -> BodyT SOACS -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam)
                        , distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty
                        }

      Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
        DistEnv IntraGroupM
-> DistNestT IntraGroupM DistAcc
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *).
MonadLogger m =>
DistEnv m -> DistNestT m DistAcc -> m (Stms Kernels)
runDistNestT DistEnv IntraGroupM
env (DistAcc -> Seq Stm -> DistNestT IntraGroupM DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> Seq Stm -> DistNestT m DistAcc
distributeMapBodyStms DistAcc
acc (BodyT SOACS -> Seq Stm
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Seq Stm) -> BodyT SOACS -> Seq Stm
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam))

    Op (Screma w form arrs)
      | Just (Lambda
scanfun, [SubExp]
nes, Lambda
mapfun) <- ScremaForm SOACS -> Maybe (Lambda, [SubExp], Lambda)
forall lore.
ScremaForm lore -> Maybe (Lambda lore, [SubExp], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form -> do
      let scanfun' :: Lambda Kernels
scanfun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
scanfun
          mapfun' :: Lambda Kernels
mapfun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
mapfun
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Pattern Kernels
-> SubExp
-> Lambda Kernels
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> Lambda Kernels
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segScan SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w Lambda Kernels
scanfun' Lambda Kernels
mapfun' [SubExp]
nes [VName]
arrs [] []
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]

    Op (Screma w form arrs)
      | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
        Reduce Commutativity
comm Lambda
red_lam [SubExp]
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds -> do
      let red_lam' :: Lambda Kernels
red_lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
red_lam
          map_lam' :: Lambda Kernels
map_lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
map_lam
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segRed SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [Commutativity
-> Lambda Kernels -> [SubExp] -> Shape -> SegRedOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegRedOp lore
SegRedOp Commutativity
comm Lambda Kernels
red_lam' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda Kernels
map_lam' [VName]
arrs [] []
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]


    Op (Hist w ops bucket_fun arrs) -> do
      [HistOp Kernels]
ops' <- [HistOp SOACS]
-> (HistOp SOACS
    -> BinderT
         Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
  -> BinderT
       Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels])
-> (HistOp SOACS
    -> BinderT
         Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda
op) -> do
        (Lambda Kernels
op', [SubExp]
nes', Shape
shape) <- Lambda
-> [SubExp]
-> BinderT
     Kernels
     (RWST () Acc VNameSource Identity)
     (Lambda Kernels, [SubExp], Shape)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
Lambda -> [SubExp] -> m (Lambda Kernels, [SubExp], Shape)
determineReduceOp Lambda
op [SubExp]
nes
        HistOp Kernels
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp Kernels
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> HistOp Kernels
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda Kernels
-> HistOp Kernels
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
Out.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda Kernels
op'

      let bucket_fun' :: Lambda Kernels
bucket_fun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
bucket_fun
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Pattern Kernels
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp Kernels]
-> Lambda Kernels
-> [VName]
-> m (Stms Kernels)
segHist SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [] [] [HistOp Kernels]
ops' Lambda Kernels
bucket_fun' [VName]
arrs
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]

    Op (Stream w (Sequential accs) lam arrs)
      | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam -> do
      Scope SOACS
types <- (Scope Kernels -> Scope SOACS)
-> BinderT Kernels (RWST () Acc VNameSource Identity) (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope
      ((), Seq Stm
stream_bnds) <-
        BinderT SOACS IntraGroupM ()
-> Scope SOACS
-> BinderT Kernels (RWST () Acc VNameSource Identity) ((), Seq Stm)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS IntraGroupM))
-> SubExp
-> [SubExp]
-> LambdaT (Lore (BinderT SOACS IntraGroupM))
-> [VName]
-> BinderT SOACS IntraGroupM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> [SubExp] -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS IntraGroupM))
Pattern SOACS
pat SubExp
w [SubExp]
accs LambdaT (Lore (BinderT SOACS IntraGroupM))
Lambda
lam [VName]
arrs) Scope SOACS
types
      let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
          replace SubExp
se = SubExp
se
          replaceSets :: Acc -> Acc
replaceSets (Acc Set [SubExp]
x Set [SubExp]
y Log
log) =
            Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
      (Acc -> Acc) -> IntraGroupM () -> IntraGroupM ()
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor Acc -> Acc
replaceSets (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ (Stm -> IntraGroupM ()) -> Seq Stm -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl) Seq Stm
stream_bnds

    Op (Scatter w lam ivs dests) -> do
      VName
write_i <- String -> BinderT Kernels (RWST () Acc VNameSource Identity) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
      SegSpace
space <- [(VName, SubExp)]
-> BinderT Kernels (RWST () Acc VNameSource Identity) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]

      let lam' :: Lambda Kernels
lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam
          ([SubExp]
dests_ws, [Int]
dests_ns, [VName]
dests_vs) = [(SubExp, Int, VName)] -> ([SubExp], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
dests
          ([SubExp]
i_res, [SubExp]
v_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
dests_ns) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ Body Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body Kernels -> [SubExp]) -> Body Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
          krets :: [KernelResult]
krets = do (SubExp
a_w, VName
a, [(SubExp, SubExp)]
is_vs) <- [SubExp]
-> [VName]
-> [[(SubExp, SubExp)]]
-> [(SubExp, VName, [(SubExp, SubExp)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExp]
dests_ws [VName]
dests_vs ([[(SubExp, SubExp)]] -> [(SubExp, VName, [(SubExp, SubExp)])])
-> [[(SubExp, SubExp)]] -> [(SubExp, VName, [(SubExp, SubExp)])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
dests_ns ([(SubExp, SubExp)] -> [[(SubExp, SubExp)]])
-> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
i_res [SubExp]
v_res
                     KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> [([SubExp], SubExp)] -> KernelResult
WriteReturns [SubExp
a_w] VName
a [ ([SubExp
i],SubExp
v) | (SubExp
i,SubExp
v) <- [(SubExp, SubExp)]
is_vs ]
          inputs :: [KernelInput]
inputs = do (Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam') [VName]
ivs
                      KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
p) (Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]

      Stms Kernels
kstms <- BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) ()
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Scope Kernels
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (BinderT Kernels (State VNameSource) ()
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ do
        (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
        Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Body Kernels -> Stms Kernels) -> Body Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'

      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
        let ts :: [Type]
ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern SOACS
pat
            body :: KernelBody Kernels
body = BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms [KernelResult]
krets
        Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody Kernels -> SegOp Kernels
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl' SegSpace
space [Type]
ts KernelBody Kernels
body

      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]

    Exp SOACS
_ ->
      Stm (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore IntraGroupM) -> IntraGroupM ())
-> Stm (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Stm -> Stm Kernels
soacsStmToKernels Stm
stm

intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms :: SegLevel -> Seq Stm -> IntraGroupM ()
intraGroupStms SegLevel
lvl = (Stm -> IntraGroupM ()) -> Seq Stm -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl)

intraGroupParalleliseBody :: (MonadFreshNames m, HasScope Out.Kernels m) =>
                             SegLevel -> Body
                          -> m ([[SubExp]], [[SubExp]], Log, Out.KernelBody Out.Kernels)
intraGroupParalleliseBody :: SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
intraGroupParalleliseBody SegLevel
lvl BodyT SOACS
body = do
  (Acc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms Kernels
kstms) <-
    IntraGroupM () -> m (Acc, Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
IntraGroupM () -> m (Acc, Stms Kernels)
runIntraGroupM (IntraGroupM () -> m (Acc, Stms Kernels))
-> IntraGroupM () -> m (Acc, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel -> Seq Stm -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Seq Stm -> IntraGroupM ()) -> Seq Stm -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Seq Stm
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
  ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws, Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws, Log
log,
          BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult]) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body)