module Futhark.CodeGen.ImpGen.Multicore.SegHist
  ( compileSegHist,
  )
where

import Control.Monad
import Data.List (zip4)
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.CodeGen.ImpGen.Multicore.SegRed (compileSegRed')
import Futhark.IR.MCMem
import Futhark.MonadFreshNames
import Futhark.Transform.Rename (renameLambda)
import Futhark.Util (chunks, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (rem)
import Prelude hiding (quot, rem)

compileSegHist ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.MCCode
compileSegHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
compileSegHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
      Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
nonsegmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
  | Bool
otherwise =
      Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
segmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segHistOpChunks :: [HistOp rep] -> [a] -> [[a]]
segHistOpChunks :: forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([HistOp rep] -> [Int]) -> [HistOp rep] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (HistOp rep -> [SubExp]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp rep -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral)

histSize :: HistOp MCMem -> Imp.TExp Int64
histSize :: HistOp MCMem -> TPrimExp Int64 VName
histSize = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> (HistOp MCMem -> [TPrimExp Int64 VName])
-> HistOp MCMem
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (HistOp MCMem -> [SubExp])
-> HistOp MCMem
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp])
-> (HistOp MCMem -> Shape) -> HistOp MCMem -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histShape

genHistOpParams :: HistOp MCMem -> MulticoreGen ()
genHistOpParams :: HistOp MCMem -> MulticoreGen ()
genHistOpParams HistOp MCMem
histops =
  Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda MCMem -> [LParam MCMem]) -> Lambda MCMem -> [LParam MCMem]
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
histops

renameHistop :: HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop :: HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop HistOp MCMem
histop = do
  let op :: Lambda MCMem
op = HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
histop
  Lambda MCMem
lambda' <- Lambda MCMem -> ImpM MCMem HostEnv Multicore (Lambda MCMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda MCMem
op
  HistOp MCMem -> MulticoreGen (HistOp MCMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure HistOp MCMem
histop {histOp :: Lambda MCMem
histOp = Lambda MCMem
lambda'}

nonsegmentedHist ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.MCCode
nonsegmentedHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
nonsegmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
num_histos = do
  let ns :: [SubExp]
ns = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TPrimExp Int64 VName]
ns_64 = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ns
      num_histos' :: TExp Int32
num_histos' = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_histos
      hist_width :: TPrimExp Int64 VName
hist_width = HistOp MCMem -> TPrimExp Int64 VName
histSize (HistOp MCMem -> TPrimExp Int64 VName)
-> HistOp MCMem -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [HistOp MCMem] -> HistOp MCMem
forall a. [a] -> a
head [HistOp MCMem]
histops
      use_subhistogram :: TPrimExp Bool VName
use_subhistogram = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_histos' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
hist_width TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
ns_64

  [HistOp MCMem]
histops' <- [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda [HistOp MCMem]
histops

  -- Only do something if there is actually input.
  MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$
    TPrimExp Bool VName -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
ns_64 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
      TPrimExp Bool VName
-> MulticoreGen () -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        TPrimExp Bool VName
use_subhistogram
        (Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> MulticoreGen ()
subHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody)
        (Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen ()
atomicHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops' KernelBody MCMem
kbody)

-- |
-- Atomic Histogram approach
-- The implementation has three sub-strategies depending on the
-- type of the operator
-- 1. If values are integral scalars, a direct-supported atomic update is used.
-- 2. If values are on one memory location, e.g. a float, then a
-- CAS operation is used to perform the update, where the float is
-- casted to an integral scalar.
-- 1. and 2. currently only works for 32-bit and 64-bit types,
-- but GCC has support for 8-, 16- and 128- bit types as well.
-- 3. Otherwise a locking based approach is used
onOpAtomic :: HistOp MCMem -> MulticoreGen ([VName] -> [Imp.TExp Int64] -> MulticoreGen ())
onOpAtomic :: HistOp MCMem
-> MulticoreGen
     ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
onOpAtomic HistOp MCMem
op = do
  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM MCMem HostEnv Multicore HostEnv
-> ImpM MCMem HostEnv Multicore AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM MCMem HostEnv Multicore HostEnv
forall rep r op. ImpM rep r op r
askEnv
  let lambda :: Lambda MCMem
lambda = HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
      do_op :: AtomicUpdate MCMem ()
do_op = AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem ()
atomicUpdateLocking AtomicBinOp
atomics Lambda MCMem
lambda
  case AtomicUpdate MCMem ()
do_op of
    AtomicPrim [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
f -> ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
-> MulticoreGen
     ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
f
    AtomicCAS [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
f -> ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
-> MulticoreGen
     ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
f
    AtomicLocking Locking -> [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
f -> do
      -- Allocate a static array of locks
      -- as in the GPU backend
      let num_locks :: Int
num_locks = Int
100151 -- This number is taken from the GPU backend
          dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp MCMem
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp MCMem
op)
      VName
locks <-
        String
-> Space
-> PrimType
-> ArrayContents
-> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"hist_locks" Space
DefaultSpace PrimType
int32 (ArrayContents -> ImpM MCMem HostEnv Multicore VName)
-> ArrayContents -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$
          Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
      let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
      ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
-> MulticoreGen
     ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
 -> MulticoreGen
      ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()))
-> ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
-> MulticoreGen
     ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
forall a b. (a -> b) -> a -> b
$ Locking -> [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
f Locking
l'

atomicHistogram ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen ()
atomicHistogram :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen ()
atomicHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TPrimExp Int64 VName]
ns_64 = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ns
  let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      ([PatElem LParamMem]
all_red_pes, [PatElem LParamMem]
map_pes) = Int
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElem LParamMem] -> ([PatElem LParamMem], [PatElem LParamMem]))
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat

  [[VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()]
atomicOps <- (HistOp MCMem
 -> MulticoreGen
      ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()))
-> [HistOp MCMem]
-> ImpM
     MCMem
     HostEnv
     Multicore
     [[VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp MCMem
-> MulticoreGen
     ([VName] -> [TPrimExp Int64 VName] -> MulticoreGen ())
onOpAtomic [HistOp MCMem]
histops

  Code Multicore
body <- MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ do
    VName -> PrimType -> MulticoreGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    Multicore -> MulticoreGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
    String
-> ChunkLoopVectorization
-> (TPrimExp Int64 VName -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegHist" ChunkLoopVectorization
Scalar ((TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ())
-> (TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
flat_idx -> do
      (VName -> TPrimExp Int64 VName -> MulticoreGen ())
-> [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> MulticoreGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is ([TPrimExp Int64 VName] -> MulticoreGen ())
-> [TPrimExp Int64 VName] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
ns_64 TPrimExp Int64 VName
flat_idx
      Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
        let ([KernelResult]
red_res, [KernelResult]
map_res) =
              Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
            red_res_split :: [([SubExp], [SubExp])]
red_res_split = [HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops ([SubExp] -> [([SubExp], [SubExp])])
-> [SubExp] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res

        let pes_per_op :: [[PatElem LParamMem]]
pes_per_op = [Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp MCMem -> [VName]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp MCMem]
histops) [PatElem LParamMem]
all_red_pes
        [(HistOp MCMem, ([SubExp], [SubExp]),
  [VName] -> [TPrimExp Int64 VName] -> MulticoreGen (),
  [PatElem LParamMem])]
-> ((HistOp MCMem, ([SubExp], [SubExp]),
     [VName] -> [TPrimExp Int64 VName] -> MulticoreGen (),
     [PatElem LParamMem])
    -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [([SubExp], [SubExp])]
-> [[VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()]
-> [[PatElem LParamMem]]
-> [(HistOp MCMem, ([SubExp], [SubExp]),
     [VName] -> [TPrimExp Int64 VName] -> MulticoreGen (),
     [PatElem LParamMem])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [HistOp MCMem]
histops [([SubExp], [SubExp])]
red_res_split [[VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()]
atomicOps [[PatElem LParamMem]]
pes_per_op) (((HistOp MCMem, ([SubExp], [SubExp]),
   [VName] -> [TPrimExp Int64 VName] -> MulticoreGen (),
   [PatElem LParamMem])
  -> MulticoreGen ())
 -> MulticoreGen ())
-> ((HistOp MCMem, ([SubExp], [SubExp]),
     [VName] -> [TPrimExp Int64 VName] -> MulticoreGen (),
     [PatElem LParamMem])
    -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
          \(HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs'), [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
do_op, [PatElem LParamMem]
dest_res) -> do
            let ([Param LParamMem]
_is_params, [Param LParamMem]
vs_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
                bucket' :: [TPrimExp Int64 VName]
bucket' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
bucket
                bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Bool VName
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
bucket')) [TPrimExp Int64 VName]
dest_shape'

            String -> MulticoreGen () -> MulticoreGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              [(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) (((PatElem LParamMem, KernelResult) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElem LParamMem, KernelResult) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
                VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

            String -> MulticoreGen () -> MulticoreGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform updates" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool VName -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
                let bucket_is :: [TPrimExp Int64 VName]
bucket_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
bucket'
                [LParam MCMem] -> MulticoreGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam MCMem] -> MulticoreGen ())
-> [LParam MCMem] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                Shape
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is' -> do
                  [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
                    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TPrimExp Int64 VName]
is'
                  [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
do_op ((PatElem LParamMem -> VName) -> [PatElem LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem LParamMem]
dest_res) ([TPrimExp Int64 VName]
bucket_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is')

  [Param]
free_params <- Code Multicore -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
body
  Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code Multicore
forall a. a -> Code a
Imp.Op (Multicore -> Code Multicore) -> Multicore -> Code Multicore
forall a b. (a -> b) -> a -> b
$ String -> Code Multicore -> [Param] -> Multicore
Imp.ParLoop String
"atomic_seg_hist" Code Multicore
body [Param]
free_params

updateHisto ::
  HistOp MCMem ->
  [VName] ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  [Param LParamMem] ->
  MulticoreGen ()
updateHisto :: HistOp MCMem
-> [VName]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> [Param LParamMem]
-> MulticoreGen ()
updateHisto HistOp MCMem
op [VName]
arrs [TPrimExp Int64 VName]
bucket TPrimExp Int64 VName
j [Param LParamMem]
uni_acc = do
  let bind_acc_params :: ImpM rep r op ()
bind_acc_params =
        [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
uni_acc [VName]
arrs) (((Param LParamMem, VName) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((Param LParamMem, VName) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_u, VName
arr) -> do
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
acc_u) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
bucket

      op_body :: ImpM MCMem r op ()
op_body = [Param Any] -> Body MCMem -> ImpM MCMem r op ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [] (Body MCMem -> ImpM MCMem r op ())
-> Body MCMem -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
      writeArray :: VName -> SubExp -> MulticoreGen ()
writeArray VName
arr SubExp
val = TPrimExp Int64 VName
-> MulticoreGen (Code Multicore) -> MulticoreGen ()
extractVectorLane TPrimExp Int64 VName
j (MulticoreGen (Code Multicore) -> MulticoreGen ())
-> MulticoreGen (Code Multicore) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
bucket SubExp
val []
      do_hist :: MulticoreGen ()
do_hist = (VName -> SubExp -> MulticoreGen ())
-> [VName] -> [SubExp] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> SubExp -> MulticoreGen ()
writeArray [VName]
arrs ([SubExp] -> MulticoreGen ()) -> [SubExp] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body MCMem -> [SubExpRes]) -> Body MCMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op

  String -> MulticoreGen () -> MulticoreGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Start of body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
    MulticoreGen ()
forall {rep} {r} {op}. ImpM rep r op ()
bind_acc_params
    MulticoreGen ()
forall {r} {op}. ImpM MCMem r op ()
op_body
    MulticoreGen ()
do_hist

-- Generates num_histos sub-histograms of the size
-- of the destination histogram
-- Then for each chunk of the input each subhistogram
-- is computed and finally combined through a segmented reduction
-- across the histogram indicies.
-- This is expected to be fast if len(histDest) is small
subHistogram ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  TV Int32 ->
  KernelBody MCMem ->
  MulticoreGen ()
subHistogram :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> MulticoreGen ()
subHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody = do
  Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code Multicore
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"subHistogram segHist" Maybe Exp
forall a. Maybe a
Nothing

  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TPrimExp Int64 VName]
ns_64 = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ns

  let pes :: [PatElem LParamMem]
pes = Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
      num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      map_pes :: [PatElem LParamMem]
map_pes = Int -> [PatElem LParamMem] -> [PatElem LParamMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res [PatElem LParamMem]
pes
      per_red_pes :: [[PatElem LParamMem]]
per_red_pes = [HistOp MCMem] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElem LParamMem] -> [[PatElem LParamMem]])
-> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat

  -- Allocate array of subhistograms in the calling thread.  Each
  -- tasks will work in its own private allocations (to avoid false
  -- sharing), but this is where they will ultimately copy their
  -- results.
  [[VName]]
global_subhistograms <- [HistOp MCMem]
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp MCMem]
histops ((HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
 -> ImpM MCMem HostEnv Multicore [[VName]])
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall a b. (a -> b) -> a -> b
$ \HistOp MCMem
histop ->
    [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
      let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_histos] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
      String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Shape
shape Space
DefaultSpace

  let tid' :: TPrimExp Int64 VName
tid' = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

  -- Generate loop body of parallel function
  Code Multicore
body <- MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ do
    VName -> PrimType -> MulticoreGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    Multicore -> MulticoreGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

    [[VName]]
local_subhistograms <- [([PatElem LParamMem], HistOp MCMem)]
-> (([PatElem LParamMem], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[PatElem LParamMem]]
-> [HistOp MCMem] -> [([PatElem LParamMem], HistOp MCMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LParamMem]]
per_red_pes [HistOp MCMem]
histops) ((([PatElem LParamMem], HistOp MCMem)
  -> ImpM MCMem HostEnv Multicore [VName])
 -> ImpM MCMem HostEnv Multicore [[VName]])
-> (([PatElem LParamMem], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes', HistOp MCMem
histop) -> do
      [VName]
op_local_subhistograms <- [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
        String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace

      [(PatElem LParamMem, VName, SubExp)]
-> ((PatElem LParamMem, VName, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [VName] -> [SubExp] -> [(PatElem LParamMem, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem LParamMem]
pes' [VName]
op_local_subhistograms (HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp MCMem
histop)) (((PatElem LParamMem, VName, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElem LParamMem, VName, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
hist, SubExp
ne) ->
        -- First thread initializes histogram with dest vals. Others
        -- initialize with neutral element
        TPrimExp Bool VName
-> MulticoreGen () -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          (TPrimExp Int64 VName
tid' TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
          (VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
hist [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [])
          ( Shape
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp MCMem
histop) (([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
shape_is ->
              Shape
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp MCMem
histop) (([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is ->
                VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
hist ([TPrimExp Int64 VName]
shape_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Semigroup a => a -> a -> a
<> [TPrimExp Int64 VName]
vec_is) SubExp
ne []
          )

      [VName] -> ImpM MCMem HostEnv Multicore [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
op_local_subhistograms

    MulticoreGen () -> MulticoreGen ()
inISPC (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      String
-> ChunkLoopVectorization
-> (TPrimExp Int64 VName -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized ((TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ())
-> (TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
        (VName -> TPrimExp Int64 VName -> MulticoreGen ())
-> [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> MulticoreGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is ([TPrimExp Int64 VName] -> MulticoreGen ())
-> [TPrimExp Int64 VName] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
ns_64 TPrimExp Int64 VName
i
        Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
          let ([SubExp]
red_res, [SubExp]
map_res) =
                Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                  (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
                    KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody

          String -> MulticoreGen () -> MulticoreGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            [(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
res) ->
              VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []

          [(HistOp MCMem, [VName], ([SubExp], [SubExp]))]
-> ((HistOp MCMem, [VName], ([SubExp], [SubExp]))
    -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [[VName]]
-> [([SubExp], [SubExp])]
-> [(HistOp MCMem, [VName], ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [HistOp MCMem]
histops [[VName]]
local_subhistograms ([HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) (((HistOp MCMem, [VName], ([SubExp], [SubExp])) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((HistOp MCMem, [VName], ([SubExp], [SubExp]))
    -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            \( histop :: HistOp MCMem
histop@(HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda MCMem
_),
               [VName]
histop_subhistograms,
               ([SubExp]
bucket, [SubExp]
vs')
               ) -> do
                HistOp MCMem
histop' <- HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop HistOp MCMem
histop

                let bucket' :: [TPrimExp Int64 VName]
bucket' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
bucket
                    dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
                    acc_params' :: [Param LParamMem]
acc_params' = (Lambda MCMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda MCMem -> [Param LParamMem])
-> (HistOp MCMem -> Lambda MCMem)
-> HistOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp) HistOp MCMem
histop'
                    vs_params' :: [Param LParamMem]
vs_params' = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda MCMem -> [LParam MCMem]) -> Lambda MCMem -> [LParam MCMem]
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
histop'

                (TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ()
generateUniformizeLoop ((TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ())
-> (TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
j ->
                  String -> MulticoreGen () -> MulticoreGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform updates" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
                    -- Create new set of uniform buckets
                    -- That is extract each bucket from a SIMD vector lane
                    [TV Int64]
extract_buckets <- (TPrimExp Int64 VName -> ImpM MCMem HostEnv Multicore (TV Int64))
-> [TPrimExp Int64 VName]
-> ImpM MCMem HostEnv Multicore [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"extract_bucket" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int64))
-> (TPrimExp Int64 VName -> PrimType)
-> TPrimExp Int64 VName
-> ImpM MCMem HostEnv Multicore (TV Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType)
-> (TPrimExp Int64 VName -> Exp)
-> TPrimExp Int64 VName
-> PrimType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped)) [TPrimExp Int64 VName]
bucket'
                    [(TV Int64, TPrimExp Int64 VName)]
-> ((TV Int64, TPrimExp Int64 VName) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Int64]
-> [TPrimExp Int64 VName] -> [(TV Int64, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Int64]
extract_buckets [TPrimExp Int64 VName]
bucket') (((TV Int64, TPrimExp Int64 VName) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((TV Int64, TPrimExp Int64 VName) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(TV Int64
x, TPrimExp Int64 VName
y) ->
                      Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code Multicore
forall a. a -> Code a
Imp.Op (Multicore -> Code Multicore) -> Multicore -> Code Multicore
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Exp -> Multicore
Imp.ExtractLane (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
x) (TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
y) (TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
j)
                    let bucket'' :: [TPrimExp Int64 VName]
bucket'' = (TV Int64 -> TPrimExp Int64 VName)
-> [TV Int64] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp [TV Int64]
extract_buckets
                        bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds =
                          Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Bool VName
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
bucket'')) [TPrimExp Int64 VName]
dest_shape'
                    TPrimExp Bool VName -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
                      HistOp MCMem -> MulticoreGen ()
genHistOpParams HistOp MCMem
histop'
                      Shape
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is' -> do
                        -- read values vs and perform lambda writing result back to is
                        [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params' [SubExp]
vs') (((Param LParamMem, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
                          Type -> (PrimType -> MulticoreGen ()) -> MulticoreGen ()
forall {f :: * -> *} {shape} {u}.
Applicative f =>
TypeBase shape u -> (PrimType -> f ()) -> f ()
ifPrimType (Param LParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p) ((PrimType -> MulticoreGen ()) -> MulticoreGen ())
-> (PrimType -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \PrimType
pt -> do
                            -- Hack to copy varying load into uniform result variable
                            TV Any
tmp <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"tmp" PrimType
pt
                            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
tmp) [] SubExp
res [TPrimExp Int64 VName]
is'
                            TPrimExp Int64 VName
-> MulticoreGen (Code Multicore) -> MulticoreGen ()
extractVectorLane TPrimExp Int64 VName
j (MulticoreGen (Code Multicore) -> MulticoreGen ())
-> MulticoreGen (Code Multicore) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                              Code Multicore -> MulticoreGen (Code Multicore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Code Multicore -> MulticoreGen (Code Multicore))
-> Code Multicore -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$
                                VName -> Exp -> Code Multicore
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) (VName -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
tmp) PrimType
pt)
                        HistOp MCMem
-> [VName]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> [Param LParamMem]
-> MulticoreGen ()
updateHisto HistOp MCMem
histop' [VName]
histop_subhistograms ([TPrimExp Int64 VName]
bucket'' [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is') TPrimExp Int64 VName
j [Param LParamMem]
acc_params'

    -- Copy the task-local subhistograms to the global subhistograms,
    -- where they will be combined.
    [(VName, VName)]
-> ((VName, VName) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
global_subhistograms) ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
local_subhistograms)) (((VName, VName) -> MulticoreGen ()) -> MulticoreGen ())
-> ((VName, VName) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      \(VName
global, VName
local) -> VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
global [TPrimExp Int64 VName
tid'] (VName -> SubExp
Var VName
local) []

  [Param]
free_params <- Code Multicore -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
body
  Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code Multicore
forall a. a -> Code a
Imp.Op (Multicore -> Code Multicore) -> Multicore -> Code Multicore
forall a b. (a -> b) -> a -> b
$ String -> Code Multicore -> [Param] -> Multicore
Imp.ParLoop String
"seghist_stage_1" Code Multicore
body [Param]
free_params

  -- Perform a segmented reduction over the subhistograms
  [([PatElem LParamMem], [VName], HistOp MCMem)]
-> (([PatElem LParamMem], [VName], HistOp MCMem)
    -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem LParamMem]]
-> [[VName]]
-> [HistOp MCMem]
-> [([PatElem LParamMem], [VName], HistOp MCMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_red_pes [[VName]]
global_subhistograms [HistOp MCMem]
histops) ((([PatElem LParamMem], [VName], HistOp MCMem) -> MulticoreGen ())
 -> MulticoreGen ())
-> (([PatElem LParamMem], [VName], HistOp MCMem)
    -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
red_pes, [VName]
hists, HistOp MCMem
op) -> do
    [VName]
bucket_ids <-
      Int
-> ImpM MCMem HostEnv Multicore VName
-> ImpM MCMem HostEnv Multicore [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp MCMem
op)) (String -> ImpM MCMem HostEnv Multicore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id")
    VName
subhistogram_id <- String -> ImpM MCMem HostEnv Multicore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"

    let segred_space :: SegSpace
segred_space =
          VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
            [(VName, SubExp)]
segment_dims
              [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp MCMem
op))
              [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_histos)]

        segred_op :: SegBinOp MCMem
segred_op = Commutativity
-> Lambda MCMem -> [SubExp] -> Shape -> SegBinOp MCMem
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative (HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op) (HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp MCMem
op) (HistOp MCMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp MCMem
op)

    Code Multicore
red_code <- MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ do
      TV Int32
nsubtasks <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"nsubtasks" PrimType
int32
      Multicore -> MulticoreGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetNumTasks (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
nsubtasks
      Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> (DoSegBody -> MulticoreGen (Code Multicore))
-> DoSegBody
-> MulticoreGen ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen (Code Multicore)
compileSegRed' ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
red_pes) SegSpace
segred_space [SegBinOp MCMem
segred_op] TV Int32
nsubtasks (DoSegBody -> MulticoreGen ()) -> DoSegBody -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TPrimExp Int64 VName])]] -> MulticoreGen ()
red_cont ->
        [[(SubExp, [TPrimExp Int64 VName])]] -> MulticoreGen ()
red_cont ([[(SubExp, [TPrimExp Int64 VName])]] -> MulticoreGen ())
-> [[(SubExp, [TPrimExp Int64 VName])]] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
          [SegBinOp MCMem]
-> [(SubExp, [TPrimExp Int64 VName])]
-> [[(SubExp, [TPrimExp Int64 VName])]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem
segred_op] ([(SubExp, [TPrimExp Int64 VName])]
 -> [[(SubExp, [TPrimExp Int64 VName])]])
-> [(SubExp, [TPrimExp Int64 VName])]
-> [[(SubExp, [TPrimExp Int64 VName])]]
forall a b. (a -> b) -> a -> b
$
            ((VName -> (SubExp, [TPrimExp Int64 VName]))
 -> [VName] -> [(SubExp, [TPrimExp Int64 VName])])
-> [VName]
-> (VName -> (SubExp, [TPrimExp Int64 VName]))
-> [(SubExp, [TPrimExp Int64 VName])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TPrimExp Int64 VName]))
-> [VName] -> [(SubExp, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
hists ((VName -> (SubExp, [TPrimExp Int64 VName]))
 -> [(SubExp, [TPrimExp Int64 VName])])
-> (VName -> (SubExp, [TPrimExp Int64 VName]))
-> [(SubExp, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
              ( VName -> SubExp
Var VName
subhisto,
                (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [TPrimExp Int64 VName])
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
                  ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
              )

    let ns_red :: [TPrimExp Int64 VName]
ns_red = ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
segred_space
        iterations :: TPrimExp Int64 VName
iterations = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
init [TPrimExp Int64 VName]
ns_red -- The segmented reduction is sequential over the inner most dimension
        scheduler_info :: SchedulerInfo
scheduler_info = Exp -> Scheduling -> SchedulerInfo
Imp.SchedulerInfo (TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
iterations) Scheduling
Imp.Static
        red_task :: ParallelTask
red_task = Code Multicore -> ParallelTask
Imp.ParallelTask Code Multicore
red_code
    [Param]
free_params_red <- Code Multicore -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
red_code
    Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code Multicore
forall a. a -> Code a
Imp.Op (Multicore -> Code Multicore) -> Multicore -> Code Multicore
forall a b. (a -> b) -> a -> b
$ String
-> [Param]
-> ParallelTask
-> Maybe ParallelTask
-> [Param]
-> SchedulerInfo
-> Multicore
Imp.SegOp String
"seghist_red" [Param]
free_params_red ParallelTask
red_task Maybe ParallelTask
forall a. Maybe a
Nothing [Param]
forall a. Monoid a => a
mempty SchedulerInfo
scheduler_info
  where
    segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
    ifPrimType :: TypeBase shape u -> (PrimType -> f ()) -> f ()
ifPrimType (Prim PrimType
pt) PrimType -> f ()
f = PrimType -> f ()
f PrimType
pt
    ifPrimType TypeBase shape u
_ PrimType -> f ()
_ = () -> f ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- Note: This isn't currently used anywhere.
-- This implementation for a Segmented Hist only
-- parallelize over the segments,
-- where each segment is updated sequentially.
segmentedHist ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
segmentedHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
segmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
  Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code Multicore
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Segmented segHist" Maybe Exp
forall a. Maybe a
Nothing
  MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ do
    Code Multicore
body <- Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
compileSegHistBody Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody
    [Param]
free_params <- Code Multicore -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
body
    Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code Multicore
forall a. a -> Code a
Imp.Op (Multicore -> Code Multicore) -> Multicore -> Code Multicore
forall a b. (a -> b) -> a -> b
$ String -> Code Multicore -> [Param] -> Multicore
Imp.ParLoop String
"segmented_hist" Code Multicore
body [Param]
free_params

compileSegHistBody ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
compileSegHistBody :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
compileSegHistBody Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TPrimExp Int64 VName]
ns_64 = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ns

  let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      map_pes :: [PatElem LParamMem]
map_pes = Int -> [PatElem LParamMem] -> [PatElem LParamMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res ([PatElem LParamMem] -> [PatElem LParamMem])
-> [PatElem LParamMem] -> [PatElem LParamMem]
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
      per_red_pes :: [[PatElem LParamMem]]
per_red_pes = [HistOp MCMem] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElem LParamMem] -> [[PatElem LParamMem]])
-> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat

  VName -> PrimType -> MulticoreGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
  Multicore -> MulticoreGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

  String
-> ChunkLoopVectorization
-> (TPrimExp Int64 VName -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegHist" ChunkLoopVectorization
Scalar ((TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ())
-> (TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
idx -> do
    let inner_bound :: TPrimExp Int64 VName
inner_bound = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
ns_64
    String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> MulticoreGen ())
-> MulticoreGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
inner_bound ((TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ())
-> (TPrimExp Int64 VName -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
      (VName -> TPrimExp Int64 VName -> MulticoreGen ())
-> [VName] -> [TPrimExp Int64 VName] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> MulticoreGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) ([TPrimExp Int64 VName] -> MulticoreGen ())
-> [TPrimExp Int64 VName] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
init [TPrimExp Int64 VName]
ns_64) TPrimExp Int64 VName
idx
      VName -> TPrimExp Int64 VName -> MulticoreGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
is) TPrimExp Int64 VName
i

      Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
        let ([SubExp]
red_res, [SubExp]
map_res) =
              Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
                  KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
        [([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))]
-> (([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))
    -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem LParamMem]]
-> [HistOp MCMem]
-> [([SubExp], [SubExp])]
-> [([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_red_pes [HistOp MCMem]
histops ([HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) ((([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))
  -> MulticoreGen ())
 -> MulticoreGen ())
-> (([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))
    -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
          \([PatElem LParamMem]
red_pes, HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs')) -> do
            let ([Param LParamMem]
is_params, [Param LParamMem]
vs_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                bucket' :: [TPrimExp Int64 VName]
bucket' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
bucket
                dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
                bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Bool VName
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
bucket')) [TPrimExp Int64 VName]
dest_shape'

            String -> MulticoreGen () -> MulticoreGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              [(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
res) ->
                VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []

            String -> MulticoreGen () -> MulticoreGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform updates" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool VName -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
                [LParam MCMem] -> MulticoreGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam MCMem] -> MulticoreGen ())
-> [LParam MCMem] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                Shape
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TPrimExp Int64 VName] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
                  -- Index
                  [(PatElem LParamMem, Param LParamMem)]
-> ((PatElem LParamMem, Param LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [Param LParamMem] -> [(PatElem LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [Param LParamMem]
is_params) (((PatElem LParamMem, Param LParamMem) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElem LParamMem, Param LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, Param LParamMem
p) ->
                    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                      (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
                      []
                      (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                      ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
bucket' [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
                  -- Value at index
                  [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
                    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TPrimExp Int64 VName]
vec_is
                  Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall rep. Body rep -> Stms rep
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                    [(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes ([SubExp] -> [(PatElem LParamMem, SubExp)])
-> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body MCMem -> [SubExpRes]) -> Body MCMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) (((PatElem LParamMem, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                      \(PatElem LParamMem
pe, SubExp
se) ->
                        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> MulticoreGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                          (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                          ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
bucket' [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
                          SubExp
se
                          []