{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegHist (compileSegHist) where
import Control.Monad.Except
import Data.List (foldl', genericLength, zip5)
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
data SubhistosInfo = SubhistosInfo
{ SubhistosInfo -> VName
subhistosArray :: VName,
SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
}
data SegHistSlug = SegHistSlug
{ SegHistSlug -> HistOp GPUMem
slugOp :: HistOp GPUMem,
SegHistSlug -> TV Int64
slugNumSubhistos :: TV Int64,
SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo],
SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate :: AtomicUpdate GPUMem KernelEnv
}
histSpaceUsage ::
HistOp GPUMem ->
Imp.Count Imp.Bytes (Imp.TExp Int64)
histSpaceUsage :: HistOp GPUMem -> Count Bytes (TExp Int64)
histSpaceUsage HistOp GPUMem
op =
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` (forall {k} (rep :: k). HistOp rep -> Shape
histShape HistOp GPUMem
op forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). HistOp rep -> Shape
histOpShape HistOp GPUMem
op))) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
histSize :: HistOp GPUMem -> Imp.TExp Int64
histSize :: HistOp GPUMem -> TExp Int64
histSize = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> Shape
histShape
histRank :: HistOp GPUMem -> Int
histRank :: HistOp GPUMem -> Int
histRank = forall a. ArrayShape a => a -> Int
shapeRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> Shape
histShape
computeHistoUsage ::
SegSpace ->
HistOp GPUMem ->
CallKernelGen
( Imp.Count Imp.Bytes (Imp.TExp Int64),
Imp.Count Imp.Bytes (Imp.TExp Int64),
SegHistSlug
)
computeHistoUsage :: SegSpace
-> HistOp GPUMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space HistOp GPUMem
op = do
let segment_dims :: [(VName, SubExp)]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_segments :: Int
num_segments = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims
TV Int64
num_subhistos <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"num_subhistos" PrimType
int32
[SubhistosInfo]
subhisto_infos <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). HistOp rep -> [VName]
histDest HistOp GPUMem
op) (forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op)) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
Type
dest_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
dest
MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest
VName
subhistos_mem <-
forall {k} (rep :: k) r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem (VName -> [Char]
baseString VName
dest forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos_mem") ([Char] -> Space
Space [Char]
"device")
let subhistos_shape :: Shape
subhistos_shape =
forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims forall a. [a] -> [a] -> [a]
++ [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_subhistos])
forall a. Semigroup a => a -> a -> a
<> forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
num_segments (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
VName
subhistos <-
forall {k} (rep :: k) r op.
[Char]
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray
(VName -> [Char]
baseString VName
dest forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos")
(forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
Shape
subhistos_shape
VName
subhistos_mem
forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64
forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos forall a b. (a -> b) -> a -> b
$ do
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) forall a b. (a -> b) -> a -> b
$
[Char] -> Space
Space [Char]
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = do
let num_elems :: TExp Int64
num_elems = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
subhistos_mem_size :: Count Bytes (TExp Int64)
subhistos_mem_size =
forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$
forall {k} (u :: k) e. Count u e -> e
Imp.unCount (forall a. a -> Count Elements a
Imp.elements TExp Int64
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`Imp.withElemType` forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
forall {k} (rep :: k) r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
subhistos_mem Count Bytes (TExp Int64)
subhistos_mem_size forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
Type
subhistos_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
subhistos
let slice :: Slice (TExp Int64)
slice =
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TExp Int64
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix TExp Int64
0]
forall {k} (rep :: k) r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate VName
subhistos Slice (TExp Int64)
slice forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_subhistos forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
let h :: Count Bytes (TExp Int64)
h = HistOp GPUMem -> Count Bytes (TExp Int64)
histSpaceUsage HistOp GPUMem
op
segmented_h :: Count Bytes (TExp Int64)
segmented_h = Count Bytes (TExp Int64)
h forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> Count Bytes a
Imp.bytes forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Count Bytes (TExp Int64)
h,
Count Bytes (TExp Int64)
segmented_h,
HistOp GPUMem
-> TV Int64
-> [SubhistosInfo]
-> AtomicUpdate GPUMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_infos forall a b. (a -> b) -> a -> b
$
AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
)
prepareAtomicUpdateGlobal ::
Maybe Locking ->
[VName] ->
SegHistSlug ->
CallKernelGen
( Maybe Locking,
[Imp.TExp Int64] -> InKernelGen ()
)
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
case (Maybe Locking
l, SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
let num_locks :: Int
num_locks = Int
100151
dims :: [TExp Int64]
dims =
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$
forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
forall a. [a] -> [a] -> [a]
++ [forall {k} (t :: k). TV t -> SubExp
tvSize (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug)]
forall a. [a] -> [a] -> [a]
++ forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
VName
locks <-
forall {k} (rep :: k) r op.
[Char] -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
"hist_locks" ([Char] -> Space
Space [Char]
"device") PrimType
int32 forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall e. IntegralExp e => e -> e -> e
`rem` forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
Ord)
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody
| forall a. Monoid a => a
mempty forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Aliased rep => KernelBody rep -> Names
consumedInKernelBody (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody forall a. Monoid a => a
mempty KernelBody GPUMem
kbody) =
Passage
MayBeMultiPass
| Bool
otherwise =
Passage
MustBeSinglePass
prepareIntermediateArraysGlobal ::
Passage ->
Imp.TExp Int32 ->
Imp.TExp Int64 ->
[SegHistSlug] ->
CallKernelGen
( Imp.TExp Int32,
[[Imp.TExp Int64] -> InKernelGen ()]
)
prepareIntermediateArraysGlobal :: Passage
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage TExp Int32
hist_T TExp Int64
hist_N [SegHistSlug]
slugs = do
TExp Int64
hist_H <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TExp Int64
histSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs
TExp Double
hist_RF <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> SubExp
histRaceFactor forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
forall a. Fractional a => a -> a -> a
/ forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
TExp Int32
hist_el_size <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> TExp Int32
slugElAvgSize [SegHistSlug]
slugs
TExp Double
hist_C_max <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C_max" forall a b. (a -> b) -> a -> b
$
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) forall a b. (a -> b) -> a -> b
$
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_ct_min
TExp Int32
hist_M_min <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M_min" forall a b. (a -> b) -> a -> b
$
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 forall a b. (a -> b) -> a -> b
$
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C_max
let hist_L2_def :: Int64
hist_L2_def = Int64
4 forall a. Num a => a -> a -> a
* Int64
1024 forall a. Num a => a -> a -> a
* Int64
1024
TV Any
hist_L2 <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"L2_size" PrimType
int32
Maybe Name
entry <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize
(forall {k} (t :: k). TV t -> VName
tvVar TV Any
hist_L2)
(Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry forall a b. (a -> b) -> a -> b
$ [Char] -> Name
nameFromString (forall a. Pretty a => a -> [Char]
prettyString (forall {k} (t :: k). TV t -> VName
tvVar TV Any
hist_L2)))
forall a b. (a -> b) -> a -> b
$ Name -> Int64 -> SizeClass
Imp.SizeBespoke ([Char] -> Name
nameFromString [Char]
"L2_for_histogram") Int64
hist_L2_def
let hist_L2_ln_sz :: TExp Double
hist_L2_ln_sz = TExp Double
16 forall a. Num a => a -> a -> a
* TExp Double
4
TExp Double
hist_RACE_exp <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RACE_exp" forall a b. (a -> b) -> a -> b
$
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMax64 TExp Double
1 forall a b. (a -> b) -> a -> b
$
(TExp Double
hist_k_RF forall a. Num a => a -> a -> a
* TExp Double
hist_RF)
forall a. Fractional a => a -> a -> a
/ (TExp Double
hist_L2_ln_sz forall a. Fractional a => a -> a -> a
/ forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_el_size)
TV Int32
hist_S <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"hist_S" PrimType
int32
forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int64
hist_N forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
hist_H)
(TV Int32
hist_S forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32))
forall a b. (a -> b) -> a -> b
$ TV Int32
hist_S
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- case Passage
passage of
Passage
MayBeMultiPass ->
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_min forall a. Num a => a -> a -> a
* TExp Int64
hist_H forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_el_size)
forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double
hist_F_L2 forall a. Num a => a -> a -> a
* forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
hist_L2) forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
Passage
MustBeSinglePass ->
TExp Int32
1
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race expansion factor (RACE^exp)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_RACE_exp
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_S
[[TExp Int64] -> InKernelGen ()]
histograms <-
forall a b. (a, b) -> b
snd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
(TPrimExp Any VName
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
hist_L2) TExp Int32
hist_M_min (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_S) TExp Double
hist_RACE_exp)
forall a. Maybe a
Nothing
[SegHistSlug]
slugs
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms)
where
hist_k_ct_min :: TExp Double
hist_k_ct_min = TExp Double
2
hist_k_RF :: TExp Double
hist_k_RF = TExp Double
0.75
hist_F_L2 :: TExp Double
hist_F_L2 = TExp Double
0.4
r64 :: TPrimExp t v -> TPrimExp Double v
r64 = forall v. PrimExp v -> TPrimExp Double v
isF64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
slugElAvgSize :: SegHistSlug -> TExp Int32
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicLocking {} ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
1 forall a. Num a => a -> a -> a
+ forall i a. Num i => [a] -> i
genericLength (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)))
AtomicUpdate GPUMem KernelEnv
_ ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug forall e. IntegralExp e => e -> e -> e
`quot` forall i a. Num i => [a] -> i
genericLength (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op))
slugElSize :: SegHistSlug -> TExp Int32
slugElSize (SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicLocking {} ->
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` forall {k} (rep :: k). HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) forall a b. (a -> b) -> a -> b
$
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 forall a. a -> [a] -> [a]
: forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
AtomicUpdate GPUMem KernelEnv
_ ->
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` forall {k} (rep :: k). HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
onOp :: TPrimExp Any VName
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp TPrimExp Any VName
hist_L2 TExp Int32
hist_M_min TExp Int32
hist_S TExp Double
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
let SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op = SegHistSlug
slug
hist_H :: TExp Int64
hist_H = HistOp GPUMem -> TExp Int64
histSize HistOp GPUMem
op
TExp Int64
hist_H_chk <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_H forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Chunk size (H_chk)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H_chk
TExp Double
hist_k_max <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_k_max" forall a b. (a -> b) -> a -> b
$
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64
(TExp Double
hist_F_L2 forall a. Num a => a -> a -> a
* (forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Any VName
hist_L2 forall a. Fractional a => a -> a -> a
/ forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug)) forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
(forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_N)
forall a. Fractional a => a -> a -> a
/ forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T
TExp Int64
hist_u <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_u" forall a b. (a -> b) -> a -> b
$
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicPrim {} -> TExp Int64
2
AtomicUpdate GPUMem KernelEnv
_ -> TExp Int64
1
TExp Double
hist_C <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" forall a b. (a -> b) -> a -> b
$
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) forall a b. (a -> b) -> a -> b
$
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TExp Int64
hist_u forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk) forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_max
TExp Int32
hist_M <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M" forall a b. (a -> b) -> a -> b
$
case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim {} -> TExp Int32
1
AtomicUpdate GPUMem KernelEnv
_ -> forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
hist_M_min forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 forall a b. (a -> b) -> a -> b
$ forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Elements/thread in L2 cache (k_max)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_k_max
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_M
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_C
TV Int64
num_subhistos forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M
[VName]
dests <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). HistOp rep -> [VName]
histDest HistOp GPUMem
op) [SubhistosInfo]
subhisto_info) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest
VName
sub_mem <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLoc -> VName
memLocName forall a b. (a -> b) -> a -> b
$
ArrayEntry -> MemLoc
entryArrayLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) forall a b. (a -> b) -> a -> b
$
[Char] -> Space
Space [Char]
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int32
hist_M forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
(Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op')
histKernelGlobalPass ::
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
[[Imp.TExp Int64] -> InKernelGen ()] ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelGlobalPass :: [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody [[TExp Int64] -> InKernelGen ()]
histograms TExp Int32
hist_S TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
space_sizes_64 :: [TExp Int64]
space_sizes_64 = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
space_sizes
total_w_64 :: TExp Int64
total_w_64 = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
space_sizes_64
[TExp Int64]
hist_H_chks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TExp Int64
histSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs) forall a b. (a -> b) -> a -> b
$ \TExp Int64
w ->
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" forall a b. (a -> b) -> a -> b
$ TExp Int64
w forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
[Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_global" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
[TExp Int32]
subhisto_inds <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"subhisto_ind" forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
forall e. IntegralExp e => e -> e -> e
`quot` ( KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (t :: k). TV t -> TExp t
tvExp (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug))
)
let gtid :: TExp Int64
gtid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
num_threads :: TExp Int64
num_threads = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int64
gtid TExp Int64
num_threads TExp Int64
total_w_64 forall a b. (a -> b) -> a -> b
$ \TExp Int64
offset -> do
forall {k} (rep :: k) r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
space_is [TExp Int64]
space_sizes_64) TExp Int64
offset
let input_in_bounds :: TExp Bool
input_in_bounds = TExp Int64
offset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
total_w_64
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
input_in_bounds forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitFromEnd (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
(forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
Imp.le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
let red_res_split :: [([SubExp], [SubExp])]
red_res_split =
forall {k} (rep :: k).
[HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults (forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 (forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [[TExp Int64] -> InKernelGen ()]
histograms [([SubExp], [SubExp])]
red_res_split [TExp Int32]
subhisto_inds [TExp Int64]
hist_H_chks) forall a b. (a -> b) -> a -> b
$
\( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
[TExp Int64] -> InKernelGen ()
do_op,
([SubExp]
bucket, [SubExp]
vs'),
TExp Int32
subhisto_ind,
TExp Int64
hist_H_chk
) -> do
let chk_beg :: TExp Int64
chk_beg = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk
bucket' :: [TExp Int64]
bucket' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' :: [TExp Int64]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
flat_bucket :: TExp Int64
flat_bucket = forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dest_shape' [TExp Int64]
bucket'
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
TExp Int64
chk_beg forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
flat_bucket
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
flat_bucket forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg forall a. Num a => a -> a -> a
+ TExp Int64
hist_H_chk)
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
vs_params :: [Param LParamMem]
vs_params = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds forall a b. (a -> b) -> a -> b
$ do
let bucket_is :: [TExp Int64]
bucket_is =
forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
space_is)
forall a. [a] -> [a] -> [a]
++ [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_ind]
forall a. [a] -> [a] -> [a]
++ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dest_shape' TExp Int64
flat_bucket
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TExp Int64]
is
[TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)
histKernelGlobal ::
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen ()
histKernelGlobal :: [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count GroupSize SubExp
group_size
let ([VName]
_space_is, [SubExp]
space_sizes) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_threads :: TExp Int32
num_threads = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using global memory" forall a. Maybe a
Nothing
(TExp Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms) <-
Passage
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal
(KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody)
TExp Int32
num_threads
(SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [SubExp]
space_sizes)
[SegHistSlug]
slugs
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
[PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
[[TExp Int64] -> InKernelGen ()]
histograms
TExp Int32
hist_S
TExp Int32
chk_i
type InitLocalHistograms =
[ ( [VName],
SubExp ->
InKernelGen
( [VName],
[Imp.TExp Int64] -> InKernelGen ()
)
)
]
prepareIntermediateArraysLocal ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[SegHistSlug] ->
CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: TV Int32
-> Count NumGroups (TExp Int64)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group Count NumGroups (TExp Int64)
groups_per_segment =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ()))
onOp
where
onOp :: SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ()))
onOp (SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op) = do
TV Int64
num_subhistos forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment)
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of subhistograms in global memory per segment" forall a b. (a -> b) -> a -> b
$
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_subhistos
SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op <-
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
let lock_shape :: Shape
lock_shape =
forall d. [d] -> ShapeBase d
Shape [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group, SubExp
hist_H_chk]
let dims :: [TExp Int64]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape
VName
locks <- forall {k} (rep :: k) r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
"locks" PrimType
int32 Shape
lock_shape forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"All locks start out unlocked" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TExp Int64]
dims forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
locks [TExp Int64]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate GPUMem KernelEnv
f forall a b. (a -> b) -> a -> b
$ VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 forall a. a -> a
id
let init_local_subhistos :: SubExp
-> ImpM
GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
[VName]
local_subhistos <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall {k} (rep :: k). HistOp rep -> [Type]
histType HistOp GPUMem
op) forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let sub_local_shape :: Shape
sub_local_shape =
forall d. [d] -> ShapeBase d
Shape [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group]
forall a. Semigroup a => a -> a -> a
<> forall d. ShapeBase d -> Int -> ShapeBase d -> ShapeBase d
setOuterDims (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (HistOp GPUMem -> Int
histRank HistOp GPUMem
op) (forall d. [d] -> ShapeBase d
Shape [SubExp
hist_H_chk])
forall {k} (rep :: k) r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
[Char]
"subhistogram_local"
(forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
Shape
sub_local_shape
([Char] -> Space
Space [Char]
"local")
DoAtomicUpdate GPUMem KernelEnv
do_op' <- SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op SubExp
hist_H_chk
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
local_subhistos, DoAtomicUpdate GPUMem KernelEnv
do_op' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
[VName]
glob_subhistos <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
glob_subhistos, SubExp
-> ImpM
GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos)
histKernelLocalPass ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
InitLocalHistograms ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelLocalPass :: TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_group_var
Count NumGroups (TExp Int64)
groups_per_segment
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
segment_is :: [VName]
segment_is = forall a. [a] -> [a]
init [VName]
space_is
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init [SubExp]
space_sizes
(VName
i_in_segment, SubExp
segment_size) = forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
segment_size' :: TExp Int64
segment_size' = SubExp -> TExp Int64
pe64 SubExp
segment_size
TExp Int64
num_segments <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_segments" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
segment_dims
[TV Int64]
hist_H_chks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) forall a b. (a -> b) -> a -> b
$ \HistOp GPUMem
op ->
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_H_chk" forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> TExp Int64
histSize HistOp GPUMem
op forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
[([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [TV Int64]
hist_H_chks) forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, TV Int64
hist_H_chk) -> do
let histo_dims :: [TExp Int64]
histo_dims =
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk
forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
TExp Int64
histo_size <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
histo_dims
let group_hists_size :: TExp Int64
group_hists_size =
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_subhistos_per_group forall a. Num a => a -> a -> a
* TExp Int64
histo_size
TExp Int32
init_per_thread <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"init_per_thread" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ TExp Int64
group_hists_size forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
pe64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)
let attrs :: KernelAttrs
attrs = (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) {kAttrCheckLocalMemory :: Bool
kAttrCheckLocalMemory = Bool
False}
[Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_local" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs forall a b. (a -> b) -> a -> b
$
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment forall a. Num a => a -> a -> a
* TExp Int64
num_segments) forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
TExp Int32
flat_segment_id <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"flat_segment_id" forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id forall e. IntegralExp e => e -> e -> e
`quot` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment)
TExp Int32
gid_in_segment <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"gid_in_segment" forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment)
TExp Int32
pgtid_in_segment <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"pgtid_in_segment" forall a b. (a -> b) -> a -> b
$
TExp Int32
gid_in_segment forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
threads_per_segment <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"threads_per_segment" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
segment_is forall a b. (a -> b) -> a -> b
$
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
segment_dims) forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [TV Int64]
hist_H_chks) forall a b. (a -> b) -> a -> b
$
\(([VName]
glob_subhistos, SubExp
-> ImpM
GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos), TV Int64
hist_H_chk) -> do
([VName]
local_subhistos, [TExp Int64] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Int64
hist_H_chk
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op)
TExp Int32
thread_local_subhisto_i <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"thread_local_subhisto_i" forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_subhistos_per_group
let onSlugs :: (SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes) forall a b. (a -> b) -> a -> b
$
\(SegHistSlug
slug, ([(VName, VName)]
dests, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
_), ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)) ->
SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk) [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread
let onAllHistograms :: (VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f =
(SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread -> do
let group_hists_size :: TExp Int32
group_hists_size = TExp Int32
num_subhistos_per_group forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)) forall a b. (a -> b) -> a -> b
$
\((VName
dest_global, VName
dest_local), SubExp
ne) ->
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
init_per_thread forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" forall a b. (a -> b) -> a -> b
$
TExp Int32
i forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
j_offset <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j_offset" forall a b. (a -> b) -> a -> b
$
TExp Int32
num_subhistos_per_group forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size forall a. Num a => a -> a -> a
* TExp Int32
gid_in_segment forall a. Num a => a -> a -> a
+ TExp Int32
j
TExp Int32
local_subhisto_i <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_subhisto_i" forall a b. (a -> b) -> a -> b
$ TExp Int32
j forall e. IntegralExp e => e -> e -> e
`quot` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
let local_bucket_is :: [TExp Int64]
local_bucket_is = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ TExp Int32
j forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
nested_hist_size :: [TExp Int64]
nested_hist_size =
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Shape
histShape forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
global_bucket_is :: [TExp Int64]
global_bucket_is =
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
[TExp Int64]
nested_hist_size
(forall a. [a] -> a
head [TExp Int64]
local_bucket_is forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk)
forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [a]
tail [TExp Int64]
local_bucket_is
TExp Int32
global_subhisto_i <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"global_subhisto_i" forall a b. (a -> b) -> a -> b
$ TExp Int32
j_offset forall e. IntegralExp e => e -> e -> e
`quot` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
group_hists_size) forall a b. (a -> b) -> a -> b
$
VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f
VName
dest_local
VName
dest_global
(SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)
SubExp
ne
TExp Int32
local_subhisto_i
TExp Int32
global_subhisto_i
[TExp Int64]
local_bucket_is
[TExp Int64]
global_bucket_is
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"initialize histograms in local memory" forall a b. (a -> b) -> a -> b
$
(VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp GPUMem
op SubExp
ne TExp Int32
local_subhisto_i TExp Int32
global_subhisto_i [TExp Int64]
local_bucket_is [TExp Int64]
global_bucket_is ->
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"First subhistogram is initialised from global memory; others with neutral element." forall a b. (a -> b) -> a -> b
$ do
let global_is :: [TExp Int64]
global_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is forall a. [a] -> [a] -> [a]
++ [TExp Int64
0] forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
local_is :: [TExp Int64]
local_is = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_subhisto_i forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is
forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int32
global_subhisto_i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
(forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest_local [TExp Int64]
local_is (VName -> SubExp
Var VName
dest_global) [TExp Int64]
global_is)
( forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). HistOp rep -> Shape
histOpShape HistOp GPUMem
op) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest_local ([TExp Int64]
local_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is) SubExp
ne []
)
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int32
pgtid_in_segment TExp Int32
threads_per_segment (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
segment_size') forall a b. (a -> b) -> a -> b
$ \TExp Int32
ie -> do
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
i_in_segment forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ie
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chk_i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
(forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
space_is)
SubExp
se
[]
let red_res_split :: [([SubExp], [SubExp])]
red_res_split = forall {k} (rep :: k).
[HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults (forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [SubExp]
red_res
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([SubExp], [SubExp])]
red_res_split) forall a b. (a -> b) -> a -> b
$
\( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
([(VName, VName)]
_, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op),
([SubExp]
bucket, [SubExp]
vs')
) -> do
let chk_beg :: TExp Int64
chk_beg = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i forall a. Num a => a -> a -> a
* forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk
bucket' :: [TExp Int64]
bucket' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' :: [TExp Int64]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
flat_bucket :: TExp Int64
flat_bucket = forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dest_shape' [TExp Int64]
bucket'
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
chk_beg forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
flat_bucket
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
flat_bucket forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg forall a. Num a => a -> a -> a
+ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk)
bucket_is :: [TExp Int64]
bucket_is =
[forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
thread_local_subhisto_i, TExp Int64
flat_bucket forall a. Num a => a -> a -> a
- TExp Int64
chk_beg]
vs_params :: [Param LParamMem]
vs_params = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TExp Int64]
is
[TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Compact the multiple local memory subhistograms to result in global memory" forall a b. (a -> b) -> a -> b
$
(SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
_histo_size TExp Int32
bins_per_thread -> do
TV Int64
trunc_H <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"trunc_H" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_H_chk forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> TExp Int64
histSize (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug) forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i forall a. Num a => a -> a -> a
* forall a. [a] -> a
head [TExp Int64]
histo_dims
let trunc_histo_dims :: [TExp Int64]
trunc_histo_dims =
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
trunc_H
forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
TExp Int32
trunc_histo_size <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
trunc_histo_dims
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
bins_per_thread forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" forall a b. (a -> b) -> a -> b
$
TExp Int32
i forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
trunc_histo_size) forall a b. (a -> b) -> a -> b
$ do
let local_bucket_is :: [TExp Int64]
local_bucket_is = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
j
nested_hist_size :: [TExp Int64]
nested_hist_size =
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Shape
histShape forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
global_bucket_is :: [TExp Int64]
global_bucket_is =
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
[TExp Int64]
nested_hist_size
(forall a. [a] -> a
head [TExp Int64]
local_bucket_is forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk)
forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [a]
tail [TExp Int64]
local_bucket_is
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
let ([VName]
global_dests, [VName]
local_dests) = forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
([Param LParamMem]
xparams, [Param LParamMem]
yparams) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp forall a b. (a -> b) -> a -> b
$
SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Read values from subhistogram 0." forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
local_dests) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
subhisto) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(forall dec. Param dec -> VName
paramName Param LParamMem
xp)
[]
(VName -> SubExp
Var VName
subhisto)
(TExp Int64
0 forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Accumulate based on values in other subhistograms." forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"subhisto_id" (TExp Int32
num_subhistos_per_group forall a. Num a => a -> a -> a
- TExp Int32
1) forall a b. (a -> b) -> a -> b
$ \TExp Int32
subhisto_id -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
yparams [VName]
local_dests) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
yp, VName
subhisto) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(forall dec. Param dec -> VName
paramName Param LParamMem
yp)
[]
(VName -> SubExp
Var VName
subhisto)
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_id forall a. Num a => a -> a -> a
+ TExp Int64
1 forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
xparams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Put final bucket value in global memory." forall a b. (a -> b) -> a -> b
$ do
let global_is :: [TExp Int64]
global_is =
forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is
forall a. [a] -> [a] -> [a]
++ [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
global_dests) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
global_dest) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global_dest [TExp Int64]
global_is (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
xp) []
histKernelLocal ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen ()
histKernelLocal :: TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TExp Int64)
groups_per_segment [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space TExp Int32
hist_S [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of local subhistograms per group" forall a b. (a -> b) -> a -> b
$
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
num_subhistos_per_group
InitLocalHistograms
init_histograms <-
TV Int32
-> Count NumGroups (TExp Int64)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TExp Int64)
groups_per_segment [SegHistSlug]
slugs
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_group_var
Count NumGroups (TExp Int64)
groups_per_segment
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim DoAtomicUpdate GPUMem KernelEnv
_ -> Int
3
AtomicCAS DoAtomicUpdate GPUMem KernelEnv
_ -> Int
4
AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ -> Int
6
localMemoryCase ::
[PatElem LetDecMem] ->
Imp.TExp Int32 ->
SegSpace ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen (Imp.TExp Bool, CallKernelGen ())
localMemoryCase :: [PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
_ [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init [SubExp]
space_sizes
segmented :: Bool
segmented = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims
TV Int64
hist_L <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"hist_L" PrimType
int32
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
hist_L) SizeClass
Imp.SizeLocalMemory
TV Any
max_group_size <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"max_group_size" PrimType
int32
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (forall {k} (t :: k). TV t -> VName
tvVar TV Any
max_group_size) SizeClass
Imp.SizeGroup
let withSizeMax :: Map VName (VarEntry GPUMem) -> Map VName (VarEntry GPUMem)
withSizeMax Map VName (VarEntry GPUMem)
vtable =
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall {k} (t :: k). TV t -> VName
tvVar TV Any
max_group_size) Map VName (VarEntry GPUMem)
vtable of
Just (ScalarVar Maybe (Exp GPUMem)
_ ScalarEntry
se) ->
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
(forall {k} (t :: k). TV t -> VName
tvVar TV Any
max_group_size)
(forall {k} (rep :: k).
Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar (forall a. a -> Maybe a
Just (forall {k} (rep :: k). Op rep -> Exp rep
Op (forall inner. inner -> MemOp inner
Inner (forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp (SizeClass -> SizeOp
GetSizeMax SizeClass
SizeGroup))))) ScalarEntry
se)
Map VName (VarEntry GPUMem)
vtable
Maybe (VarEntry GPUMem)
_ -> Map VName (VarEntry GPUMem)
vtable
let group_size :: Count GroupSize SubExp
group_size = forall {k} (u :: k) e. e -> Count u e
Imp.Count forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Any
max_group_size
Count NumGroups SubExp
num_groups <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (u :: k) e. e -> Count u e
Imp.Count forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k). TV t -> SubExp
tvSize) forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"num_groups" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
pe64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = SubExp -> TExp Int64
pe64 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = SubExp -> TExp Int64
pe64 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size
let r64 :: TPrimExp t v -> TPrimExp Double v
r64 = forall v. PrimExp v -> TPrimExp Double v
isF64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int64 FloatType
Float64) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
TExp Double
hist_m' <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_m_prime" forall a b. (a -> b) -> a -> b
$
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64
( forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
hist_el_size))
(TExp Int64
hist_N forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups'))
)
forall a. Fractional a => a -> a -> a
/ forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H
let hist_B :: TExp Int64
hist_B = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
TExp Int64
hist_M0 <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M0" forall a b. (a -> b) -> a -> b
$
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 forall a b. (a -> b) -> a -> b
$
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 TExp Double
hist_m') TExp Int64
hist_B
let q_small :: TExp Int64
q_small = TExp Int64
2
TExp Int64
hist_Nout <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nout" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
segment_dims
TExp Int64
hist_Nin <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nin" forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [SubExp]
space_sizes
TExp Int64
work_asymp_M_max <-
if Bool
segmented
then do
TExp Int32
hist_T_hist_min <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_T_hist_min" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nin forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout) (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T)
forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout
let r :: TExp Int32
r = TExp Int32
hist_T_hist_min forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
hist_B
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_Nin forall e. IntegralExp e => e -> e -> e
`quot` (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
r forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
else
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" forall a b. (a -> b) -> a -> b
$
(TExp Int64
hist_Nout forall a. Num a => a -> a -> a
* TExp Int64
hist_N)
forall e. IntegralExp e => e -> e -> e
`quot` ( (TExp Int64
q_small forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
forall e. IntegralExp e => e -> e -> e
`quot` forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
)
TV Int32
hist_M <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_M" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_M0 TExp Int64
work_asymp_M_max
let hist_M_nonzero :: TExp Int32
hist_M_nonzero = forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
TExp Int64
hist_C <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_B forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_nonzero
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local hist_M0" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_M0
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local work asymp M max" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
work_asymp_M_max
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local C" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local B" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local M" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local memory needed" forall a b. (a -> b) -> a -> b
$
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_H forall a. Num a => a -> a -> a
* TExp Int64
hist_el_size forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M)
TExp Int64
local_mem_needed <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_mem_needed" forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_el_size forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M)
TExp Int32
hist_S <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_S" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
(TExp Int64
hist_H forall a. Num a => a -> a -> a
* TExp Int64
local_mem_needed) forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L
let max_S :: TExp Int32
max_S = case KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody of
Passage
MustBeSinglePass -> TExp Int32
1
Passage
MayBeMultiPass -> forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs
Count NumGroups (TExp Int64)
groups_per_segment <-
if Bool
segmented
then
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"groups_per_segment" forall a b. (a -> b) -> a -> b
$
forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_Nout
else forall (f :: * -> *) a. Applicative f => a -> f a
pure Count NumGroups (TExp Int64)
num_groups'
let pick_local :: TExp Bool
pick_local =
TExp Int64
hist_Nin forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int64
hist_H
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int64
local_mem_needed forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L)
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int32
hist_S forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
max_S)
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
hist_C forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
hist_B
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0
run :: CallKernelGen ()
run = do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using local memory" forall a. Maybe a
Nothing
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_S
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented forall a b. (a -> b) -> a -> b
$
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Groups per segment" forall a b. (a -> b) -> a -> b
$
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment
forall {k} (rep :: k) r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable Map VName (VarEntry GPUMem) -> Map VName (VarEntry GPUMem)
withSizeMax forall a b. (a -> b) -> a -> b
$
TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal
TV Int32
hist_M
Count NumGroups (TExp Int64)
groups_per_segment
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
TExp Int32
hist_S
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Bool
pick_local, CallKernelGen ()
run)
compileSegHist ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
[HistOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegHist :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist (Pat [PatElem LParamMem]
pes) SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody = do
KernelAttrs Bool
_ Bool
_ Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count GroupSize SubExp
group_size
dims :: [TExp Int64]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
num_red_res :: Int
num_red_res = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
([PatElem LParamMem]
all_red_pes, [PatElem LParamMem]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem LParamMem]
pes
segment_size :: TExp Int64
segment_size = forall a. [a] -> a
last [TExp Int64]
dims
([Count Bytes (TExp Int64)]
op_hs, [Count Bytes (TExp Int64)]
op_seg_hs, [SegHistSlug]
slugs) <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp GPUMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp GPUMem]
ops
TExp Int64
h <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"h" forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
Imp.unCount forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_hs
TExp Int64
seg_h <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"seg_h" forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
Imp.unCount forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_seg_hs
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TExp Int64
seg_h forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) forall a b. (a -> b) -> a -> b
$ do
let hist_B :: TExp Int64
hist_B = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
TExp Int64
hist_H <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> TExp Int64
histSize [HistOp GPUMem]
ops
let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicLocking {} -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Num a => PrimType -> a
primByteSize PrimType
int32
AtomicUpdate GPUMem KernelEnv
_ -> forall a. Maybe a
Nothing
TExp Int64
hist_el_size <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Num a => a -> a -> a
(+) (TExp Int64
h forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_H) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {a}. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs
TExp Int64
hist_N <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_N" TExp Int64
segment_size
TExp Int32
hist_RF <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> SubExp
histRaceFactor forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
forall e. IntegralExp e => e -> e -> e
`quot` forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
let hist_T :: TExp Int32
hist_T = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"\n# SegHist" forall a. Maybe a
Nothing
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of threads (T)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_T
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Desired group size (B)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Input elements per histogram (N)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_N
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of segments" forall a b. (a -> b) -> a -> b
$
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram element size (el_size)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_el_size
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race factor (RF)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_RF
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms per segment" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
h
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms times segments" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
seg_h
(TExp Bool
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
[PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
hist_RF [SegHistSlug]
slugs KernelBody GPUMem
kbody
forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
use_local_memory CallKernelGen ()
run_in_local_memory forall a b. (a -> b) -> a -> b
$
[PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody
let pes_per_op :: [[PatElem LParamMem]]
pes_per_op = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [PatElem LParamMem]
all_red_pes
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElem LParamMem]]
pes_per_op [HistOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElem LParamMem]
red_pes, HistOp GPUMem
op) -> do
let num_histos :: TV Int64
num_histos = SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug
subhistos :: [VName]
subhistos = forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [VName]
subhistos) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
subhisto) -> do
VName
pe_mem <-
MemLoc -> VName
memLocName forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
VName
subhisto_mem <-
MemLoc -> VName
memLocName forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
subhisto
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_histos forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase forall a b. (a -> b) -> a -> b
$ do
[VName]
bucket_ids <-
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a. ArrayShape a => a -> Int
shapeRank (forall {k} (rep :: k). HistOp rep -> Shape
histShape HistOp GPUMem
op)) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bucket_id")
VName
subhistogram_id <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"subhistogram_id"
[VName]
vector_ids <-
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a. ArrayShape a => a -> Int
shapeRank (forall {k} (rep :: k). HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"vector_id")
VName
flat_gtid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"flat_gtid"
let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
segred_space :: SegSpace
segred_space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
segment_dims
forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> Shape
histShape HistOp GPUMem
op))
forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Shape
histOpShape HistOp GPUMem
op)
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Int64
num_histos)]
let segred_op :: SegBinOp GPUMem
segred_op = forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Commutative (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) (forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op) forall a. Monoid a => a
mempty
Pat LParamMem
-> KernelGrid
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
red_pes) KernelGrid
grid SegSpace
segred_space [SegBinOp GPUMem
segred_op] forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ->
[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
( VName -> SubExp
Var VName
subhisto,
forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids
)
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"" forall a. Maybe a
Nothing
where
segment_dims :: [(VName, SubExp)]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space