{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegHist (compileSegHist) where
import Control.Monad.Except
import Data.List (foldl', genericLength, zip5)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
data SubhistosInfo = SubhistosInfo
{ SubhistosInfo -> VName
subhistosArray :: VName,
SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
}
data SegHistSlug = SegHistSlug
{ SegHistSlug -> HistOp GPUMem
slugOp :: HistOp GPUMem,
SegHistSlug -> TV Int64
slugNumSubhistos :: TV Int64,
SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo],
SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate :: AtomicUpdate GPUMem KernelEnv
}
histSpaceUsage ::
HistOp GPUMem ->
Imp.Count Imp.Bytes (Imp.TExp Int64)
histSpaceUsage :: HistOp GPUMem -> Count Bytes (TPrimExp Int64 VName)
histSpaceUsage HistOp GPUMem
op =
[Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName))
-> ([Type] -> [Count Bytes (TPrimExp Int64 VName)])
-> [Type]
-> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TPrimExp Int64 VName)
typeSize (Type -> Count Bytes (TPrimExp Int64 VName))
-> (Type -> Type) -> Type -> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op))) ([Type] -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda GPUMem -> [Type]) -> Lambda GPUMem -> [Type]
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
histSize :: HistOp GPUMem -> Imp.TExp Int64
histSize :: HistOp GPUMem -> TPrimExp Int64 VName
histSize = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> (HistOp GPUMem -> [TPrimExp Int64 VName])
-> HistOp GPUMem
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (HistOp GPUMem -> [SubExp])
-> HistOp GPUMem
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp])
-> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape
histRank :: HistOp GPUMem -> Int
histRank :: HistOp GPUMem -> Int
histRank = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape
computeHistoUsage ::
SegSpace ->
HistOp GPUMem ->
CallKernelGen
( Imp.Count Imp.Bytes (Imp.TExp Int64),
Imp.Count Imp.Bytes (Imp.TExp Int64),
SegHistSlug
)
computeHistoUsage :: SegSpace
-> HistOp GPUMem
-> CallKernelGen
(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)
computeHistoUsage SegSpace
space HistOp GPUMem
op = do
let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims
TV Int64
num_subhistos <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"num_subhistos" PrimType
int32
[SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op)) (((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
Type
dest_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dest
MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest
VName
subhistos_mem <-
[Char] -> Space -> ImpM GPUMem HostEnv HostOp VName
forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem (VName -> [Char]
baseString VName
dest [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos_mem") ([Char] -> Space
Space [Char]
"device")
let subhistos_shape :: Shape
subhistos_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_subhistos])
Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Int -> Shape -> Shape
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
VName
subhistos <-
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
[Char]
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray
(VName -> [Char]
baseString VName
dest [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos")
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
Shape
subhistos_shape
VName
subhistos_mem
(IxFun -> ImpM GPUMem HostEnv HostOp VName)
-> IxFun -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota
([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64
([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$
VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
[Char] -> Space
Space [Char]
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = do
let num_elems :: TPrimExp Int64 VName
num_elems = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
subhistos_mem_size :: Count Bytes (TPrimExp Int64 VName)
subhistos_mem_size =
TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 VName
num_elems Count Elements (TPrimExp Int64 VName)
-> PrimType -> Count Bytes (TPrimExp Int64 VName)
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
VName
-> Count Bytes (TPrimExp Int64 VName) -> Space -> CallKernelGen ()
forall rep r op.
VName
-> Count Bytes (TPrimExp Int64 VName) -> Space -> ImpM rep r op ()
sAlloc_ VName
subhistos_mem Count Bytes (TPrimExp Int64 VName)
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
Type
subhistos_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
subhistos
let slice :: Slice (TPrimExp Int64 VName)
slice =
[TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> ((VName, SubExp) -> TPrimExp Int64 VName)
-> (VName, SubExp)
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
[DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0]
VName -> Slice (TPrimExp Int64 VName) -> SubExp -> CallKernelGen ()
forall rep r op.
VName -> Slice (TPrimExp Int64 VName) -> SubExp -> ImpM rep r op ()
sUpdate VName
subhistos Slice (TPrimExp Int64 VName)
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
let h :: Count Bytes (TPrimExp Int64 VName)
h = HistOp GPUMem -> Count Bytes (TPrimExp Int64 VName)
histSpaceUsage HistOp GPUMem
op
segmented_h :: Count Bytes (TPrimExp Int64 VName)
segmented_h = Count Bytes (TPrimExp Int64 VName)
h Count Bytes (TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName)
forall a. Num a => a -> a -> a
* [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes (TPrimExp Int64 VName))
-> [SubExp] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) ([SubExp] -> [Count Bytes (TPrimExp Int64 VName)])
-> [SubExp] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)
-> CallKernelGen
(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Count Bytes (TPrimExp Int64 VName)
h,
Count Bytes (TPrimExp Int64 VName)
segmented_h,
HistOp GPUMem
-> TV Int64
-> [SubhistosInfo]
-> AtomicUpdate GPUMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate GPUMem KernelEnv -> SegHistSlug)
-> AtomicUpdate GPUMem KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv)
-> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
)
prepareAtomicUpdateGlobal ::
Maybe Locking ->
[VName] ->
SegHistSlug ->
CallKernelGen
( Maybe Locking,
[Imp.TExp Int64] -> InKernelGen ()
)
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
case (Maybe Locking
l, SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
let num_locks :: Int
num_locks = Int
100151
dims :: [TPrimExp Int64 VName]
dims =
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug)]
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
VName
locks <-
[Char]
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
[Char] -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
"hist_locks" ([Char] -> Space
Space [Char]
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage
-> (Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
Ord)
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody
| Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases GPUMem) -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (AliasTable -> KernelBody GPUMem -> KernelBody (Aliases GPUMem)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
forall a. Monoid a => a
mempty KernelBody GPUMem
kbody) =
Passage
MayBeMultiPass
| Bool
otherwise =
Passage
MustBeSinglePass
prepareIntermediateArraysGlobal ::
Passage ->
Imp.TExp Int32 ->
Imp.TExp Int64 ->
[SegHistSlug] ->
CallKernelGen
( Imp.TExp Int32,
[[Imp.TExp Int64] -> InKernelGen ()]
)
prepareIntermediateArraysGlobal :: Passage
-> TExp Int32
-> TPrimExp Int64 VName
-> [SegHistSlug]
-> CallKernelGen
(TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage TExp Int32
hist_T TPrimExp Int64 VName
hist_N [SegHistSlug]
slugs = do
TPrimExp Int64 VName
hist_H <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TPrimExp Int64 VName)
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TPrimExp Int64 VName
histSize (HistOp GPUMem -> TPrimExp Int64 VName)
-> (SegHistSlug -> HistOp GPUMem)
-> SegHistSlug
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs
TPrimExp Double VName
hist_RF <-
[Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" (TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
[TPrimExp Double VName] -> TPrimExp Double VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TPrimExp Double VName)
-> [SegHistSlug] -> [TPrimExp Double VName]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TPrimExp Int64 VName -> TPrimExp Double VName)
-> (SegHistSlug -> TPrimExp Int64 VName)
-> SegHistSlug
-> TPrimExp Double VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ [SegHistSlug] -> TPrimExp Double VName
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
TExp Int32
hist_el_size <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int32) -> [SegHistSlug] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> TExp Int32
slugElAvgSize [SegHistSlug]
slugs
TPrimExp Double VName
hist_C_max <-
[Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C_max" (TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TPrimExp Double VName -> TPrimExp Double VName)
-> TPrimExp Double VName -> TPrimExp Double VName
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Int64 VName
hist_H TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_k_ct_min
TExp Int32
hist_M_min <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TPrimExp Double VName -> TPrimExp Int64 VName)
-> TPrimExp Double VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_C_max
let hist_L2_def :: Int64
hist_L2_def = Int64
4 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024
TV Any
hist_L2 <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Any)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"L2_size" PrimType
int32
Maybe Name
entry <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp
(HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize
(TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)
(Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ [Char] -> Name
nameFromString (VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)))
(SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$ Name -> Int64 -> SizeClass
Imp.SizeBespoke ([Char] -> Name
nameFromString [Char]
"L2_for_histogram") Int64
hist_L2_def
let hist_L2_ln_sz :: TPrimExp Double VName
hist_L2_ln_sz = TPrimExp Double VName
16 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
4
TPrimExp Double VName
hist_RACE_exp <-
[Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RACE_exp" (TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMax64 TPrimExp Double VName
1 (TPrimExp Double VName -> TPrimExp Double VName)
-> TPrimExp Double VName -> TPrimExp Double VName
forall a b. (a -> b) -> a -> b
$
(TPrimExp Double VName
hist_k_RF TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
hist_RF)
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ (TPrimExp Double VName
hist_L2_ln_sz TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_el_size)
TV Int32
hist_S <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"hist_S" PrimType
int32
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Int64 VName
hist_N TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
hist_H)
(TV Int32
hist_S TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32))
(CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ TV Int32
hist_S
TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- case Passage
passage of
Passage
MayBeMultiPass ->
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
(TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_min TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_el_size)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TPrimExp Double VName
hist_F_L2 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Any VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TV Any -> TPrimExp Any VName
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
hist_RACE_exp)
Passage
MustBeSinglePass ->
TExp Int32
1
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Double VName
hist_RACE_exp
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S
[[TPrimExp Int64 VName] -> InKernelGen ()]
histograms <-
(Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
forall a b. (a, b) -> b
snd
((Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ImpM
GPUMem
HostEnv
HostOp
(Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ImpM
GPUMem HostEnv HostOp [[TPrimExp Int64 VName] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
-> SegHistSlug
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
GPUMem
HostEnv
HostOp
(Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
(TPrimExp Any VName
-> TExp Int32
-> TExp Int32
-> TPrimExp Double VName
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp (TV Any -> TPrimExp Any VName
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TExp Int32
hist_M_min (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S) TPrimExp Double VName
hist_RACE_exp)
Maybe Locking
forall a. Maybe a
Nothing
[SegHistSlug]
slugs
(TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> CallKernelGen
(TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S, [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms)
where
hist_k_ct_min :: TPrimExp Double VName
hist_k_ct_min = TPrimExp Double VName
2
hist_k_RF :: TPrimExp Double VName
hist_k_RF = TPrimExp Double VName
0.75
hist_F_L2 :: TPrimExp Double VName
hist_F_L2 = TPrimExp Double VName
0.4
r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
slugElAvgSize :: SegHistSlug -> TExp Int32
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicLocking {} ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)))
AtomicUpdate GPUMem KernelEnv
_ ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op))
slugElSize :: SegHistSlug -> TExp Int32
slugElSize (SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicLocking {} ->
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
[Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName))
-> [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(Type -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TPrimExp Int64 VName)
typeSize (Type -> Count Bytes (TPrimExp Int64 VName))
-> (Type -> Type) -> Type -> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TPrimExp Int64 VName)])
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$
PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
AtomicUpdate GPUMem KernelEnv
_ ->
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
[Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName))
-> [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(Type -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TPrimExp Int64 VName)
typeSize (Type -> Count Bytes (TPrimExp Int64 VName))
-> (Type -> Type) -> Type -> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TPrimExp Int64 VName)])
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$
Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
onOp :: TPrimExp Any VName
-> TExp Int32
-> TExp Int32
-> TPrimExp Double VName
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp TPrimExp Any VName
hist_L2 TExp Int32
hist_M_min TExp Int32
hist_S TPrimExp Double VName
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
let SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op = SegHistSlug
slug
hist_H :: TPrimExp Int64 VName
hist_H = HistOp GPUMem -> TPrimExp Int64 VName
histSize HistOp GPUMem
op
TPrimExp Int64 VName
hist_H_chk <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_H_chk
TPrimExp Double VName
hist_k_max <-
[Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_k_max" (TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64
(TPrimExp Double VName
hist_F_L2 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* (TPrimExp Any VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Any VName
hist_L2 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug)) TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
hist_RACE_exp)
(TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Int64 VName
hist_N)
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T
TPrimExp Int64 VName
hist_u <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_u" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicPrim {} -> TPrimExp Int64 VName
2
AtomicUpdate GPUMem KernelEnv
_ -> TPrimExp Int64 VName
1
TPrimExp Double VName
hist_C <-
[Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" (TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TPrimExp Double VName -> TPrimExp Double VName)
-> TPrimExp Double VName -> TPrimExp Double VName
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TPrimExp Int64 VName
hist_u TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk) TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_k_max
TExp Int32
hist_M <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim {} -> TExp Int32
1
AtomicUpdate GPUMem KernelEnv
_ -> TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
hist_M_min (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TPrimExp Double VName -> TPrimExp Int64 VName)
-> TPrimExp Double VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_C
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Double VName
hist_k_max
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Double VName
hist_C
TV Int64
num_subhistos TV Int64 -> TPrimExp Int64 VName -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M
[VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName])
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest
VName
sub_mem <-
(MemLoc -> VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLoc -> VName
memLocName (ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
ArrayEntry -> MemLoc
entryArrayLoc
(ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
[Char] -> Space
Space [Char]
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
VName -> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
(Maybe Locking
l', [TPrimExp Int64 VName] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l', [TPrimExp Int64 VName] -> InKernelGen ()
do_op')
histKernelGlobalPass ::
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
[[Imp.TExp Int64] -> InKernelGen ()] ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelGlobalPass :: [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms TExp Int32
hist_S TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
space_sizes_64 :: [TPrimExp Int64 VName]
space_sizes_64 = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) [SubExp]
space_sizes
total_w_64 :: TPrimExp Int64 VName
total_w_64 = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
space_sizes_64
[TPrimExp Int64 VName]
hist_H_chks <- [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> ImpM GPUMem HostEnv HostOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> TPrimExp Int64 VName)
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TPrimExp Int64 VName
histSize (HistOp GPUMem -> TPrimExp Int64 VName)
-> (SegHistSlug -> HistOp GPUMem)
-> SegHistSlug
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs) ((TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> ImpM GPUMem HostEnv HostOp [TPrimExp Int64 VName])
-> (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> ImpM GPUMem HostEnv HostOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
w ->
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
[Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_global" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
[TExp Int32]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32])
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"subhisto_ind" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` ( KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug))
)
let gtid :: TPrimExp Int64 VName
gtid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
num_threads :: TPrimExp Int64 VName
num_threads = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TPrimExp Int64 VName
gtid TPrimExp Int64 VName
num_threads TPrimExp Int64 VName
total_w_64 ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
offset -> do
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
space_is [TPrimExp Int64 VName]
space_sizes_64) TPrimExp Int64 VName
offset
let input_in_bounds :: TExp Bool
input_in_bounds = TPrimExp Int64 VName
offset TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
total_w_64
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) (((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
(((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
let red_res_split :: [([SubExp], [SubExp])]
red_res_split =
[HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) ([SubExp] -> [([SubExp], [SubExp])])
-> [SubExp] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)]
-> ((HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> [([SubExp], [SubExp])]
-> [TExp Int32]
-> [TPrimExp Int64 VName]
-> [(HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms [([SubExp], [SubExp])]
red_res_split [TExp Int32]
subhisto_inds [TPrimExp Int64 VName]
hist_H_chks) (((HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
[TPrimExp Int64 VName] -> InKernelGen ()
do_op,
([SubExp]
bucket, [SubExp]
vs'),
TExp Int32
subhisto_ind,
TPrimExp Int64 VName
hist_H_chk
) -> do
let chk_beg :: TPrimExp Int64 VName
chk_beg = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk
bucket' :: [TPrimExp Int64 VName]
bucket' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
bucket
dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
flat_bucket :: TPrimExp Int64 VName
flat_bucket = [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dest_shape' [TPrimExp Int64 VName]
bucket'
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
flat_bucket
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
flat_bucket TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
hist_H_chk)
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
bucket')) [TPrimExp Int64 VName]
dest_shape'
vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let bucket_is :: [TPrimExp Int64 VName]
bucket_is =
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is)
[TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_ind]
[TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dest_shape' TPrimExp Int64 VName
flat_bucket
[LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TPrimExp Int64 VName]
is
[TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bucket_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)
histKernelGlobal ::
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen ()
histKernelGlobal :: [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_threads :: TExp Int32
num_threads = TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing
(TExp Int32
hist_S, [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms) <-
Passage
-> TExp Int32
-> TPrimExp Int64 VName
-> [SegHistSlug]
-> CallKernelGen
(TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
prepareIntermediateArraysGlobal
(KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody)
TExp Int32
num_threads
(SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes)
[SegHistSlug]
slugs
[Char]
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
[PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
[[TPrimExp Int64 VName] -> InKernelGen ()]
histograms
TExp Int32
hist_S
TExp Int32
chk_i
type InitLocalHistograms =
[ ( [VName],
SubExp ->
InKernelGen
( [VName],
[Imp.TExp Int64] -> InKernelGen ()
)
)
]
prepareIntermediateArraysLocal ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[SegHistSlug] ->
CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group Count NumGroups (TPrimExp Int64 VName)
groups_per_segment =
(SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
onOp
where
onOp :: SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
onOp (SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op) = do
TV Int64
num_subhistos TV Int64 -> TPrimExp Int64 VName -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment)
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of subhistograms in global memory per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos
SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op <-
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
let lock_shape :: Shape
lock_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$
TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group
SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
hist_H_chk]
let dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape
VName
locks <- [Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName]
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TPrimExp Int64 VName]
dims (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DoAtomicUpdate GPUMem KernelEnv
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate GPUMem KernelEnv
f (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> Locking -> DoAtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> a
id
let init_local_subhistos :: SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
[VName]
local_subhistos <-
[Type]
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp GPUMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp GPUMem
op) ((Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let sub_local_shape :: Shape
sub_local_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group]
Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape -> Int -> Shape -> Shape
forall d. ShapeBase d -> Int -> ShapeBase d -> ShapeBase d
setOuterDims (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (HistOp GPUMem -> Int
histRank HistOp GPUMem
op) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
hist_H_chk])
[Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
[Char]
"subhistogram_local"
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
Shape
sub_local_shape
([Char] -> Space
Space [Char]
"local")
DoAtomicUpdate GPUMem KernelEnv
do_op' <- SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op SubExp
hist_H_chk
([VName], [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
local_subhistos, DoAtomicUpdate GPUMem KernelEnv
do_op' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
[VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
VName -> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
glob_subhistos, SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos)
histKernelLocalPass ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
InitLocalHistograms ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelLocalPass :: TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_group_var
Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
(VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
segment_size' :: TPrimExp Int64 VName
segment_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
segment_size
TPrimExp Int64 VName
num_segments <-
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_segments" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
segment_dims
[TV Int64]
hist_H_chks <- [HistOp GPUMem]
-> (HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) ((HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64])
-> (HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall a b. (a -> b) -> a -> b
$ \HistOp GPUMem
op ->
[Char]
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_H_chk" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> TPrimExp Int64 VName
histSize HistOp GPUMem
op TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
[([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
histo_sizes <- [(SegHistSlug, TV Int64)]
-> ((SegHistSlug, TV Int64)
-> ImpM
GPUMem
HostEnv
HostOp
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> ImpM
GPUMem
HostEnv
HostOp
[([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SegHistSlug] -> [TV Int64] -> [(SegHistSlug, TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [TV Int64]
hist_H_chks) (((SegHistSlug, TV Int64)
-> ImpM
GPUMem
HostEnv
HostOp
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> ImpM
GPUMem
HostEnv
HostOp
[([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)])
-> ((SegHistSlug, TV Int64)
-> ImpM
GPUMem
HostEnv
HostOp
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> ImpM
GPUMem
HostEnv
HostOp
[([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, TV Int64
hist_H_chk) -> do
let histo_dims :: [TPrimExp Int64 VName]
histo_dims =
TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk
TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
TPrimExp Int64 VName
histo_size <-
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
histo_dims
let group_hists_size :: TPrimExp Int64 VName
group_hists_size =
TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_subhistos_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
histo_size
TExp Int32
init_per_thread <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"init_per_thread" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_hists_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)
-> ImpM
GPUMem
HostEnv
HostOp
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TPrimExp Int64 VName]
histo_dims, TPrimExp Int64 VName
histo_size, TExp Int32
init_per_thread)
let attrs :: KernelAttrs
attrs = (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) {kAttrCheckLocalMemory :: Bool
kAttrCheckLocalMemory = Bool
False}
[Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_local" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_segments) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
TExp Int32
flat_segment_id <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"flat_segment_id" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment)
TExp Int32
gid_in_segment <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"gid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment)
TExp Int32
pgtid_in_segment <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"pgtid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
threads_per_segment <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"threads_per_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
(VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
segment_is ([TPrimExp Int64 VName] -> InKernelGen ())
-> [TPrimExp Int64 VName] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
segment_dims) (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id
[([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())]
histograms <- [(([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
TV Int64)]
-> ((([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
TV Int64)
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()))
-> ImpM
GPUMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [TV Int64]
-> [(([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [TV Int64]
hist_H_chks) (((([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
TV Int64)
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()))
-> ImpM
GPUMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())])
-> ((([VName],
SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
TV Int64)
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()))
-> ImpM
GPUMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
\(([VName]
glob_subhistos, SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos), TV Int64
hist_H_chk) -> do
([VName]
local_subhistos, [TPrimExp Int64 VName] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos (SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
-> SubExp
-> ImpM
GPUMem
KernelEnv
KernelOp
([VName], [TPrimExp Int64 VName] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_H_chk
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, TV Int64
hist_H_chk, [TPrimExp Int64 VName] -> InKernelGen ()
do_op)
TExp Int32
thread_local_subhisto_i <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"thread_local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_subhistos_per_group
let onSlugs :: (SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ()
f =
[(SegHistSlug,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))]
-> ((SegHistSlug,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())]
-> [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
-> [(SegHistSlug,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())]
histograms [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
histo_sizes) (((SegHistSlug,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> InKernelGen ())
-> InKernelGen ())
-> ((SegHistSlug,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SegHistSlug
slug, ([(VName, VName)]
dests, TV Int64
hist_H_chk, [TPrimExp Int64 VName] -> InKernelGen ()
_), ([TPrimExp Int64 VName]
histo_dims, TPrimExp Int64 VName
histo_size, TExp Int32
init_per_thread)) ->
SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk) [TPrimExp Int64 VName]
histo_dims TPrimExp Int64 VName
histo_size TExp Int32
init_per_thread
let onAllHistograms :: (VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ()
f =
(SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TPrimExp Int64 VName
hist_H_chk [TPrimExp Int64 VName]
histo_dims TPrimExp Int64 VName
histo_size TExp Int32
init_per_thread -> do
let group_hists_size :: TExp Int32
group_hists_size = TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size
[((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\((VName
dest_global, VName
dest_local), SubExp
ne) ->
[Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
init_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
j_offset <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j_offset" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
j
TExp Int32
local_subhisto_i <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size
let local_bucket_is :: [TPrimExp Int64 VName]
local_bucket_is = [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
histo_dims (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size
nested_hist_size :: [TPrimExp Int64 VName]
nested_hist_size =
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
global_bucket_is :: [TPrimExp Int64 VName]
global_bucket_is =
[TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
[TPrimExp Int64 VName]
nested_hist_size
([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
head [TPrimExp Int64 VName]
local_bucket_is TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk)
[TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
tail [TPrimExp Int64 VName]
local_bucket_is
TExp Int32
global_subhisto_i <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"global_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
group_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ()
f
VName
dest_local
VName
dest_global
(SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)
SubExp
ne
TExp Int32
local_subhisto_i
TExp Int32
global_subhisto_i
[TPrimExp Int64 VName]
local_bucket_is
[TPrimExp Int64 VName]
global_bucket_is
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"initialize histograms in local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ())
-> InKernelGen ())
-> (VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp GPUMem
op SubExp
ne TExp Int32
local_subhisto_i TExp Int32
global_subhisto_i [TPrimExp Int64 VName]
local_bucket_is [TPrimExp Int64 VName]
global_bucket_is ->
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let global_is :: [TPrimExp Int64 VName]
global_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
0] [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
global_bucket_is
local_is :: [TPrimExp Int64 VName]
local_is = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_subhisto_i TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int64 VName]
local_bucket_is
TExp Bool -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int32
global_subhisto_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
(VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest_local [TPrimExp Int64 VName]
local_is (VName -> SubExp
Var VName
dest_global) [TPrimExp Int64 VName]
global_is)
( Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest_local ([TPrimExp Int64 VName]
local_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is) SubExp
ne []
)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
TExp Int32
-> TExp Int32
-> TExp Int32
-> (TExp Int32 -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int32
pgtid_in_segment TExp Int32
threads_per_segment (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
segment_size') ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
ie -> do
VName -> TPrimExp Int64 VName -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
i_in_segment (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ie
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chk_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
space_is)
SubExp
se
[]
let red_res_split :: [([SubExp], [SubExp])]
red_res_split = [HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [SubExp]
red_res
[(HistOp GPUMem,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([SubExp], [SubExp]))]
-> ((HistOp GPUMem,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([SubExp], [SubExp]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())]
-> [([SubExp], [SubExp])]
-> [(HistOp GPUMem,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ())]
histograms [([SubExp], [SubExp])]
red_res_split) (((HistOp GPUMem,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([SubExp], [SubExp]))
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp GPUMem,
([(VName, VName)], TV Int64,
[TPrimExp Int64 VName] -> InKernelGen ()),
([SubExp], [SubExp]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
([(VName, VName)]
_, TV Int64
hist_H_chk, [TPrimExp Int64 VName] -> InKernelGen ()
do_op),
([SubExp]
bucket, [SubExp]
vs')
) -> do
let chk_beg :: TPrimExp Int64 VName
chk_beg = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk
bucket' :: [TPrimExp Int64 VName]
bucket' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
bucket
dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
flat_bucket :: TPrimExp Int64 VName
flat_bucket = [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dest_shape' [TPrimExp Int64 VName]
bucket'
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
bucket')) [TPrimExp Int64 VName]
dest_shape'
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
flat_bucket
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
flat_bucket TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk)
bucket_is :: [TPrimExp Int64 VName]
bucket_is =
[TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
thread_local_subhisto_i, TPrimExp Int64 VName
flat_bucket TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
chk_beg]
vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
[TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bucket_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Compact the multiple local memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TPrimExp Int64 VName
hist_H_chk [TPrimExp Int64 VName]
histo_dims TPrimExp Int64 VName
_histo_size TExp Int32
bins_per_thread -> do
TV Int64
trunc_H <-
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"trunc_H" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TPrimExp Int64 VName
hist_H_chk (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> TPrimExp Int64 VName
histSize (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
head [TPrimExp Int64 VName]
histo_dims
let trunc_histo_dims :: [TPrimExp Int64 VName]
trunc_histo_dims =
TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
trunc_H
TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
TExp Int32
trunc_histo_size <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
trunc_histo_dims
[Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
bins_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let local_bucket_is :: [TPrimExp Int64 VName]
local_bucket_is = [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
histo_dims (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
j
nested_hist_size :: [TPrimExp Int64 VName]
nested_hist_size =
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
global_bucket_is :: [TPrimExp Int64 VName]
global_bucket_is =
[TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
[TPrimExp Int64 VName]
nested_hist_size
([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
head [TPrimExp Int64 VName]
local_bucket_is TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk)
[TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
tail [TPrimExp Int64 VName]
local_bucket_is
[LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
([Param LParamMem]
xparams, [Param LParamMem]
yparams) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$
Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$
SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
local_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
subhisto) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp)
[]
(VName -> SubExp
Var VName
subhisto)
(TPrimExp Int64 VName
0 TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int64 VName]
local_bucket_is)
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"subhisto_id" (TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
subhisto_id -> do
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
yparams [VName]
local_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
yp, VName
subhisto) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
yp)
[]
(VName -> SubExp
Var VName
subhisto)
(TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1 TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int64 VName]
local_bucket_is)
[Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
xparams (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
[Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let global_is :: [TPrimExp Int64 VName]
global_is =
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is
[TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment]
[TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
global_bucket_is
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
global_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
global_dest) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
global_dest [TPrimExp Int64 VName]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp) []
histKernelLocal ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[PatElem LetDecMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen ()
histKernelLocal :: TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TPrimExp Int64 VName)
groups_per_segment [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space TExp Int32
hist_S [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of local subhistograms per group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
num_subhistos_per_group
InitLocalHistograms
init_histograms <-
TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TPrimExp Int64 VName)
groups_per_segment [SegHistSlug]
slugs
[Char]
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_group_var
Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim DoAtomicUpdate GPUMem KernelEnv
_ -> Int
3
AtomicCAS DoAtomicUpdate GPUMem KernelEnv
_ -> Int
4
AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ -> Int
6
localMemoryCase ::
[PatElem LetDecMem] ->
Imp.TExp Int32 ->
SegSpace ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen (Imp.TExp Bool, CallKernelGen ())
localMemoryCase :: [PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
hist_N TExp Int32
_ [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims
TV Int64
hist_L <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"hist_L" PrimType
int32
HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_L) SizeClass
Imp.SizeLocalMemory
TV Any
max_group_size <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Any)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"max_group_size" PrimType
int32
HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size) SizeClass
Imp.SizeGroup
let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size
Count NumGroups SubExp
num_groups <-
(TV Int64 -> Count NumGroups SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize) (ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$
[Char]
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"num_groups" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size
let r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int64 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
TPrimExp Double VName
hist_m' <-
[Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_m_prime" (TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64
( TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
(TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_L TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
hist_el_size))
(TPrimExp Int64 VName
hist_N TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups'))
)
TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Int64 VName
hist_H
let hist_B :: TPrimExp Int64 VName
hist_B = Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
TPrimExp Int64 VName
hist_M0 <-
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M0" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 TPrimExp Double VName
hist_m') TPrimExp Int64 VName
hist_B
let q_small :: TPrimExp Int64 VName
q_small = TPrimExp Int64 VName
2
TPrimExp Int64 VName
hist_Nout <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nout" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
segment_dims
TPrimExp Int64 VName
hist_Nin <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nin" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes
TPrimExp Int64 VName
work_asymp_M_max <-
if Bool
segmented
then do
TExp Int32
hist_T_hist_min <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_T_hist_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
hist_Nin TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
hist_Nout) (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
hist_Nout
let r :: TExp Int32
r = TExp Int32
hist_T_hist_min TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
hist_B
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
hist_Nin TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
r TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H)
else
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName
hist_Nout TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_N)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` ( (TPrimExp Int64 VName
q_small TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TPrimExp Int64 VName
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
)
TV Int32
hist_M <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TPrimExp Int64 VName
hist_M0 TPrimExp Int64 VName
work_asymp_M_max
let hist_M_nonzero :: TExp Int32
hist_M_nonzero = TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
TPrimExp Int64 VName
hist_C <-
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
hist_B TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_nonzero
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_M0
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
work_asymp_M_max
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_C
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_B
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)
TPrimExp Int64 VName
local_mem_needed <-
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_mem_needed" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)
TExp Int32
hist_S <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_S" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
local_mem_needed) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_L
let max_S :: TExp Int32
max_S = case KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody of
Passage
MustBeSinglePass -> TExp Int32
1
Passage
MayBeMultiPass -> Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int32) -> Int -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs
Count NumGroups (TPrimExp Int64 VName)
groups_per_segment <-
if Bool
segmented
then
(TPrimExp Int64 VName -> Count NumGroups (TPrimExp Int64 VName))
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
-> ImpM
GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> Count NumGroups (TPrimExp Int64 VName)
forall u e. e -> Count u e
Count (ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
-> ImpM
GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName)))
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
-> ImpM
GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"groups_per_segment" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
hist_Nout
else Count NumGroups (TPrimExp Int64 VName)
-> ImpM
GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Count NumGroups (TPrimExp Int64 VName)
num_groups'
let pick_local :: TExp Bool
pick_local =
TPrimExp Int64 VName
hist_Nin TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TPrimExp Int64 VName
hist_H
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TPrimExp Int64 VName
local_mem_needed TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_L)
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int32
hist_S TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
max_S)
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
hist_C TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
hist_B
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0
run :: CallKernelGen ()
run = do
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using local memory" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_H
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_C
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_S
Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Groups per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal
TV Int32
hist_M
Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
[PatElem LParamMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
TExp Int32
hist_S
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
(TExp Bool, CallKernelGen ())
-> CallKernelGen (TExp Bool, CallKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Bool
pick_local, CallKernelGen ()
run)
compileSegHist ::
Pat LetDecMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[HistOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegHist :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist (Pat [PatElem LParamMem]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody = do
let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
num_red_res :: Int
num_red_res = [HistOp GPUMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
([PatElem LParamMem]
all_red_pes, [PatElem LParamMem]
map_pes) = Int
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem LParamMem]
pes
segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims
([Count Bytes (TPrimExp Int64 VName)]
op_hs, [Count Bytes (TPrimExp Int64 VName)]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
-> ([Count Bytes (TPrimExp Int64 VName)],
[Count Bytes (TPrimExp Int64 VName)], [SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
-> ([Count Bytes (TPrimExp Int64 VName)],
[Count Bytes (TPrimExp Int64 VName)], [SegHistSlug]))
-> ImpM
GPUMem
HostEnv
HostOp
[(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
-> ImpM
GPUMem
HostEnv
HostOp
([Count Bytes (TPrimExp Int64 VName)],
[Count Bytes (TPrimExp Int64 VName)], [SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp GPUMem
-> CallKernelGen
(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug))
-> [HistOp GPUMem]
-> ImpM
GPUMem
HostEnv
HostOp
[(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp GPUMem
-> CallKernelGen
(Count Bytes (TPrimExp Int64 VName),
Count Bytes (TPrimExp Int64 VName), SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp GPUMem]
ops
TPrimExp Int64 VName
h <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"h" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TPrimExp Int64 VName)]
op_hs
TPrimExp Int64 VName
seg_h <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"seg_h" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TPrimExp Int64 VName)]
op_seg_hs
TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TPrimExp Int64 VName
seg_h TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let hist_B :: TPrimExp Int64 VName
hist_B = Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
TPrimExp Int64 VName
hist_H <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (HistOp GPUMem -> TPrimExp Int64 VName)
-> [HistOp GPUMem] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> TPrimExp Int64 VName
histSize [HistOp GPUMem]
ops
let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicLocking {} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
AtomicUpdate GPUMem KernelEnv
_ -> Maybe a
forall a. Maybe a
Nothing
TPrimExp Int64 VName
hist_el_size <-
[Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(+) (TPrimExp Int64 VName
h TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
hist_H) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
(SegHistSlug -> Maybe (TPrimExp Int64 VName))
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe (TPrimExp Int64 VName)
forall {a}. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs
TPrimExp Int64 VName
hist_N <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_N" TPrimExp Int64 VName
segment_size
TExp Int32
hist_RF <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TPrimExp Int64 VName)
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TPrimExp Int64 VName
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
let hist_T :: TExp Int32
hist_T = TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_T
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Desired group size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_B
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_H
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_N
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_el_size
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_RF
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
h
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
seg_h
(TExp Bool
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
[PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
hist_N TExp Int32
hist_RF [SegHistSlug]
slugs KernelBody GPUMem
kbody
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
use_local_memory CallKernelGen ()
run_in_local_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody
let pes_per_op :: [[PatElem LParamMem]]
pes_per_op = [Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [PatElem LParamMem]
all_red_pes
[(SegHistSlug, [PatElem LParamMem], HistOp GPUMem)]
-> ((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElem LParamMem]]
-> [HistOp GPUMem]
-> [(SegHistSlug, [PatElem LParamMem], HistOp GPUMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElem LParamMem]]
pes_per_op [HistOp GPUMem]
ops) (((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
-> CallKernelGen ())
-> CallKernelGen ())
-> ((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElem LParamMem]
red_pes, HistOp GPUMem
op) -> do
let num_histos :: TV Int64
num_histos = SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug
subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
[(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [VName]
subhistos) (((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ())
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
subhisto) -> do
VName
pe_mem <-
MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
(ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
VName
subhisto_mem <-
MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
(ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
subhisto
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
num_histos TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[VName]
bucket_ids <-
Int
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)) ([Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bucket_id")
VName
subhistogram_id <- [Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"subhistogram_id"
[VName]
vector_ids <-
Int
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"vector_id")
VName
flat_gtid <- [Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"flat_gtid"
let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegVirt
segred_space :: SegSpace
segred_space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
segment_dims
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op))
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_histos)]
let segred_op :: SegBinOp GPUMem
segred_op = Commutativity
-> Lambda GPUMem -> [SubExp] -> Shape -> SegBinOp GPUMem
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Commutative (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op) Shape
forall a. Monoid a => a
mempty
Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
red_pes) SegLevel
lvl SegSpace
segred_space [SegBinOp GPUMem
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ->
[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ([(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> ((VName -> (SubExp, [TPrimExp Int64 VName]))
-> [(SubExp, [TPrimExp Int64 VName])])
-> (VName -> (SubExp, [TPrimExp Int64 VName]))
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName -> (SubExp, [TPrimExp Int64 VName]))
-> [VName] -> [(SubExp, [TPrimExp Int64 VName])])
-> [VName]
-> (VName -> (SubExp, [TPrimExp Int64 VName]))
-> [(SubExp, [TPrimExp Int64 VName])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TPrimExp Int64 VName]))
-> [VName] -> [(SubExp, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [TPrimExp Int64 VName])) -> InKernelGen ())
-> (VName -> (SubExp, [TPrimExp Int64 VName])) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
( VName -> SubExp
Var VName
subhisto,
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [TPrimExp Int64 VName])
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id]
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids
)
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"" Maybe Exp
forall a. Maybe a
Nothing
where
segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space