{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Code generation for segmented and non-segmented scans.  Uses a
-- fairly inefficient two-pass algorithm, but can handle anything.
module Futhark.CodeGen.ImpGen.Kernels.SegScan.TwoPass (compileSegScan) where

import Control.Monad.Except
import Control.Monad.State
import Data.List (delete, find, foldl', zip4)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- Aggressively try to reuse memory for different SegBinOps, because
-- we will run them sequentially after another.
makeLocalArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  [SegBinOp KernelsMem] ->
  InKernelGen [[VName]]
makeLocalArrays :: Count GroupSize SubExp
-> SubExp -> [SegBinOp KernelsMem] -> InKernelGen [[VName]]
makeLocalArrays (Count SubExp
group_size) SubExp
num_threads [SegBinOp KernelsMem]
scans = do
  ([[VName]]
arrs, [([Count Bytes (TExp Int64)], VName)]
mems_and_sizes) <- StateT
  [([Count Bytes (TExp Int64)], VName)]
  (ImpM KernelsMem KernelEnv KernelOp)
  [[VName]]
-> [([Count Bytes (TExp Int64)], VName)]
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     ([[VName]], [([Count Bytes (TExp Int64)], VName)])
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((SegBinOp KernelsMem
 -> StateT
      [([Count Bytes (TExp Int64)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      [VName])
-> [SegBinOp KernelsMem]
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp KernelsMem
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [VName]
onScan [SegBinOp KernelsMem]
scans) [([Count Bytes (TExp Int64)], VName)]
forall a. Monoid a => a
mempty
  let maxSize :: [Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count u (TPrimExp Int64 v)]
sizes = TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v))
-> TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v)
-> TPrimExp Int64 v -> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 v
1 ([TPrimExp Int64 v] -> TPrimExp Int64 v)
-> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall a b. (a -> b) -> a -> b
$ (Count u (TPrimExp Int64 v) -> TPrimExp Int64 v)
-> [Count u (TPrimExp Int64 v)] -> [TPrimExp Int64 v]
forall a b. (a -> b) -> [a] -> [b]
map Count u (TPrimExp Int64 v) -> TPrimExp Int64 v
forall u e. Count u e -> e
Imp.unCount [Count u (TPrimExp Int64 v)]
sizes
  [([Count Bytes (TExp Int64)], VName)]
-> (([Count Bytes (TExp Int64)], VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([Count Bytes (TExp Int64)], VName)]
mems_and_sizes ((([Count Bytes (TExp Int64)], VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (([Count Bytes (TExp Int64)], VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \([Count Bytes (TExp Int64)]
sizes, VName
mem) ->
    VName
-> Count Bytes (TExp Int64)
-> Space
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
mem ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall {v} {u}.
Pretty v =>
[Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count Bytes (TExp Int64)]
sizes) (SpaceId -> Space
Space SpaceId
"local")
  [[VName]] -> InKernelGen [[VName]]
forall (m :: * -> *) a. Monad m => a -> m a
return [[VName]]
arrs
  where
    onScan :: SegBinOp KernelsMem
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [VName]
onScan (SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
_) = do
      let ([Param LParamMem]
scan_x_params, [Param LParamMem]
_scan_y_params) =
            Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_op
      ([VName]
arrs, [[([Count Bytes (TExp Int64)], VName)]]
used_mems) <- ([(VName, [([Count Bytes (TExp Int64)], VName)])]
 -> ([VName], [[([Count Bytes (TExp Int64)], VName)]]))
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TExp Int64)], VName)])]
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ([VName], [[([Count Bytes (TExp Int64)], VName)]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, [([Count Bytes (TExp Int64)], VName)])]
-> ([VName], [[([Count Bytes (TExp Int64)], VName)]])
forall a b. [(a, b)] -> ([a], [b])
unzip (StateT
   [([Count Bytes (TExp Int64)], VName)]
   (ImpM KernelsMem KernelEnv KernelOp)
   [(VName, [([Count Bytes (TExp Int64)], VName)])]
 -> StateT
      [([Count Bytes (TExp Int64)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      ([VName], [[([Count Bytes (TExp Int64)], VName)]]))
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TExp Int64)], VName)])]
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ([VName], [[([Count Bytes (TExp Int64)], VName)]])
forall a b. (a -> b) -> a -> b
$
        [Param LParamMem]
