{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegRed
( compileSegRed,
compileSegRed',
DoSegBody,
)
where
import Control.Monad.Except
import Data.List (genericLength, zip7)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Error
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
maxNumOps :: Int32
maxNumOps :: Int32
maxNumOps = Int32
10
type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen ()
compileSegRed ::
Pat GPUMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegRed :: Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegRed Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds KernelBody GPUMem
body =
Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ->
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let map_arrs :: [PatElemT LetDecMem]
map_arrs = Int -> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds) ([PatElemT LetDecMem] -> [PatElemT LetDecMem])
-> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a b. (a -> b) -> a -> b
$ PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat GPUMem
PatT LetDecMem
pat
(PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem GPUMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LetDecMem]
map_arrs [KernelResult]
map_res
[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ([(SubExp, [TExp Int64])] -> InKernelGen ())
-> [(SubExp, [TExp Int64])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[TExp Int64]] -> [(SubExp, [TExp Int64])])
-> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [[TExp Int64]]
forall a. a -> [a]
repeat []
compileSegRed' ::
Pat GPUMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
compileSegRed' :: Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body
| [SegBinOp GPUMem] -> Int32
forall i a. Num i => [a] -> i
genericLength [SegBinOp GPUMem]
reds Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
maxNumOps =
String -> CallKernelGen ()
forall a. String -> a
compilerLimitationS (String -> CallKernelGen ()) -> String -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
String
"compileSegRed': at most " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int32 -> String
forall a. Show a => a -> String
show Int32
maxNumOps String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
| [(VName
_, Constant (IntValue (Int64Value Int64
1))), (VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pat GPUMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body
| Bool
otherwise = do
let group_size' :: TExp Int64
group_size' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
segment_size :: TExp Int64
segment_size = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
use_small_segments :: TPrimExp Bool VName
use_small_segments = TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
2 TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
group_size'
TPrimExp Bool VName
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
TPrimExp Bool VName
use_small_segments
(Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction Pat GPUMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body)
(Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pat GPUMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body)
where
num_groups :: Count NumGroups SubExp
num_groups = SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
intermediateArrays ::
Count GroupSize SubExp ->
SubExp ->
SegBinOp GPUMem ->
InKernelGen [VName]
intermediateArrays :: Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays (Count SubExp
group_size) SubExp
num_threads (SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_) = do
let red_op_params :: [LParam GPUMem]
red_op_params = Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
([Param LetDecMem]
red_acc_params, [Param LetDecMem]
_) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [LParam GPUMem]
[Param LetDecMem]
red_op_params
[Param LetDecMem]
-> (Param LetDecMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LetDecMem]
red_acc_params ((Param LetDecMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [VName])
-> (Param LetDecMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param LetDecMem
p ->
case Param LetDecMem -> LetDecMem
forall dec. Param dec -> dec
paramDec Param LetDecMem
p of
MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
String
-> PrimType
-> Shape
-> VName
-> IxFun
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray String
"red_arr" PrimType
pt Shape
shape' VName
mem (IxFun -> ImpM GPUMem KernelEnv KernelOp VName)
-> IxFun -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TExp Int64] -> IxFun) -> [TExp Int64] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
LetDecMem
_ -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"red_arr" PrimType
pt Shape
shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
groupResultArrays ::
Count NumGroups SubExp ->
Count GroupSize SubExp ->
[SegBinOp GPUMem] ->
CallKernelGen [[VName]]
groupResultArrays :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays (Count SubExp
virt_num_groups) (Count SubExp
group_size) [SegBinOp GPUMem]
reds =
[SegBinOp GPUMem]
-> (SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp GPUMem]
reds ((SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]])
-> (SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
lam [SubExp]
_ Shape
shape) ->
[TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
-> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam) ((TypeBase Shape NoUniqueness -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName])
-> (TypeBase Shape NoUniqueness
-> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
t -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
extra_dim :: SubExp
extra_dim
| TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t, Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
| Bool
otherwise = SubExp
group_size
full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
extra_dim, SubExp
virt_num_groups] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
perm :: [Int]
perm = [Int
1 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]
String
-> PrimType
-> Shape
-> Space
-> [Int]
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm String
"segred_tmp" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm
nonsegmentedReduction ::
Pat GPUMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
nonsegmentedReduction :: Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pat GPUMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
global_tid :: TExp Int64
global_tid = VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TExp Int64) -> VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
w :: TExp Int64
w = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'
VName
counter <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
[PrimValue] -> ArrayContents
Imp.ArrayValues ([PrimValue] -> ArrayContents) -> [PrimValue] -> ArrayContents
forall a b. (a -> b) -> a -> b
$ Int -> PrimValue -> [PrimValue]
forall a. Int -> a -> [a]
replicate (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps) (PrimValue -> [PrimValue]) -> PrimValue -> [PrimValue]
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
0
[[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size [SegBinOp GPUMem]
reds
TV Int64
num_threads <-
String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" Maybe Exp
forall a. Maybe a
Nothing
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_nonseg" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
[[VName]]
reds_arrs <- (SegBinOp GPUMem -> InKernelGen [VName])
-> [SegBinOp GPUMem] -> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp GPUMem]
reds
[VName] -> (VName -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> TExp Int64 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int64
0 :: Imp.TExp Int64)
let num_elements :: Count Elements (TExp Int64)
num_elements = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
w
elems_per_thread :: Count Elements (TExp Int64)
elems_per_thread =
Count Elements (TExp Int64)
num_elements
Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelNumThreads KernelConstants
constants))
[SegBinOpSlug]
slugs <-
((SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants) (KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants)) ([(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
[SegBinOp GPUMem]
-> [[VName]] -> [[VName]] -> [(SegBinOp GPUMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
[Lambda GPUMem]
reds_op_renamed <-
KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int64
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
KernelConstants
constants
([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims')
Count Elements (TExp Int64)
num_elements
TExp Int64
global_tid
Count Elements (TExp Int64)
elems_per_thread
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)
[SegBinOpSlug]
slugs
DoSegBody
body
let segred_pes :: [[PatElemT LetDecMem]]
segred_pes =
[Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
reds) ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$
PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat GPUMem
PatT LetDecMem
segred_pat
[(SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)]
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LetDecMem]]
-> [SegBinOpSlug]
-> [Lambda GPUMem]
-> [Integer]
-> [(SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT LetDecMem]]
segred_pes [SegBinOpSlug]
slugs [Lambda GPUMem]
reds_op_renamed [Integer
0 ..]) (((SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElemT LetDecMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
red_op_renamed, Integer
i) -> do
let ([Param LetDecMem]
red_x_params, [Param LetDecMem]
red_y_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
KernelConstants
-> [PatElem GPUMem]
-> TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> [TExp Int64]
-> TExp Int64
-> TExp Int64
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TPrimExp Int32 VName
-> VName
-> TPrimExp Int32 VName
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
KernelConstants
constants
[PatElem GPUMem]
[PatElemT LetDecMem]
pes
(KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants)
TPrimExp Int32 VName
0
[TExp Int64
0]
TExp Int64
0
(TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelNumGroups KernelConstants
constants)
SegBinOpSlug
slug
[LParam GPUMem]
[Param LetDecMem]
red_x_params
[LParam GPUMem]
[Param LetDecMem]
red_y_params
Lambda GPUMem
red_op_renamed
[SubExp]
nes
TPrimExp Int32 VName
1
VName
counter
(Integer -> TPrimExp Int32 VName
forall a. Num a => Integer -> a
fromInteger Integer
i)
VName
sync_arr
[VName]
group_res_arrs
[VName]
red_arrs
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing
smallSegmentsReduction ::
Pat GPUMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
smallSegmentsReduction :: Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pat [PatElem GPUMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'
TExp Int64
segment_size_nonzero <-
String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segment_size_nonzero" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 TExp Int64
segment_size
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
TV Int64
num_threads <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
let num_segments :: TExp Int64
num_segments = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims'
segments_per_group :: TExp Int64
segments_per_group = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
segment_size_nonzero
required_groups :: TPrimExp Int32 VName
required_groups = TExp Int64 -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 VName)
-> TExp Int64 -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ TExp Int64
num_segments TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
segments_per_group
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-small" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
num_segments
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
segment_size
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
segments_per_group
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int32 VName
required_groups
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_small" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
[[VName]]
reds_arrs <- (SegBinOp GPUMem -> InKernelGen [VName])
-> [SegBinOp GPUMem] -> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)) [SegBinOp GPUMem]
reds
SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TPrimExp Int32 VName
required_groups ((TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
group_id' -> do
let ltid :: TExp Int64
ltid = TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TExp Int64)
-> TPrimExp Int32 VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
segment_index :: TExp Int64
segment_index =
(TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
segment_size_nonzero)
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ (TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
segments_per_group)
index_within_segment :: TExp Int64
index_within_segment = TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
segment_size
[(VName, TExp Int64)] -> TExp Int64 -> InKernelGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims')) TExp Int64
segment_index
VName -> TExp Int64 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) TExp Int64
index_within_segment
let out_of_bounds :: InKernelGen ()
out_of_bounds =
[(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
reds [[VName]]
reds_arrs) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
_ [SubExp]
nes Shape
_, [VName]
red_arrs) ->
[(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
red_arrs [SubExp]
nes) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
ne) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64
ltid] SubExp
ne []
in_bounds :: InKernelGen ()
in_bounds =
DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
red_res ->
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save results to be reduced" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let red_dests :: [(VName, [TExp Int64])]
red_dests = [VName] -> [[TExp Int64]] -> [(VName, [TExp Int64])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs) ([[TExp Int64]] -> [(VName, [TExp Int64])])
-> [[TExp Int64]] -> [(VName, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [[TExp Int64]]
forall a. a -> [a]
repeat [TExp Int64
ltid]
[((VName, [TExp Int64]), (SubExp, [TExp Int64]))]
-> (((VName, [TExp Int64]), (SubExp, [TExp Int64]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [(SubExp, [TExp Int64])]
-> [((VName, [TExp Int64]), (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [TExp Int64])]
red_dests [(SubExp, [TExp Int64])]
red_res) ((((VName, [TExp Int64]), (SubExp, [TExp Int64]))
-> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TExp Int64]), (SubExp, [TExp Int64]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
d, [TExp Int64]
d_is), (SubExp
res, [TExp Int64]
res_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
d [TExp Int64]
d_is SubExp
res [TExp Int64]
res_is
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"apply map function if in bounds" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
( TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int64
0
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(VName, SubExp)] -> TPrimExp Bool VName
isActive ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims)
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
ltid TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group
)
InKernelGen ()
in_bounds
InKernelGen ()
out_of_bounds
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let crossesSegment :: TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment TPrimExp Int32 VName
from TPrimExp Int32 VName
to =
(TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from) TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
segment_size)
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int64
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform segmented scan to imitate reduction" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
reds [[VName]]
reds_arrs) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
_ Shape
_, [VName]
red_arrs) ->
Maybe
(TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Bool VName)
-> TExp Int64
-> TExp Int64
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
((TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Bool VName)
-> Maybe
(TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Bool VName)
forall a. a -> Maybe a
Just TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment)
(TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_threads)
(TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group)
Lambda GPUMem
red_op
[VName]
red_arrs
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save final values of segments" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen
( TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
ltid TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
num_segments
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
ltid TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
segments_per_group
)
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem GPUMem]
[PatElemT LetDecMem]
segred_pes ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)) (((PatElemT LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
arr) -> do
let flat_segment_index :: TExp Int64
flat_segment_index =
TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
ltid
gtids' :: [TExp Int64]
gtids' =
[TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims') TExp Int64
flat_segment_index
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
[TExp Int64]
gtids'
(VName -> SubExp
Var VName
arr)
[(TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segment_size_nonzero TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing
largeSegmentsReduction ::
Pat GPUMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
largeSegmentsReduction :: Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pat GPUMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
num_segments :: TExp Int64
num_segments = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims'
segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'
num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
(TExp Int64
groups_per_segment, Count Elements (TExp Int64)
elems_per_thread) <-
TExp Int64
-> TExp Int64
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> CallKernelGen (TExp Int64, Count Elements (TExp Int64))
groupsPerSegmentAndElementsPerThread
TExp Int64
segment_size
TExp Int64
num_segments
Count NumGroups (TExp Int64)
num_groups'
Count GroupSize (TExp Int64)
group_size'
TV Int64
virt_num_groups <-
String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"virt_num_groups" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_segments
TV Int64
num_threads <-
String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
TV Int64
threads_per_segment <-
String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"threads_per_segment" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-large" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
num_segments
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
segment_size
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"virt_num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count NumGroups (TExp Int64)
num_groups'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count GroupSize (TExp Int64)
group_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_thread" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elems_per_thread
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
groups_per_segment
[[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
virt_num_groups)) Count GroupSize SubExp
group_size [SegBinOp GPUMem]
reds
let num_counters :: Int
num_counters = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024
VName
counter <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_counters
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_large" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
[[VName]]
reds_arrs <- (SegBinOp GPUMem -> InKernelGen [VName])
-> [SegBinOp GPUMem] -> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp GPUMem]
reds
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TExp Int64 -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups)) ((TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
group_id -> do
let segment_gtids :: [VName]
segment_gtids = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids
w :: SubExp
w = [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims
local_tid :: TPrimExp Int32 VName
local_tid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
TPrimExp Int32 VName
flat_segment_id <-
String
-> TPrimExp Int32 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_segment_id" (TPrimExp Int32 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 VName))
-> TPrimExp Int32 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int32 VName
group_id TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Int32 VName
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
groups_per_segment
TExp Int64
global_tid <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size') TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size') TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groups_per_segment)
let first_group_for_segment :: TExp Int64
first_group_for_segment = TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
flat_segment_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groups_per_segment
[(VName, TExp Int64)] -> TExp Int64 -> InKernelGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
segment_gtids ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims')) (TExp Int64 -> InKernelGen ()) -> TExp Int64 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
flat_segment_id
VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
let num_elements :: Count Elements (TExp Int64)
num_elements = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w
[SegBinOpSlug]
slugs <-
((SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
local_tid TPrimExp Int32 VName
group_id) ([(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
[SegBinOp GPUMem]
-> [[VName]] -> [[VName]] -> [(SegBinOp GPUMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
[Lambda GPUMem]
reds_op_renamed <-
KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int64
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
KernelConstants
constants
([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims')
Count Elements (TExp Int64)
num_elements
TExp Int64
global_tid
Count Elements (TExp Int64)
elems_per_thread
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
threads_per_segment)
[SegBinOpSlug]
slugs
DoSegBody
body
let segred_pes :: [[PatElemT LetDecMem]]
segred_pes =
[Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
reds) ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$
PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat GPUMem
PatT LetDecMem
segred_pat
multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
[(SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)]
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LetDecMem]]
-> [SegBinOpSlug]
-> [Lambda GPUMem]
-> [Integer]
-> [(SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT LetDecMem]]
segred_pes [SegBinOpSlug]
slugs [Lambda GPUMem]
reds_op_renamed [Integer
0 ..]) (((SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LetDecMem],
SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElemT LetDecMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
red_op_renamed, Integer
i) -> do
let ([Param LetDecMem]
red_x_params, [Param LetDecMem]
red_y_params) =
Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
KernelConstants
-> [PatElem GPUMem]
-> TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> [TExp Int64]
-> TExp Int64
-> TExp Int64
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TPrimExp Int32 VName
-> VName
-> TPrimExp Int32 VName
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
KernelConstants
constants
[PatElem GPUMem]
[PatElemT LetDecMem]
pes
TPrimExp Int32 VName
group_id
TPrimExp Int32 VName
flat_segment_id
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids)
(TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
first_group_for_segment)
TExp Int64
groups_per_segment
SegBinOpSlug
slug
[LParam GPUMem]
[Param LetDecMem]
red_x_params
[LParam GPUMem]
[Param LetDecMem]
red_y_params
Lambda GPUMem
red_op_renamed
[SubExp]
nes
(Int -> TPrimExp Int32 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters)
VName
counter
(Integer -> TPrimExp Int32 VName
forall a. Num a => Integer -> a
fromInteger Integer
i)
VName
sync_arr
[VName]
group_res_arrs
[VName]
red_arrs
one_group_per_segment :: InKernelGen ()
one_group_per_segment =
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"first thread in group saves final result to memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOpSlug, [PatElemT LetDecMem])]
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[PatElemT LetDecMem]] -> [(SegBinOpSlug, [PatElemT LetDecMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElemT LetDecMem]]
segred_pes) (((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElemT LetDecMem]
pes) ->
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, (VName, [TExp Int64]))]
-> ((PatElemT LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [(VName, [TExp Int64])]
-> [(PatElemT LetDecMem, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
pes (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((PatElemT LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
v, (VName
acc, [TExp Int64]
acc_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
v) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [TExp Int64]
acc_is
TPrimExp Bool VName
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing
groupsPerSegmentAndElementsPerThread ::
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
CallKernelGen
( Imp.TExp Int64,
Imp.Count Imp.Elements (Imp.TExp Int64)
)
groupsPerSegmentAndElementsPerThread :: TExp Int64
-> TExp Int64
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> CallKernelGen (TExp Int64, Count Elements (TExp Int64))
groupsPerSegmentAndElementsPerThread TExp Int64
segment_size TExp Int64
num_segments Count NumGroups (TExp Int64)
num_groups_hint Count GroupSize (TExp Int64)
group_size = do
TExp Int64
groups_per_segment <-
String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"groups_per_segment" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups_hint TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 TExp Int64
num_segments
TExp Int64
elements_per_thread <-
String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"elements_per_thread" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groups_per_segment)
(TExp Int64, Count Elements (TExp Int64))
-> CallKernelGen (TExp Int64, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64
groups_per_segment, TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
elements_per_thread)
data SegBinOpSlug = SegBinOpSlug
{ SegBinOpSlug -> SegBinOp GPUMem
slugOp :: SegBinOp GPUMem,
SegBinOpSlug -> [VName]
slugArrs :: [VName],
SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs :: [(VName, [Imp.TExp Int64])]
}
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody = Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody (Lambda GPUMem -> Body GPUMem)
-> (SegBinOpSlug -> Lambda GPUMem) -> SegBinOpSlug -> Body GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda (SegBinOp GPUMem -> Lambda GPUMem)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Lambda GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams = Lambda GPUMem -> [Param LetDecMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [Param LetDecMem])
-> (SegBinOpSlug -> Lambda GPUMem)
-> SegBinOpSlug
-> [Param LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda (SegBinOp GPUMem -> Lambda GPUMem)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Lambda GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral (SegBinOp GPUMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp GPUMem) -> SegBinOpSlug -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape (SegBinOp GPUMem -> Shape)
-> (SegBinOpSlug -> SegBinOp GPUMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOpSlug] -> [Commutativity])
-> [SegBinOpSlug]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOpSlug -> Commutativity)
-> [SegBinOpSlug] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map (SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm (SegBinOp GPUMem -> Commutativity)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp)
accParams, nextParams :: SegBinOpSlug -> [LParam GPUMem]
accParams :: SegBinOpSlug -> [LParam GPUMem]
accParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam GPUMem]
nextParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
segBinOpSlug :: Imp.TExp Int32 -> Imp.TExp Int32 -> (SegBinOp GPUMem, [VName], [VName]) -> InKernelGen SegBinOpSlug
segBinOpSlug :: TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
local_tid TPrimExp Int32 VName
group_id (SegBinOp GPUMem
op, [VName]
group_res_arrs, [VName]
param_arrs) =
SegBinOp GPUMem
-> [VName] -> [(VName, [TExp Int64])] -> SegBinOpSlug
SegBinOpSlug SegBinOp GPUMem
op [VName]
group_res_arrs
([(VName, [TExp Int64])] -> SegBinOpSlug)
-> ImpM GPUMem KernelEnv KernelOp [(VName, [TExp Int64])]
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param LetDecMem
-> VName -> ImpM GPUMem KernelEnv KernelOp (VName, [TExp Int64]))
-> [Param LetDecMem]
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp [(VName, [TExp Int64])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LetDecMem
-> VName -> ImpM GPUMem KernelEnv KernelOp (VName, [TExp Int64])
mkAcc (Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)) [VName]
param_arrs
where
mkAcc :: Param LetDecMem
-> VName -> ImpM GPUMem KernelEnv KernelOp (VName, [TExp Int64])
mkAcc Param LetDecMem
p VName
param_arr
| Prim PrimType
t <- Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p,
Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
TV Any
acc <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim (VName -> String
baseString (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_acc") PrimType
t
(VName, [TExp Int64])
-> ImpM GPUMem KernelEnv KernelOp (VName, [TExp Int64])
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc, [])
| Bool
otherwise =
(VName, [TExp Int64])
-> ImpM GPUMem KernelEnv KernelOp (VName, [TExp Int64])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
param_arr, [TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid, TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id])
reductionStageZero ::
KernelConstants ->
[(VName, Imp.TExp Int64)] ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
Imp.TExp Int64 ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
VName ->
[SegBinOpSlug] ->
DoSegBody ->
InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero :: KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int64
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TExp Int64)]
ispace Count Elements (TExp Int64)
num_elements TExp Int64
global_tid Count Elements (TExp Int64)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
let ([VName]
gtids, [TExp Int64]
_dims) = [(VName, TExp Int64)] -> ([VName], [TExp Int64])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, TExp Int64)]
ispace
gtid :: TV Int64
gtid = VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
mkTV ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
local_tid :: TExp Int64
local_tid = TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TExp Int64)
-> TPrimExp Int32 VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
TV Int64
chunk_size <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"chunk_size" PrimType
int64
let ordering :: SplitOrdering
ordering = case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
Commutativity
Commutative -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
threads_per_segment
Commutativity
Noncommutative -> SplitOrdering
SplitContiguous
SplitOrdering
-> TExp Int64
-> Count Elements (TExp Int64)
-> Count Elements (TExp Int64)
-> TV Int64
-> InKernelGen ()
forall rep r op.
SplitOrdering
-> TExp Int64
-> Count Elements (TExp Int64)
-> Count Elements (TExp Int64)
-> TV Int64
-> ImpM rep r op ()
computeThreadChunkSize SplitOrdering
ordering (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
global_tid) Count Elements (TExp Int64)
elems_per_thread Count Elements (TExp Int64)
num_elements TV Int64
chunk_size
Maybe (Exp GPUMem) -> Scope GPUMem -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> InKernelGen ()) -> Scope GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LetDecMem] -> Scope GPUMem)
-> [Param LetDecMem] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LetDecMem])
-> [SegBinOpSlug] -> [Param LetDecMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam GPUMem]
SegBinOpSlug -> [Param LetDecMem]
slugParams [SegBinOpSlug]
slugs
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"neutral-initialise the accumulators" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[((VName, [TExp Int64]), SubExp)]
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [SubExp] -> [((VName, [TExp Int64]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), SubExp
ne) ->
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []
[Lambda GPUMem]
slugs_op_renamed <- (SegBinOpSlug -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem))
-> [SegBinOpSlug] -> InKernelGen [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem))
-> (SegBinOpSlug -> Lambda GPUMem)
-> SegBinOpSlug
-> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda (SegBinOp GPUMem -> Lambda GPUMem)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Lambda GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp) [SegBinOpSlug]
slugs
let doTheReduction :: InKernelGen ()
doTheReduction =
[(Lambda GPUMem, SegBinOpSlug)]
-> ((Lambda GPUMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Lambda GPUMem]
-> [SegBinOpSlug] -> [(Lambda GPUMem, SegBinOpSlug)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Lambda GPUMem]
slugs_op_renamed [SegBinOpSlug]
slugs) (((Lambda GPUMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ())
-> ((Lambda GPUMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Lambda GPUMem
slug_op_renamed, SegBinOpSlug
slug) ->
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"to reduce current chunk, first store our result in memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(Param LetDecMem, (VName, [TExp Int64]))]
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [TExp Int64])]
-> [(Param LetDecMem, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [TExp Int64]
acc_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
[(VName, Param LetDecMem)]
-> ((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LetDecMem] -> [(VName, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug) (SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug)) (((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LetDecMem
p) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TExp Int64 -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)) Lambda GPUMem
slug_op_renamed (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"first thread saves the result in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64
local_tid TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[((VName, [TExp Int64]), Param LetDecMem)]
-> (((VName, [TExp Int64]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [Param LetDecMem] -> [((VName, [TExp Int64]), Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug) (Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
slug_op_renamed)) ((((VName, [TExp Int64]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TExp Int64]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), Param LetDecMem
p) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
let comm :: Commutativity
comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs
(TExp Int64
bound, InKernelGen () -> InKernelGen ()
check_bounds) =
case Commutativity
comm of
Commutativity
Commutative -> (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_size, InKernelGen () -> InKernelGen ()
forall a. a -> a
id)
Commutativity
Noncommutative ->
( Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elems_per_thread,
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
gtid TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
num_elements)
)
String
-> TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
bound ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TV Int64
gtid
TV Int64 -> TExp Int64 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- case Commutativity
comm of
Commutativity
Commutative ->
TExp Int64
global_tid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
threads_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
i
Commutativity
Noncommutative ->
let index_in_segment :: TExp Int64
index_in_segment = TExp Int64
global_tid TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
in TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
local_tid
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ (TExp Int64
index_in_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elems_per_thread TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i)
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
InKernelGen () -> InKernelGen ()
check_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"apply map function" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
all_red_res -> do
let slugs_res :: [[(SubExp, [TExp Int64])]]
slugs_res = [Int] -> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOpSlug -> Int) -> [SegBinOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOpSlug -> [SubExp]) -> SegBinOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TExp Int64])]
all_red_res
[(SegBinOpSlug, [(SubExp, [TExp Int64])])]
-> ((SegBinOpSlug, [(SubExp, [TExp Int64])]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[(SubExp, [TExp Int64])]]
-> [(SegBinOpSlug, [(SubExp, [TExp Int64])])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[(SubExp, [TExp Int64])]]
slugs_res) (((SegBinOpSlug, [(SubExp, [TExp Int64])]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOpSlug, [(SubExp, [TExp Int64])]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [(SubExp, [TExp Int64])]
red_res) ->
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"load accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, (VName, [TExp Int64]))]
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [TExp Int64])]
-> [(Param LetDecMem, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [TExp Int64]
acc_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"load new values" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, (SubExp, [TExp Int64]))]
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(SubExp, [TExp Int64])]
-> [(Param LetDecMem, (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TExp Int64])]
red_res) (((Param LetDecMem, (SubExp, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
res ([TExp Int64]
res_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"apply reduction operator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"store in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[((VName, [TExp Int64]), SubExp)]
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
( [(VName, [TExp Int64])]
-> [SubExp] -> [((VName, [TExp Int64]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
(SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)
((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
)
((((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), SubExp
se) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []
case Commutativity
comm of
Commutativity
Noncommutative -> do
InKernelGen ()
doTheReduction
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"first thread keeps accumulator; others reset to neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let reset_to_neutral :: InKernelGen ()
reset_to_neutral =
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[((VName, [TExp Int64]), SubExp)]
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [SubExp] -> [((VName, [TExp Int64]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ())
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), SubExp
ne) ->
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TExp Int64
local_tid TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) InKernelGen ()
reset_to_neutral
Commutativity
_ -> () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
([Lambda GPUMem], InKernelGen ())
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lambda GPUMem]
slugs_op_renamed, InKernelGen ()
doTheReduction)
reductionStageOne ::
KernelConstants ->
[(VName, Imp.TExp Int64)] ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
Imp.TExp Int64 ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
VName ->
[SegBinOpSlug] ->
DoSegBody ->
InKernelGen [Lambda GPUMem]
reductionStageOne :: KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int64
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne KernelConstants
constants [(VName, TExp Int64)]
ispace Count Elements (TExp Int64)
num_elements TExp Int64
global_tid Count Elements (TExp Int64)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
([Lambda GPUMem]
slugs_op_renamed, InKernelGen ()
doTheReduction) <-
KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int64
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TExp Int64)]
ispace Count Elements (TExp Int64)
num_elements TExp Int64
global_tid Count Elements (TExp Int64)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body
case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
Commutativity
Noncommutative -> () -> InKernelGen ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Commutativity
Commutative -> InKernelGen ()
doTheReduction
[Lambda GPUMem] -> InKernelGen [Lambda GPUMem]
forall (m :: * -> *) a. Monad m => a -> m a
return [Lambda GPUMem]
slugs_op_renamed
reductionStageTwo ::
KernelConstants ->
[PatElem GPUMem] ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
[Imp.TExp Int64] ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
SegBinOpSlug ->
[LParam GPUMem] ->
[LParam GPUMem] ->
Lambda GPUMem ->
[SubExp] ->
Imp.TExp Int32 ->
VName ->
Imp.TExp Int32 ->
VName ->
[VName] ->
[VName] ->
InKernelGen ()
reductionStageTwo :: KernelConstants
-> [PatElem GPUMem]
-> TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> [TExp Int64]
-> TExp Int64
-> TExp Int64
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TPrimExp Int32 VName
-> VName
-> TPrimExp Int32 VName
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
KernelConstants
constants
[PatElem GPUMem]
segred_pes
TPrimExp Int32 VName
group_id
TPrimExp Int32 VName
flat_segment_id
[TExp Int64]
segment_gtids
TExp Int64
first_group_for_segment
TExp Int64
groups_per_segment
SegBinOpSlug
slug
[LParam GPUMem]
red_x_params
[LParam GPUMem]
red_y_params
Lambda GPUMem
red_op_renamed
[SubExp]
nes
TPrimExp Int32 VName
num_counters
VName
counter
TPrimExp Int32 VName
counter_i
VName
sync_arr
[VName]
group_res_arrs
[VName]
red_arrs = do
let local_tid :: TPrimExp Int32 VName
local_tid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
group_size :: TExp Int64
group_size = KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
TV Int64
old_counter <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"old_counter" PrimType
int32
(VName
counter_mem, Space
_, Count Elements (TExp Int64)
counter_offset) <-
VName
-> [TExp Int64]
-> ImpM
GPUMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray
VName
counter
[ TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TExp Int64)
-> TPrimExp Int32 VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$
TPrimExp Int32 VName
counter_i TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Int32 VName
forall a. Num a => a -> a -> a
* TPrimExp Int32 VName
num_counters
TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Int32 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName
flat_segment_id TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Int32 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int32 VName
num_counters
]
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"first thread in group saves group result to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, (VName, [TExp Int64]))]
-> ((VName, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Int
-> [(VName, (VName, [TExp Int64]))]
-> [(VName, (VName, [TExp Int64]))]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([(VName, (VName, [TExp Int64]))]
-> [(VName, (VName, [TExp Int64]))])
-> [(VName, (VName, [TExp Int64]))]
-> [(VName, (VName, [TExp Int64]))]
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(VName, [TExp Int64])] -> [(VName, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((VName, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ())
-> ((VName, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [TExp Int64]
acc_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
v [TExp Int64
0, TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id] (VName -> SubExp
Var VName
acc) [TExp Int64]
acc_is
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd
IntType
Int32
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter)
VName
counter_mem
Count Elements (TExp Int64)
counter_offset
(Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName
1 :: Imp.TExp Int32)
VName -> [TExp Int64] -> Exp -> InKernelGen ()
forall rep r op. VName -> [TExp Int64] -> Exp -> ImpM rep r op ()
sWrite VName
sync_arr [TExp Int64
0] (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> Exp) -> TPrimExp Bool VName -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
old_counter TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
TV Bool
is_last_group <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Bool)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"is_last_group" PrimType
Bool
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Bool -> VName
forall t. TV t -> VName
tvVar TV Bool
is_last_group) [] (VName -> SubExp
Var VName
sync_arr) [TExp Int64
0]
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Bool -> TPrimExp Bool VName
forall t. TV t -> TExp t
tvExp TV Bool
is_last_group) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32 (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter) VName
counter_mem Count Elements (TExp Int64)
counter_offset (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall a. Num a => a -> a
negate TExp Int64
groups_per_segment
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"read in the per-group-results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TExp Int64
read_per_thread <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"read_per_thread" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
group_size
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LetDecMem]
red_x_params [SubExp]
nes) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
ne) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
ne []
String
-> TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
read_per_thread ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
group_res_id <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_res_id" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
read_per_thread TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
TExp Int64
index_of_group_res <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"index_of_group_res" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
first_group_for_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
group_res_id
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64
group_res_id TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
groups_per_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LetDecMem]
red_y_params [VName]
group_res_arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(Param LetDecMem
p, VName
group_res_arr) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p)
[]
(VName -> SubExp
Var VName
group_res_arr)
([TExp Int64
0, TExp Int64
index_of_group_res] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LetDecMem]
red_x_params ([SubExp] -> [(Param LetDecMem, SubExp)])
-> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
se) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
se []
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LetDecMem]
red_x_params [VName]
red_arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"reduce the per-group results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TExp Int64 -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
group_size) Lambda GPUMem
red_op_renamed [VName]
red_arrs
String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"and back to memory with the final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, Param LetDecMem)]
-> ((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem GPUMem]
[PatElemT LetDecMem]
segred_pes ([Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)])
-> [Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op_renamed) (((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, Param LetDecMem
p) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
([TExp Int64]
segment_gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p)
[]