module Futhark.CodeGen.ImpGen.Multicore.SegRed
  ( compileSegRed,
    compileSegRed',
  )
where

import Control.Monad
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Util (chunks)
import Prelude hiding (quot, rem)

type DoSegBody = (([(SubExp, [Imp.TExp Int64])] -> MulticoreGen ()) -> MulticoreGen ())

-- | Generate code for a SegRed construct
compileSegRed ::
  Pattern MCMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.Code
compileSegRed :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
compileSegRed Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks =
  Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
compileSegRed' Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks (DoSegBody -> MulticoreGen Code) -> DoSegBody -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> MulticoreGen ()
red_cont ->
    Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp MCMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp MCMem]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody

      String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
        let map_arrs :: [PatElemT LParamMem]
map_arrs = Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp MCMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp MCMem]
reds) ([PatElemT LParamMem] -> [PatElemT LParamMem])
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LParamMem
pat
        (PatElemT LParamMem -> KernelResult -> MulticoreGen ())
-> [PatElemT LParamMem] -> [KernelResult] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem MCMem -> KernelResult -> MulticoreGen ()
compileThreadResult SegSpace
space) [PatElemT LParamMem]
map_arrs [KernelResult]
map_res

      [(SubExp, [TExp Int64])] -> MulticoreGen ()
red_cont ([(SubExp, [TExp Int64])] -> MulticoreGen ())
-> [(SubExp, [TExp Int64])] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[TExp Int64]] -> [(SubExp, [TExp Int64])])
-> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [[TExp Int64]]
forall a. a -> [a]
repeat []

-- | Like 'compileSegRed', but where the body is a monadic action.
compileSegRed' ::
  Pattern MCMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  TV Int32 ->
  DoSegBody ->
  MulticoreGen Imp.Code
compileSegRed' :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
compileSegRed' Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
    Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
nonsegmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
  | Bool
otherwise =
    Pattern MCMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen Code
segmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody

-- | A SegBinOp with auxiliary information.
data SegBinOpSlug = SegBinOpSlug
  { SegBinOpSlug -> SegBinOp MCMem
slugOp :: SegBinOp MCMem,
    -- | The array in which we write the intermediate results, indexed
    -- by the flat/physical thread ID.
    SegBinOpSlug -> [VName]
slugResArrs :: [VName]
  }

slugBody :: SegBinOpSlug -> Body MCMem
slugBody :: SegBinOpSlug -> Body MCMem
slugBody = LambdaT MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT MCMem -> Body MCMem)
-> (SegBinOpSlug -> LambdaT MCMem) -> SegBinOpSlug -> Body MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp MCMem -> LambdaT MCMem)
-> (SegBinOpSlug -> SegBinOp MCMem)
-> SegBinOpSlug
-> LambdaT MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams = LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOpSlug -> LambdaT MCMem)
-> SegBinOpSlug
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp MCMem -> LambdaT MCMem)
-> (SegBinOpSlug -> SegBinOp MCMem)
-> SegBinOpSlug
-> LambdaT MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral (SegBinOp MCMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp MCMem) -> SegBinOpSlug -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape (SegBinOp MCMem -> Shape)
-> (SegBinOpSlug -> SegBinOp MCMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp

accParams, nextParams :: SegBinOpSlug -> [LParam MCMem]
accParams :: SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug

nonsegmentedReduction ::
  Pattern MCMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  TV Int32 ->
  DoSegBody ->
  MulticoreGen Imp.Code
nonsegmentedReduction :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
nonsegmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody = MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
  [[VName]]
thread_res_arrs <- String -> SubExp -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
groupResultArrays String
"reduce_stage_1_tid_res_arr" (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
nsubtasks) [SegBinOp MCMem]
reds
  let slugs1 :: [SegBinOpSlug]
slugs1 = (SegBinOp MCMem -> [VName] -> SegBinOpSlug)
-> [SegBinOp MCMem] -> [[VName]] -> [SegBinOpSlug]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds [[VName]]
thread_res_arrs
      nsubtasks' :: TExp Int32
nsubtasks' = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
nsubtasks

  SegSpace -> [SegBinOpSlug] -> DoSegBody -> MulticoreGen ()
reductionStage1 SegSpace
space [SegBinOpSlug]
slugs1 DoSegBody
kbody
  [SegBinOp MCMem]
reds2 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
reds
  let slugs2 :: [SegBinOpSlug]
slugs2 = (SegBinOp MCMem -> [VName] -> SegBinOpSlug)
-> [SegBinOp MCMem] -> [[VName]] -> [SegBinOpSlug]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds2 [[VName]]
thread_res_arrs
  Pattern MCMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pattern MCMem
pat SegSpace
space TExp Int32
nsubtasks' [SegBinOpSlug]
slugs2

reductionStage1 ::
  SegSpace ->
  [SegBinOpSlug] ->
  DoSegBody ->
  MulticoreGen ()
reductionStage1 :: SegSpace -> [SegBinOpSlug] -> DoSegBody -> MulticoreGen ()
reductionStage1 SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  TV Int64
flat_idx <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"iter" PrimType
int64

  -- Create local accumulator variables in which we carry out the
  -- sequential reduction of this function.  If we are dealing with
  -- vectorised operators, then this implies a private allocation.  If
  -- the original operand type of the reduction is a memory block,
  -- then our hands are unfortunately tied, and we have to use exactly
  -- that memory.  This is likely to be slow.

  ([[VName]]
slug_local_accs, Code
prebody) <- MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (MulticoreGen [[VName]]
 -> ImpM MCMem HostEnv Multicore ([[VName]], Code))
-> MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall a b. (a -> b) -> a -> b
$ do
    Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LParamMem])
-> [SegBinOpSlug] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
SegBinOpSlug -> [Param LParamMem]
slugParams [SegBinOpSlug]
slugs

    [SegBinOpSlug]
-> (SegBinOpSlug -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOpSlug]
slugs ((SegBinOpSlug -> ImpM MCMem HostEnv Multicore [VName])
 -> MulticoreGen [[VName]])
-> (SegBinOpSlug -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
      let shape :: Shape
shape = SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape (SegBinOp MCMem -> Shape) -> SegBinOp MCMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug

      [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp)
    -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) (((Param LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> ((Param LParamMem, SubExp)
    -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) -> do
        -- Declare accumulator variable.
        VName
acc <-
          case Param LParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p of
            Prim PrimType
pt
              | Shape
shape Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty ->
                TV Any -> VName
forall t. TV t -> VName
tvVar (TV Any -> VName)
-> ImpM MCMem HostEnv Multicore (TV Any)
-> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"local_acc" PrimType
pt
              | Bool
otherwise ->
                String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"local_acc" PrimType
pt Shape
shape Space
DefaultSpace
            Type
_ ->
              VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM MCMem HostEnv Multicore VName)
-> VName -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p

        -- Now neutral-initialise the accumulator.
        Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is SubExp
ne []

        VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc

  Code
fbody <- MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    (VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx
    DoSegBody
kbody DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
all_red_res -> do
      let all_red_res' :: [[(SubExp, [TExp Int64])]]
all_red_res' = [SegBinOp MCMem]
-> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks ((SegBinOpSlug -> SegBinOp MCMem)
-> [SegBinOpSlug] -> [SegBinOp MCMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) [(SubExp, [TExp Int64])]
all_red_res
      [([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])]
-> (([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])
    -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[(SubExp, [TExp Int64])]]
-> [SegBinOpSlug]
-> [[VName]]
-> [([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[(SubExp, [TExp Int64])]]
all_red_res' [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) ((([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])
  -> MulticoreGen ())
 -> MulticoreGen ())
-> (([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])
    -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \([(SubExp, [TExp Int64])]
red_res, SegBinOpSlug
slug, [VName]
local_accs) ->
        Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
          let lamtypes :: [Type]
lamtypes = LambdaT MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT MCMem -> [Type]) -> LambdaT MCMem -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem -> LambdaT MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
          -- Load accum params
          String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Load accum params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param LParamMem, VName, Type)]
-> ((Param LParamMem, VName, Type) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [VName] -> [Type] -> [(Param LParamMem, VName, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [VName]
local_accs [Type]
lamtypes) (((Param LParamMem, VName, Type) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((Param LParamMem, VName, Type) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              \(Param LParamMem
p, VName
local_acc, Type
t) ->
                Bool -> MulticoreGen () -> MulticoreGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
t) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
local_acc) [TExp Int64]
vec_is

          String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Load next params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param LParamMem, (SubExp, [TExp Int64]))]
-> ((Param LParamMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(SubExp, [TExp Int64])]
-> [(Param LParamMem, (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TExp Int64])]
red_res) (((Param LParamMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((Param LParamMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
              VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TExp Int64]
res_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

          String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Red body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              [(VName, SubExp)]
-> ((VName, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_accs (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug)) (((VName, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((VName, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                \(VName
local_acc, SubExp
se) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
local_acc [TExp Int64]
vec_is SubExp
se []

  Code
postbody <- MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$
    [(SegBinOpSlug, [VName])]
-> ((SegBinOpSlug, [VName]) -> ImpM MCMem HostEnv Multicore [()])
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug] -> [[VName]] -> [(SegBinOpSlug, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) (((SegBinOpSlug, [VName]) -> ImpM MCMem HostEnv Multicore [()])
 -> MulticoreGen ())
-> ((SegBinOpSlug, [VName]) -> ImpM MCMem HostEnv Multicore [()])
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [VName]
local_accs) ->
      [(VName, VName)]
-> ((VName, VName) -> MulticoreGen ())
-> ImpM MCMem HostEnv Multicore [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug) [VName]
local_accs) (((VName, VName) -> MulticoreGen ())
 -> ImpM MCMem HostEnv Multicore [()])
-> ((VName, VName) -> MulticoreGen ())
-> ImpM MCMem HostEnv Multicore [()]
forall a b. (a -> b) -> a -> b
$ \(VName
acc, VName
local_acc) ->
        VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc [VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64) -> VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space] (VName -> SubExp
Var VName
local_acc) []

  [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams (Code
prebody Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
fbody Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
postbody) (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx])
  let (Code
body_allocs, Code
fbody') = Code -> (Code, Code)
extractAllocations Code
fbody
  Code -> MulticoreGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"segred_stage_1" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx) (Code
body_allocs Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
prebody) Code
fbody' Code
postbody [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

reductionStage2 ::
  Pattern MCMem ->
  SegSpace ->
  Imp.TExp Int32 ->
  [SegBinOpSlug] ->
  MulticoreGen ()
reductionStage2 :: Pattern MCMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pattern MCMem
pat SegSpace
space TExp Int32
nsubtasks [SegBinOpSlug]
slugs = do
  let per_red_pes :: [[PatElemT LParamMem]]
per_red_pes = [SegBinOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks ((SegBinOpSlug -> SegBinOp MCMem)
-> [SegBinOpSlug] -> [SegBinOp MCMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat
      phys_id :: TExp Int64
phys_id = VName -> TExp Int64
Imp.vi64 (SegSpace -> VName
segFlat SegSpace
space)
  String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the output" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
    [(SegBinOp MCMem, [PatElemT LParamMem])]
-> ((SegBinOp MCMem, [PatElemT LParamMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp MCMem]
-> [[PatElemT LParamMem]]
-> [(SegBinOp MCMem, [PatElemT LParamMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((SegBinOpSlug -> SegBinOp MCMem)
-> [SegBinOpSlug] -> [SegBinOp MCMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) [[PatElemT LParamMem]]
per_red_pes) (((SegBinOp MCMem, [PatElemT LParamMem]) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((SegBinOp MCMem, [PatElemT LParamMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
red, [PatElemT LParamMem]
red_res) ->
      [(PatElemT LParamMem, SubExp)]
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
red_res ([SubExp] -> [(PatElemT LParamMem, SubExp)])
-> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red) (((PatElemT LParamMem, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, SubExp
ne) ->
        Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
red) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) [TExp Int64]
vec_is SubExp
ne []

  Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LParamMem])
-> [SegBinOpSlug] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
SegBinOpSlug -> [Param LParamMem]
slugParams [SegBinOpSlug]
slugs

  String
-> TExp Int32 -> (TExp Int32 -> MulticoreGen ()) -> MulticoreGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int32
nsubtasks ((TExp Int32 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int32 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i' -> do
    VName -> PrimType -> TV Int32
forall t. VName -> PrimType -> TV t
mkTV (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64 TV Int32 -> TExp Int32 -> MulticoreGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int32
i'
    String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Apply main thread reduction" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      [(SegBinOpSlug, [PatElemT LParamMem])]
-> ((SegBinOpSlug, [PatElemT LParamMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[PatElemT LParamMem]] -> [(SegBinOpSlug, [PatElemT LParamMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElemT LParamMem]]
per_red_pes) (((SegBinOpSlug, [PatElemT LParamMem]) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((SegBinOpSlug, [PatElemT LParamMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElemT LParamMem]
red_res) ->
        Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
          String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load acc params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [PatElemT LParamMem]
red_res) (((Param LParamMem, PatElemT LParamMem) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((Param LParamMem, PatElemT LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
              VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) [TExp Int64]
vec_is
          String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load next params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug)) (((Param LParamMem, VName) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LParamMem, VName) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
acc) ->
              VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) (TExp Int64
phys_id TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)
          String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"red body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
            Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT LParamMem, SubExp)]
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
red_res (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug)) (((PatElemT LParamMem, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                \(PatElemT LParamMem
pe, SubExp
se') -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) [TExp Int64]
vec_is SubExp
se' []

-- Each thread reduces over the number of segments
-- each of which is done sequentially
-- Maybe we should select the work of the inner loop
-- based on n_segments and dimensions etc.
segmentedReduction ::
  Pattern MCMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  DoSegBody ->
  MulticoreGen Imp.Code
segmentedReduction :: Pattern MCMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen Code
segmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody =
  MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    TV Int64
n_par_segments <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"segment_iter" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int64))
-> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
    Code
body <- TV Int64
-> Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> DoSegBody
-> MulticoreGen Code
compileSegRedBody TV Int64
n_par_segments Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody
    [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
body (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
n_par_segments])
    let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
    Code -> MulticoreGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"segmented_segred" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
n_par_segments) Code
body_allocs Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

compileSegRedBody ::
  TV Int64 ->
  Pattern MCMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  DoSegBody ->
  MulticoreGen Imp.Code
compileSegRedBody :: TV Int64
-> Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> DoSegBody
-> MulticoreGen Code
compileSegRedBody TV Int64
n_segments Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
      inner_bound :: TExp Int64
inner_bound = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
ns_64
      n_segments' :: TExp Int64
n_segments' = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
n_segments

  let per_red_pes :: [[PatElemT LParamMem]]
per_red_pes = [SegBinOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
reds ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat
  -- Perform sequential reduce on inner most dimension
  MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    TExp Int64
flat_idx <- String -> TExp Int64 -> ImpM MCMem HostEnv Multicore (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"flat_idx" (TExp Int64 -> ImpM MCMem HostEnv Multicore (TExp Int64))
-> TExp Int64 -> ImpM MCMem HostEnv Multicore (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
n_segments' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
inner_bound
    (VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
flat_idx
    String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the accumulators" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      [([PatElemT LParamMem], SegBinOp MCMem)]
-> (([PatElemT LParamMem], SegBinOp MCMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp MCMem] -> [([PatElemT LParamMem], SegBinOp MCMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElemT LParamMem]]
per_red_pes [SegBinOp MCMem]
reds) ((([PatElemT LParamMem], SegBinOp MCMem) -> MulticoreGen ())
 -> MulticoreGen ())
-> (([PatElemT LParamMem], SegBinOp MCMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LParamMem]
pes, SegBinOp MCMem
red) ->
        [(PatElemT LParamMem, SubExp)]
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
pes (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) (((PatElemT LParamMem, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, SubExp
ne) ->
          Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
red) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
            VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []

    String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"main body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
      Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOp MCMem -> [Param LParamMem])
-> [SegBinOp MCMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp MCMem]
reds
      String
-> TExp Int64 -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
inner_bound ((TExp Int64 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        (TV Int64 -> TExp Int64 -> MulticoreGen ())
-> [TV Int64] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
          TV Int64 -> TExp Int64 -> MulticoreGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
(<--)
          ((VName -> TV Int64) -> [VName] -> [TV Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
`mkTV` PrimType
int64) ([VName] -> [TV Int64]) -> [VName] -> [TV Int64]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is)
          ([TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_64) (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
n_segments'))
        VName -> TExp Int64 -> MulticoreGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
is) TExp Int64
i
        DoSegBody
kbody DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
all_red_res -> do
          let red_res' :: [[(SubExp, [TExp Int64])]]
red_res' = [Int] -> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp MCMem -> Int) -> [SegBinOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp MCMem -> [SubExp]) -> SegBinOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp MCMem]
reds) [(SubExp, [TExp Int64])]
all_red_res
          [([PatElemT LParamMem], SegBinOp MCMem, [(SubExp, [TExp Int64])])]
-> (([PatElemT LParamMem], SegBinOp MCMem,
     [(SubExp, [TExp Int64])])
    -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp MCMem]
-> [[(SubExp, [TExp Int64])]]
-> [([PatElemT LParamMem], SegBinOp MCMem,
     [(SubExp, [TExp Int64])])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LParamMem]]
per_red_pes [SegBinOp MCMem]
reds [[(SubExp, [TExp Int64])]]
red_res') ((([PatElemT LParamMem], SegBinOp MCMem, [(SubExp, [TExp Int64])])
  -> MulticoreGen ())
 -> MulticoreGen ())
-> (([PatElemT LParamMem], SegBinOp MCMem,
     [(SubExp, [TExp Int64])])
    -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LParamMem]
pes, SegBinOp MCMem
red, [(SubExp, [TExp Int64])]
res') ->
            Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
red) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
              String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load accum" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
                let acc_params :: [Param LParamMem]
acc_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ (LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) SegBinOp MCMem
red
                [(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((Param LParamMem, PatElemT LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

              String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load new val" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
                let next_params :: [Param LParamMem]
next_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ (LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) SegBinOp MCMem
red
                [(Param LParamMem, (SubExp, [TExp Int64]))]
-> ((Param LParamMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(SubExp, [TExp Int64])]
-> [(Param LParamMem, (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
next_params [(SubExp, [TExp Int64])]
res') (((Param LParamMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((Param LParamMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TExp Int64]
res_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

              String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply reduction" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
                let lbody :: Body MCMem
lbody = (LambdaT MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT MCMem -> Body MCMem)
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> Body MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) SegBinOp MCMem
red
                Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms Body MCMem
lbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                  String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write back to res" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                    [(PatElemT LParamMem, SubExp)]
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
pes (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body MCMem
lbody)) (((PatElemT LParamMem, SubExp) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((PatElemT LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
                      \(PatElemT LParamMem
pe, SubExp
se') -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se' []