-> (Param LParamMem
    -> StateT
         [([Count Bytes (TExp Int64)], VName)]
         (ImpM KernelsMem KernelEnv KernelOp)
         (VName, [([Count Bytes (TExp Int64)], VName)]))
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TExp Int64)], VName)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LParamMem]
scan_x_params ((Param LParamMem
  -> StateT
       [([Count Bytes (TExp Int64)], VName)]
       (ImpM KernelsMem KernelEnv KernelOp)
       (VName, [([Count Bytes (TExp Int64)], VName)]))
 -> StateT
      [([Count Bytes (TExp Int64)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      [(VName, [([Count Bytes (TExp Int64)], VName)])])
-> (Param LParamMem
    -> StateT
         [([Count Bytes (TExp Int64)], VName)]
         (ImpM KernelsMem KernelEnv KernelOp)
         (VName, [([Count Bytes (TExp Int64)], VName)]))
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TExp Int64)], VName)])]
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
          case Param LParamMem -> LParamMem
forall dec. Param dec -> dec
paramDec Param LParamMem
p of
            MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
              let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
              VName
arr <-
                ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM KernelsMem KernelEnv KernelOp VName
 -> StateT
      [([Count Bytes (TExp Int64)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      VName)
-> ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall a b. (a -> b) -> a -> b
$
                  SpaceId
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray SpaceId
"scan_arr" PrimType
pt Shape
shape' (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
              (VName, [([Count Bytes (TExp Int64)], VName)])
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     (VName, [([Count Bytes (TExp Int64)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [])
            LParamMem
_ -> do
              let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
                  shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
              ([Count Bytes (TExp Int64)]
sizes, VName
mem') <- PrimType
-> Shape
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ([Count Bytes (TExp Int64)], VName)
forall {t :: (* -> *) -> * -> *} {lore} {r} {op}.
(MonadState
   [([Count Bytes (TExp Int64)], VName)] (t (ImpM lore r op)),
 MonadTrans t) =>
PrimType
-> Shape -> t (ImpM lore r op) ([Count Bytes (TExp Int64)], VName)
getMem PrimType
pt Shape
shape
              VName
arr <- ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM KernelsMem KernelEnv KernelOp VName
 -> StateT
      [([Count Bytes (TExp Int64)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      VName)
-> ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall a b. (a -> b) -> a -> b
$ SpaceId
-> PrimType
-> Shape
-> VName
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimType -> Shape -> VName -> ImpM lore r op VName
sArrayInMem SpaceId
"scan_arr" PrimType
pt Shape
shape VName
mem'
              (VName, [([Count Bytes (TExp Int64)], VName)])
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     (VName, [([Count Bytes (TExp Int64)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [([Count Bytes (TExp Int64)]
sizes, VName
mem')])
      ([([Count Bytes (TExp Int64)], VName)]
 -> [([Count Bytes (TExp Int64)], VName)])
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)]
forall a. Semigroup a => a -> a -> a
<> [[([Count Bytes (TExp Int64)], VName)]]
-> [([Count Bytes (TExp Int64)], VName)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[([Count Bytes (TExp Int64)], VName)]]
used_mems)
      [VName]
-> StateT
     [([Count Bytes (TExp Int64)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
arrs

    getMem :: PrimType
-> Shape -> t (ImpM lore r op) ([Count Bytes (TExp Int64)], VName)
getMem PrimType
pt Shape
shape = do
      let size :: Count Bytes (TExp Int64)
size = TypeBase Shape NoUniqueness -> Count Bytes (TExp Int64)
typeSize (TypeBase Shape NoUniqueness -> Count Bytes (TExp Int64))
-> TypeBase Shape NoUniqueness -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness
      [([Count Bytes (TExp Int64)], VName)]
mems <- t (ImpM lore r op) [([Count Bytes (TExp Int64)], VName)]
forall s (m :: * -> *). MonadState s m => m s
get
      case ((([Count Bytes (TExp Int64)], VName) -> Bool)
-> [([Count Bytes (TExp Int64)], VName)]
-> Maybe ([Count Bytes (TExp Int64)], VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Count Bytes (TExp Int64)
size Count Bytes (TExp Int64) -> [Count Bytes (TExp Int64)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`) ([Count Bytes (TExp Int64)] -> Bool)
-> (([Count Bytes (TExp Int64)], VName)
    -> [Count Bytes (TExp Int64)])
-> ([Count Bytes (TExp Int64)], VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Count Bytes (TExp Int64)], VName) -> [Count Bytes (TExp Int64)]
forall a b. (a, b) -> a
fst) [([Count Bytes (TExp Int64)], VName)]
mems, [([Count Bytes (TExp Int64)], VName)]
mems) of
        (Just ([Count Bytes (TExp Int64)], VName)
mem, [([Count Bytes (TExp Int64)], VName)]
_) -> do
          ([([Count Bytes (TExp Int64)], VName)]
 -> [([Count Bytes (TExp Int64)], VName)])
-> t (ImpM lore r op) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([([Count Bytes (TExp Int64)], VName)]
  -> [([Count Bytes (TExp Int64)], VName)])
 -> t (ImpM lore r op) ())
-> ([([Count Bytes (TExp Int64)], VName)]
    -> [([Count Bytes (TExp Int64)], VName)])
-> t (ImpM lore r op) ()
forall a b. (a -> b) -> a -> b
$ ([Count Bytes (TExp Int64)], VName)
-> [([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)]
forall a. Eq a => a -> [a] -> [a]
delete ([Count Bytes (TExp Int64)], VName)
mem
          ([Count Bytes (TExp Int64)], VName)
-> t (ImpM lore r op) ([Count Bytes (TExp Int64)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (TExp Int64)], VName)
mem
        (Maybe ([Count Bytes (TExp Int64)], VName)
Nothing, ([Count Bytes (TExp Int64)]
size', VName
mem) : [([Count Bytes (TExp Int64)], VName)]
mems') -> do
          [([Count Bytes (TExp Int64)], VName)] -> t (ImpM lore r op) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [([Count Bytes (TExp Int64)], VName)]
mems'
          ([Count Bytes (TExp Int64)], VName)
-> t (ImpM lore r op) ([Count Bytes (TExp Int64)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count Bytes (TExp Int64)
size Count Bytes (TExp Int64)
-> [Count Bytes (TExp Int64)] -> [Count Bytes (TExp Int64)]
forall a. a -> [a] -> [a]
: [Count Bytes (TExp Int64)]
size', VName
mem)
        (Maybe ([Count Bytes (TExp Int64)], VName)
Nothing, []) -> do
          VName
mem <- ImpM lore r op VName -> t (ImpM lore r op) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM lore r op VName -> t (ImpM lore r op) VName)
-> ImpM lore r op VName -> t (ImpM lore r op) VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space -> ImpM lore r op VName
forall lore r op. SpaceId -> Space -> ImpM lore r op VName
sDeclareMem SpaceId
"scan_arr_mem" (Space -> ImpM lore r op VName) -> Space -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space
Space SpaceId
"local"
          ([Count Bytes (TExp Int64)], VName)
-> t (ImpM lore r op) ([Count Bytes (TExp Int64)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (TExp Int64)
size], VName
mem)

type CrossesSegment = Maybe (Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Bool)

localArrayIndex :: KernelConstants -> Type -> Imp.TExp Int64
localArrayIndex :: KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t =
  if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
    then TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
    else TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelGlobalThreadId KernelConstants
constants)

barrierFor :: Lambda KernelsMem -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda KernelsMem
-> (Bool, Fence, ImpM KernelsMem KernelEnv KernelOp ())
barrierFor Lambda KernelsMem
scan_op = (Bool
array_scan, Fence
fence, KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM KernelsMem KernelEnv KernelOp ())
-> KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
  where
    array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda KernelsMem
scan_op
    fence :: Fence
fence
      | Bool
array_scan = Fence
Imp.FenceGlobal
      | Bool
otherwise = Fence
Imp.FenceLocal

xParams, yParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
xParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan =
  Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan))
yParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan =
  Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan))

writeToScanValues ::
  [VName] ->
  ([PatElem KernelsMem], SegBinOp KernelsMem, [KernelResult]) ->
  InKernelGen ()
writeToScanValues :: [VName]
-> ([PatElem KernelsMem], SegBinOp KernelsMem, [KernelResult])
-> ImpM KernelsMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids ([PatElem KernelsMem]
pes, SegBinOp KernelsMem
scan, [KernelResult]
scan_res)
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
    [(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LParamMem]
pes [KernelResult]
scan_res) (((PatElemT LParamMem, KernelResult)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, KernelResult
res) ->
      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
        (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
        ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids)
        (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
        []
  | Bool
otherwise =
    [(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan) [KernelResult]
scan_res) (((Param LParamMem, KernelResult)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
res) ->
      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

readToScanValues ::
  [Imp.TExp Int64] ->
  [PatElem KernelsMem] ->
  SegBinOp KernelsMem ->
  InKernelGen ()
readToScanValues :: [TExp Int64]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readToScanValues [TExp Int64]
is [PatElem KernelsMem]
pes SegBinOp KernelsMem
scan
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
    [(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan) [PatElem KernelsMem]
[PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)) [TExp Int64]
is
  | Bool
otherwise =
    () -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

readCarries ::
  Imp.TExp Int64 ->
  [Imp.TExp Int64] ->
  [Imp.TExp Int64] ->
  [PatElem KernelsMem] ->
  SegBinOp KernelsMem ->
  InKernelGen ()
readCarries :: TExp Int64
-> [TExp Int64]
-> [TExp Int64]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readCarries TExp Int64
chunk_offset [TExp Int64]
dims' [TExp Int64]
vec_is [PatElem KernelsMem]
pes SegBinOp KernelsMem
scan
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
    TPrimExp Int32 ExpLeaf
ltid <- KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId (KernelConstants -> TPrimExp Int32 ExpLeaf)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> TPrimExp Int32 ExpLeaf
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int32 ExpLeaf)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    -- We may have to reload the carries from the output of the
    -- previous chunk.
    TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
      (TExp Int64
chunk_offset TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int64
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int32 ExpLeaf
ltid TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0)
      ( do
          let is :: [TExp Int64]
is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int64
chunk_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
          [(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan) [PatElem KernelsMem]
[PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)) ([TExp Int64]
is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
      )
      ( [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan) (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (((Param LParamMem, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
      )
  | Bool
otherwise =
    () -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Produce partially scanned intervals; one per workgroup.
scanStage1 ::
  Pattern KernelsMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  KernelBody KernelsMem ->
  CallKernelGen (TV Int32, Imp.TExp Int64, CrossesSegment)
scanStage1 :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen (TV Int32, TExp Int64, CrossesSegment)
scanStage1 (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody = do
  let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
  TV Int32
num_threads <- SpaceId
-> TPrimExp Int32 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int32)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TV t)
dPrimV SpaceId
"num_threads" (TPrimExp Int32 ExpLeaf
 -> ImpM KernelsMem HostEnv HostOp (TV Int32))
-> TPrimExp Int32 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 ExpLeaf)
-> TExp Int64 -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'

  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
  let num_elements :: TExp Int64
num_elements = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims'
      elems_per_thread :: TExp Int64
elems_per_thread = TExp Int64
num_elements TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
num_threads)
      elems_per_group :: TExp Int64
elems_per_group = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_thread

  let crossesSegment :: CrossesSegment
crossesSegment =
        case [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
reverse [TExp Int64]
dims' of
          TExp Int64
segment_size : TExp Int64
_ : [TExp Int64]
_ -> (TExp Int64 -> TExp Int64 -> TExp Bool) -> CrossesSegment
forall a. a -> Maybe a
Just ((TExp Int64 -> TExp Int64 -> TExp Bool) -> CrossesSegment)
-> (TExp Int64 -> TExp Int64 -> TExp Bool) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \TExp Int64
from TExp Int64
to ->
            (TExp Int64
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
from) TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int64
to TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
segment_size)
          [TExp Int64]
_ -> CrossesSegment
forall a. Maybe a
Nothing

  SpaceId
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM KernelsMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage1" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
all_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp KernelsMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_threads) [SegBinOp KernelsMem]
scans

    -- The variables from scan_op will be used for the carry and such
    -- in the big chunking loop.
    [SegBinOp KernelsMem]
-> (SegBinOp KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOp KernelsMem]
scans ((SegBinOp KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (SegBinOp KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \SegBinOp KernelsMem
scan -> do
      Maybe (Exp KernelsMem)
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp KernelsMem)
forall a. Maybe a
Nothing (Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope KernelsMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope KernelsMem)
-> [Param LParamMem] -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda KernelsMem -> [LParam KernelsMem])
-> Lambda KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan
      [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan) (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (((Param LParamMem, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
        VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []

    SpaceId
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
SpaceId
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor SpaceId
"j" TExp Int64
elems_per_thread ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
j -> do
      TV Int64
chunk_offset <-
        SpaceId
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TV t)
dPrimV SpaceId
"chunk_offset" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
          TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
j
            TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelGroupId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group
      TV Int64
flat_idx <-
        SpaceId
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TV t)
dPrimV SpaceId
"flat_idx" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
          TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
      -- Construct segment indices.
      (VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> [VName] -> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx

      let per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp KernelsMem]
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [PatElem KernelsMem]
[PatElemT LParamMem]
all_pes

          in_bounds :: TExp Bool
in_bounds =
            (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64 -> TExp Bool)
-> [TExp Int64] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) [TExp Int64]
dims'

          when_in_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds = Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
                  Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
scans) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
                per_scan_res :: [[KernelResult]]
per_scan_res =
                  [SegBinOp KernelsMem] -> [KernelResult] -> [[KernelResult]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [KernelResult]
all_scan_res

            SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"write to-scan values to parameters" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              (([PatElemT LParamMem], SegBinOp KernelsMem, [KernelResult])
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> [([PatElemT LParamMem], SegBinOp KernelsMem, [KernelResult])]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([VName]
-> ([PatElem KernelsMem], SegBinOp KernelsMem, [KernelResult])
-> ImpM KernelsMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids) ([([PatElemT LParamMem], SegBinOp KernelsMem, [KernelResult])]
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> [([PatElemT LParamMem], SegBinOp KernelsMem, [KernelResult])]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                [[PatElemT LParamMem]]
-> [SegBinOp KernelsMem]
-> [[KernelResult]]
-> [([PatElemT LParamMem], SegBinOp KernelsMem, [KernelResult])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LParamMem]]
per_scan_pes [SegBinOp KernelsMem]
scans [[KernelResult]]
per_scan_res

            SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"write mapped values results to global memory" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem KernelsMem]
[PatElemT LParamMem]
all_pes) [KernelResult]
map_res) (((PatElemT LParamMem, KernelResult)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, KernelResult
se) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                  (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                  ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids)
                  (KernelResult -> SubExp
kernelResultSubExp KernelResult
se)
                  []

      SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bounds read input" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
in_bounds ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds

      [([PatElemT LParamMem], SegBinOp KernelsMem, [VName])]
-> (([PatElemT LParamMem], SegBinOp KernelsMem, [VName])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp KernelsMem]
-> [[VName]]
-> [([PatElemT LParamMem], SegBinOp KernelsMem, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LParamMem]]
per_scan_pes [SegBinOp KernelsMem]
scans [[VName]]
all_local_arrs) ((([PatElemT LParamMem], SegBinOp KernelsMem, [VName])
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (([PatElemT LParamMem], SegBinOp KernelsMem, [VName])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        \([PatElemT LParamMem]
pes, scan :: SegBinOp KernelsMem
scan@(SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
vec_shape), [VName]
local_arrs) ->
          SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"do one intra-group scan operation" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            let rets :: [TypeBase Shape NoUniqueness]
rets = Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda KernelsMem
scan_op
                scan_x_params :: [LParam KernelsMem]
scan_x_params = SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan
                (Bool
array_scan, Fence
fence, ImpM KernelsMem KernelEnv KernelOp ()
barrier) = Lambda KernelsMem
-> (Bool, Fence, ImpM KernelsMem KernelEnv KernelOp ())
barrierFor Lambda KernelsMem
scan_op

            Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM KernelsMem KernelEnv KernelOp ()
barrier

            Shape
-> ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
vec_shape (([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
              SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"maybe restore some to-scan values to parameters, or read neutral" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
                  TExp Bool
in_bounds
                  ( do
                      [TExp Int64]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readToScanValues ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) [PatElem KernelsMem]
[PatElemT LParamMem]
pes SegBinOp KernelsMem
scan
                      TExp Int64
-> [TExp Int64]
-> [TExp Int64]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readCarries (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset) [TExp Int64]
dims' [TExp Int64]
vec_is [PatElem KernelsMem]
[PatElemT LParamMem]
pes SegBinOp KernelsMem
scan
                  )
                  ( [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan) (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (((Param LParamMem, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
                      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
                  )

              SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"combine with carry and write to local memory" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_op) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs (BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_op)) (((TypeBase Shape NoUniqueness, VName, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                    \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
se) ->
                      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
se []

              let crossesSegment' :: Maybe
  (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment' = do
                    TExp Int64 -> TExp Int64 -> TExp Bool
f <- CrossesSegment
crossesSegment
                    (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
     (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a. a -> Maybe a
Just ((TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
 -> Maybe
      (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool))
-> (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
     (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
from TPrimExp Int32 ExpLeaf
to ->
                      let from' :: TExp Int64
from' = TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
from TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                          to' :: TExp Int64
to' = TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                       in TExp Int64 -> TExp Int64 -> TExp Bool
f TExp Int64
from' TExp Int64
to'

              KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM KernelsMem KernelEnv KernelOp ())
-> KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

              -- We need to avoid parameter name clashes.
              Lambda KernelsMem
scan_op_renamed <- Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda KernelsMem
scan_op
              Maybe
  (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> Lambda KernelsMem
-> [VName]
-> ImpM KernelsMem KernelEnv KernelOp ()
groupScan
                Maybe
  (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment'
                (TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
num_threads)
                (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
                Lambda KernelsMem
scan_op_renamed
                [VName]
local_arrs

              SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bounds write partial scan result" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
in_bounds (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  [(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElemT LParamMem]
-> [VName]
-> [(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElemT LParamMem]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElemT LParamMem
pe, VName
arr) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                      (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                      ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
                      (VName -> SubExp
Var VName
arr)
                      [KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]

              ImpM KernelsMem KernelEnv KernelOp ()
barrier

              let load_carry :: ImpM KernelsMem KernelEnv KernelOp ()
load_carry =
                    [(VName, Param LParamMem)]
-> ((VName, Param LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LParamMem] -> [(VName, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [LParam KernelsMem]
[Param LParamMem]
scan_x_params) (((VName, Param LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, Param LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LParamMem
p) ->
                      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                        (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
                        []
                        (VName -> SubExp
Var VName
arr)
                        [ if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
                            then TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
                            else
                              (TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelGroupId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1)
                                TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
                        ]
                  load_neutral :: ImpM KernelsMem KernelEnv KernelOp ()
load_neutral =
                    [(SubExp, Param LParamMem)]
-> ((SubExp, Param LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp] -> [Param LParamMem] -> [(SubExp, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [LParam KernelsMem]
[Param LParamMem]
scan_x_params) (((SubExp, Param LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((SubExp, Param LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param LParamMem
p) ->
                      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []

              SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"first thread reads last element as carry-in for next iteration" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                TExp Bool
crosses_segment <- SpaceId
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TExp t)
dPrimVE SpaceId
"crosses_segment" (TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$
                  case CrossesSegment
crossesSegment of
                    CrossesSegment
Nothing -> TExp Bool
forall v. TPrimExp Bool v
false
                    Just TExp Int64 -> TExp Int64 -> TExp Bool
f ->
                      TExp Int64 -> TExp Int64 -> TExp Bool
f
                        ( TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                            TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
-TExp Int64
1
                        )
                        ( TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                            TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
                        )
                TExp Bool
should_load_carry <-
                  SpaceId
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TExp t)
dPrimVE SpaceId
"should_load_carry" (TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$
                    KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment
                TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
should_load_carry ImpM KernelsMem KernelEnv KernelOp ()
load_carry
                Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM KernelsMem KernelEnv KernelOp ()
barrier
                TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
should_load_carry ImpM KernelsMem KernelEnv KernelOp ()
load_neutral

              ImpM KernelsMem KernelEnv KernelOp ()
barrier

  (TV Int32, TExp Int64, CrossesSegment)
-> CallKernelGen (TV Int32, TExp Int64, CrossesSegment)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Int32
num_threads, TExp Int64
elems_per_group, CrossesSegment
crossesSegment)

scanStage2 ::
  Pattern KernelsMem ->
  TV Int32 ->
  Imp.TExp Int64 ->
  Count NumGroups SubExp ->
  CrossesSegment ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  CallKernelGen ()
scanStage2 :: Pattern KernelsMem
-> TV Int32
-> TExp Int64
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage2 (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
all_pes) TV Int32
stage1_num_threads TExp Int64
elems_per_group Count NumGroups SubExp
num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims

  -- Our group size is the number of groups for the stage 1 kernel.
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size

  let crossesSegment' :: Maybe
  (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment' = do
        TExp Int64 -> TExp Int64 -> TExp Bool
f <- CrossesSegment
crossesSegment
        (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
     (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a. a -> Maybe a
Just ((TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
 -> Maybe
      (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool))
-> (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
     (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
from TPrimExp Int32 ExpLeaf
to ->
          TExp Int64 -> TExp Int64 -> TExp Bool
f
            ((TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
from TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1)
            ((TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1)

  SpaceId
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM KernelsMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage2" Count NumGroups (TExp Int64)
1 Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
per_scan_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp KernelsMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
stage1_num_threads) [SegBinOp KernelsMem]
scans
    let per_scan_rets :: [[TypeBase Shape NoUniqueness]]
per_scan_rets = (SegBinOp KernelsMem -> [TypeBase Shape NoUniqueness])
-> [SegBinOp KernelsMem] -> [[TypeBase Shape NoUniqueness]]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda KernelsMem -> [TypeBase Shape NoUniqueness])
-> (SegBinOp KernelsMem -> Lambda KernelsMem)
-> SegBinOp KernelsMem
-> [TypeBase Shape NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp KernelsMem]
scans
        per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp KernelsMem]
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [PatElem KernelsMem]
[PatElemT LParamMem]
all_pes

    TV Int64
flat_idx <-
      SpaceId
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TV t)
dPrimV SpaceId
"flat_idx" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        (TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
    -- Construct segment indices.
    (VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> [VName] -> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx

    [(SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
  [PatElemT LParamMem])]
-> ((SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElemT LParamMem])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]]
-> [[TypeBase Shape NoUniqueness]]
-> [[PatElemT LParamMem]]
-> [(SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElemT LParamMem])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SegBinOp KernelsMem]
scans [[VName]]
per_scan_local_arrs [[TypeBase Shape NoUniqueness]]
per_scan_rets [[PatElemT LParamMem]]
per_scan_pes) (((SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
   [PatElemT LParamMem])
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElemT LParamMem])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      \(SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
vec_shape, [VName]
local_arrs, [TypeBase Shape NoUniqueness]
rets, [PatElemT LParamMem]
pes) ->
        Shape
-> ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
vec_shape (([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
          let glob_is :: [TExp Int64]
glob_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is

              in_bounds :: TExp Bool
in_bounds =
                (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64 -> TExp Bool)
-> [TExp Int64] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) [TExp Int64]
dims'

              when_in_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds = [(TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)]
-> ((TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [PatElemT LParamMem]
-> [(TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [PatElemT LParamMem]
pes) (((TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, PatElemT LParamMem
pe) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                  VName
arr
                  [KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
                  (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                  [TExp Int64]
glob_is

              when_out_of_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
when_out_of_bounds = [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [SubExp]
nes) (((TypeBase Shape NoUniqueness, VName, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
ne) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
ne []
              (Bool
_, Fence
_, ImpM KernelsMem KernelEnv KernelOp ()
barrier) =
                Lambda KernelsMem
-> (Bool, Fence, ImpM KernelsMem KernelEnv KernelOp ())
barrierFor Lambda KernelsMem
scan_op

          SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bound read carries; others get neutral element" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
in_bounds ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds ImpM KernelsMem KernelEnv KernelOp ()
when_out_of_bounds

          ImpM KernelsMem KernelEnv KernelOp ()
barrier

          Maybe
  (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> Lambda KernelsMem
-> [VName]
-> ImpM KernelsMem KernelEnv KernelOp ()
groupScan
            Maybe
  (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment'
            (TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
stage1_num_threads)
            (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
            Lambda KernelsMem
scan_op
            [VName]
local_arrs

          SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bounds write scanned carries" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
in_bounds (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElemT LParamMem]
-> [VName]
-> [(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElemT LParamMem]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElemT LParamMem
pe, VName
arr) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                  (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                  [TExp Int64]
glob_is
                  (VName -> SubExp
Var VName
arr)
                  [KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]

scanStage3 ::
  Pattern KernelsMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  Imp.TExp Int64 ->
  CrossesSegment ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  CallKernelGen ()
scanStage3 :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TExp Int64
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage3 (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TExp Int64
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans = do
  let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
      ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
  TPrimExp Int32 ExpLeaf
required_groups <-
    SpaceId
-> TPrimExp Int32 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int32 ExpLeaf)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TExp t)
dPrimVE SpaceId
"required_groups" (TPrimExp Int32 ExpLeaf
 -> ImpM KernelsMem HostEnv HostOp (TPrimExp Int32 ExpLeaf))
-> TPrimExp Int32 ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (TPrimExp Int32 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
      TExp Int64 -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 ExpLeaf)
-> TExp Int64 -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size')

  SpaceId
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM KernelsMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage3" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    SegVirt
-> TPrimExp Int32 ExpLeaf
-> (TPrimExp Int32 ExpLeaf
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
virtualiseGroups SegVirt
SegVirt TPrimExp Int32 ExpLeaf
required_groups ((TPrimExp Int32 ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TPrimExp Int32 ExpLeaf
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
virt_group_id -> do
      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

      -- Compute our logical index.
      TExp Int64
flat_idx <-
        SpaceId
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TExp t)
dPrimVE SpaceId
"flat_idx" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
virt_group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size')
            TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
      (VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> [VName] -> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' TExp Int64
flat_idx

      -- Figure out which group this element was originally in.
      TV Int64
orig_group <- SpaceId
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TV t)
dPrimV SpaceId
"orig_group" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
flat_idx TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
elems_per_group
      -- Then the index of the carry-in of the preceding group.
      TV Int64
carry_in_flat_idx <-
        SpaceId
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TV t)
dPrimV SpaceId
"carry_in_flat_idx" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
          TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
      -- Figure out the logical index of the carry-in.
      let carry_in_idx :: [TExp Int64]
carry_in_idx = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx

      -- Apply the carry if we are not in the scan results for the first
      -- group, and are not the last element in such a group (because
      -- then the carry was updated in stage 2), and we are not crossing
      -- a segment boundary.
      let in_bounds :: TExp Bool
in_bounds =
            (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64 -> TExp Bool)
-> [TExp Int64] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) [TExp Int64]
dims'
          crosses_segment :: TExp Bool
crosses_segment =
            TExp Bool -> Maybe (TExp Bool) -> TExp Bool
forall a. a -> Maybe a -> a
fromMaybe TExp Bool
forall v. TPrimExp Bool v
false (Maybe (TExp Bool) -> TExp Bool) -> Maybe (TExp Bool) -> TExp Bool
forall a b. (a -> b) -> a -> b
$
              CrossesSegment
crossesSegment
                CrossesSegment
-> Maybe (TExp Int64) -> Maybe (TExp Int64 -> TExp Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TExp Int64 -> Maybe (TExp Int64)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx)
                Maybe (TExp Int64 -> TExp Bool)
-> Maybe (TExp Int64) -> Maybe (TExp Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TExp Int64 -> Maybe (TExp Int64)
forall (f :: * -> *) a. Applicative f => a -> f a
pure TExp Int64
flat_idx
          is_a_carry :: TExp Bool
is_a_carry = TExp Int64
flat_idx TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
          no_carry_in :: TExp Bool
no_carry_in = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
is_a_carry TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
crosses_segment

      let per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp KernelsMem]
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [PatElem KernelsMem]
[PatElemT LParamMem]
all_pes
      TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
in_bounds (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
no_carry_in (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          [([PatElemT LParamMem], SegBinOp KernelsMem)]
-> (([PatElemT LParamMem], SegBinOp KernelsMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp KernelsMem]
-> [([PatElemT LParamMem], SegBinOp KernelsMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElemT LParamMem]]
per_scan_pes [SegBinOp KernelsMem]
scans) ((([PatElemT LParamMem], SegBinOp KernelsMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (([PatElemT LParamMem], SegBinOp KernelsMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            \([PatElemT LParamMem]
pes, SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
vec_shape) -> do
              Maybe (Exp KernelsMem)
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp KernelsMem)
forall a. Maybe a
Nothing (Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope KernelsMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope KernelsMem)
-> [Param LParamMem] -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_op
              let ([Param LParamMem]
scan_x_params, [Param LParamMem]
scan_y_params) =
                    Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_op

              Shape
-> ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
vec_shape (([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
                [(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                    (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []
                    (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                    ([TExp Int64]
carry_in_idx [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

                [(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_y_params [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                    (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []
                    (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                    ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

                [Param LParamMem]
-> BodyT KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
scan_x_params (BodyT KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> BodyT KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_op

                [(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                    (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                    ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
                    (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pattern KernelsMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  KernelBody KernelsMem ->
  CallKernelGen ()
compileSegScan :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody = do
  -- Since stage 2 involves a group size equal to the number of groups
  -- used for stage 1, we have to cap this number to the maximum group
  -- size.
  TV Int64
stage1_max_num_groups <- SpaceId -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall lore r op t. SpaceId -> PrimType -> ImpM lore r op (TV t)
dPrim SpaceId
"stage1_max_num_groups" PrimType
int64
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
stage1_max_num_groups) SizeClass
SizeGroup

  Count NumGroups SubExp
stage1_num_groups <-
    (TV Int64 -> Count NumGroups SubExp)
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize) (ImpM KernelsMem HostEnv HostOp (TV Int64)
 -> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$
      SpaceId -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. SpaceId -> TExp t -> ImpM lore r op (TV t)
dPrimV SpaceId
"stage1_num_groups" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
stage1_max_num_groups) (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
Imp.unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl

  (TV Int32
stage1_num_threads, TExp Int64
elems_per_group, CrossesSegment
crossesSegment) <-
    Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen (TV Int32, TExp Int64, CrossesSegment)
scanStage1 Pattern KernelsMem
pat Count NumGroups SubExp
stage1_num_groups (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId -> Maybe Exp -> Code HostOp
forall a. SpaceId -> Maybe Exp -> Code a
Imp.DebugPrint SpaceId
"elems_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
elems_per_group

  Pattern KernelsMem
-> TV Int32
-> TExp Int64
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage2 Pattern KernelsMem
pat TV Int32
stage1_num_threads TExp Int64
elems_per_group Count NumGroups SubExp
stage1_num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans
  Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TExp Int64
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage3 Pattern KernelsMem
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) TExp Int64
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans