{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

-- | Our compilation strategy for 'SegHist' is based around avoiding
-- bin conflicts.  We do this by splitting the input into chunks, and
-- for each chunk computing a single subhistogram.  Then we combine
-- the subhistograms using an ordinary segmented reduction ('SegRed').
--
-- There are some branches around to efficiently handle the case where
-- we use only a single subhistogram (because it's large), so that we
-- respect the asymptotics, and do not copy the destination array.
--
-- We also use a heuristic strategy for computing subhistograms in
-- local memory when possible.  Given:
--
-- H: total size of histograms in bytes, including any lock arrays.
--
-- G: group size
--
-- T: number of bytes of local memory each thread can be given without
-- impacting occupancy (determined experimentally, e.g. 32).
--
-- LMAX: maximum amount of local memory per workgroup (hard limit).
--
-- We wish to compute:
--
-- COOP: cooperation level (number of threads per subhistogram)
--
-- LH: number of local memory subhistograms
--
-- We do this as:
--
-- COOP = ceil(H / T)
-- LH = ceil((G*T)/H)
-- if COOP <= G && H <= LMAX then
--   use local memory
-- else
--   use global memory
module Futhark.CodeGen.ImpGen.GPU.SegHist (compileSegHist) where

import Control.Monad.Except
import Data.List (foldl', genericLength, zip5)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

data SubhistosInfo = SubhistosInfo
  { SubhistosInfo -> VName
subhistosArray :: VName,
    SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
  }

data SegHistSlug = SegHistSlug
  { SegHistSlug -> HistOp GPUMem
slugOp :: HistOp GPUMem,
    SegHistSlug -> TV Int64
slugNumSubhistos :: TV Int64,
    SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo],
    SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate :: AtomicUpdate GPUMem KernelEnv
  }

histSpaceUsage ::
  HistOp GPUMem ->
  Imp.Count Imp.Bytes (Imp.TExp Int64)
histSpaceUsage :: HistOp GPUMem -> Count Bytes (TPrimExp Int64 VName)
histSpaceUsage HistOp GPUMem
op =
  [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TPrimExp Int64 VName)]
 -> Count Bytes (TPrimExp Int64 VName))
-> ([Type] -> [Count Bytes (TPrimExp Int64 VName)])
-> [Type]
-> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TPrimExp Int64 VName)
typeSize (Type -> Count Bytes (TPrimExp Int64 VName))
-> (Type -> Type) -> Type -> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op))) ([Type] -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
    Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda GPUMem -> [Type]) -> Lambda GPUMem -> [Type]
forall a b. (a -> b) -> a -> b
$
      HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op

histSize :: HistOp GPUMem -> Imp.TExp Int64
histSize :: HistOp GPUMem -> TPrimExp Int64 VName
histSize = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> (HistOp GPUMem -> [TPrimExp Int64 VName])
-> HistOp GPUMem
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (HistOp GPUMem -> [SubExp])
-> HistOp GPUMem
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp])
-> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape

histRank :: HistOp GPUMem -> Int
histRank :: HistOp GPUMem -> Int
histRank = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape

-- | Figure out how much memory is needed per histogram, both
-- segmented and unsegmented, and compute some other auxiliary
-- information.
computeHistoUsage ::
  SegSpace ->
  HistOp GPUMem ->
  CallKernelGen
    ( Imp.Count Imp.Bytes (Imp.TExp Int64),
      Imp.Count Imp.Bytes (Imp.TExp Int64),
      SegHistSlug
    )
computeHistoUsage :: SegSpace
-> HistOp GPUMem
-> CallKernelGen
     (Count Bytes (TPrimExp Int64 VName),
      Count Bytes (TPrimExp Int64 VName), SegHistSlug)
computeHistoUsage SegSpace
space HistOp GPUMem
op = do
  let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims

  -- Create names for the intermediate array memory blocks,
  -- memory block sizes, arrays, and number of subhistograms.
  TV Int64
