{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.ImpGen.Kernels.SegRed
( compileSegRed
, compileSegRed'
, DoSegBody
)
where
import Control.Monad.Except
import Data.Maybe
import Data.List (genericLength, zip7)
import Prelude hiding (quot, rem)
import Futhark.Error
import Futhark.Transform.Rename
import Futhark.IR.KernelsMem
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
maxNumOps :: Int32
maxNumOps :: Int32
maxNumOps = Int32
10
type DoSegBody = ([(SubExp, [Imp.Exp])] -> InKernelGen ()) -> InKernelGen ()
compileSegRed :: Pattern KernelsMem
-> SegLevel -> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegRed :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegRed Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds KernelBody KernelsMem
body =
Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])] -> InKernelGen ()
red_cont ->
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let map_arrs :: [PatElemT LetDecMem]
map_arrs = Int -> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
reds) ([PatElemT LetDecMem] -> [PatElemT LetDecMem])
-> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat
(PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem KernelsMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LetDecMem]
map_arrs [KernelResult]
map_res
[(SubExp, [Exp])] -> InKernelGen ()
red_cont ([(SubExp, [Exp])] -> InKernelGen ())
-> [(SubExp, [Exp])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[Exp]] -> [(SubExp, [Exp])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[Exp]] -> [(SubExp, [Exp])]) -> [[Exp]] -> [(SubExp, [Exp])]
forall a b. (a -> b) -> a -> b
$ [Exp] -> [[Exp]]
forall a. a -> [a]
repeat []
compileSegRed' :: Pattern KernelsMem
-> SegLevel -> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body
| [SegBinOp KernelsMem] -> Int32
forall i a. Num i => [a] -> i
genericLength [SegBinOp KernelsMem]
reds Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
maxNumOps =
String -> CallKernelGen ()
forall a. String -> a
compilerLimitationS (String -> CallKernelGen ()) -> String -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
String
"compileSegRed': at most " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int32 -> String
forall a. Show a => a -> String
show Int32
maxNumOps String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
| [(VName
_, Constant (IntValue (Int32Value Int32
1))), (VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body
| Bool
otherwise = do
Exp
group_size' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
Exp
segment_size <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
let use_small_segments :: Exp
use_small_segments = Exp
segment_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
group_size'
Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
use_small_segments
(Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body)
(Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body)
where num_groups :: Count NumGroups SubExp
num_groups = SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
intermediateArrays :: Count GroupSize SubExp -> SubExp
-> SegBinOp KernelsMem
-> InKernelGen [VName]
intermediateArrays :: Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays (Count SubExp
group_size) SubExp
num_threads (SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_) = do
let red_op_params :: [LParam KernelsMem]
red_op_params = Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
([Param LetDecMem]
red_acc_params, [Param LetDecMem]
_) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [LParam KernelsMem]
[Param LetDecMem]
red_op_params
[Param LetDecMem]
-> (Param LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LetDecMem]
red_acc_params ((Param LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName])
-> (Param LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param LetDecMem
p ->
case Param LetDecMem -> LetDecMem
forall dec. Param dec -> dec
paramDec Param LetDecMem
p of
MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"red_arr" PrimType
pt Shape
shape' (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
LetDecMem
_ -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" PrimType
pt Shape
shape (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
groupResultArrays :: Count NumGroups SubExp -> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays (Count SubExp
virt_num_groups) (Count SubExp
group_size) [SegBinOp KernelsMem]
reds =
[SegBinOp KernelsMem]
-> (SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp KernelsMem]
reds ((SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]])
-> (SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
lam [SubExp]
_ Shape
shape) ->
[TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda KernelsMem
lam) ((TypeBase Shape NoUniqueness
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName])
-> (TypeBase Shape NoUniqueness
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
t -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size, SubExp
virt_num_groups] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
perm :: [Int]
perm = [Int
1..Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]
String
-> PrimType
-> Shape
-> Space
-> [Int]
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm String
"group_res_arr" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm
nonsegmentedReduction :: Pattern KernelsMem
-> Count NumGroups SubExp -> Count GroupSize SubExp -> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern KernelsMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
[Exp]
dims' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
let global_tid :: Exp
global_tid = VName -> Exp
Imp.vi32 (VName -> Exp) -> VName -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
w :: Exp
w = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
VName
counter <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
[PrimValue] -> ArrayContents
Imp.ArrayValues ([PrimValue] -> ArrayContents) -> [PrimValue] -> ArrayContents
forall a b. (a -> b) -> a -> b
$ Int -> PrimValue -> [PrimValue]
forall a. Int -> a -> [a]
replicate (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps) (PrimValue -> [PrimValue]) -> PrimValue -> [PrimValue]
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
0
[[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size [SegBinOp KernelsMem]
reds
VName
num_threads <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" Maybe Exp
forall a. Maybe a
Nothing
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_nonseg" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
[[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads)) [SegBinOp KernelsMem]
reds
[VName] -> (VName -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
v Exp
0
let num_elements :: Count Elements Exp
num_elements = Exp -> Count Elements Exp
Imp.elements Exp
w
let elems_per_thread :: Count Elements Exp
elems_per_thread = Count Elements Exp
num_elements Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp -> Count Elements Exp
Imp.elements (KernelConstants -> Exp
kernelNumThreads KernelConstants
constants)
[SegBinOpSlug]
slugs <- ((SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> Exp
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug
(KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants)
(KernelConstants -> Exp
kernelGroupId KernelConstants
constants)) ([(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
[SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [(SegBinOp KernelsMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
[Lambda KernelsMem]
reds_op_renamed <-
KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne KernelConstants
constants ([VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [Exp]
dims') Count Elements Exp
num_elements
Exp
global_tid Count Elements Exp
elems_per_thread VName
num_threads
[SegBinOpSlug]
slugs DoSegBody
body
let segred_pes :: [[PatElemT LetDecMem]]
segred_pes = [Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
reds) ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$
PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
segred_pat
[(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)]
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LetDecMem]]
-> [SegBinOpSlug]
-> [Lambda KernelsMem]
-> [Int32]
-> [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT LetDecMem]]
segred_pes
[SegBinOpSlug]
slugs [Lambda KernelsMem]
reds_op_renamed [Int32
0..]) (((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)
-> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_,
[VName]
red_arrs, [VName]
group_res_arrs, [PatElemT LetDecMem]
pes, SegBinOpSlug
slug, Lambda KernelsMem
red_op_renamed, Int32
i) -> do
let ([Param LetDecMem]
red_x_params, [Param LetDecMem]
red_y_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
KernelConstants
-> [PatElem KernelsMem]
-> Exp
-> Exp
-> [Exp]
-> Exp
-> Exp
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> Exp
-> VName
-> Exp
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo KernelConstants
constants [PatElem KernelsMem]
[PatElemT LetDecMem]
pes (KernelConstants -> Exp
kernelGroupId KernelConstants
constants) Exp
0 [Exp
0] Exp
0
(KernelConstants -> Exp
kernelNumGroups KernelConstants
constants) SegBinOpSlug
slug [LParam KernelsMem]
[Param LetDecMem]
red_x_params [LParam KernelsMem]
[Param LetDecMem]
red_y_params
Lambda KernelsMem
red_op_renamed [SubExp]
nes
Exp
1 VName
counter (PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
i)
VName
sync_arr [VName]
group_res_arrs [VName]
red_arrs
smallSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
[Exp]
dims' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
VName
segment_size_nonzero_v <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"segment_size_nonzero" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 Exp
segment_size
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
VName
num_threads <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
let segment_size_nonzero :: Exp
segment_size_nonzero = VName -> PrimType -> Exp
Imp.var VName
segment_size_nonzero_v PrimType
int32
num_segments :: Exp
num_segments = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims'
segments_per_group :: Exp
segments_per_group = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
segment_size_nonzero
required_groups :: Exp
required_groups = Exp
num_segments Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp
segments_per_group
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-small" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
num_segments
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
segment_size
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
segments_per_group
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
required_groups
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_small" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads)) [SegBinOp KernelsMem]
reds
SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt Exp
required_groups ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id_var' -> do
let group_id' :: Exp
group_id' = VName -> Exp
Imp.vi32 VName
group_id_var'
let ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
segment_index :: Exp
segment_index = (Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
segment_size_nonzero) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ (Exp
group_id' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group)
index_within_segment :: Exp
index_within_segment = Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims') Exp
segment_index
VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) Exp
index_within_segment
let out_of_bounds :: InKernelGen ()
out_of_bounds =
[(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
_ [SubExp]
nes Shape
_, [VName]
red_arrs) ->
[(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
red_arrs [SubExp]
nes) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
ne) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
ltid] SubExp
ne []
in_bounds :: InKernelGen ()
in_bounds =
DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])]
red_res ->
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save results to be reduced" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let red_dests :: [(VName, [Exp])]
red_dests = [VName] -> [[Exp]] -> [(VName, [Exp])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs) ([[Exp]] -> [(VName, [Exp])]) -> [[Exp]] -> [(VName, [Exp])]
forall a b. (a -> b) -> a -> b
$ [Exp] -> [[Exp]]
forall a. a -> [a]
repeat [Exp
ltid]
[((VName, [Exp]), (SubExp, [Exp]))]
-> (((VName, [Exp]), (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])]
-> [(SubExp, [Exp])] -> [((VName, [Exp]), (SubExp, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [Exp])]
red_dests [(SubExp, [Exp])]
red_res) ((((VName, [Exp]), (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [Exp]), (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
d,[Exp]
d_is), (SubExp
res, [Exp]
res_is)) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
d [Exp]
d_is SubExp
res [Exp]
res_is
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function if in bounds" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
segment_size Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
[(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
segment_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group) InKernelGen ()
in_bounds InKernelGen ()
out_of_bounds
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
segment_size Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform segmented scan to imitate reduction" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
_ Shape
_, [VName]
red_arrs) ->
Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) (VName -> Exp
Imp.vi32 VName
num_threads)
(Exp
segment_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
segments_per_group) Lambda KernelsMem
red_op [VName]
red_arrs
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save final values of segments" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
group_id' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
num_segments Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
segments_per_group) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
segred_pes ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)) (((PatElemT LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
arr) -> do
let flat_segment_index :: Exp
flat_segment_index = Exp
group_id' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid
gtids' :: [Exp]
gtids' = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims') Exp
flat_segment_index
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) [Exp]
gtids'
(VName -> SubExp
Var VName
arr) [(Exp
ltidExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segment_size_nonzero Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
largeSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern KernelsMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
[Exp]
dims' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
num_segments :: Exp
num_segments = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims'
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
(Exp
groups_per_segment, Count Elements Exp
elems_per_thread) <-
Exp
-> Exp
-> Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (Exp, Count Elements Exp)
groupsPerSegmentAndElementsPerThread Exp
segment_size Exp
num_segments
Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size'
VName
virt_num_groups <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"virt_num_groups" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments
VName
num_threads <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
VName
threads_per_segment <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"threads_per_segment" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-large" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
num_segments
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
segment_size
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"virt_num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
virt_num_groups
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count NumGroups Exp
num_groups'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count GroupSize Exp
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_thread" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elems_per_thread
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
groups_per_segment
[[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (VName -> SubExp
Var VName
virt_num_groups)) Count GroupSize SubExp
group_size [SegBinOp KernelsMem]
reds
let num_counters :: Int
num_counters = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024
VName
counter <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_counters
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_large" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads)) [SegBinOp KernelsMem]
reds
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (VName -> Exp
Imp.vi32 VName
virt_num_groups) ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id_var -> do
let segment_gtids :: [VName]
segment_gtids = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids
w :: SubExp
w = [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims
group_id :: Exp
group_id = VName -> Exp
Imp.vi32 VName
group_id_var
local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
Exp
flat_segment_id <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"flat_segment_id" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
group_id Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
groups_per_segment
Exp
global_tid <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"global_tid" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
(Exp
group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
local_tid)
Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` (Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
groups_per_segment)
let first_group_for_segment :: Exp
first_group_for_segment = Exp
flat_segment_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
groups_per_segment
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
segment_gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims') Exp
flat_segment_id
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int32
Count Elements Exp
num_elements <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM KernelsMem KernelEnv KernelOp Exp
-> ImpM KernelsMem KernelEnv KernelOp (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w
[SegBinOpSlug]
slugs <- ((SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> Exp
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug Exp
local_tid Exp
group_id) ([(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
[SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [(SegBinOp KernelsMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
[Lambda KernelsMem]
reds_op_renamed <-
KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne KernelConstants
constants ([VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [Exp]
dims') Count Elements Exp
num_elements
Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment
[SegBinOpSlug]
slugs DoSegBody
body
let segred_pes :: [[PatElemT LetDecMem]]
segred_pes = [Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
reds) ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$
PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
segred_pat
multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
[(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)]
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LetDecMem]]
-> [SegBinOpSlug]
-> [Lambda KernelsMem]
-> [Int32]
-> [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT LetDecMem]]
segred_pes
[SegBinOpSlug]
slugs [Lambda KernelsMem]
reds_op_renamed [Int32
0..]) (((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)
-> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda KernelsMem, Int32)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElemT LetDecMem]
pes,
SegBinOpSlug
slug, Lambda KernelsMem
red_op_renamed, Int32
i) -> do
let ([Param LetDecMem]
red_x_params, [Param LetDecMem]
red_y_params) =
Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
KernelConstants
-> [PatElem KernelsMem]
-> Exp
-> Exp
-> [Exp]
-> Exp
-> Exp
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> Exp
-> VName
-> Exp
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo KernelConstants
constants [PatElem KernelsMem]
[PatElemT LetDecMem]
pes
Exp
group_id Exp
flat_segment_id ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
segment_gtids)
Exp
first_group_for_segment Exp
groups_per_segment
SegBinOpSlug
slug [LParam KernelsMem]
[Param LetDecMem]
red_x_params [LParam KernelsMem]
[Param LetDecMem]
red_y_params Lambda KernelsMem
red_op_renamed [SubExp]
nes
(Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters) VName
counter (PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
i)
VName
sync_arr [VName]
group_res_arrs [VName]
red_arrs
one_group_per_segment :: InKernelGen ()
one_group_per_segment =
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves final result to memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOpSlug, [PatElemT LetDecMem])]
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[PatElemT LetDecMem]] -> [(SegBinOpSlug, [PatElemT LetDecMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElemT LetDecMem]]
segred_pes) (((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElemT LetDecMem]
pes) ->
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, (VName, [Exp]))]
-> ((PatElemT LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [(VName, [Exp])] -> [(PatElemT LetDecMem, (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
pes (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug)) (((PatElemT LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
v, (VName
acc, [Exp]
acc_is)) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
v) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [Exp]
acc_is
Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
groups_per_segment Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment
groupsPerSegmentAndElementsPerThread :: Imp.Exp -> Imp.Exp
-> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
-> CallKernelGen (Imp.Exp, Imp.Count Imp.Elements Imp.Exp)
groupsPerSegmentAndElementsPerThread :: Exp
-> Exp
-> Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (Exp, Count Elements Exp)
groupsPerSegmentAndElementsPerThread Exp
segment_size Exp
num_segments Count NumGroups Exp
num_groups_hint Count GroupSize Exp
group_size = do
Exp
groups_per_segment <-
String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"groups_per_segment" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups_hint Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 Exp
num_segments
Exp
elements_per_thread <-
String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"elements_per_thread" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
segment_size Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` (Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
groups_per_segment)
(Exp, Count Elements Exp)
-> CallKernelGen (Exp, Count Elements Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
groups_per_segment, Exp -> Count Elements Exp
Imp.elements Exp
elements_per_thread)
data SegBinOpSlug =
SegBinOpSlug
{ SegBinOpSlug -> SegBinOp KernelsMem
slugOp :: SegBinOp KernelsMem
, SegBinOpSlug -> [VName]
slugArrs :: [VName]
, SegBinOpSlug -> [(VName, [Exp])]
slugAccs :: [(VName, [Imp.Exp])]
}
slugBody :: SegBinOpSlug -> Body KernelsMem
slugBody :: SegBinOpSlug -> Body KernelsMem
slugBody = Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda KernelsMem -> Body KernelsMem)
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> Body KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugParams :: SegBinOpSlug -> [LParam KernelsMem]
slugParams :: SegBinOpSlug -> [LParam KernelsMem]
slugParams = Lambda KernelsMem -> [Param LetDecMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda KernelsMem -> [Param LetDecMem])
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> [Param LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral (SegBinOp KernelsMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape (SegBinOp KernelsMem -> Shape)
-> (SegBinOpSlug -> SegBinOp KernelsMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOpSlug] -> [Commutativity])
-> [SegBinOpSlug]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOpSlug -> Commutativity)
-> [SegBinOpSlug] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map (SegBinOp KernelsMem -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm (SegBinOp KernelsMem -> Commutativity)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp)
accParams, nextParams :: SegBinOpSlug -> [LParam KernelsMem]
accParams :: SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam KernelsMem]
nextParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug
segBinOpSlug :: Imp.Exp -> Imp.Exp -> (SegBinOp KernelsMem, [VName], [VName]) -> InKernelGen SegBinOpSlug
segBinOpSlug :: Exp
-> Exp
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug Exp
local_tid Exp
group_id (SegBinOp KernelsMem
op, [VName]
group_res_arrs, [VName]
param_arrs) =
SegBinOp KernelsMem -> [VName] -> [(VName, [Exp])] -> SegBinOpSlug
SegBinOpSlug SegBinOp KernelsMem
op [VName]
group_res_arrs ([(VName, [Exp])] -> SegBinOpSlug)
-> ImpM KernelsMem KernelEnv KernelOp [(VName, [Exp])]
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
(Param LetDecMem
-> VName -> ImpM KernelsMem KernelEnv KernelOp (VName, [Exp]))
-> [Param LetDecMem]
-> [VName]
-> ImpM KernelsMem KernelEnv KernelOp [(VName, [Exp])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LetDecMem
-> VName -> ImpM KernelsMem KernelEnv KernelOp (VName, [Exp])
mkAcc (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
op)) [VName]
param_arrs
where mkAcc :: Param LetDecMem
-> VName -> ImpM KernelsMem KernelEnv KernelOp (VName, [Exp])
mkAcc Param LetDecMem
p VName
param_arr
| Prim PrimType
t <- Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p,
Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
op) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
VName
acc <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim (VName -> String
baseString (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_acc") PrimType
t
(VName, [Exp]) -> ImpM KernelsMem KernelEnv KernelOp (VName, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
acc, [])
| Bool
otherwise =
(VName, [Exp]) -> ImpM KernelsMem KernelEnv KernelOp (VName, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
param_arr, [Exp
local_tid, Exp
group_id])
reductionStageZero :: KernelConstants
-> [(VName, Imp.Exp)]
-> Imp.Count Imp.Elements Imp.Exp
-> Imp.Exp
-> Imp.Count Imp.Elements Imp.Exp
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero :: KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, Exp)]
ispace Count Elements Exp
num_elements Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
let ([VName]
gtids, [Exp]
_dims) = [(VName, Exp)] -> ([VName], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, Exp)]
ispace
gtid :: VName
gtid = [VName] -> VName
forall a. [a] -> a
last [VName]
gtids
local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
VName
chunk_size <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"chunk_size" PrimType
int32
let ordering :: SplitOrdering
ordering = case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
Commutativity
Commutative -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
threads_per_segment
Commutativity
Noncommutative -> SplitOrdering
SplitContiguous
SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> InKernelGen ()
forall lore r op.
SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
ordering Exp
global_tid Count Elements Exp
elems_per_thread Count Elements Exp
num_elements VName
chunk_size
Maybe (Exp KernelsMem) -> Scope KernelsMem -> InKernelGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp KernelsMem)
forall a. Maybe a
Nothing (Scope KernelsMem -> InKernelGen ())
-> Scope KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Scope KernelsMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LetDecMem] -> Scope KernelsMem)
-> [Param LetDecMem] -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LetDecMem])
-> [SegBinOpSlug] -> [Param LetDecMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam KernelsMem]
SegBinOpSlug -> [Param LetDecMem]
slugParams [SegBinOpSlug]
slugs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the accumulators" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[((VName, [Exp]), SubExp)]
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])] -> [SubExp] -> [((VName, [Exp]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), SubExp
ne) ->
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is) SubExp
ne []
[Lambda KernelsMem]
slugs_op_renamed <- (SegBinOpSlug
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem))
-> [SegBinOpSlug] -> InKernelGen [Lambda KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem))
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp) [SegBinOpSlug]
slugs
let doTheReduction :: InKernelGen ()
doTheReduction =
[(Lambda KernelsMem, SegBinOpSlug)]
-> ((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Lambda KernelsMem]
-> [SegBinOpSlug] -> [(Lambda KernelsMem, SegBinOpSlug)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Lambda KernelsMem]
slugs_op_renamed [SegBinOpSlug]
slugs) (((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ())
-> ((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Lambda KernelsMem
slug_op_renamed, SegBinOpSlug
slug) ->
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"to reduce current chunk, first store our result in memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(Param LetDecMem, (VName, [Exp]))]
-> ((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [Exp])] -> [(Param LetDecMem, (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [Exp]
acc_is)) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is)
[(VName, Param LetDecMem)]
-> ((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LetDecMem] -> [(VName, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug) (SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug)) (((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LetDecMem
p) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Lambda KernelsMem
slug_op_renamed (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread saves the result in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[((VName, [Exp]), Param LetDecMem)]
-> (((VName, [Exp]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])]
-> [Param LetDecMem] -> [((VName, [Exp]), Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug) (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
slug_op_renamed)) ((((VName, [Exp]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [Exp]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), Param LetDecMem
p) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
let comm :: Commutativity
comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs
(Exp
bound, InKernelGen () -> InKernelGen ()
check_bounds) =
case Commutativity
comm of
Commutativity
Commutative -> (VName -> PrimType -> Exp
Imp.var VName
chunk_size PrimType
int32, InKernelGen () -> InKernelGen ()
forall a. a -> a
id)
Commutativity
Noncommutative -> (Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elems_per_thread,
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (VName -> PrimType -> Exp
Imp.var VName
gtid PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements))
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
bound ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
VName
gtid VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
case Commutativity
comm of
Commutativity
Commutative ->
Exp
global_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
VName -> PrimType -> Exp
Imp.var VName
threads_per_segment PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
i
Commutativity
Noncommutative ->
let index_in_segment :: Exp
index_in_segment = Exp
global_tid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
in Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
(Exp
index_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elems_per_thread Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
*
KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
InKernelGen () -> InKernelGen ()
check_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])]
all_red_res -> do
let slugs_res :: [[(SubExp, [Exp])]]
slugs_res = [Int] -> [(SubExp, [Exp])] -> [[(SubExp, [Exp])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOpSlug -> Int) -> [SegBinOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOpSlug -> [SubExp]) -> SegBinOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [Exp])]
all_red_res
[(SegBinOpSlug, [(SubExp, [Exp])])]
-> ((SegBinOpSlug, [(SubExp, [Exp])]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[(SubExp, [Exp])]] -> [(SegBinOpSlug, [(SubExp, [Exp])])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[(SubExp, [Exp])]]
slugs_res) (((SegBinOpSlug, [(SubExp, [Exp])]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOpSlug, [(SubExp, [Exp])]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [(SubExp, [Exp])]
red_res) ->
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, (VName, [Exp]))]
-> ((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [Exp])] -> [(Param LetDecMem, (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [Exp]
acc_is)) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) ([Exp]
acc_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load new values" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, (SubExp, [Exp]))]
-> ((Param LetDecMem, (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(SubExp, [Exp])] -> [(Param LetDecMem, (SubExp, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
nextParams SegBinOpSlug
slug) [(SubExp, [Exp])]
red_res) (((Param LetDecMem, (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
res, [Exp]
res_is)) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
res ([Exp]
res_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply reduction operator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body KernelsMem -> Stms KernelsMem)
-> Body KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"store in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[((VName, [Exp]), SubExp)]
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])] -> [SubExp] -> [((VName, [Exp]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
(SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug)
(Body KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body KernelsMem -> [SubExp]) -> Body KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug)) ((((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), SubExp
se) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is) SubExp
se []
case Commutativity
comm of
Commutativity
Noncommutative -> do
InKernelGen ()
doTheReduction
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread keeps accumulator; others reset to neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let reset_to_neutral :: InKernelGen ()
reset_to_neutral =
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[((VName, [Exp]), SubExp)]
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])] -> [SubExp] -> [((VName, [Exp]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), SubExp
ne) ->
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is) SubExp
ne []
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) InKernelGen ()
reset_to_neutral
Commutativity
_ -> () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
([Lambda KernelsMem], InKernelGen ())
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lambda KernelsMem]
slugs_op_renamed, InKernelGen ()
doTheReduction)
reductionStageOne :: KernelConstants
-> [(VName, Imp.Exp)]
-> Imp.Count Imp.Elements Imp.Exp
-> Imp.Exp
-> Imp.Count Imp.Elements Imp.Exp
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne :: KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne KernelConstants
constants [(VName, Exp)]
ispace Count Elements Exp
num_elements Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
([Lambda KernelsMem]
slugs_op_renamed, InKernelGen ()
doTheReduction) <-
KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, Exp)]
ispace Count Elements Exp
num_elements Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body
case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
Commutativity
Noncommutative ->
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[(Param LetDecMem, (VName, [Exp]))]
-> ((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [Exp])] -> [(Param LetDecMem, (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, (VName, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [Exp]
acc_is)) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) [Exp]
acc_is
Commutativity
_ -> InKernelGen ()
doTheReduction
[Lambda KernelsMem] -> InKernelGen [Lambda KernelsMem]
forall (m :: * -> *) a. Monad m => a -> m a
return [Lambda KernelsMem]
slugs_op_renamed
reductionStageTwo :: KernelConstants
-> [PatElem KernelsMem]
-> Imp.Exp
-> Imp.Exp
-> [Imp.Exp]
-> Imp.Exp
-> Imp.Exp
-> SegBinOpSlug
-> [LParam KernelsMem] -> [LParam KernelsMem]
-> Lambda KernelsMem -> [SubExp]
-> Imp.Exp -> VName -> Imp.Exp -> VName -> [VName] -> [VName]
-> InKernelGen ()
reductionStageTwo :: KernelConstants
-> [PatElem KernelsMem]
-> Exp
-> Exp
-> [Exp]
-> Exp
-> Exp
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> Exp
-> VName
-> Exp
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo KernelConstants
constants [PatElem KernelsMem]
segred_pes
Exp
group_id Exp
flat_segment_id [Exp]
segment_gtids Exp
first_group_for_segment Exp
groups_per_segment
SegBinOpSlug
slug [LParam KernelsMem]
red_x_params [LParam KernelsMem]
red_y_params
Lambda KernelsMem
red_op_renamed [SubExp]
nes
Exp
num_counters VName
counter Exp
counter_i VName
sync_arr [VName]
group_res_arrs [VName]
red_arrs = do
let local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
group_size :: Exp
group_size = KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
VName
old_counter <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old_counter" PrimType
int32
(VName
counter_mem, Space
_, Count Elements Exp
counter_offset) <- VName
-> [Exp]
-> ImpM
KernelsMem KernelEnv KernelOp (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
counter [Exp
counter_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_counters Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
Exp
flat_segment_id Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
num_counters]
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves group result to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, (VName, [Exp]))]
-> ((VName, (VName, [Exp])) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Int -> [(VName, (VName, [Exp]))] -> [(VName, (VName, [Exp]))]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([(VName, (VName, [Exp]))] -> [(VName, (VName, [Exp]))])
-> [(VName, (VName, [Exp]))] -> [(VName, (VName, [Exp]))]
forall a b. (a -> b) -> a -> b
$ [VName] -> [(VName, [Exp])] -> [(VName, (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegBinOpSlug -> [(VName, [Exp])]
slugAccs SegBinOpSlug
slug)) (((VName, (VName, [Exp])) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, (VName, [Exp])) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [Exp]
acc_is)) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
v [Exp
0, Exp
group_id] (VName -> SubExp
Var VName
acc) [Exp]
acc_is
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$ IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32 VName
old_counter VName
counter_mem Count Elements Exp
counter_offset Exp
1
VName -> [Exp] -> Exp -> InKernelGen ()
forall lore r op. VName -> [Exp] -> Exp -> ImpM lore r op ()
sWrite VName
sync_arr [Exp
0] (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
old_counter PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
VName
is_last_group <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"is_last_group" PrimType
Bool
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
is_last_group [] (VName -> SubExp
Var VName
sync_arr) [Exp
0]
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (VName -> PrimType -> Exp
Imp.var VName
is_last_group PrimType
Bool) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32 VName
old_counter VName
counter_mem Count Elements Exp
counter_offset (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$
Exp -> Exp
forall a. Num a => a -> a
negate Exp
groups_per_segment
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read in the per-group-results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Exp
read_per_thread <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"read_per_thread" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
groups_per_segment Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp
group_size
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_x_params [SubExp]
nes) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
ne) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
ne []
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
read_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
Exp
group_res_id <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"group_res_id" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
read_per_thread Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i
Exp
index_of_group_res <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"index_of_group_res" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
first_group_for_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_res_id
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
group_res_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
groups_per_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_y_params [VName]
group_res_arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(Param LetDecMem
p, VName
group_res_arr) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
(VName -> SubExp
Var VName
group_res_arr)
([Exp
0, Exp
index_of_group_res] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is)
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body KernelsMem -> Stms KernelsMem)
-> Body KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_x_params (Body KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body KernelsMem -> [SubExp]) -> Body KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug)) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
se) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
se []
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_x_params [VName]
red_arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"reduce the per-group results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce Exp
group_size Lambda KernelsMem
red_op_renamed [VName]
red_arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"and back to memory with the final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, Param LetDecMem)]
-> ((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
segred_pes ([Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)])
-> [Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)]
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op_renamed) (((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, Param LetDecMem
p) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix
(PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) ([Exp]
segment_gtids[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is)
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []