{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.ImpGen.Kernels.SegScan
  ( compileSegScan )
  where

import Control.Monad.Except
import Data.Maybe
import Data.List ()

import Prelude hiding (quot, rem)

import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)

makeLocalArrays :: Count GroupSize SubExp -> SubExp -> [SubExp] -> Lambda ExplicitMemory
                -> InKernelGen [VName]
makeLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [SubExp]
-> Lambda ExplicitMemory
-> InKernelGen [VName]
makeLocalArrays (Count SubExp
group_size) SubExp
num_threads [SubExp]
nes Lambda ExplicitMemory
scan_op = do
  let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
_scan_y_params) =
        Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
  [Param (MemInfo SubExp NoUniqueness MemBind)]
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params ((Param (MemInfo SubExp NoUniqueness MemBind)
  -> ImpM ExplicitMemory KernelEnv KernelOp VName)
 -> InKernelGen [VName])
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param (MemInfo SubExp NoUniqueness MemBind)
p ->
    case Param (MemInfo SubExp NoUniqueness MemBind)
-> MemInfo SubExp NoUniqueness MemBind
forall attr. Param attr -> attr
paramAttr Param (MemInfo SubExp NoUniqueness MemBind)
p of
      MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
        let shape' :: ShapeBase SubExp
shape' = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
        String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray String
"scan_arr" PrimType
pt ShapeBase SubExp
shape' (MemBind -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> MemBind -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape'
      MemInfo SubExp NoUniqueness MemBind
_ -> do
        let pt :: PrimType
pt = TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
            shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
        String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray String
"scan_arr" PrimType
pt ShapeBase SubExp
shape (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

type CrossesSegment = Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)

localArrayIndex :: KernelConstants -> Type -> Imp.Exp
localArrayIndex :: KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t =
  if TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase (ShapeBase SubExp) NoUniqueness
t
  then KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
  else KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants

barrierFor :: Lambda ExplicitMemory -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda ExplicitMemory -> (Bool, Fence, InKernelGen ())
barrierFor Lambda ExplicitMemory
scan_op = (Bool
array_scan, Fence
fence, KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
  where array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda ExplicitMemory
scan_op
        fence :: Fence
fence | Bool
array_scan = Fence
Imp.FenceGlobal
              | Bool
otherwise = Fence
Imp.FenceLocal

-- | Produce partially scanned intervals; one per workgroup.
scanStage1 :: Pattern ExplicitMemory
           -> Count NumGroups SubExp -> Count GroupSize SubExp -> SegSpace
           -> Lambda ExplicitMemory -> [SubExp]
           -> KernelBody ExplicitMemory
           -> CallKernelGen (VName, Imp.Exp, CrossesSegment)
scanStage1 :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen (VName, Exp, CrossesSegment)
scanStage1 (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes KernelBody ExplicitMemory
kbody = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
  VName
num_threads <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
                 Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      rets :: [TypeBase (ShapeBase SubExp) NoUniqueness]
rets = Lambda ExplicitMemory -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda ExplicitMemory
scan_op
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
  let num_elements :: Exp
num_elements = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims'
      elems_per_thread :: Exp
elems_per_thread = Exp
num_elements Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` VName -> Exp
Imp.vi32 VName
num_threads
      elems_per_group :: Exp
elems_per_group = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_thread

  -- Squirrel away a copy of the operator with unique names that we
  -- can pass to groupScan.
  Lambda ExplicitMemory
scan_op_renamed <- Lambda ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
scan_op

  let crossesSegment :: CrossesSegment
crossesSegment =
        case [Exp] -> [Exp]
forall a. [a] -> [a]
reverse [Exp]
dims' of
          Exp
segment_size : Exp
_ : [Exp]
_ -> (Exp -> Exp -> Exp) -> CrossesSegment
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> CrossesSegment)
-> (Exp -> Exp -> Exp) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
            (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)
          [Exp]
_ -> CrossesSegment
forall a. Maybe a
Nothing

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"scan_stage1" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [VName]
local_arrs <- Count GroupSize SubExp
-> SubExp
-> [SubExp]
-> Lambda ExplicitMemory
-> InKernelGen [VName]
makeLocalArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads) [SubExp]
nes Lambda ExplicitMemory
scan_op

    -- The variables from scan_op will be used for the carry and such
    -- in the big chunking loop.
    Maybe (Exp ExplicitMemory)
-> Scope ExplicitMemory -> InKernelGen ()
forall lore r op.
Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore r op ()
dScope Maybe (Exp ExplicitMemory)
forall a. Maybe a
Nothing (Scope ExplicitMemory -> InKernelGen ())
-> Scope ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> Scope ExplicitMemory)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
    let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params) =
          Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op

    [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [SubExp]
nes) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []

    String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"j" Exp
elems_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
j -> do
      VName
chunk_offset <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"chunk_offset" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                      KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
j Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
                      KernelConstants -> Exp
kernelGroupId KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group
      VName
flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                  VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      -- Construct segment indices.
      (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
flat_idx PrimType
int32

      let in_bounds :: Exp
in_bounds =
            (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) [Exp]
dims'
          when_in_bounds :: InKernelGen ()
when_in_bounds = Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            let ([KernelResult]
scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
kbody
            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write to-scan values to parameters" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [KernelResult]
scan_res) (((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, KernelResult
se) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []
            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write mapped values results to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) [KernelResult]
map_res) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, KernelResult
se) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids)
              (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []
          when_out_of_bounds :: InKernelGen ()
when_out_of_bounds = [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [SubExp]
nes) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bounds read input; others get neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
in_bounds InKernelGen ()
when_in_bounds InKernelGen ()
when_out_of_bounds

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"combine with carry and write to local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT ExplicitMemory -> Stms ExplicitMemory
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT ExplicitMemory -> Stms ExplicitMemory)
-> BodyT ExplicitMemory -> Stms ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> BodyT ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_op) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [VName]
local_arrs (BodyT ExplicitMemory -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT ExplicitMemory -> [SubExp])
-> BodyT ExplicitMemory -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> BodyT ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_op)) (((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, SubExp
se) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t] SubExp
se []

      let crossesSegment' :: CrossesSegment
crossesSegment' = do
            Exp -> Exp -> Exp
f <- CrossesSegment
crossesSegment
            (Exp -> Exp -> Exp) -> CrossesSegment
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> CrossesSegment)
-> (Exp -> Exp -> Exp) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
              let from' :: Exp
from' = Exp
from Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32
                  to' :: Exp
to' = Exp
to Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32
              in Exp -> Exp -> Exp
f Exp
from' Exp
to'

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

      CrossesSegment
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan CrossesSegment
crossesSegment'
        (VName -> Exp
Imp.vi32 VName
num_threads)
        (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Lambda ExplicitMemory
scan_op_renamed [VName]
local_arrs

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bounds write partial scan result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(TypeBase (ShapeBase SubExp) NoUniqueness,
  PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase (ShapeBase SubExp) NoUniqueness,
   PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids)
        (VName -> SubExp
Var VName
arr) [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t]

      InKernelGen ()
barrier

      let load_carry :: InKernelGen ()
load_carry =
            [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((VName, Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr)
            [if TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
             then KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
             else (KernelConstants -> Exp
kernelGroupId KernelConstants
constantsExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]
          load_neutral :: InKernelGen ()
load_neutral =
            [(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread reads last element as carry-in for next iteration" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        Exp
crosses_segment <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"crosses_segment" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
          case CrossesSegment
crossesSegment of
            CrossesSegment
Nothing -> Exp
forall v. PrimExp v
false
            Just Exp -> Exp -> Exp
f -> Exp -> Exp -> Exp
f (VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
                         KernelConstants -> Exp
kernelGroupSize KernelConstants
constantsExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1)
                        (VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
                         KernelConstants -> Exp
kernelGroupSize KernelConstants
constants)
        Exp
should_load_carry <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"should_load_carry" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
          KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
Not Exp
crosses_segment
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
should_load_carry InKernelGen ()
load_carry
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
should_load_carry InKernelGen ()
load_neutral

      InKernelGen ()
barrier

  (VName, Exp, CrossesSegment)
-> CallKernelGen (VName, Exp, CrossesSegment)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
num_threads, Exp
elems_per_group, CrossesSegment
crossesSegment)

  where (Bool
array_scan, Fence
fence, InKernelGen ()
barrier) = Lambda ExplicitMemory -> (Bool, Fence, InKernelGen ())
barrierFor Lambda ExplicitMemory
scan_op

scanStage2 :: Pattern ExplicitMemory
           -> VName -> Imp.Exp -> Count NumGroups SubExp -> CrossesSegment -> SegSpace
           -> Lambda ExplicitMemory -> [SubExp]
           -> CallKernelGen ()
scanStage2 :: Pattern ExplicitMemory
-> VName
-> Exp
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage2 (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)]
pes) VName
stage1_num_threads Exp
elems_per_group Count NumGroups SubExp
num_groups CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes = do
  -- Our group size is the number of groups for the stage 1 kernel.
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size

  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      rets :: [TypeBase (ShapeBase SubExp) NoUniqueness]
rets = Lambda ExplicitMemory -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda ExplicitMemory
scan_op
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
  let crossesSegment' :: CrossesSegment
crossesSegment' = do
        Exp -> Exp -> Exp
f <- CrossesSegment
crossesSegment
        (Exp -> Exp -> Exp) -> CrossesSegment
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> CrossesSegment)
-> (Exp -> Exp -> Exp) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
          Exp -> Exp -> Exp
f ((Exp
from Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) ((Exp
to Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread  String
"scan_stage2" Count NumGroups Exp
1 Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [VName]
local_arrs <- Count GroupSize SubExp
-> SubExp
-> [SubExp]
-> Lambda ExplicitMemory
-> InKernelGen [VName]
makeLocalArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
stage1_num_threads) [SubExp]
nes Lambda ExplicitMemory
scan_op

    VName
flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
      (KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
    -- Construct segment indices.
    (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
flat_idx PrimType
int32

    let in_bounds :: Exp
in_bounds =
          (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) [Exp]
dims'
        when_in_bounds :: InKernelGen ()
when_in_bounds = [(TypeBase (ShapeBase SubExp) NoUniqueness, VName,
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName,
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [VName]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness, VName,
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [VName]
local_arrs [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((TypeBase (ShapeBase SubExp) NoUniqueness, VName,
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName,
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t]
          (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids
        when_out_of_bounds :: InKernelGen ()
when_out_of_bounds = [(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [VName]
local_arrs [SubExp]
nes) (((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, SubExp
ne) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t] SubExp
ne []

    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bound read carries; others get neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
in_bounds InKernelGen ()
when_in_bounds InKernelGen ()
when_out_of_bounds

    InKernelGen ()
barrier

    CrossesSegment
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan CrossesSegment
crossesSegment'
      (VName -> Exp
Imp.vi32 VName
stage1_num_threads) (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Lambda ExplicitMemory
scan_op [VName]
local_arrs

    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bounds write scanned carries" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(TypeBase (ShapeBase SubExp) NoUniqueness,
  PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase (ShapeBase SubExp) NoUniqueness,
   PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids)
      (VName -> SubExp
Var VName
arr) [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t]

  where (Bool
_, Fence
_, InKernelGen ()
barrier) = Lambda ExplicitMemory -> (Bool, Fence, InKernelGen ())
barrierFor Lambda ExplicitMemory
scan_op

scanStage3 :: Pattern ExplicitMemory
           -> Count NumGroups SubExp -> Count GroupSize SubExp
           -> Imp.Exp -> CrossesSegment -> SegSpace
           -> Lambda ExplicitMemory -> [SubExp]
           -> CallKernelGen ()
scanStage3 :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Exp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage3 (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size Exp
elems_per_group CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
  Exp
required_groups <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"required_groups" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                     [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims' Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"scan_stage3" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt Exp
required_groups ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
virt_group_id -> do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

    -- Compute our logical index.
    Exp
flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                VName -> Exp
Imp.vi32 VName
virt_group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
                KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
    (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' Exp
flat_idx

    -- Figure out which group this element was originally in.
    VName
orig_group <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"orig_group" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ Exp
flat_idx Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
elems_per_group
    -- Then the index of the carry-in of the preceding group.
    VName
carry_in_flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"carry_in_flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                         VName -> PrimType -> Exp
Imp.var VName
orig_group PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
    -- Figure out the logical index of the carry-in.
    let carry_in_idx :: [Exp]
carry_in_idx = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
carry_in_flat_idx PrimType
int32

    -- Apply the carry if we are not in the scan results for the first
    -- group, and are not the last element in such a group (because
    -- then the carry was updated in stage 2), and we are not crossing
    -- a segment boundary.
    let in_bounds :: Exp
in_bounds =
          (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) [Exp]
dims'
        crosses_segment :: Exp
crosses_segment = Exp -> Maybe Exp -> Exp
forall a. a -> Maybe a -> a
fromMaybe Exp
forall v. PrimExp v
false (Maybe Exp -> Exp) -> Maybe Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          CrossesSegment
crossesSegment CrossesSegment -> Maybe Exp -> Maybe (Exp -> Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            Exp -> Maybe Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> PrimType -> Exp
Imp.var VName
carry_in_flat_idx PrimType
int32) Maybe (Exp -> Exp) -> Maybe Exp -> Maybe Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            Exp -> Maybe Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
flat_idx
        is_a_carry :: Exp
is_a_carry = Exp
flat_idx Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==.
                     (VName -> PrimType -> Exp
Imp.var VName
orig_group PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
        no_carry_in :: Exp
no_carry_in = VName -> PrimType -> Exp
Imp.var VName
orig_group PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
is_a_carry Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
crosses_segment

    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      Maybe (Exp ExplicitMemory)
-> Scope ExplicitMemory -> InKernelGen ()
forall lore r op.
Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore r op ()
dScope Maybe (Exp ExplicitMemory)
forall a. Maybe a
Nothing (Scope ExplicitMemory -> InKernelGen ())
-> Scope ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> Scope ExplicitMemory)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
      let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params) =
            Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
      [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
carry_in_idx
      [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids
      [Param (MemInfo SubExp NoUniqueness MemBind)]
-> BodyT ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params (BodyT ExplicitMemory -> InKernelGen ())
-> BodyT ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> BodyT ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_op
      [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan :: Pattern ExplicitMemory
               -> SegLevel -> SegSpace
               -> Lambda ExplicitMemory -> [SubExp]
               -> KernelBody ExplicitMemory
               -> CallKernelGen ()
compileSegScan :: Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
compileSegScan Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes KernelBody ExplicitMemory
kbody = do
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing

  (VName
stage1_num_threads, Exp
elems_per_group, CrossesSegment
crossesSegment) <-
    Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen (VName, Exp, CrossesSegment)
scanStage1 Pattern ExplicitMemory
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes KernelBody ExplicitMemory
kbody

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
elems_per_group

  Lambda ExplicitMemory
scan_op' <- Lambda ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
scan_op
  Lambda ExplicitMemory
scan_op'' <- Lambda ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
scan_op
  Pattern ExplicitMemory
-> VName
-> Exp
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage2 Pattern ExplicitMemory
pat VName
stage1_num_threads Exp
elems_per_group (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op' [SubExp]
nes
  Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Exp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage3 Pattern ExplicitMemory
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Exp
elems_per_group CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op'' [SubExp]
nes