{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.SegRed
( compileSegRed,
compileSegRed',
DoSegBody,
)
where
import Control.Monad.Except
import Data.List (genericLength, zip7)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.Error
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
maxNumOps :: Int32
maxNumOps :: Int32
maxNumOps = Int32
10
type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen ()
compileSegRed ::
Pattern KernelsMem ->
SegLevel ->
SegSpace ->
[SegBinOp KernelsMem] ->
KernelBody KernelsMem ->
CallKernelGen ()
compileSegRed :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegRed Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds KernelBody KernelsMem
body =
Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ()
red_cont ->
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let map_arrs :: [PatElemT LParamMem]
map_arrs = Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
reds) ([PatElemT LParamMem] -> [PatElemT LParamMem])
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LParamMem
pat
(PatElemT LParamMem -> KernelResult -> InKernelGen ())
-> [PatElemT LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem KernelsMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LParamMem]
map_arrs [KernelResult]
map_res
[(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ()
red_cont ([(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ())
-> [(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [[TPrimExp Int64 ExpLeaf]]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[TPrimExp Int64 ExpLeaf]]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])])
-> [[TPrimExp Int64 ExpLeaf]]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [[TPrimExp Int64 ExpLeaf]]
forall a. a -> [a]
repeat []
compileSegRed' ::
Pattern KernelsMem ->
SegLevel ->
SegSpace ->
[SegBinOp KernelsMem] ->
DoSegBody ->
CallKernelGen ()
compileSegRed' :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body
| [SegBinOp KernelsMem] -> Int32
forall i a. Num i => [a] -> i
genericLength [SegBinOp KernelsMem]
reds Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
maxNumOps =
String -> CallKernelGen ()
forall a. String -> a
compilerLimitationS (String -> CallKernelGen ()) -> String -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
String
"compileSegRed': at most " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int32 -> String
forall a. Show a => a -> String
show Int32
maxNumOps String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
| [(VName
_, Constant (IntValue (Int64Value Int64
1))), (VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body
| Bool
otherwise = do
let group_size' :: TPrimExp Int64 ExpLeaf
group_size' = SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp (SubExp -> TPrimExp Int64 ExpLeaf)
-> SubExp -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
segment_size :: TPrimExp Int64 ExpLeaf
segment_size = SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp (SubExp -> TPrimExp Int64 ExpLeaf)
-> SubExp -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
use_small_segments :: TPrimExp Bool ExpLeaf
use_small_segments = TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
2 TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
group_size'
TPrimExp Bool ExpLeaf
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
TPrimExp Bool ExpLeaf
use_small_segments
(Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body)
(Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body)
where
num_groups :: Count NumGroups SubExp
num_groups = SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
intermediateArrays ::
Count GroupSize SubExp ->
SubExp ->
SegBinOp KernelsMem ->
InKernelGen [VName]
intermediateArrays :: Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays (Count SubExp
group_size) SubExp
num_threads (SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_) = do
let red_op_params :: [LParam KernelsMem]
red_op_params = Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
([Param LParamMem]
red_acc_params, [Param LParamMem]
_) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [LParam KernelsMem]
[Param LParamMem]
red_op_params
[Param LParamMem]
-> (Param LParamMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LParamMem]
red_acc_params ((Param LParamMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName])
-> (Param LParamMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
case Param LParamMem -> LParamMem
forall dec. Param dec -> dec
paramDec Param LParamMem
p of
MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"red_arr" PrimType
pt Shape
shape' (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
LParamMem
_ -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" PrimType
pt Shape
shape (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
groupResultArrays ::
Count NumGroups SubExp ->
Count GroupSize SubExp ->
[SegBinOp KernelsMem] ->
CallKernelGen [[VName]]
groupResultArrays :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays (Count SubExp
virt_num_groups) (Count SubExp
group_size) [SegBinOp KernelsMem]
reds =
[SegBinOp KernelsMem]
-> (SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp KernelsMem]
reds ((SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]])
-> (SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
lam [SubExp]
_ Shape
shape) ->
[TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda KernelsMem
lam) ((TypeBase Shape NoUniqueness
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName])
-> (TypeBase Shape NoUniqueness
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
t -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size, SubExp
virt_num_groups] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
perm :: [Int]
perm = [Int
1 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]
String
-> PrimType
-> Shape
-> Space
-> [Int]
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm String
"group_res_arr" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm
nonsegmentedReduction ::
Pattern KernelsMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegBinOp KernelsMem] ->
DoSegBody ->
CallKernelGen ()
nonsegmentedReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern KernelsMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 ExpLeaf]
dims' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp [SubExp]
dims
num_groups' :: Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count NumGroups SubExp
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count GroupSize SubExp
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count GroupSize SubExp
group_size
global_tid :: TPrimExp Int64 ExpLeaf
global_tid = VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 (VName -> TPrimExp Int64 ExpLeaf)
-> VName -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
w :: TPrimExp Int64 ExpLeaf
w = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a. [a] -> a
last [TPrimExp Int64 ExpLeaf]
dims'
VName
counter <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
[PrimValue] -> ArrayContents
Imp.ArrayValues ([PrimValue] -> ArrayContents) -> [PrimValue] -> ArrayContents
forall a b. (a -> b) -> a -> b
$ Int -> PrimValue -> [PrimValue]
forall a. Int -> a -> [a]
replicate (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps) (PrimValue -> [PrimValue]) -> PrimValue -> [PrimValue]
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
0
[[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size [SegBinOp KernelsMem]
reds
TV Int64
num_threads <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" Maybe Exp
forall a. Maybe a
Nothing
String
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_nonseg" Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
[[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp KernelsMem]
reds
[VName] -> (VName -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
v (TPrimExp Int64 ExpLeaf
0 :: Imp.TExp Int64)
let num_elements :: Count Elements (TPrimExp Int64 ExpLeaf)
num_elements = TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 ExpLeaf
w
elems_per_thread :: Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread =
Count Elements (TPrimExp Int64 ExpLeaf)
num_elements
Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants))
[SegBinOpSlug]
slugs <-
((SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
( TExp Int32
-> TExp Int32
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug
(KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
(KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)
)
([(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$ [SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [(SegBinOp KernelsMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
[Lambda KernelsMem]
reds_op_renamed <-
KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne
KernelConstants
constants
([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 ExpLeaf]
dims')
Count Elements (TPrimExp Int64 ExpLeaf)
num_elements
TPrimExp Int64 ExpLeaf
global_tid
Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)
[SegBinOpSlug]
slugs
DoSegBody
body
let segred_pes :: [[PatElemT LParamMem]]
segred_pes =
[Int] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
reds) ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$
PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LParamMem
segred_pat
[(SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)]
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
( [SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LParamMem]]
-> [SegBinOpSlug]
-> [Lambda KernelsMem]
-> [Integer]
-> [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7
[SegBinOp KernelsMem]
reds
[[VName]]
reds_arrs
[[VName]]
reds_group_res_arrs
[[PatElemT LParamMem]]
segred_pes
[SegBinOpSlug]
slugs
[Lambda KernelsMem]
reds_op_renamed
[Integer
0 ..]
)
(((SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)
-> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \( SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_,
[VName]
red_arrs,
[VName]
group_res_arrs,
[PatElemT LParamMem]
pes,
SegBinOpSlug
slug,
Lambda KernelsMem
red_op_renamed,
Integer
i
) -> do
let ([Param LParamMem]
red_x_params, [Param LParamMem]
red_y_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
KernelConstants
-> [PatElem KernelsMem]
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
KernelConstants
constants
[PatElem KernelsMem]
[PatElemT LParamMem]
pes
(KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)
TExp Int32
0
[TPrimExp Int64 ExpLeaf
0]
TPrimExp Int64 ExpLeaf
0
(TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 ExpLeaf
kernelNumGroups KernelConstants
constants)
SegBinOpSlug
slug
[LParam KernelsMem]
[Param LParamMem]
red_x_params
[LParam KernelsMem]
[Param LParamMem]
red_y_params
Lambda KernelsMem
red_op_renamed
[SubExp]
nes
TExp Int32
1
VName
counter
(Integer -> TExp Int32
forall a. Num a => Integer -> a
fromInteger Integer
i)
VName
sync_arr
[VName]
group_res_arrs
[VName]
red_arrs
smallSegmentsReduction ::
Pattern KernelsMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegBinOp KernelsMem] ->
DoSegBody ->
CallKernelGen ()
smallSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 ExpLeaf]
dims' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp [SubExp]
dims
segment_size :: TPrimExp Int64 ExpLeaf
segment_size = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a. [a] -> a
last [TPrimExp Int64 ExpLeaf]
dims'
TPrimExp Int64 ExpLeaf
segment_size_nonzero <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"segment_size_nonzero" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 ExpLeaf
1 TPrimExp Int64 ExpLeaf
segment_size
let num_groups' :: Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count NumGroups SubExp
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count GroupSize SubExp
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count GroupSize SubExp
group_size
TV Int64
num_threads <- String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
let num_segments :: TPrimExp Int64 ExpLeaf
num_segments = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims'
segments_per_group :: TPrimExp Int64 ExpLeaf
segments_per_group = Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 ExpLeaf
segment_size_nonzero
required_groups :: TExp Int32
required_groups = TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 ExpLeaf -> TExp Int32)
-> TPrimExp Int64 ExpLeaf -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
num_segments TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf
segments_per_group
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-small" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
num_segments
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
segment_size
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
segments_per_group
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
required_groups
String
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_small" Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)) [SegBinOp KernelsMem]
reds
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
required_groups ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id' -> do
let ltid :: TPrimExp Int64 ExpLeaf
ltid = TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 ExpLeaf)
-> TExp Int32 -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
segment_index :: TPrimExp Int64 ExpLeaf
segment_index =
(TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 ExpLeaf
segment_size_nonzero)
TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
segments_per_group)
index_within_segment :: TPrimExp Int64 ExpLeaf
index_within_segment = TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 ExpLeaf
segment_size
(VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 ExpLeaf] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) ([TPrimExp Int64 ExpLeaf] -> InKernelGen ())
-> [TPrimExp Int64 ExpLeaf] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf -> [TPrimExp Int64 ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims') TPrimExp Int64 ExpLeaf
segment_index
VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) TPrimExp Int64 ExpLeaf
index_within_segment
let out_of_bounds :: InKernelGen ()
out_of_bounds =
[(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
_ [SubExp]
nes Shape
_, [VName]
red_arrs) ->
[(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
red_arrs [SubExp]
nes) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
ne) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
arr [TPrimExp Int64 ExpLeaf
ltid] SubExp
ne []
in_bounds :: InKernelGen ()
in_bounds =
DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res ->
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save results to be reduced" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let red_dests :: [(VName, [TPrimExp Int64 ExpLeaf])]
red_dests = [VName]
-> [[TPrimExp Int64 ExpLeaf]]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs) ([[TPrimExp Int64 ExpLeaf]] -> [(VName, [TPrimExp Int64 ExpLeaf])])
-> [[TPrimExp Int64 ExpLeaf]]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [[TPrimExp Int64 ExpLeaf]]
forall a. a -> [a]
repeat [TPrimExp Int64 ExpLeaf
ltid]
[((VName, [TPrimExp Int64 ExpLeaf]),
(SubExp, [TPrimExp Int64 ExpLeaf]))]
-> (((VName, [TPrimExp Int64 ExpLeaf]),
(SubExp, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
-> [((VName, [TPrimExp Int64 ExpLeaf]),
(SubExp, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [TPrimExp Int64 ExpLeaf])]
red_dests [(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res) ((((VName, [TPrimExp Int64 ExpLeaf]),
(SubExp, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]),
(SubExp, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
d, [TPrimExp Int64 ExpLeaf]
d_is), (SubExp
res, [TPrimExp Int64 ExpLeaf]
res_is)) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
d [TPrimExp Int64 ExpLeaf]
d_is SubExp
res [TPrimExp Int64 ExpLeaf]
res_is
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function if in bounds" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool ExpLeaf
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
( TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 ExpLeaf
0
TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(VName, SubExp)] -> TPrimExp Bool ExpLeaf
isActive ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims)
TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group
)
InKernelGen ()
in_bounds
InKernelGen ()
out_of_bounds
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let crossesSegment :: TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
crossesSegment TExp Int32
from TExp Int32
to =
(TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 ExpLeaf
segment_size)
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 ExpLeaf
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform segmented scan to imitate reduction" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
_ Shape
_, [VName]
red_arrs) ->
Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> Lambda KernelsMem
-> [VName]
-> InKernelGen ()
groupScan
((TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
crossesSegment)
(TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
num_threads)
(TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group)
Lambda KernelsMem
red_op
[VName]
red_arrs
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save final values of segments" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen
( TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
num_segments
TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
segments_per_group
)
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(PatElemT LParamMem, VName)]
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [VName] -> [(PatElemT LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LParamMem]
segred_pes ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)) (((PatElemT LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, VName
arr) -> do
let flat_segment_index :: TPrimExp Int64 ExpLeaf
flat_segment_index =
TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
ltid
gtids' :: [TPrimExp Int64 ExpLeaf]
gtids' =
[TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf -> [TPrimExp Int64 ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims') TPrimExp Int64 ExpLeaf
flat_segment_index
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
[TPrimExp Int64 ExpLeaf]
gtids'
(VName -> SubExp
Var VName
arr)
[(TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
1) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segment_size_nonzero TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
largeSegmentsReduction ::
Pattern KernelsMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegBinOp KernelsMem] ->
DoSegBody ->
CallKernelGen ()
largeSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern KernelsMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 ExpLeaf]
dims' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp [SubExp]
dims
num_segments :: TPrimExp Int64 ExpLeaf
num_segments = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims'
segment_size :: TPrimExp Int64 ExpLeaf
segment_size = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a. [a] -> a
last [TPrimExp Int64 ExpLeaf]
dims'
num_groups' :: Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count NumGroups SubExp
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count GroupSize SubExp
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count GroupSize SubExp
group_size
(TPrimExp Int64 ExpLeaf
groups_per_segment, Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread) <-
TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> CallKernelGen
(TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
groupsPerSegmentAndElementsPerThread
TPrimExp Int64 ExpLeaf
segment_size
TPrimExp Int64 ExpLeaf
num_segments
Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups'
Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
TV Int64
virt_num_groups <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"virt_num_groups" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
num_segments
TV Int64
num_threads <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
TV Int64
threads_per_segment <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"threads_per_segment" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-large" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
num_segments
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
segment_size
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"virt_num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_thread" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
groups_per_segment
[[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
virt_num_groups)) Count GroupSize SubExp
group_size [SegBinOp KernelsMem]
reds
let num_counters :: Int
num_counters = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024
VName
counter <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_counters
String
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_large" Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp KernelsMem]
reds
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups)) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
let segment_gtids :: [VName]
segment_gtids = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids
w :: SubExp
w = [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims
local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
flat_segment_id <-
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"flat_segment_id" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 ExpLeaf
groups_per_segment
TPrimExp Int64 ExpLeaf
global_tid <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"global_tid" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
(TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size') TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid)
TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`rem` (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size') TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groups_per_segment)
let first_group_for_segment :: TPrimExp Int64 ExpLeaf
first_group_for_segment = TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groups_per_segment
(VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 ExpLeaf] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
segment_gtids ([TPrimExp Int64 ExpLeaf] -> InKernelGen ())
-> [TPrimExp Int64 ExpLeaf] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf -> [TPrimExp Int64 ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims') (TPrimExp Int64 ExpLeaf -> [TPrimExp Int64 ExpLeaf])
-> TPrimExp Int64 ExpLeaf -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
let num_elements :: Count Elements (TPrimExp Int64 ExpLeaf)
num_elements = TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements (TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp SubExp
w
[SegBinOpSlug]
slugs <-
((SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int32
-> TExp Int32
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TExp Int32
local_tid TExp Int32
group_id) ([(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
[SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [(SegBinOp KernelsMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
[Lambda KernelsMem]
reds_op_renamed <-
KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne
KernelConstants
constants
([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 ExpLeaf]
dims')
Count Elements (TPrimExp Int64 ExpLeaf)
num_elements
TPrimExp Int64 ExpLeaf
global_tid
Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
threads_per_segment)
[SegBinOpSlug]
slugs
DoSegBody
body
let segred_pes :: [[PatElemT LParamMem]]
segred_pes =
[Int] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
reds) ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$
PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LParamMem
segred_pat
multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
[(SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)]
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
( [SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LParamMem]]
-> [SegBinOpSlug]
-> [Lambda KernelsMem]
-> [Integer]
-> [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7
[SegBinOp KernelsMem]
reds
[[VName]]
reds_arrs
[[VName]]
reds_group_res_arrs
[[PatElemT LParamMem]]
segred_pes
[SegBinOpSlug]
slugs
[Lambda KernelsMem]
reds_op_renamed
[Integer
0 ..]
)
(((SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)
-> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LParamMem],
SegBinOpSlug, Lambda KernelsMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \( SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_,
[VName]
red_arrs,
[VName]
group_res_arrs,
[PatElemT LParamMem]
pes,
SegBinOpSlug
slug,
Lambda KernelsMem
red_op_renamed,
Integer
i
) -> do
let ([Param LParamMem]
red_x_params, [Param LParamMem]
red_y_params) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
KernelConstants
-> [PatElem KernelsMem]
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
KernelConstants
constants
[PatElem KernelsMem]
[PatElemT LParamMem]
pes
TExp Int32
group_id
TExp Int32
flat_segment_id
((VName -> TPrimExp Int64 ExpLeaf)
-> [VName] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 [VName]
segment_gtids)
(TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
first_group_for_segment)
TPrimExp Int64 ExpLeaf
groups_per_segment
SegBinOpSlug
slug
[LParam KernelsMem]
[Param LParamMem]
red_x_params
[LParam KernelsMem]
[Param LParamMem]
red_y_params
Lambda KernelsMem
red_op_renamed
[SubExp]
nes
(Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters)
VName
counter
(Integer -> TExp Int32
forall a. Num a => Integer -> a
fromInteger Integer
i)
VName
sync_arr
[VName]
group_res_arrs
[VName]
red_arrs
one_group_per_segment :: InKernelGen ()
one_group_per_segment =
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves final result to memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOpSlug, [PatElemT LParamMem])]
-> ((SegBinOpSlug, [PatElemT LParamMem]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[PatElemT LParamMem]] -> [(SegBinOpSlug, [PatElemT LParamMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElemT LParamMem]]
segred_pes) (((SegBinOpSlug, [PatElemT LParamMem]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOpSlug, [PatElemT LParamMem]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElemT LParamMem]
pes) ->
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
pes (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
v, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
v) ((VName -> TPrimExp Int64 ExpLeaf)
-> [VName] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [TPrimExp Int64 ExpLeaf]
acc_is
TPrimExp Bool ExpLeaf
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment
groupsPerSegmentAndElementsPerThread ::
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
CallKernelGen
( Imp.TExp Int64,
Imp.Count Imp.Elements (Imp.TExp Int64)
)
groupsPerSegmentAndElementsPerThread :: TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> CallKernelGen
(TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
groupsPerSegmentAndElementsPerThread TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
num_segments Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups_hint Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size = do
TPrimExp Int64 ExpLeaf
groups_per_segment <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"groups_per_segment" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups_hint TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 ExpLeaf
1 TPrimExp Int64 ExpLeaf
num_segments
TPrimExp Int64 ExpLeaf
elements_per_thread <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"elements_per_thread" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groups_per_segment)
(TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
-> CallKernelGen
(TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
forall (m :: * -> *) a. Monad m => a -> m a
return (TPrimExp Int64 ExpLeaf
groups_per_segment, TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 ExpLeaf
elements_per_thread)
data SegBinOpSlug = SegBinOpSlug
{ SegBinOpSlug -> SegBinOp KernelsMem
slugOp :: SegBinOp KernelsMem,
SegBinOpSlug -> [VName]
slugArrs :: [VName],
SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs :: [(VName, [Imp.TExp Int64])]
}
slugBody :: SegBinOpSlug -> Body KernelsMem
slugBody :: SegBinOpSlug -> Body KernelsMem
slugBody = Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda KernelsMem -> Body KernelsMem)
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> Body KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugParams :: SegBinOpSlug -> [LParam KernelsMem]
slugParams :: SegBinOpSlug -> [LParam KernelsMem]
slugParams = Lambda KernelsMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda KernelsMem -> [Param LParamMem])
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral (SegBinOp KernelsMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape (SegBinOp KernelsMem -> Shape)
-> (SegBinOpSlug -> SegBinOp KernelsMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOpSlug] -> [Commutativity])
-> [SegBinOpSlug]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOpSlug -> Commutativity)
-> [SegBinOpSlug] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map (SegBinOp KernelsMem -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm (SegBinOp KernelsMem -> Commutativity)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp)
accParams, nextParams :: SegBinOpSlug -> [LParam KernelsMem]
accParams :: SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam KernelsMem]
nextParams SegBinOpSlug
slug = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug
segBinOpSlug :: Imp.TExp Int32 -> Imp.TExp Int32 -> (SegBinOp KernelsMem, [VName], [VName]) -> InKernelGen SegBinOpSlug
segBinOpSlug :: TExp Int32
-> TExp Int32
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TExp Int32
local_tid TExp Int32
group_id (SegBinOp KernelsMem
op, [VName]
group_res_arrs, [VName]
param_arrs) =
SegBinOp KernelsMem
-> [VName] -> [(VName, [TPrimExp Int64 ExpLeaf])] -> SegBinOpSlug
SegBinOpSlug SegBinOp KernelsMem
op [VName]
group_res_arrs
([(VName, [TPrimExp Int64 ExpLeaf])] -> SegBinOpSlug)
-> ImpM
KernelsMem KernelEnv KernelOp [(VName, [TPrimExp Int64 ExpLeaf])]
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param LParamMem
-> VName
-> ImpM
KernelsMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf]))
-> [Param LParamMem]
-> [VName]
-> ImpM
KernelsMem KernelEnv KernelOp [(VName, [TPrimExp Int64 ExpLeaf])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LParamMem
-> VName
-> ImpM
KernelsMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
mkAcc (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
op)) [VName]
param_arrs
where
mkAcc :: Param LParamMem
-> VName
-> ImpM
KernelsMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
mkAcc Param LParamMem
p VName
param_arr
| Prim PrimType
t <- Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p,
Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
op) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
TV Any
acc <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim (VName -> String
baseString (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_acc") PrimType
t
(VName, [TPrimExp Int64 ExpLeaf])
-> ImpM
KernelsMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc, [])
| Bool
otherwise =
(VName, [TPrimExp Int64 ExpLeaf])
-> ImpM
KernelsMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
param_arr, [TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid, TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id])
reductionStageZero ::
KernelConstants ->
[(VName, Imp.TExp Int64)] ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
Imp.TExp Int64 ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
VName ->
[SegBinOpSlug] ->
DoSegBody ->
InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero :: KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TPrimExp Int64 ExpLeaf)]
ispace Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TPrimExp Int64 ExpLeaf
global_tid Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
let ([VName]
gtids, [TPrimExp Int64 ExpLeaf]
_dims) = [(VName, TPrimExp Int64 ExpLeaf)]
-> ([VName], [TPrimExp Int64 ExpLeaf])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, TPrimExp Int64 ExpLeaf)]
ispace
gtid :: TV Int64
gtid = VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
mkTV ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
local_tid :: TPrimExp Int64 ExpLeaf
local_tid = TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 ExpLeaf)
-> TExp Int32 -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TV Int64
chunk_size <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"chunk_size" PrimType
int64
let ordering :: SplitOrdering
ordering = case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
Commutativity
Commutative -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
threads_per_segment
Commutativity
Noncommutative -> SplitOrdering
SplitContiguous
SplitOrdering
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TV Int64
-> InKernelGen ()
forall lore r op.
SplitOrdering
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TV Int64
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
ordering (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
global_tid) Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TV Int64
chunk_size
Maybe (Exp KernelsMem) -> Scope KernelsMem -> InKernelGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp KernelsMem)
forall a. Maybe a
Nothing (Scope KernelsMem -> InKernelGen ())
-> Scope KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope KernelsMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope KernelsMem)
-> [Param LParamMem] -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LParamMem])
-> [SegBinOpSlug] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam KernelsMem]
SegBinOpSlug -> [Param LParamMem]
slugParams [SegBinOpSlug]
slugs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the accumulators" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [SubExp] -> [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), SubExp
ne) ->
Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) SubExp
ne []
[Lambda KernelsMem]
slugs_op_renamed <- (SegBinOpSlug
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem))
-> [SegBinOpSlug] -> InKernelGen [Lambda KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem))
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp) [SegBinOpSlug]
slugs
let doTheReduction :: InKernelGen ()
doTheReduction =
[(Lambda KernelsMem, SegBinOpSlug)]
-> ((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Lambda KernelsMem]
-> [SegBinOpSlug] -> [(Lambda KernelsMem, SegBinOpSlug)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Lambda KernelsMem]
slugs_op_renamed [SegBinOpSlug]
slugs) (((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ())
-> ((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Lambda KernelsMem
slug_op_renamed, SegBinOpSlug
slug) ->
Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"to reduce current chunk, first store our result in memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
[(VName, Param LParamMem)]
-> ((VName, Param LParamMem) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LParamMem] -> [(VName, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug) (SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug)) (((VName, Param LParamMem) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, Param LParamMem) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LParamMem
p) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
arr [TPrimExp Int64 ExpLeaf
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants)) Lambda KernelsMem
slug_op_renamed (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread saves the result in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TPrimExp Int64 ExpLeaf
local_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [Param LParamMem]
-> [((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug) (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
slug_op_renamed)) ((((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)
-> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), Param LParamMem
p) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
let comm :: Commutativity
comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs
(TPrimExp Int64 ExpLeaf
bound, InKernelGen () -> InKernelGen ()
check_bounds) =
case Commutativity
comm of
Commutativity
Commutative -> (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
chunk_size, InKernelGen () -> InKernelGen ()
forall a. a -> a
id)
Commutativity
Noncommutative ->
( Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread,
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
gtid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
num_elements)
)
String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ())
-> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
bound ((TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
TV Int64
gtid
TV Int64 -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- case Commutativity
comm of
Commutativity
Commutative ->
TPrimExp Int64 ExpLeaf
global_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 VName
threads_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
i
Commutativity
Noncommutative ->
let index_in_segment :: TPrimExp Int64 ExpLeaf
index_in_segment = TPrimExp Int64 ExpLeaf
global_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
in TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
local_tid
TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ (TPrimExp Int64 ExpLeaf
index_in_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
i)
TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
InKernelGen () -> InKernelGen ()
check_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 ExpLeaf])]
all_red_res -> do
let slugs_res :: [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
slugs_res = [Int]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
-> [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOpSlug -> Int) -> [SegBinOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOpSlug -> [SubExp]) -> SegBinOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TPrimExp Int64 ExpLeaf])]
all_red_res
[(SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])]
-> ((SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
-> [(SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
slugs_res) (((SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])
-> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res) ->
Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load new values" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res) (((Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TPrimExp Int64 ExpLeaf]
res_is)) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TPrimExp Int64 ExpLeaf]
res_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply reduction operator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body KernelsMem -> Stms KernelsMem)
-> Body KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"store in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
( [(VName, [TPrimExp Int64 ExpLeaf])]
-> [SubExp] -> [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
(SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)
(Body KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body KernelsMem -> [SubExp]) -> Body KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug)
)
((((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), SubExp
se) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) SubExp
se []
case Commutativity
comm of
Commutativity
Noncommutative -> do
InKernelGen ()
doTheReduction
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread keeps accumulator; others reset to neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let reset_to_neutral :: InKernelGen ()
reset_to_neutral =
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [SubExp] -> [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), SubExp
ne) ->
Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) SubExp
ne []
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sUnless (TPrimExp Int64 ExpLeaf
local_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
0) InKernelGen ()
reset_to_neutral
Commutativity
_ -> () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
([Lambda KernelsMem], InKernelGen ())
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lambda KernelsMem]
slugs_op_renamed, InKernelGen ()
doTheReduction)
reductionStageOne ::
KernelConstants ->
[(VName, Imp.TExp Int64)] ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
Imp.TExp Int64 ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
VName ->
[SegBinOpSlug] ->
DoSegBody ->
InKernelGen [Lambda KernelsMem]
reductionStageOne :: KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne KernelConstants
constants [(VName, TPrimExp Int64 ExpLeaf)]
ispace Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TPrimExp Int64 ExpLeaf
global_tid Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
([Lambda KernelsMem]
slugs_op_renamed, InKernelGen ()
doTheReduction) <-
KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TPrimExp Int64 ExpLeaf)]
ispace Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TPrimExp Int64 ExpLeaf
global_tid Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body
case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
Commutativity
Noncommutative ->
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) [TPrimExp Int64 ExpLeaf]
acc_is
Commutativity
_ -> InKernelGen ()
doTheReduction
[Lambda KernelsMem] -> InKernelGen [Lambda KernelsMem]
forall (m :: * -> *) a. Monad m => a -> m a
return [Lambda KernelsMem]
slugs_op_renamed
reductionStageTwo ::
KernelConstants ->
[PatElem KernelsMem] ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
[Imp.TExp Int64] ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
SegBinOpSlug ->
[LParam KernelsMem] ->
[LParam KernelsMem] ->
Lambda KernelsMem ->
[SubExp] ->
Imp.TExp Int32 ->
VName ->
Imp.TExp Int32 ->
VName ->
[VName] ->
[VName] ->
InKernelGen ()
reductionStageTwo :: KernelConstants
-> [PatElem KernelsMem]
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
KernelConstants
constants
[PatElem KernelsMem]
segred_pes
TExp Int32
group_id
TExp Int32
flat_segment_id
[TPrimExp Int64 ExpLeaf]
segment_gtids
TPrimExp Int64 ExpLeaf
first_group_for_segment
TPrimExp Int64 ExpLeaf
groups_per_segment
SegBinOpSlug
slug
[LParam KernelsMem]
red_x_params
[LParam KernelsMem]
red_y_params
Lambda KernelsMem
red_op_renamed
[SubExp]
nes
TExp Int32
num_counters
VName
counter
TExp Int32
counter_i
VName
sync_arr
[VName]
group_res_arrs
[VName]
red_arrs = do
let local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
group_size :: TPrimExp Int64 ExpLeaf
group_size = KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
TV Int64
old_counter <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old_counter" PrimType
int32
(VName
counter_mem, Space
_, Count Elements (TPrimExp Int64 ExpLeaf)
counter_offset) <-
VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
lore r op (VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
fullyIndexArray
VName
counter
[ TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 ExpLeaf)
-> TExp Int32 -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$
TExp Int32
counter_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
num_counters
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
flat_segment_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_counters
]
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves group result to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((VName, (VName, [TPrimExp Int64 ExpLeaf])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Int
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))])
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((VName, (VName, [TPrimExp Int64 ExpLeaf])) -> InKernelGen ())
-> InKernelGen ())
-> ((VName, (VName, [TPrimExp Int64 ExpLeaf])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
v [TPrimExp Int64 ExpLeaf
0, TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id] (VName -> SubExp
Var VName
acc) [TPrimExp Int64 ExpLeaf]
acc_is
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Exp
-> AtomicOp
Imp.AtomicAdd
IntType
Int32
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter)
VName
counter_mem
Count Elements (TPrimExp Int64 ExpLeaf)
counter_offset
(Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32)
VName -> [TPrimExp Int64 ExpLeaf] -> Exp -> InKernelGen ()
forall lore r op.
VName -> [TPrimExp Int64 ExpLeaf] -> Exp -> ImpM lore r op ()
sWrite VName
sync_arr [TPrimExp Int64 ExpLeaf
0] (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool ExpLeaf -> Exp) -> TPrimExp Bool ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
old_counter TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
TV Bool
is_last_group <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Bool)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"is_last_group" PrimType
Bool
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (TV Bool -> VName
forall t. TV t -> VName
tvVar TV Bool
is_last_group) [] (VName -> SubExp
Var VName
sync_arr) [TPrimExp Int64 ExpLeaf
0]
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Bool -> TPrimExp Bool ExpLeaf
forall t. TV t -> TExp t
tvExp TV Bool
is_last_group) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int32 (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter) VName
counter_mem Count Elements (TPrimExp Int64 ExpLeaf)
counter_offset (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a
negate TPrimExp Int64 ExpLeaf
groups_per_segment
Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read in the per-group-results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Int64 ExpLeaf
read_per_thread <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"read_per_thread" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
group_size
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LParamMem]
red_x_params [SubExp]
nes) (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ())
-> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
read_per_thread ((TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
TPrimExp Int64 ExpLeaf
group_res_id <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"group_res_id" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
read_per_thread TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
i
TPrimExp Int64 ExpLeaf
index_of_group_res <-
String
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"index_of_group_res" (TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 ExpLeaf
first_group_for_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
group_res_id
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TPrimExp Int64 ExpLeaf
group_res_id TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
groups_per_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LParamMem]
red_y_params [VName]
group_res_arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(Param LParamMem
p, VName
group_res_arr) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var VName
group_res_arr)
([TPrimExp Int64 ExpLeaf
0, TPrimExp Int64 ExpLeaf
index_of_group_res] [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body KernelsMem -> Stms KernelsMem)
-> Body KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LParamMem]
red_x_params (Body KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body KernelsMem -> [SubExp]) -> Body KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug)) (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
se) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se []
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LParamMem]
red_x_params [VName]
red_arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"reduce the per-group results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 ExpLeaf
group_size) Lambda KernelsMem
red_op_renamed [VName]
red_arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"and back to memory with the final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LParamMem, Param LParamMem)]
-> ((PatElemT LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [Param LParamMem] -> [(PatElemT LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LParamMem]
segred_pes ([Param LParamMem] -> [(PatElemT LParamMem, Param LParamMem)])
-> [Param LParamMem] -> [(PatElemT LParamMem, Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op_renamed) (((PatElemT LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, Param LParamMem
p) ->
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall lore r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
([TPrimExp Int64 ExpLeaf]
segment_gtids [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]