num_subhistos <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"num_subhistos" PrimType
int32
  [SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op)) (((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
 -> ImpM GPUMem HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
    Type
dest_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dest
    MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest

    VName
subhistos_mem <-
      [Char] -> Space -> ImpM GPUMem HostEnv HostOp VName
forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem (VName -> [Char]
baseString VName
dest [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos_mem") ([Char] -> Space
Space [Char]
"device")

    let subhistos_shape :: Shape
subhistos_shape =
          [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_subhistos])
            Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Int -> Shape -> Shape
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
    VName
subhistos <-
      [Char]
-> PrimType
-> Shape
-> VName
-> IxFun
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
[Char]
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray
        (VName -> [Char]
baseString VName
dest [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos")
        (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
        Shape
subhistos_shape
        VName
subhistos_mem
        (IxFun -> ImpM GPUMem HostEnv HostOp VName)
-> IxFun -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota
        ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64
        ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape

    SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$
      VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
        let unitHistoCase :: CallKernelGen ()
unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  [Char] -> Space
Space [Char]
"device"

            multiHistoCase :: CallKernelGen ()
multiHistoCase = do
              let num_elems :: TPrimExp Int64 VName
num_elems = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
                  subhistos_mem_size :: Count Bytes (TPrimExp Int64 VName)
subhistos_mem_size =
                    TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                      Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 VName
num_elems Count Elements (TPrimExp Int64 VName)
-> PrimType -> Count Bytes (TPrimExp Int64 VName)
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)

              VName
-> Count Bytes (TPrimExp Int64 VName) -> Space -> CallKernelGen ()
forall rep r op.
VName
-> Count Bytes (TPrimExp Int64 VName) -> Space -> ImpM rep r op ()
sAlloc_ VName
subhistos_mem Count Bytes (TPrimExp Int64 VName)
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
              VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
              Type
subhistos_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
subhistos
              let slice :: Slice (TPrimExp Int64 VName)
slice =
                    [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                      ((VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> ((VName, SubExp) -> TPrimExp Int64 VName)
-> (VName, SubExp)
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
                        [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0]
              VName -> Slice (TPrimExp Int64 VName) -> SubExp -> CallKernelGen ()
forall rep r op.
VName -> Slice (TPrimExp Int64 VName) -> SubExp -> ImpM rep r op ()
sUpdate VName
subhistos Slice (TPrimExp Int64 VName)
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest

        TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

  let h :: Count Bytes (TPrimExp Int64 VName)
h = HistOp GPUMem -> Count Bytes (TPrimExp Int64 VName)
histSpaceUsage HistOp GPUMem
op
      segmented_h :: Count Bytes (TPrimExp Int64 VName)
segmented_h = Count Bytes (TPrimExp Int64 VName)
h Count Bytes (TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName)
forall a. Num a => a -> a -> a
* [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes (TPrimExp Int64 VName))
-> [SubExp] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) ([SubExp] -> [Count Bytes (TPrimExp Int64 VName)])
-> [SubExp] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv

  (Count Bytes (TPrimExp Int64 VName),
 Count Bytes (TPrimExp Int64 VName), SegHistSlug)
-> CallKernelGen
     (Count Bytes (TPrimExp Int64 VName),
      Count Bytes (TPrimExp Int64 VName), SegHistSlug)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Count Bytes (TPrimExp Int64 VName)
h,
      Count Bytes (TPrimExp Int64 VName)
segmented_h,
      HistOp GPUMem
-> TV Int64
-> [SubhistosInfo]
-> AtomicUpdate GPUMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate GPUMem KernelEnv -> SegHistSlug)
-> AtomicUpdate GPUMem KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
        AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv)
-> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$
          HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
    )

prepareAtomicUpdateGlobal ::
  Maybe Locking ->
  [VName] ->
  SegHistSlug ->
  CallKernelGen
    ( Maybe Locking,
      [Imp.TExp Int64] -> InKernelGen ()
    )
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
  -- We need a separate lock array if the operators are not all of a
  -- particularly simple form that permits pure atomic operations.
  case (Maybe Locking
l, SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
    (Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
    (Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
    (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
    (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
      -- The number of locks used here is too low, but since we are
      -- currently forced to inline a huge list, I'm keeping it down
      -- for now.  Some quick experiments suggested that it has little
      -- impact anyway (maybe the locking case is just too slow).
      --
      -- A fun solution would also be to use a simple hashing
      -- algorithm to ensure good distribution of locks.
      let num_locks :: Int
num_locks = Int
100151
          dims :: [TPrimExp Int64 VName]
dims =
            (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
              Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
                [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug)]
                [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))

      VName
locks <-
        [Char]
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
[Char] -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
"hist_locks" ([Char] -> Space
Space [Char]
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
          Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
      let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
      (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)

-- | Some kernel bodies are not safe (or efficient) to execute
-- multiple times.
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage
-> (Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
Ord)

bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody
  | Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases GPUMem) -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (AliasTable -> KernelBody GPUMem -> KernelBody (Aliases GPUMem)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
forall a. Monoid a => a
mempty KernelBody GPUMem
kbody) =
      Passage
MayBeMultiPass
  | Bool
otherwise =
      Passage
MustBeSinglePass

prepareIntermediateArraysGlobal ::
  Passage ->
  Imp.TExp Int32 ->
  Imp.TExp Int64 ->
  [SegHistSlug] ->
  CallKernelGen
    ( Imp.TExp Int32,
      [[Imp.TExp Int64] -> InKernelGen ()]
    )
prepareIntermediateArraysGlobal :: Passage
-> TExp Int32
-> TPrimExp Int64 VName
-> [SegHistSlug]
-> CallKernelGen
     (TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage TExp Int32
hist_T TPrimExp Int64 VName
hist_N [SegHistSlug]
slugs = do
  -- The paper formulae assume there is only one histogram, but in our
  -- implementation there can be multiple that have been horisontally
  -- fused.  We do a bit of trickery with summings and averages to
  -- pretend there is really only one.  For the case of a single
  -- histogram, the actual calculations should be the same as in the
  -- paper.

  -- The sum of all Hs.
  TPrimExp Int64 VName
hist_H <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TPrimExp Int64 VName)
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TPrimExp Int64 VName
histSize (HistOp GPUMem -> TPrimExp Int64 VName)
-> (SegHistSlug -> HistOp GPUMem)
-> SegHistSlug
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs

  TPrimExp Double VName
hist_RF <-
    [Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" (TPrimExp Double VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
      [TPrimExp Double VName] -> TPrimExp Double VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TPrimExp Double VName)
-> [SegHistSlug] -> [TPrimExp Double VName]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TPrimExp Int64 VName -> TPrimExp Double VName)
-> (SegHistSlug -> TPrimExp Int64 VName)
-> SegHistSlug
-> TPrimExp Double VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
        TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ [SegHistSlug] -> TPrimExp Double VName
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs

  TExp Int32
hist_el_size <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int32) -> [SegHistSlug] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> TExp Int32
slugElAvgSize [SegHistSlug]
slugs

  TPrimExp Double VName
hist_C_max <-
    [Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C_max" (TPrimExp Double VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TPrimExp Double VName -> TPrimExp Double VName)
-> TPrimExp Double VName -> TPrimExp Double VName
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Int64 VName
hist_H TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_k_ct_min

  TExp Int32
hist_M_min <-
    [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
      TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
          TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TPrimExp Double VName -> TPrimExp Int64 VName)
-> TPrimExp Double VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
            TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_C_max

  -- Querying L2 cache size is not reliable.  Instead we provide a
  -- tunable knob with a hopefully sane default.
  let hist_L2_def :: Int64
hist_L2_def = Int64
4 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024
  TV Any
hist_L2 <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Any)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"L2_size" PrimType
int32
  Maybe Name
entry <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  -- Equivalent to F_L2*L2 in paper.
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp
    (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize
      (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)
      (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ [Char] -> Name
nameFromString (VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)))
    (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$ Name -> Int64 -> SizeClass
Imp.SizeBespoke ([Char] -> Name
nameFromString [Char]
"L2_for_histogram") Int64
hist_L2_def

  let hist_L2_ln_sz :: TPrimExp Double VName
hist_L2_ln_sz = TPrimExp Double VName
16 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
4 -- L2 cache line size approximation
  TPrimExp Double VName
hist_RACE_exp <-
    [Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RACE_exp" (TPrimExp Double VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMax64 TPrimExp Double VName
1 (TPrimExp Double VName -> TPrimExp Double VName)
-> TPrimExp Double VName -> TPrimExp Double VName
forall a b. (a -> b) -> a -> b
$
        (TPrimExp Double VName
hist_k_RF TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
hist_RF)
          TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ (TPrimExp Double VName
hist_L2_ln_sz TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_el_size)

  TV Int32
hist_S <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"hist_S" PrimType
int32

  -- For sparse histograms (H exceeds N) we only want a single chunk.
  TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
    (TPrimExp Int64 VName
hist_N TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
hist_H)
    (TV Int32
hist_S TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32))
    (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ TV Int32
hist_S
      TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- case Passage
passage of
        Passage
MayBeMultiPass ->
          TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
            (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_min TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_el_size)
              TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TPrimExp Double VName
hist_F_L2 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Any VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TV Any -> TPrimExp Any VName
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
hist_RACE_exp)
        Passage
MustBeSinglePass ->
          TExp Int32
1

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Double VName
hist_RACE_exp
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S

  [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms <-
    (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
forall a b. (a, b) -> b
snd
      ((Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
 -> [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ImpM
     GPUMem HostEnv HostOp [[TPrimExp Int64 VName] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
 -> SegHistSlug
 -> CallKernelGen
      (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (TPrimExp Any VName
-> TExp Int32
-> TExp Int32
-> TPrimExp Double VName
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp (TV Any -> TPrimExp Any VName
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TExp Int32
hist_M_min (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S) TPrimExp Double VName
hist_RACE_exp)
        Maybe Locking
forall a. Maybe a
Nothing
        [SegHistSlug]
slugs

  (TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> CallKernelGen
     (TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S, [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms)
  where
    hist_k_ct_min :: TPrimExp Double VName
hist_k_ct_min = TPrimExp Double VName
2 -- Chosen experimentally
    hist_k_RF :: TPrimExp Double VName
hist_k_RF = TPrimExp Double VName
0.75 -- Chosen experimentally
    hist_F_L2 :: TPrimExp Double VName
hist_F_L2 = TPrimExp Double VName
0.4 -- Chosen experimentally
    r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
    t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElAvgSize :: SegHistSlug -> TExp Int32
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
      case AtomicUpdate GPUMem KernelEnv
do_op of
        AtomicLocking {} ->
          SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)))
        AtomicUpdate GPUMem KernelEnv
_ ->
          SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op))

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElSize :: SegHistSlug -> TExp Int32
slugElSize (SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
      case AtomicUpdate GPUMem KernelEnv
do_op of
        AtomicLocking {} ->
          TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
            Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
              [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TPrimExp Int64 VName)]
 -> Count Bytes (TPrimExp Int64 VName))
-> [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                (Type -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TPrimExp Int64 VName)
typeSize (Type -> Count Bytes (TPrimExp Int64 VName))
-> (Type -> Type) -> Type -> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TPrimExp Int64 VName)])
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$
                  PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
        AtomicUpdate GPUMem KernelEnv
_ ->
          TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
            Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
              [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TPrimExp Int64 VName)]
 -> Count Bytes (TPrimExp Int64 VName))
-> [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                (Type -> Count Bytes (TPrimExp Int64 VName))
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TPrimExp Int64 VName)
typeSize (Type -> Count Bytes (TPrimExp Int64 VName))
-> (Type -> Type) -> Type -> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TPrimExp Int64 VName)])
-> [Type] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$
                  Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)

    onOp :: TPrimExp Any VName
-> TExp Int32
-> TExp Int32
-> TPrimExp Double VName
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp TPrimExp Any VName
hist_L2 TExp Int32
hist_M_min TExp Int32
hist_S TPrimExp Double VName
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
      let SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op = SegHistSlug
slug
          hist_H :: TPrimExp Int64 VName
hist_H = HistOp GPUMem -> TPrimExp Int64 VName
histSize HistOp GPUMem
op

      TPrimExp Int64 VName
hist_H_chk <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_H_chk

      TPrimExp Double VName
hist_k_max <-
        [Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_k_max" (TPrimExp Double VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
          TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64
            (TPrimExp Double VName
hist_F_L2 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* (TPrimExp Any VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Any VName
hist_L2 TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug)) TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Num a => a -> a -> a
* TPrimExp Double VName
hist_RACE_exp)
            (TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Int64 VName
hist_N)
            TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T

      TPrimExp Int64 VName
hist_u <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_u" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
        case AtomicUpdate GPUMem KernelEnv
do_op of
          AtomicPrim {} -> TPrimExp Int64 VName
2
          AtomicUpdate GPUMem KernelEnv
_ -> TPrimExp Int64 VName
1

      TPrimExp Double VName
hist_C <-
        [Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" (TPrimExp Double VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
          TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TPrimExp Double VName -> TPrimExp Double VName)
-> TPrimExp Double VName -> TPrimExp Double VName
forall a b. (a -> b) -> a -> b
$
            TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TPrimExp Int64 VName
hist_u TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk) TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_k_max

      -- Number of subhistograms per result histogram.
      TExp Int32
hist_M <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
        case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
          AtomicPrim {} -> TExp Int32
1
          AtomicUpdate GPUMem KernelEnv
_ -> TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
hist_M_min (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TPrimExp Double VName -> TPrimExp Int64 VName)
-> TPrimExp Double VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Double VName
hist_C

      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Double VName
hist_k_max
      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_M
      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Double VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Double VName
hist_C

      -- num_subhistos is the variable we use to communicate back.
      TV Int64
num_subhistos TV Int64 -> TPrimExp Int64 VName -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M

      -- Initialise sub-histograms.
      --
      -- If hist_M is 1, then we just reuse the original
      -- destination.  The idea is to avoid a copy if we are writing a
      -- small number of values into a very large prior histogram.
      [VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
 -> ImpM GPUMem HostEnv HostOp [VName])
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
        MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest

        VName
sub_mem <-
          (MemLoc -> VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLoc -> VName
memLocName (ImpM GPUMem HostEnv HostOp MemLoc
 -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
            ArrayEntry -> MemLoc
entryArrayLoc
              (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)

        let unitHistoCase :: CallKernelGen ()
unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  [Char] -> Space
Space [Char]
"device"

            multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info

        TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

        VName -> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      (Maybe Locking
l', [TPrimExp Int64 VName] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug

      (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> CallKernelGen
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l', [TPrimExp Int64 VName] -> InKernelGen ()
do_op')

histKernelGlobalPass ::
  [PatElem LetDecMem] ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  [[Imp.TExp Int64] -> InKernelGen ()] ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  CallKernelGen ()
histKernelGlobalPass :: [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms TExp Int32
hist_S TExp Int32
chk_i = do
  let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      space_sizes_64 :: [TPrimExp Int64 VName]
space_sizes_64 = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) [SubExp]
space_sizes
      total_w_64 :: TPrimExp Int64 VName
total_w_64 = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
space_sizes_64

  [TPrimExp Int64 VName]
hist_H_chks <- [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName
    -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> ImpM GPUMem HostEnv HostOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> TPrimExp Int64 VName)
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TPrimExp Int64 VName
histSize (HistOp GPUMem -> TPrimExp Int64 VName)
-> (SegHistSlug -> HistOp GPUMem)
-> SegHistSlug
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs) ((TPrimExp Int64 VName
  -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
 -> ImpM GPUMem HostEnv HostOp [TPrimExp Int64 VName])
-> (TPrimExp Int64 VName
    -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> ImpM GPUMem HostEnv HostOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
w ->
    [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

  [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_global" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

    -- Compute subhistogram index for each thread, per histogram.
    [TExp Int32]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
 -> ImpM GPUMem KernelEnv KernelOp [TExp Int32])
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
      [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"subhisto_ind" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
        KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
          TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` ( KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
                     TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug))
                 )

    -- Loop over flat offsets into the input and output.  The
    -- calculation is done with 64-bit integers to avoid overflow,
    -- but the final unflattened segment indexes are 32 bit.
    let gtid :: TPrimExp Int64 VName
gtid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
        num_threads :: TPrimExp Int64 VName
num_threads = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
    TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TPrimExp Int64 VName
gtid TPrimExp Int64 VName
num_threads TPrimExp Int64 VName
total_w_64 ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
offset -> do
      -- Construct segment indices.
      [(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
space_is [TPrimExp Int64 VName]
space_sizes_64) TPrimExp Int64 VName
offset

      -- We execute the bucket function once and update each histogram serially.
      -- We apply the bucket function if j=offset+ltid is less than
      -- num_elements.  This also involves writing to the mapout
      -- arrays.
      let input_in_bounds :: TExp Bool
input_in_bounds = TPrimExp Int64 VName
offset TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
total_w_64

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody

          [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) (((PatElem LParamMem, KernelResult) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
              VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                (((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
                (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
                []

          let red_res_split :: [([SubExp], [SubExp])]
red_res_split =
                [HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) ([SubExp] -> [([SubExp], [SubExp])])
-> [SubExp] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> a -> b
$
                  (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res

          [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
  ([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)]
-> ((HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
     ([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> [([SubExp], [SubExp])]
-> [TExp Int32]
-> [TPrimExp Int64 VName]
-> [(HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
     ([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms [([SubExp], [SubExp])]
red_res_split [TExp Int32]
subhisto_inds [TPrimExp Int64 VName]
hist_H_chks) (((HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
   ([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp GPUMem, [TPrimExp Int64 VName] -> InKernelGen (),
     ([SubExp], [SubExp]), TExp Int32, TPrimExp Int64 VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
                 [TPrimExp Int64 VName] -> InKernelGen ()
do_op,
                 ([SubExp]
bucket, [SubExp]
vs'),
                 TExp Int32
subhisto_ind,
                 TPrimExp Int64 VName
hist_H_chk
                 ) -> do
                  let chk_beg :: TPrimExp Int64 VName
chk_beg = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk
                      bucket' :: [TPrimExp Int64 VName]
bucket' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
bucket
                      dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
                      flat_bucket :: TPrimExp Int64 VName
flat_bucket = [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dest_shape' [TPrimExp Int64 VName]
bucket'
                      bucket_in_bounds :: TExp Bool
bucket_in_bounds =
                        TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
flat_bucket
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
flat_bucket TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
hist_H_chk)
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
bucket')) [TPrimExp Int64 VName]
dest_shape'
                      vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

                  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                    let bucket_is :: [TPrimExp Int64 VName]
bucket_is =
                          (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is)
                            [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_ind]
                            [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dest_shape' TPrimExp Int64 VName
flat_bucket
                    [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
                    Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
                      [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
                        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TPrimExp Int64 VName]
is
                      [TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bucket_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)

histKernelGlobal ::
  [PatElem LetDecMem] ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen ()
histKernelGlobal :: [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
  let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_threads :: TExp Int32
num_threads = TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing

  (TExp Int32
hist_S, [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms) <-
    Passage
-> TExp Int32
-> TPrimExp Int64 VName
-> [SegHistSlug]
-> CallKernelGen
     (TExp Int32, [[TPrimExp Int64 VName] -> InKernelGen ()])
prepareIntermediateArraysGlobal
      (KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody)
      TExp Int32
num_threads
      (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes)
      [SegHistSlug]
slugs

  [Char]
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
    [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass
      [PatElem LParamMem]
map_pes
      Count NumGroups SubExp
num_groups
      Count GroupSize SubExp
group_size
      SegSpace
space
      [SegHistSlug]
slugs
      KernelBody GPUMem
kbody
      [[TPrimExp Int64 VName] -> InKernelGen ()]
histograms
      TExp Int32
hist_S
      TExp Int32
chk_i

type InitLocalHistograms =
  [ ( [VName],
      SubExp ->
      InKernelGen
        ( [VName],
          [Imp.TExp Int64] -> InKernelGen ()
        )
    )
  ]

prepareIntermediateArraysLocal ::
  TV Int32 ->
  Count NumGroups (Imp.TExp Int64) ->
  [SegHistSlug] ->
  CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group Count NumGroups (TPrimExp Int64 VName)
groups_per_segment =
  (SegHistSlug
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      ([VName],
       SubExp
       -> ImpM
            GPUMem
            KernelEnv
            KernelOp
            ([VName], [TPrimExp Int64 VName] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegHistSlug
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
onOp
  where
    onOp :: SegHistSlug
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
onOp (SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op) = do
      TV Int64
num_subhistos TV Int64 -> TPrimExp Int64 VName -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment)

      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of subhistograms in global memory per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
          Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
            TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
              TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos

      SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op <-
        case AtomicUpdate GPUMem KernelEnv
do_op of
          AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
 -> SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
          AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
 -> SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
          AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
            let lock_shape :: Shape
lock_shape =
                  [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$
                    TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group
                      SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)
                      [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
hist_H_chk]

            let dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape

            VName
locks <- [Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"

            [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [TPrimExp Int64 VName]
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TPrimExp Int64 VName]
dims (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
                VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

            DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DoAtomicUpdate GPUMem KernelEnv
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate GPUMem KernelEnv
f (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> Locking -> DoAtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> a
id

      -- Initialise local-memory sub-histograms.  These are
      -- represented as two-dimensional arrays.
      let init_local_subhistos :: SubExp
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
            [VName]
local_subhistos <-
              [Type]
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp GPUMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp GPUMem
op) ((Type -> ImpM GPUMem KernelEnv KernelOp VName)
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
                let sub_local_shape :: Shape
sub_local_shape =
                      [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group]
                        Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape -> Int -> Shape -> Shape
forall d. ShapeBase d -> Int -> ShapeBase d -> ShapeBase d
setOuterDims (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (HistOp GPUMem -> Int
histRank HistOp GPUMem
op) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
hist_H_chk])
                [Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
                  [Char]
"subhistogram_local"
                  (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
                  Shape
sub_local_shape
                  ([Char] -> Space
Space [Char]
"local")

            DoAtomicUpdate GPUMem KernelEnv
do_op' <- SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op SubExp
hist_H_chk

            ([VName], [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([VName], [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
local_subhistos, DoAtomicUpdate GPUMem KernelEnv
do_op' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)

      -- Initialise global-memory sub-histograms.
      [VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
 -> ImpM GPUMem HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
        SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
        VName -> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      ([VName],
 SubExp
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      ([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
glob_subhistos, SubExp
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos)

histKernelLocalPass ::
  TV Int32 ->
  Count NumGroups (Imp.TExp Int64) ->
  [PatElem LetDecMem] ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  InitLocalHistograms ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  CallKernelGen ()
histKernelLocalPass :: TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
  TV Int32
num_subhistos_per_group_var
  Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
  [PatElem LParamMem]
map_pes
  Count NumGroups SubExp
num_groups
  Count GroupSize SubExp
group_size
  SegSpace
space
  [SegHistSlug]
slugs
  KernelBody GPUMem
kbody
  InitLocalHistograms
init_histograms
  TExp Int32
hist_S
  TExp Int32
chk_i = do
    let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is
        segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
        (VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
        segment_size' :: TPrimExp Int64 VName
segment_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
segment_size

    TPrimExp Int64 VName
num_segments <-
      [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_segments" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
        [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
          (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
segment_dims

    [TV Int64]
hist_H_chks <- [HistOp GPUMem]
-> (HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) ((HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
 -> ImpM GPUMem HostEnv HostOp [TV Int64])
-> (HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall a b. (a -> b) -> a -> b
$ \HistOp GPUMem
op ->
      [Char]
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_H_chk" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> TPrimExp Int64 VName
histSize HistOp GPUMem
op TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

    [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
histo_sizes <- [(SegHistSlug, TV Int64)]
-> ((SegHistSlug, TV Int64)
    -> ImpM
         GPUMem
         HostEnv
         HostOp
         ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SegHistSlug] -> [TV Int64] -> [(SegHistSlug, TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [TV Int64]
hist_H_chks) (((SegHistSlug, TV Int64)
  -> ImpM
       GPUMem
       HostEnv
       HostOp
       ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)])
-> ((SegHistSlug, TV Int64)
    -> ImpM
         GPUMem
         HostEnv
         HostOp
         ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, TV Int64
hist_H_chk) -> do
      let histo_dims :: [TPrimExp Int64 VName]
histo_dims =
            TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk
              TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
      TPrimExp Int64 VName
histo_size <-
        [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
histo_dims
      let group_hists_size :: TPrimExp Int64 VName
group_hists_size =
            TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_subhistos_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
histo_size
      TExp Int32
init_per_thread <-
        [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"init_per_thread" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_hists_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
      ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TPrimExp Int64 VName]
histo_dims, TPrimExp Int64 VName
histo_size, TExp Int32
init_per_thread)

    let attrs :: KernelAttrs
attrs = (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) {kAttrCheckLocalMemory :: Bool
kAttrCheckLocalMemory = Bool
False}
    [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_local" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_segments) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
        KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

        TExp Int32
flat_segment_id <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"flat_segment_id" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment)
        TExp Int32
gid_in_segment <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"gid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment)
        -- This pgtid is kind of a "virtualised physical" gtid - not the
        -- same thing as the gtid used for the SegHist itself.
        TExp Int32
pgtid_in_segment <-
          [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"pgtid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        TExp Int32
threads_per_segment <-
          [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"threads_per_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
              Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants

        -- Set segment indices.
        (VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
segment_is ([TPrimExp Int64 VName] -> InKernelGen ())
-> [TPrimExp Int64 VName] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
segment_dims) (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
            TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id

        [([(VName, VName)], TV Int64,
  [TPrimExp Int64 VName] -> InKernelGen ())]
histograms <- [(([VName],
   SubExp
   -> ImpM
        GPUMem
        KernelEnv
        KernelOp
        ([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
  TV Int64)]
-> ((([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
     TV Int64)
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         ([(VName, VName)], TV Int64,
          [TPrimExp Int64 VName] -> InKernelGen ()))
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     [([(VName, VName)], TV Int64,
       [TPrimExp Int64 VName] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [TV Int64]
-> [(([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
     TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [TV Int64]
hist_H_chks) (((([VName],
    SubExp
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         ([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
   TV Int64)
  -> ImpM
       GPUMem
       KernelEnv
       KernelOp
       ([(VName, VName)], TV Int64,
        [TPrimExp Int64 VName] -> InKernelGen ()))
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      [([(VName, VName)], TV Int64,
        [TPrimExp Int64 VName] -> InKernelGen ())])
-> ((([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TPrimExp Int64 VName] -> InKernelGen ())),
     TV Int64)
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         ([(VName, VName)], TV Int64,
          [TPrimExp Int64 VName] -> InKernelGen ()))
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     [([(VName, VName)], TV Int64,
       [TPrimExp Int64 VName] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
          \(([VName]
glob_subhistos, SubExp
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos), TV Int64
hist_H_chk) -> do
            ([VName]
local_subhistos, [TPrimExp Int64 VName] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([VName], [TPrimExp Int64 VName] -> InKernelGen ())
init_local_subhistos (SubExp
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      ([VName], [TPrimExp Int64 VName] -> InKernelGen ()))
-> SubExp
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([VName], [TPrimExp Int64 VName] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_H_chk
            ([(VName, VName)], TV Int64,
 [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([(VName, VName)], TV Int64,
      [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, TV Int64
hist_H_chk, [TPrimExp Int64 VName] -> InKernelGen ()
do_op)

        -- Find index of local subhistograms updated by this thread.  We
        -- try to ensure, as much as possible, that threads in the same
        -- warp use different subhistograms, to avoid conflicts.
        TExp Int32
thread_local_subhisto_i <-
          [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"thread_local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_subhistos_per_group

        let onSlugs :: (SegHistSlug
 -> [(VName, VName)]
 -> TPrimExp Int64 VName
 -> [TPrimExp Int64 VName]
 -> TPrimExp Int64 VName
 -> TExp Int32
 -> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ()
f =
              [(SegHistSlug,
  ([(VName, VName)], TV Int64,
   [TPrimExp Int64 VName] -> InKernelGen ()),
  ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))]
-> ((SegHistSlug,
     ([(VName, VName)], TV Int64,
      [TPrimExp Int64 VName] -> InKernelGen ()),
     ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], TV Int64,
     [TPrimExp Int64 VName] -> InKernelGen ())]
-> [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
-> [(SegHistSlug,
     ([(VName, VName)], TV Int64,
      [TPrimExp Int64 VName] -> InKernelGen ()),
     ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [([(VName, VName)], TV Int64,
  [TPrimExp Int64 VName] -> InKernelGen ())]
histograms [([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32)]
histo_sizes) (((SegHistSlug,
   ([(VName, VName)], TV Int64,
    [TPrimExp Int64 VName] -> InKernelGen ()),
   ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegHistSlug,
     ([(VName, VName)], TV Int64,
      [TPrimExp Int64 VName] -> InKernelGen ()),
     ([TPrimExp Int64 VName], TPrimExp Int64 VName, TExp Int32))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                \(SegHistSlug
slug, ([(VName, VName)]
dests, TV Int64
hist_H_chk, [TPrimExp Int64 VName] -> InKernelGen ()
_), ([TPrimExp Int64 VName]
histo_dims, TPrimExp Int64 VName
histo_size, TExp Int32
init_per_thread)) ->
                  SegHistSlug
-> [(VName, VName)]
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TExp Int32
-> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk) [TPrimExp Int64 VName]
histo_dims TPrimExp Int64 VName
histo_size TExp Int32
init_per_thread

        let onAllHistograms :: (VName
 -> VName
 -> HistOp GPUMem
 -> SubExp
 -> TExp Int32
 -> TExp Int32
 -> [TPrimExp Int64 VName]
 -> [TPrimExp Int64 VName]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ()
f =
              (SegHistSlug
 -> [(VName, VName)]
 -> TPrimExp Int64 VName
 -> [TPrimExp Int64 VName]
 -> TPrimExp Int64 VName
 -> TExp Int32
 -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)]
  -> TPrimExp Int64 VName
  -> [TPrimExp Int64 VName]
  -> TPrimExp Int64 VName
  -> TExp Int32
  -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)]
    -> TPrimExp Int64 VName
    -> [TPrimExp Int64 VName]
    -> TPrimExp Int64 VName
    -> TExp Int32
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TPrimExp Int64 VName
hist_H_chk [TPrimExp Int64 VName]
histo_dims TPrimExp Int64 VName
histo_size TExp Int32
init_per_thread -> do
                let group_hists_size :: TExp Int32
group_hists_size = TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size

                [((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  \((VName
dest_global, VName
dest_local), SubExp
ne) ->
                    [Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
init_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
                      TExp Int32
j <-
                        [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                          TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                            TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
                      TExp Int32
j_offset <-
                        [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j_offset" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                          TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
j

                      TExp Int32
local_subhisto_i <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size
                      let local_bucket_is :: [TPrimExp Int64 VName]
local_bucket_is = [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
histo_dims (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size
                          nested_hist_size :: [TPrimExp Int64 VName]
nested_hist_size =
                            (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug

                          global_bucket_is :: [TPrimExp Int64 VName]
global_bucket_is =
                            [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                              [TPrimExp Int64 VName]
nested_hist_size
                              ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
head [TPrimExp Int64 VName]
local_bucket_is TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk)
                              [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
tail [TPrimExp Int64 VName]
local_bucket_is
                      TExp Int32
global_subhisto_i <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"global_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
histo_size

                      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
group_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                        VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> InKernelGen ()
f
                          VName
dest_local
                          VName
dest_global
                          (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)
                          SubExp
ne
                          TExp Int32
local_subhisto_i
                          TExp Int32
global_subhisto_i
                          [TPrimExp Int64 VName]
local_bucket_is
                          [TPrimExp Int64 VName]
global_bucket_is

        [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"initialize histograms in local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (VName
 -> VName
 -> HistOp GPUMem
 -> SubExp
 -> TExp Int32
 -> TExp Int32
 -> [TPrimExp Int64 VName]
 -> [TPrimExp Int64 VName]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
  -> VName
  -> HistOp GPUMem
  -> SubExp
  -> TExp Int32
  -> TExp Int32
  -> [TPrimExp Int64 VName]
  -> [TPrimExp Int64 VName]
  -> InKernelGen ())
 -> InKernelGen ())
-> (VName
    -> VName
    -> HistOp GPUMem
    -> SubExp
    -> TExp Int32
    -> TExp Int32
    -> [TPrimExp Int64 VName]
    -> [TPrimExp Int64 VName]
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp GPUMem
op SubExp
ne TExp Int32
local_subhisto_i TExp Int32
global_subhisto_i [TPrimExp Int64 VName]
local_bucket_is [TPrimExp Int64 VName]
global_bucket_is ->
            [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              let global_is :: [TPrimExp Int64 VName]
global_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
0] [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
global_bucket_is
                  local_is :: [TPrimExp Int64 VName]
local_is = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_subhisto_i TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int64 VName]
local_bucket_is
              TExp Bool -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                (TExp Int32
global_subhisto_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
                (VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest_local [TPrimExp Int64 VName]
local_is (VName -> SubExp
Var VName
dest_global) [TPrimExp Int64 VName]
global_is)
                ( Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
                    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest_local ([TPrimExp Int64 VName]
local_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is) SubExp
ne []
                )

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

        TExp Int32
-> TExp Int32
-> TExp Int32
-> (TExp Int32 -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int32
pgtid_in_segment TExp Int32
threads_per_segment (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
segment_size') ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
ie -> do
          VName -> TPrimExp Int64 VName -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
i_in_segment (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ie

          -- We execute the bucket function once and update each histogram
          -- serially.  This also involves writing to the mapout arrays if
          -- this is the first chunk.

          Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            let ([SubExp]
red_res, [SubExp]
map_res) =
                  Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                    (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
                      KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody

            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chk_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
                  VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                    (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                    ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
space_is)
                    SubExp
se
                    []

            let red_res_split :: [([SubExp], [SubExp])]
red_res_split = [HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [SubExp]
red_res
            [(HistOp GPUMem,
  ([(VName, VName)], TV Int64,
   [TPrimExp Int64 VName] -> InKernelGen ()),
  ([SubExp], [SubExp]))]
-> ((HistOp GPUMem,
     ([(VName, VName)], TV Int64,
      [TPrimExp Int64 VName] -> InKernelGen ()),
     ([SubExp], [SubExp]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [([(VName, VName)], TV Int64,
     [TPrimExp Int64 VName] -> InKernelGen ())]
-> [([SubExp], [SubExp])]
-> [(HistOp GPUMem,
     ([(VName, VName)], TV Int64,
      [TPrimExp Int64 VName] -> InKernelGen ()),
     ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], TV Int64,
  [TPrimExp Int64 VName] -> InKernelGen ())]
histograms [([SubExp], [SubExp])]
red_res_split) (((HistOp GPUMem,
   ([(VName, VName)], TV Int64,
    [TPrimExp Int64 VName] -> InKernelGen ()),
   ([SubExp], [SubExp]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp GPUMem,
     ([(VName, VName)], TV Int64,
      [TPrimExp Int64 VName] -> InKernelGen ()),
     ([SubExp], [SubExp]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
                 ([(VName, VName)]
_, TV Int64
hist_H_chk, [TPrimExp Int64 VName] -> InKernelGen ()
do_op),
                 ([SubExp]
bucket, [SubExp]
vs')
                 ) -> do
                  let chk_beg :: TPrimExp Int64 VName
chk_beg = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk
                      bucket' :: [TPrimExp Int64 VName]
bucket' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
bucket
                      dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
                      flat_bucket :: TPrimExp Int64 VName
flat_bucket = [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dest_shape' [TPrimExp Int64 VName]
bucket'
                      bucket_in_bounds :: TExp Bool
bucket_in_bounds =
                        Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
bucket')) [TPrimExp Int64 VName]
dest_shape'
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
flat_bucket
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
flat_bucket TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int64 VName
chk_beg TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk)
                      bucket_is :: [TPrimExp Int64 VName]
bucket_is =
                        [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
thread_local_subhisto_i, TPrimExp Int64 VName
flat_bucket TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
chk_beg]
                      vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

                  [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                      [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
                      Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
                        [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
                          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
                        [TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bucket_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal

        [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Compact the multiple local memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (SegHistSlug
 -> [(VName, VName)]
 -> TPrimExp Int64 VName
 -> [TPrimExp Int64 VName]
 -> TPrimExp Int64 VName
 -> TExp Int32
 -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)]
  -> TPrimExp Int64 VName
  -> [TPrimExp Int64 VName]
  -> TPrimExp Int64 VName
  -> TExp Int32
  -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)]
    -> TPrimExp Int64 VName
    -> [TPrimExp Int64 VName]
    -> TPrimExp Int64 VName
    -> TExp Int32
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TPrimExp Int64 VName
hist_H_chk [TPrimExp Int64 VName]
histo_dims TPrimExp Int64 VName
_histo_size TExp Int32
bins_per_thread -> do
            TV Int64
trunc_H <-
              [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"trunc_H" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TPrimExp Int64 VName
hist_H_chk (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
                HistOp GPUMem -> TPrimExp Int64 VName
histSize (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
head [TPrimExp Int64 VName]
histo_dims
            let trunc_histo_dims :: [TPrimExp Int64 VName]
trunc_histo_dims =
                  TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
trunc_H
                    TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
            TExp Int32
trunc_histo_size <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
trunc_histo_dims

            [Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
bins_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
              TExp Int32
j <-
                [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                  TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                    TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
              TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                -- We are responsible for compacting the flat bin 'j', which
                -- we immediately unflatten.
                let local_bucket_is :: [TPrimExp Int64 VName]
local_bucket_is = [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
histo_dims (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
j
                    nested_hist_size :: [TPrimExp Int64 VName]
nested_hist_size =
                      (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
                    global_bucket_is :: [TPrimExp Int64 VName]
global_bucket_is =
                      [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                        [TPrimExp Int64 VName]
nested_hist_size
                        ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
head [TPrimExp Int64 VName]
local_bucket_is TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H_chk)
                        [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
tail [TPrimExp Int64 VName]
local_bucket_is
                [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
                let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
                    ([Param LParamMem]
xparams, [Param LParamMem]
yparams) =
                      Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$
                        Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$
                          HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$
                            SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug

                [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
local_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
subhisto) ->
                    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                      (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp)
                      []
                      (VName -> SubExp
Var VName
subhisto)
                      (TPrimExp Int64 VName
0 TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int64 VName]
local_bucket_is)

                [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  [Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"subhisto_id" (TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
subhisto_id -> do
                    [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
yparams [VName]
local_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
yp, VName
subhisto) ->
                      VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                        (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
yp)
                        []
                        (VName -> SubExp
Var VName
subhisto)
                        (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1 TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int64 VName]
local_bucket_is)
                    [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
xparams (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug

                [Char] -> InKernelGen () -> InKernelGen ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                  let global_is :: [TPrimExp Int64 VName]
global_is =
                        (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is
                          [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment]
                          [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
global_bucket_is
                  [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
global_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
global_dest) ->
                    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
global_dest [TPrimExp Int64 VName]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp) []

histKernelLocal ::
  TV Int32 ->
  Count NumGroups (Imp.TExp Int64) ->
  [PatElem LetDecMem] ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  Imp.TExp Int32 ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen ()
histKernelLocal :: TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TPrimExp Int64 VName)
groups_per_segment [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space TExp Int32
hist_S [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of local subhistograms per group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
        TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
num_subhistos_per_group

  InitLocalHistograms
init_histograms <-
    TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TPrimExp Int64 VName)
groups_per_segment [SegHistSlug]
slugs

  [Char]
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
    TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
      TV Int32
num_subhistos_per_group_var
      Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
      [PatElem LParamMem]
map_pes
      Count NumGroups SubExp
num_groups
      Count GroupSize SubExp
group_size
      SegSpace
space
      [SegHistSlug]
slugs
      KernelBody GPUMem
kbody
      InitLocalHistograms
init_histograms
      TExp Int32
hist_S
      TExp Int32
chk_i

-- | The maximum number of passes we are willing to accept for this
-- kind of atomic update.
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
  case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
    AtomicPrim DoAtomicUpdate GPUMem KernelEnv
_ -> Int
3
    AtomicCAS DoAtomicUpdate GPUMem KernelEnv
_ -> Int
4
    AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ -> Int
6

localMemoryCase ::
  [PatElem LetDecMem] ->
  Imp.TExp Int32 ->
  SegSpace ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int32 ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen (Imp.TExp Bool, CallKernelGen ())
localMemoryCase :: [PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
hist_N TExp Int32
_ [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
      segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims

  TV Int64
hist_L <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"hist_L" PrimType
int32
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_L) SizeClass
Imp.SizeLocalMemory

  TV Any
max_group_size <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Any)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"max_group_size" PrimType
int32
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size) SizeClass
Imp.SizeGroup
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size
  Count NumGroups SubExp
num_groups <-
    (TV Int64 -> Count NumGroups SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize) (ImpM GPUMem HostEnv HostOp (TV Int64)
 -> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$
      [Char]
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"num_groups" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
  let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size

  let r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int64 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
      t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped

  -- M approximation.
  TPrimExp Double VName
hist_m' <-
    [Char]
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_m_prime" (TPrimExp Double VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName))
-> TPrimExp Double VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Double VName)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64
        ( TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
            (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_L TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
hist_el_size))
            (TPrimExp Int64 VName
hist_N TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups'))
        )
        TPrimExp Double VName
-> TPrimExp Double VName -> TPrimExp Double VName
forall a. Fractional a => a -> a -> a
/ TPrimExp Int64 VName -> TPrimExp Double VName
forall {t} {v}. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Int64 VName
hist_H

  let hist_B :: TPrimExp Int64 VName
hist_B = Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

  -- M in the paper, but not adjusted for asymptotic efficiency.
  TPrimExp Int64 VName
hist_M0 <-
    [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M0" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Double VName -> TPrimExp Int64 VName
forall {t} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 TPrimExp Double VName
hist_m') TPrimExp Int64 VName
hist_B

  -- Minimal sequential chunking factor.
  let q_small :: TPrimExp Int64 VName
q_small = TPrimExp Int64 VName
2

  -- The number of segments/histograms produced..
  TPrimExp Int64 VName
hist_Nout <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nout" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
segment_dims

  TPrimExp Int64 VName
hist_Nin <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nin" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes

  -- Maximum M for work efficiency.
  TPrimExp Int64 VName
work_asymp_M_max <-
    if Bool
segmented
      then do
        TExp Int32
hist_T_hist_min <-
          [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_T_hist_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
              TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
hist_Nin TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
hist_Nout) (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T)
                TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
hist_Nout

        -- Number of groups, rounded up.
        let r :: TExp Int32
r = TExp Int32
hist_T_hist_min TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
hist_B

        [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
hist_Nin TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
r TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H)
      else
        [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
          (TPrimExp Int64 VName
hist_Nout TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_N)
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` ( (TPrimExp Int64 VName
q_small TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_H)
                       TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TPrimExp Int64 VName
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
                   )

  -- Number of subhistograms per result histogram.
  TV Int32
hist_M <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TPrimExp Int64 VName
hist_M0 TPrimExp Int64 VName
work_asymp_M_max

  -- hist_M may be zero (which we'll check for below), but we need it
  -- for some divisions first, so crudely make a nonzero form.
  let hist_M_nonzero :: TExp Int32
hist_M_nonzero = TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M

  -- "Cooperation factor" - the number of threads cooperatively
  -- working on the same (sub)histogram.
  TPrimExp Int64 VName
hist_C <-
    [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
hist_B TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_nonzero

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_M0
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
work_asymp_M_max
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_C
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_B
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)

  -- local_mem_needed is what we need to keep a single bucket in local
  -- memory - this is an absolute minimum.  We can fit anything else
  -- by doing multiple passes, although more than a few is
  -- (heuristically) not efficient.
  TPrimExp Int64 VName
local_mem_needed <-
    [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_mem_needed" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)
  TExp Int32
hist_S <-
    [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_S" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
        (TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
local_mem_needed) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_L
  let max_S :: TExp Int32
max_S = case KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody of
        Passage
MustBeSinglePass -> TExp Int32
1
        Passage
MayBeMultiPass -> Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int32) -> Int -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs

  Count NumGroups (TPrimExp Int64 VName)
groups_per_segment <-
    if Bool
segmented
      then
        (TPrimExp Int64 VName -> Count NumGroups (TPrimExp Int64 VName))
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
-> ImpM
     GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> Count NumGroups (TPrimExp Int64 VName)
forall u e. e -> Count u e
Count (ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
 -> ImpM
      GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName)))
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
-> ImpM
     GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
          [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"groups_per_segment" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
            Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
hist_Nout
      else Count NumGroups (TPrimExp Int64 VName)
-> ImpM
     GPUMem HostEnv HostOp (Count NumGroups (TPrimExp Int64 VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Count NumGroups (TPrimExp Int64 VName)
num_groups'

  -- We only use local memory if the number of updates per histogram
  -- at least matches the histogram size, as otherwise it is not
  -- asymptotically efficient.  This mostly matters for the segmented
  -- case.
  let pick_local :: TExp Bool
pick_local =
        TPrimExp Int64 VName
hist_Nin TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TPrimExp Int64 VName
hist_H
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TPrimExp Int64 VName
local_mem_needed TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
hist_L)
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int32
hist_S TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
max_S)
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
hist_C TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
hist_B
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0

      run :: CallKernelGen ()
run = do
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using local memory" Maybe Exp
forall a. Maybe a
Nothing
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_H
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_C
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_S
        Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Groups per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
              Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
                TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
                  Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
        TV Int32
-> Count NumGroups (TPrimExp Int64 VName)
-> [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal
          TV Int32
hist_M
          Count NumGroups (TPrimExp Int64 VName)
groups_per_segment
          [PatElem LParamMem]
map_pes
          Count NumGroups SubExp
num_groups
          Count GroupSize SubExp
group_size
          SegSpace
space
          TExp Int32
hist_S
          [SegHistSlug]
slugs
          KernelBody GPUMem
kbody

  (TExp Bool, CallKernelGen ())
-> CallKernelGen (TExp Bool, CallKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Bool
pick_local, CallKernelGen ()
run)

-- | Generate code for a segmented histogram called from the host.
compileSegHist ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [HistOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegHist :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist (Pat [PatElem LParamMem]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody = do
  -- Most of this function is not the histogram part itself, but
  -- rather figuring out whether to use a local or global memory
  -- strategy, as well as collapsing the subhistograms produced (which
  -- are always in global memory, but their number may vary).
  let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
      dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

      num_red_res :: Int
num_red_res = [HistOp GPUMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
      ([PatElem LParamMem]
all_red_pes, [PatElem LParamMem]
map_pes) = Int
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem LParamMem]
pes
      segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims

  ([Count Bytes (TPrimExp Int64 VName)]
op_hs, [Count Bytes (TPrimExp Int64 VName)]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes (TPrimExp Int64 VName),
  Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
-> ([Count Bytes (TPrimExp Int64 VName)],
    [Count Bytes (TPrimExp Int64 VName)], [SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes (TPrimExp Int64 VName),
   Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
 -> ([Count Bytes (TPrimExp Int64 VName)],
     [Count Bytes (TPrimExp Int64 VName)], [SegHistSlug]))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     [(Count Bytes (TPrimExp Int64 VName),
       Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([Count Bytes (TPrimExp Int64 VName)],
      [Count Bytes (TPrimExp Int64 VName)], [SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp GPUMem
 -> CallKernelGen
      (Count Bytes (TPrimExp Int64 VName),
       Count Bytes (TPrimExp Int64 VName), SegHistSlug))
-> [HistOp GPUMem]
-> ImpM
     GPUMem
     HostEnv
     HostOp
     [(Count Bytes (TPrimExp Int64 VName),
       Count Bytes (TPrimExp Int64 VName), SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp GPUMem
-> CallKernelGen
     (Count Bytes (TPrimExp Int64 VName),
      Count Bytes (TPrimExp Int64 VName), SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp GPUMem]
ops
  TPrimExp Int64 VName
h <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"h" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TPrimExp Int64 VName)]
op_hs
  TPrimExp Int64 VName
seg_h <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"seg_h" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TPrimExp Int64 VName)]
op_seg_hs

  -- Check for emptyness to avoid division-by-zero.
  TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TPrimExp Int64 VName
seg_h TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    -- Maximum group size (or actual, in this case).
    let hist_B :: TPrimExp Int64 VName
hist_B = Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

    -- Size of a histogram.
    TPrimExp Int64 VName
hist_H <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (HistOp GPUMem -> TPrimExp Int64 VName)
-> [HistOp GPUMem] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> TPrimExp Int64 VName
histSize [HistOp GPUMem]
ops

    -- Size of a single histogram element.  Actually the weighted
    -- average of histogram elements in cases where we have more than
    -- one histogram operation, plus any locks.
    let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
          AtomicLocking {} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
          AtomicUpdate GPUMem KernelEnv
_ -> Maybe a
forall a. Maybe a
Nothing
    TPrimExp Int64 VName
hist_el_size <-
      [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
        (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(+) (TPrimExp Int64 VName
h TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
hist_H) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
          (SegHistSlug -> Maybe (TPrimExp Int64 VName))
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe (TPrimExp Int64 VName)
forall {a}. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs

    -- Input elements contributing to each histogram.
    TPrimExp Int64 VName
hist_N <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_N" TPrimExp Int64 VName
segment_size

    -- Compute RF as the average RF over all the histograms.
    TExp Int32
hist_RF <-
      [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
          [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TPrimExp Int64 VName)
-> [SegHistSlug] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TPrimExp Int64 VName
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs

    let hist_T :: TExp Int32
hist_T = TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_T
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Desired group size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_B
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_H
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_N
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
        Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$
            [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
              ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
hist_el_size
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_RF
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
h
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
seg_h

    (TExp Bool
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
      [PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TPrimExp Int64 VName
hist_H TPrimExp Int64 VName
hist_el_size TPrimExp Int64 VName
hist_N TExp Int32
hist_RF [SegHistSlug]
slugs KernelBody GPUMem
kbody

    TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
use_local_memory CallKernelGen ()
run_in_local_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [PatElem LParamMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody

    let pes_per_op :: [[PatElem LParamMem]]
pes_per_op = [Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [PatElem LParamMem]
all_red_pes

    [(SegHistSlug, [PatElem LParamMem], HistOp GPUMem)]
-> ((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
    -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElem LParamMem]]
-> [HistOp GPUMem]
-> [(SegHistSlug, [PatElem LParamMem], HistOp GPUMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElem LParamMem]]
pes_per_op [HistOp GPUMem]
ops) (((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
  -> CallKernelGen ())
 -> CallKernelGen ())
-> ((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
    -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElem LParamMem]
red_pes, HistOp GPUMem
op) -> do
      let num_histos :: TV Int64
num_histos = SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug
          subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug

      let unitHistoCase :: CallKernelGen ()
unitHistoCase =
            -- This is OK because the memory blocks are at least as
            -- large as the ones we are supposed to use for the result.
            [(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [VName]
subhistos) (((PatElem LParamMem, VName) -> CallKernelGen ())
 -> CallKernelGen ())
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
subhisto) -> do
              VName
pe_mem <-
                MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
                  (ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
              VName
subhisto_mem <-
                MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
                  (ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
subhisto
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"

      TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
num_histos TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        -- For the segmented reduction, we keep the segment dimensions
        -- unchanged.  To this, we add two dimensions: one over the number
        -- of buckets, and one over the number of subhistograms.  This
        -- inner dimension is the one that is collapsed in the reduction.
        [VName]
bucket_ids <-
          Int
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)) ([Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bucket_id")
        VName
subhistogram_id <- [Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"subhistogram_id"
        [VName]
vector_ids <-
          Int
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"vector_id")

        VName
flat_gtid <- [Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"flat_gtid"

        let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegVirt
            segred_space :: SegSpace
segred_space =
              VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
                [(VName, SubExp)]
segment_dims
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op))
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_histos)]

        let segred_op :: SegBinOp GPUMem
segred_op = Commutativity
-> Lambda GPUMem -> [SubExp] -> Shape -> SegBinOp GPUMem
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Commutative (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op) Shape
forall a. Monoid a => a
mempty
        Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
red_pes) SegLevel
lvl SegSpace
segred_space [SegBinOp GPUMem
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ->
          [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ([(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> ((VName -> (SubExp, [TPrimExp Int64 VName]))
    -> [(SubExp, [TPrimExp Int64 VName])])
-> (VName -> (SubExp, [TPrimExp Int64 VName]))
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName -> (SubExp, [TPrimExp Int64 VName]))
 -> [VName] -> [(SubExp, [TPrimExp Int64 VName])])
-> [VName]
-> (VName -> (SubExp, [TPrimExp Int64 VName]))
-> [(SubExp, [TPrimExp Int64 VName])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TPrimExp Int64 VName]))
-> [VName] -> [(SubExp, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [TPrimExp Int64 VName])) -> InKernelGen ())
-> (VName -> (SubExp, [TPrimExp Int64 VName])) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
            ( VName -> SubExp
Var VName
subhisto,
              (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [TPrimExp Int64 VName])
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
                ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id]
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids
            )

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"" Maybe Exp
forall a. Maybe a
Nothing
  where
    segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space