{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Code generation for segmented and non-segmented scans.  Uses a
-- fast single-pass algorithm, but which only works on NVIDIA GPUs and
-- with some constraints on the operator.  We use this when we can.
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where

import Control.Monad.Except
import Data.List (zip4)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, quot)
import Prelude hiding (mod, quot, rem)

xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
  Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
  Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))

alignTo :: IntegralExp a => a -> a -> a
alignTo :: forall a. IntegralExp a => a -> a -> a
alignTo a
x a
a = (a
x a -> a -> a
forall a. IntegralExp a => a -> a -> a
`divUp` a
a) a -> a -> a
forall a. Num a => a -> a -> a
* a
a

createLocalArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  [PrimType] ->
  InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays (Count SubExp
groupSize) SubExp
m [PrimType]
types = do
  let groupSizeE :: TPrimExp Int64 ExpLeaf
groupSizeE = SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp SubExp
groupSize
      workSize :: TPrimExp Int64 ExpLeaf
workSize = SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp SubExp
m TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groupSizeE
      prefixArraysSize :: TPrimExp Int64 ExpLeaf
prefixArraysSize =
        (TPrimExp Int64 ExpLeaf
 -> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 ExpLeaf
acc TPrimExp Int64 ExpLeaf
tySize -> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 ExpLeaf
acc TPrimExp Int64 ExpLeaf
tySize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
tySize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groupSizeE) TPrimExp Int64 ExpLeaf
0 ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$
          (PrimType -> TPrimExp Int64 ExpLeaf)
-> [PrimType] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 ExpLeaf
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxTransposedArraySize :: TPrimExp Int64 ExpLeaf
maxTransposedArraySize =
        (TPrimExp Int64 ExpLeaf
 -> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ (PrimType -> TPrimExp Int64 ExpLeaf)
-> [PrimType] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TPrimExp Int64 ExpLeaf
workSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* PrimType -> TPrimExp Int64 ExpLeaf
forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types

      warpSize :: Num a => a
      warpSize :: forall a. Num a => a
warpSize = a
32
      maxWarpExchangeSize :: TPrimExp Int64 ExpLeaf
maxWarpExchangeSize =
        (TPrimExp Int64 ExpLeaf
 -> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 ExpLeaf
acc TPrimExp Int64 ExpLeaf
tySize -> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 ExpLeaf
acc TPrimExp Int64 ExpLeaf
tySize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
tySize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Integer -> TPrimExp Int64 ExpLeaf
forall a. Num a => Integer -> a
fromInteger Integer
forall a. Num a => a
warpSize) TPrimExp Int64 ExpLeaf
0 ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$
          (PrimType -> TPrimExp Int64 ExpLeaf)
-> [PrimType] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 ExpLeaf
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxLookbackSize :: TPrimExp Int64 ExpLeaf
maxLookbackSize = TPrimExp Int64 ExpLeaf
maxWarpExchangeSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
forall a. Num a => a
warpSize
      size :: Count Bytes (TPrimExp Int64 ExpLeaf)
size = TPrimExp Int64 ExpLeaf -> Count Bytes (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 ExpLeaf -> Count Bytes (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf -> Count Bytes (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
maxLookbackSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TPrimExp Int64 ExpLeaf
prefixArraysSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TPrimExp Int64 ExpLeaf
maxTransposedArraySize

      varTE :: TV Int64 -> TPrimExp Int64 VName
      varTE :: TV Int64 -> TPrimExp Int64 VName
varTE = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (TV Int64 -> VName) -> TV Int64 -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> VName
forall t. TV t -> VName
tvVar

  [TPrimExp Int64 VName]
byteOffsets <-
    (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM GPUMem KernelEnv KernelOp (TV Int64)
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TPrimExp Int64 ExpLeaf
    -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"byte_offsets") ([TPrimExp Int64 ExpLeaf]
 -> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
      (TPrimExp Int64 ExpLeaf
 -> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TPrimExp Int64 ExpLeaf
off TPrimExp Int64 ExpLeaf
tySize -> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 ExpLeaf
off TPrimExp Int64 ExpLeaf
tySize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp SubExp
groupSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
tySize) TPrimExp Int64 ExpLeaf
0 ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf])
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$
        (PrimType -> TPrimExp Int64 ExpLeaf)
-> [PrimType] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 ExpLeaf
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types

  [TPrimExp Int64 VName]
warpByteOffsets <-
    (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM GPUMem KernelEnv KernelOp (TV Int64)
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TPrimExp Int64 ExpLeaf
    -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"warp_byte_offset") ([TPrimExp Int64 ExpLeaf]
 -> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
      (TPrimExp Int64 ExpLeaf
 -> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TPrimExp Int64 ExpLeaf
off TPrimExp Int64 ExpLeaf
tySize -> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 ExpLeaf
off TPrimExp Int64 ExpLeaf
tySize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
forall a. Num a => a
warpSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
tySize) TPrimExp Int64 ExpLeaf
forall a. Num a => a
warpSize ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf])
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$
        (PrimType -> TPrimExp Int64 ExpLeaf)
-> [PrimType] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 ExpLeaf
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types

  String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Allocate reused shared memeory" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  VName
localMem <- String
-> Count Bytes (TPrimExp Int64 ExpLeaf)
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> Count Bytes (TPrimExp Int64 ExpLeaf)
-> Space
-> ImpM rep r op VName
sAlloc String
"local_mem" Count Bytes (TPrimExp Int64 ExpLeaf)
size (String -> Space
Space String
"local")
  TV Int64
transposeArrayLength <- String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"trans_arr_len" TPrimExp Int64 ExpLeaf
workSize

  VName
sharedId <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"shared_id" PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem
  VName
sharedReadOffset <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"shared_read_offset" PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem

  [VName]
transposedArrays <-
    [PrimType]
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
types ((PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
      String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem
        String
"local_transpose_arr"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
        VName
localMem

  [VName]
prefixArrays <-
    [(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
byteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
  -> ImpM GPUMem KernelEnv KernelOp VName)
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
      let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM rep r op VName
sArray
        String
"local_prefix_arr"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
groupSize])
        (MemBind -> ImpM GPUMem KernelEnv KernelOp VName)
-> MemBind -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
localMem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [SubExp -> TPrimExp Int64 VName
pe64 SubExp
groupSize]

  VName
warpscan <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"warpscan" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)]) VName
localMem
  [VName]
warpExchanges <-
    [(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
warpByteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
  -> ImpM GPUMem KernelEnv KernelOp VName)
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
      let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM rep r op VName
sArray
        String
"warp_exchange"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)])
        (MemBind -> ImpM GPUMem KernelEnv KernelOp VName)
-> MemBind -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
localMem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [TPrimExp Int64 VName
forall a. Num a => a
warpSize]

  (VName, [VName], [VName], VName, VName, [VName])
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
sharedReadOffset, VName
warpscan, [VName]
warpExchanges)

-- | Compile 'SegScan' instance to host-level code with calls to a
-- single-pass kernel.
compileSegScan ::
  Pattern GPUMem ->
  SegLevel ->
  SegSpace ->
  SegBinOp GPUMem ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pattern GPUMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pattern GPUMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scanOp KernelBody GPUMem
kbody = do
  let Pattern [PatElemT LetDecMem]
_ [PatElemT LetDecMem]
all_pes = Pattern GPUMem
PatternT LetDecMem
pat
      group_size :: Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size = SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count GroupSize SubExp
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
      n :: TPrimExp Int64 ExpLeaf
n = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp ([SubExp] -> [TPrimExp Int64 ExpLeaf])
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      num_groups :: Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups = TPrimExp Int64 ExpLeaf -> Count NumGroups (TPrimExp Int64 ExpLeaf)
forall u e. e -> Count u e
Count (TPrimExp Int64 ExpLeaf
n TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. IntegralExp a => a -> a -> a
`divUp` (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
forall a. Num a => a
m))
      num_threads :: TPrimExp Int64 ExpLeaf
num_threads = Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size
      ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 ExpLeaf]
dims' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp [SubExp]
dims
      segmented :: Bool
segmented = [TPrimExp Int64 ExpLeaf] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 ExpLeaf]
dims' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
      not_segmented_e :: TPrimExp Bool ExpLeaf
not_segmented_e = if Bool
segmented then TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
false else TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
true
      segment_size :: TPrimExp Int64 ExpLeaf
segment_size = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a. [a] -> a
last [TPrimExp Int64 ExpLeaf]
dims'
      scanOpNe :: [SubExp]
scanOpNe = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scanOp
      tys :: [PrimType]
tys = (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map (\(Prim PrimType
pt) -> PrimType
pt) ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
LambdaT rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (LambdaT GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> LambdaT GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp

      statusX, statusA, statusP :: Num a => a
      statusX :: forall a. Num a => a
statusX = a
0
      statusA :: forall a. Num a => a
statusA = a
1
      statusP :: forall a. Num a => a
statusP = a
2
      makeStatusUsed :: TV t -> TV t -> TPrimExp t ExpLeaf
makeStatusUsed TV t
flag TV t
used = TV t -> TPrimExp t ExpLeaf
forall t. TV t -> TExp t
tvExp TV t
flag TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.|. (TV t -> TPrimExp t ExpLeaf
forall t. TV t -> TExp t
tvExp TV t
used TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.<<. TPrimExp t ExpLeaf
2)
      unmakeStatusUsed :: TV Int8 -> TV Int8 -> TV Int8 -> InKernelGen ()
      unmakeStatusUsed :: TV Int8 -> TV Int8 -> TV Int8 -> ImpM GPUMem KernelEnv KernelOp ()
unmakeStatusUsed TV Int8
flagUsed TV Int8
flag TV Int8
used = do
        TV Int8
used TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flagUsed TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Int8 ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TPrimExp Int8 ExpLeaf
2
        TV Int8
flag TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flagUsed TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Int8 ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. TPrimExp Int8 ExpLeaf
3

      sumT :: Integer
      maxT :: Integer
      sumT :: Integer
sumT = (Integer -> PrimType -> Integer)
-> Integer -> [PrimType] -> Integer
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize PrimType
typ) Integer
0 [PrimType]
tys
      primByteSize' :: PrimType -> Integer
primByteSize' = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
4 (Integer -> Integer)
-> (PrimType -> Integer) -> PrimType -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize
      sumT' :: Integer
sumT' = (Integer -> PrimType -> Integer)
-> Integer -> [PrimType] -> Integer
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ PrimType -> Integer
primByteSize' PrimType
typ) Integer
0 [PrimType]
tys Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
4
      maxT :: Integer
maxT = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((PrimType -> Integer) -> [PrimType] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize [PrimType]
tys)
      -- TODO: Make these constants dynamic by querying device
      -- RTX 2080 Ti constants (CC 7.5)
      k_reg :: Integer
k_reg = Integer
64
      k_mem :: Integer
k_mem = Integer
48 --12*4
      mem_constraint :: Integer
mem_constraint = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
k_mem Integer
sumT Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
maxT
      reg_constraint :: Integer
reg_constraint = (Integer
k_reg Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
sumT') Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
sumT' Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
3)
      m :: Num a => a
      m :: forall a. Num a => a
m = Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> a) -> Integer -> a
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
1 (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min Integer
mem_constraint Integer
reg_constraint

  -- Allocate the shared memory for output component
  TV Int64
numThreads <- String
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"numThreads" TPrimExp Int64 ExpLeaf
num_threads
  TV Int64
numGroups <- String
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"numGroups" (TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups

  VName
globalId <- String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"id_counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Int -> ArrayContents
Imp.ArrayZeros Int
1
  VName
statusFlags <- String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"status_flags" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")
  ([VName]
aggregateArrays, [VName]
incprefixArrays) <-
    ([(VName, VName)] -> ([VName], [VName]))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip (ImpM GPUMem HostEnv HostOp [(VName, VName)]
 -> ImpM GPUMem HostEnv HostOp ([VName], [VName]))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
      [PrimType]
-> (PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
 -> ImpM GPUMem HostEnv HostOp [(VName, VName)])
-> (PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
        (,) (VName -> VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"aggregates" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")
          ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"incprefixes" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")

  VName -> SubExp -> CallKernelGen ()
sReplicate VName
statusFlags (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusX

  String
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread String
"segscan" Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size (SegSpace -> VName
segFlat SegSpace
space) (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

    (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
sharedReadOffset, VName
warpscan, [VName]
exchanges) <-
      Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m) [PrimType]
tys

    TV Int64
dynamicId <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"dynamic_id" PrimType
int32
    TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      (VName
globalIdMem, Space
_, Count Elements (TPrimExp Int64 ExpLeaf)
globalIdOff) <- VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
fullyIndexArray VName
globalId [TPrimExp Int64 ExpLeaf
0]
      KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
          IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Exp
-> AtomicOp
Imp.AtomicAdd
            IntType
Int32
            (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId)
            VName
globalIdMem
            (TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall u e. e -> Count u e
Count (TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$ Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count Elements (TPrimExp Int64 ExpLeaf)
globalIdOff)
            (TPrimExp Int32 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 ExpLeaf
1 :: Imp.TExp Int32))
      VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
sharedId [TPrimExp Int64 ExpLeaf
0] (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
dynamicId) []

    let localBarrier :: KernelOp
localBarrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        localFence :: KernelOp
localFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
        globalFence :: KernelOp
globalFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal

    KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
    VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId) [] (VName -> SubExp
Var VName
sharedId) [TPrimExp Int64 ExpLeaf
0]
    KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    TV Int64
blockOff <-
      String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"blockOff" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
forall a. Num a => a
m TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
    TPrimExp Int64 ExpLeaf
sgmIdx <- String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sgm_idx" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. IntegralExp a => a -> a -> a
`mod` TPrimExp Int64 ExpLeaf
segment_size
    TPrimExp Int32 ExpLeaf
boundary <-
      String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"boundary" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf)
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Int64 ExpLeaf
forall a. Num a => a
m TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size) (TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
sgmIdx)
    TPrimExp Int32 ExpLeaf
segsize_compact <-
      String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segsize_compact" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf)
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Int64 ExpLeaf
forall a. Num a => a
m TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size) TPrimExp Int64 ExpLeaf
segment_size
    [VName]
privateArrays <-
      [PrimType]
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
        String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray
          String
"private"
          PrimType
ty
          ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m])
          ([SubExp] -> PrimType -> Space
ScalarSpace [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m] PrimType
ty)

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Load and map" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
forall a. Num a => a
m ((TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
        -- The map's input index
        TPrimExp Int64 ExpLeaf
phys_tid <-
          String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"phys_tid" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
            TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
              TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
i TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
        [(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
[(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 ExpLeaf]
dims') TPrimExp Int64 ExpLeaf
phys_tid
        -- Perform the map
        let in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
in_bounds =
              Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                let ([KernelResult]
all_scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem
scanOp]) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody

                -- Write map results to their global memory destinations
                [(PatElemT LetDecMem, KernelResult)]
-> ((PatElemT LetDecMem, KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [KernelResult] -> [(PatElemT LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElemT LetDecMem]
all_pes) [KernelResult]
map_res) (((PatElemT LetDecMem, KernelResult)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElemT LetDecMem, KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
dest, KernelResult
src) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
dest) ((VName -> TPrimExp Int64 ExpLeaf)
-> [VName] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 [VName]
gtids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []

                -- Write to-scan results to private memory.
                [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 ExpLeaf
i] SubExp
src []

            out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds =
              [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [SubExp]
scanOpNe) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 ExpLeaf
i] SubExp
ne []

        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TPrimExp Int64 ExpLeaf
phys_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
n) ImpM GPUMem KernelEnv KernelOp ()
in_bounds ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Transpose scan inputs" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      [(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
forall a. Num a => a
m ((TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
          TPrimExp Int64 ExpLeaf
sharedIdx <-
            String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sharedIdx" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
              TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
                TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
i TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
trans [TPrimExp Int64 ExpLeaf
sharedIdx] (VName -> SubExp
Var VName
priv) [TPrimExp Int64 ExpLeaf
i]
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        String
-> TPrimExp Int32 ExpLeaf
-> (TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int32 ExpLeaf
forall a. Num a => a
m ((TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
i -> do
          TV Int32
sharedIdx <- String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int32 ExpLeaf
forall a. Num a => a
m TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf
i
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
priv [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
i] (VName -> SubExp
Var VName
trans) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
sharedIdx]
      KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Per thread scan" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      -- We don't need to touch the first element, so only m-1
      -- iterations here.
      TPrimExp Int32 ExpLeaf
globalIdx <-
        String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gidx" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
          (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int32 ExpLeaf
forall a. Num a => a
m) TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf
1
      String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (TPrimExp Int64 ExpLeaf
forall a. Num a => a
m TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
-TPrimExp Int64 ExpLeaf
1) ((TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
        let xs :: [VName]
xs = (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scanOp
            ys :: [VName]
ys = (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scanOp
        -- determine if start of segment
        TPrimExp Bool ExpLeaf
new_sgm <-
          if Bool
segmented
            then String
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"new_sgm" (TPrimExp Bool ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf))
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int32 ExpLeaf
globalIdx TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 ExpLeaf
i TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
boundary) TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a -> a -> a
`mod` TPrimExp Int32 ExpLeaf
segsize_compact TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0
            else TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
false
        -- skip scan of first element in segment
        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sUnless TPrimExp Bool ExpLeaf
new_sgm (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
          [(VName, (VName, VName, PrimType))]
-> ((VName, (VName, VName, PrimType))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([(VName, VName, PrimType)] -> [(VName, (VName, VName, PrimType))])
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [PrimType] -> [(VName, VName, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [VName]
ys [PrimType]
tys) (((VName, (VName, VName, PrimType))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, (VName, VName, PrimType))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, (VName
x, VName
y, PrimType
ty)) -> do
            VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
            VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TPrimExp Int64 ExpLeaf
i]
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TPrimExp Int64 ExpLeaf
i TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
1]

          Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody (LambdaT GPUMem -> BodyT GPUMem) -> LambdaT GPUMem -> BodyT GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT GPUMem -> [SubExp]) -> BodyT GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody (LambdaT GPUMem -> BodyT GPUMem) -> LambdaT GPUMem -> BodyT GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
res) ->
              VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 ExpLeaf
i TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
1] SubExp
res []

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Publish results in shared memory" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      [(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prefixArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
        VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants] (VName -> SubExp
Var VName
src) [TPrimExp Int64 ExpLeaf
forall a. Num a => a
m TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
      KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    let crossesSegment :: Maybe
  (TPrimExp Int32 ExpLeaf
   -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
crossesSegment = do
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
segmented
          (TPrimExp Int32 ExpLeaf
 -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
-> Maybe
     (TPrimExp Int32 ExpLeaf
      -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
forall a. a -> Maybe a
Just ((TPrimExp Int32 ExpLeaf
  -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
 -> Maybe
      (TPrimExp Int32 ExpLeaf
       -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf))
-> (TPrimExp Int32 ExpLeaf
    -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
-> Maybe
     (TPrimExp Int32 ExpLeaf
      -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
from TPrimExp Int32 ExpLeaf
to ->
            let from' :: TPrimExp Int32 ExpLeaf
from' = (TPrimExp Int32 ExpLeaf
from TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf
1) TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int32 ExpLeaf
forall a. Num a => a
m TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
1
                to' :: TPrimExp Int32 ExpLeaf
to' = (TPrimExp Int32 ExpLeaf
to TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf
1) TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int32 ExpLeaf
forall a. Num a => a
m TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
1
             in (TPrimExp Int32 ExpLeaf
to' TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
from') TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TPrimExp Int32 ExpLeaf
to' TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf
segsize_compact TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
boundary) TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a -> a -> a
`mod` TPrimExp Int32 ExpLeaf
segsize_compact

    LambdaT GPUMem
scanOp' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem))
-> LambdaT GPUMem
-> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> LambdaT GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp

    [TV Any]
accs <- (PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> [PrimType] -> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"acc") [PrimType]
tys
    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Scan results (with warp scan)" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      Maybe
  (TPrimExp Int32 ExpLeaf
   -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> LambdaT GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
        Maybe
  (TPrimExp Int32 ExpLeaf
   -> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf)
crossesSegment
        (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
numThreads)
        (KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants)
        LambdaT GPUMem
scanOp'
        [VName]
prefixArrays

      KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

      let firstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread TV Any
acc VName
prefixes =
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
          notFirstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread TV Any
acc VName
prefixes =
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
      TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0)
        ((TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread [TV Any]
accs [VName]
prefixArrays)
        ((TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread [TV Any]
accs [VName]
prefixArrays)

      KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    [TV Any]
prefixes <- [(SubExp, PrimType)]
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
 -> ImpM GPUMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
      String -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"prefix" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
    TPrimExp Bool ExpLeaf
blockNewSgm <- String
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_new_sgm" (TPrimExp Bool ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf))
-> TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool ExpLeaf)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
sgmIdx TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
0
    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Perform lookback" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool ExpLeaf
blockNewSgm TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
        ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
        ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
        [(SubExp, TV Any)]
-> ((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp] -> [TV Any] -> [(SubExp, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [TV Any]
accs) (((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, TV Any
acc) ->
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] SubExp
ne []
      -- end sWhen

      let warpSize :: TPrimExp Int32 ExpLeaf
warpSize = KernelConstants -> TPrimExp Int32 ExpLeaf
kernelWaveSize KernelConstants
constants
      TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool ExpLeaf
blockNewSgm TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int32 ExpLeaf
warpSize) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
          TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
            (TPrimExp Bool ExpLeaf
not_segmented_e TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TPrimExp Int32 ExpLeaf
boundary TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
forall a. Num a => a
m))
            ( do
                ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
aggregateArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
aggregateArray, TV Any
acc) ->
                    VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
aggregateArray [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
                KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
                ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusA) []
            )
            ( do
                ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
                    VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
                KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
                ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
            )
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
warpscan [TPrimExp Int64 ExpLeaf
0] (VName -> SubExp
Var VName
statusFlags) [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
        -- sWhen
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence

        TV Int8
status <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"status" PrimType
int8 :: InKernelGen (TV Int8)
        VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
status) [] (VName -> SubExp
Var VName
warpscan) [TPrimExp Int64 ExpLeaf
0]

        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          (TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
status TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusP)
          ( TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                [(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [VName]
incprefixArrays) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
prefix, VName
incprefixArray) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
          )
          ( do
              TV Int32
readOffset <-
                String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"readOffset" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$
                  TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf)
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelWaveSize KernelConstants
constants)
              let loopStop :: TPrimExp Int32 ExpLeaf
loopStop = TPrimExp Int32 ExpLeaf
warpSize TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* (-TPrimExp Int32 ExpLeaf
1)
                  sameSegment :: TV Int32 -> TPrimExp Bool ExpLeaf
sameSegment TV Int32
readIdx
                    | Bool
segmented =
                      let startIdx :: TPrimExp Int64 ExpLeaf
startIdx = TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readIdx TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf
1) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
forall a. Num a => a
m TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1
                       in TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
startIdx TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 ExpLeaf
sgmIdx
                    | Bool
otherwise = TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v
true
              TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int32 ExpLeaf
loopStop) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                TV Int32
readI <- String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"read_i" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants
                [TV Any]
aggrs <- [(SubExp, PrimType)]
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
 -> ImpM GPUMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
                  String -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"aggr" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
                TV Int8
flag <- String
-> TPrimExp Int8 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"flag" TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusX
                TV Int8
used <- String
-> TPrimExp Int8 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"used" TPrimExp Int8 ExpLeaf
0
                ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (ImpM GPUMem KernelEnv KernelOp ()
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readI TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TPrimExp Int32 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                  TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                    (TV Int32 -> TPrimExp Bool ExpLeaf
sameSegment TV Int32
readI)
                    ( do
                        VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
statusFlags) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readI]
                        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                          (TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusP)
                          ( [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefix, TV Any
aggr) ->
                              VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
incprefix) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readI]
                          )
                          ( TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusA) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                              [(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
aggregateArrays) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
aggregate) ->
                                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
aggregate) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readI]
                              TV Int8
used TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int8 ExpLeaf
1
                          )
                    )
                    (VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) [])
                -- end sIf
                -- end sWhen
                [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
aggr) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
exchange [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
aggr) []
                TV Int8
tmp <- String
-> TPrimExp Int8 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"tmp" (TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp (TV Int8))
-> TPrimExp Int8 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TV Int8 -> TPrimExp Int8 ExpLeaf
forall {t}. NumExp t => TV t -> TV t -> TPrimExp t ExpLeaf
makeStatusUsed TV Int8
flag TV Int8
used
                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
warpscan [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants] (TV Int8 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int8
tmp) []
                KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence

                (VName
warpscanMem, Space
warpscanSpace, Count Elements (TPrimExp Int64 ExpLeaf)
warpscanOff) <-
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
fullyIndexArray VName
warpscan [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
warpSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
                TV Int8
flag TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- Exp -> TPrimExp Int8 ExpLeaf
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (VName
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
warpscanMem Count Elements (TPrimExp Int64 ExpLeaf)
warpscanOff PrimType
int8 Space
warpscanSpace Volatility
Imp.Volatile)
                TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                  -- TODO: This is a single-threaded reduce
                  TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                    (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf)
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusP)
                    ( do
                        LambdaT GPUMem
scanOp'' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
                        let ([VName]
agg1s, [VName]
agg2s) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''

                        [(VName, SubExp, PrimType)]
-> ((VName, SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [PrimType] -> [(VName, SubExp, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [SubExp]
scanOpNe [PrimType]
tys) (((VName, SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg1, SubExp
ne, PrimType
ty) ->
                          VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
agg1 (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
                        (VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName] -> [PrimType] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ [VName]
agg2s [PrimType]
tys

                        TV Int8
flag1 <- String
-> TPrimExp Int8 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"flag1" TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusX
                        TV Int8
flag2 <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flag2" PrimType
int8
                        TV Int8
used1 <- String
-> TPrimExp Int8 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"used1" TPrimExp Int8 ExpLeaf
0
                        TV Int8
used2 <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"used2" PrimType
int8
                        String
-> TPrimExp Int32 ExpLeaf
-> (TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int32 ExpLeaf
warpSize ((TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
i -> do
                          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag2) [] (VName -> SubExp
Var VName
warpscan) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
i]
                          TV Int8 -> TV Int8 -> TV Int8 -> ImpM GPUMem KernelEnv KernelOp ()
unmakeStatusUsed TV Int8
flag2 TV Int8
flag2 TV Int8
used2
                          [(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
agg2s [VName]
exchanges) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg2, VName
exchange) ->
                            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
agg2 [] (VName -> SubExp
Var VName
exchange) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
i]
                          TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                            (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf)
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flag2 TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusA)
                            ( do
                                TV Int8
flag1 TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flag2
                                TV Int8
used1 TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
used2
                                [(VName, PrimType, VName)]
-> ((VName, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [VName] -> [(VName, PrimType, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [PrimType]
tys [VName]
agg2s) (((VName, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg1, PrimType
ty, VName
agg2) ->
                                  VName
agg1 VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty (VName -> SubExp
Var VName
agg2)
                            )
                            ( do
                                TV Int8
used1 TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
used1 TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Int8 ExpLeaf
forall a. Num a => a -> a -> a
+ TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
used2
                                Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'') (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                                  [(VName, PrimType, SubExp)]
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [PrimType]
tys ([SubExp] -> [(VName, PrimType, SubExp)])
-> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT GPUMem -> [SubExp]) -> BodyT GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'') (((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                                    \(VName
agg1, PrimType
ty, SubExp
res) -> VName
agg1 VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res
                            )
                        TV Int8
flag TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flag1
                        TV Int8
used TV Int8
-> TPrimExp Int8 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
used1
                        [(TV Any, PrimType, VName)]
-> ((TV Any, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [VName] -> [(TV Any, PrimType, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
aggrs [PrimType]
tys [VName]
agg1s) (((TV Any, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, PrimType
ty, VName
agg1) ->
                          TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty (VName -> SubExp
Var VName
agg1)
                    )
                    -- else
                    ( [(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
exchanges) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
exchange) ->
                        VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
exchange) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
warpSize TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]
                    )
                  -- end sIf
                  TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                    (TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 ExpLeaf
-> TPrimExp Int8 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 ExpLeaf
forall a. Num a => a
statusP)
                    (TV Int32
readOffset TV Int32
-> TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int32 ExpLeaf
loopStop)
                    (TV Int32
readOffset TV Int32
-> TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int8 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
zExt32 (TV Int8 -> TPrimExp Int8 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int8
used))
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
sharedReadOffset [TPrimExp Int64 ExpLeaf
0] (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
readOffset) []
                  LambdaT GPUMem
scanOp''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
                  let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp'''
                  [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
aggr) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
aggr)
                  [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
prefix) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix)
                  Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''') (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                    [(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
prefixes [PrimType]
tys ([SubExp] -> [(TV Any, PrimType, SubExp)])
-> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT GPUMem -> [SubExp]) -> BodyT GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''') (((TV Any, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                      \(TV Any
prefix, PrimType
ty, SubExp
res) -> TV Any
prefix TV Any -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res)
                -- end sWhen
                KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence
                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
readOffset) [] (VName -> SubExp
Var VName
sharedReadOffset) [TPrimExp Int64 ExpLeaf
0]
          )
        -- end sWhile
        -- end sIf
        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
          LambdaT GPUMem
scanOp'''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
          let xs :: [VName]
xs = (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''''
              ys :: [VName]
ys = (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
drop ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''''
          TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 ExpLeaf
boundary TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
forall a. Num a => a
m)) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
prefix) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
            [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
acc) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc
            Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''') (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT GPUMem -> [SubExp]) -> BodyT GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''') (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  \(VName
incprefixArray, SubExp
res) -> VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] SubExp
res []
            KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
            ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
          [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
exchange [TPrimExp Int64 ExpLeaf
0] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
prefix) []
          [(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
accs [PrimType]
tys [SubExp]
scanOpNe) (((TV Any, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, PrimType
ty, SubExp
ne) ->
            TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
      -- end sWhen
      -- end sWhen

      TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf)
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
0) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
exchange) [TPrimExp Int64 ExpLeaf
0]
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
    -- end sWhen
    -- end sComment

    LambdaT GPUMem
scanOp''''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'
    LambdaT GPUMem
scanOp'''''' <- LambdaT GPUMem -> ImpM GPUMem KernelEnv KernelOp (LambdaT GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda LambdaT GPUMem
scanOp'

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Distribute results" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp'''''
          ([VName]
xs', [VName]
ys') = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
scanOp''''''

      [((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(TV Any, TV Any)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [PrimType]
-> [((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ([TV Any] -> [TV Any] -> [(TV Any, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [TV Any]
accs) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [VName]
xs') ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [VName]
ys') [PrimType]
tys) ((((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        \((TV Any
prefix, TV Any
acc), (VName
x, VName
x'), (VName
y, VName
y'), PrimType
ty) -> do
          VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
          VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
          VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x' (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
          VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y' (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc

      TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int32 ExpLeaf
forall a. Num a => a
m TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int32 ExpLeaf
boundary TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool ExpLeaf
blockNewSgm)
        ( Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''''') (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            [(VName, PrimType, SubExp)]
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [PrimType]
tys ([SubExp] -> [(VName, PrimType, SubExp)])
-> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT GPUMem -> [SubExp]) -> BodyT GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp'''''') (((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              \(VName
x, PrimType
ty, SubExp
res) -> VName
x VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res
        )
        ([(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
acc) -> VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [])
      -- calculate where previous thread stopped, to determine number of
      -- elements left before new segment.
      TPrimExp Int32 ExpLeaf
stop <-
        String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"stopping_point" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int32 ExpLeaf
segsize_compact TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int32 ExpLeaf
forall a. Num a => a
m TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
1 TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf
segsize_compact TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
boundary) TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. IntegralExp a => a -> a -> a
`rem` TPrimExp Int32 ExpLeaf
segsize_compact
      String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
forall a. Num a => a
m ((TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 ExpLeaf
i TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int32 ExpLeaf
stop TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int32 ExpLeaf
1) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
          [(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [VName]
ys) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
            -- only include prefix for the first segment part per thread
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TPrimExp Int64 ExpLeaf
i]
          Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''''') (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT GPUMem -> [SubExp]) -> BodyT GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT GPUMem
scanOp''''') (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              \(VName
dest, SubExp
res) ->
                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 ExpLeaf
i] SubExp
res []

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Transpose scan output" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      [(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
forall a. Num a => a
m ((TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
          TV Int64
sharedIdx <-
            String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
              TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int32 ExpLeaf
forall a. Num a => a
m) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
i
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
trans [TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TPrimExp Int64 ExpLeaf
i]
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
forall a. Num a => a
m ((TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
          TV Int32
sharedIdx <-
            String
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TPrimExp Int32 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$
              KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants
                TPrimExp Int32 ExpLeaf
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
i)
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
priv [TPrimExp Int64 ExpLeaf
i] (VName -> SubExp
Var VName
trans) [TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
sharedIdx]
      KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Write block scan results to global memory" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
forall a. Num a => a
m ((TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
        TPrimExp Int64 ExpLeaf
flat_idx <-
          String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
            TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
i
              TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
        [(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
[(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 ExpLeaf]
dims') TPrimExp Int64 ExpLeaf
flat_idx
        TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 ExpLeaf
flat_idx TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
n) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
          [(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((PatElemT LetDecMem -> VName) -> [PatElemT LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT LetDecMem]
all_pes) [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
dest ((VName -> TPrimExp Int64 ExpLeaf)
-> [VName] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 [VName]
gtids) (VName -> SubExp
Var VName
src) [TPrimExp Int64 ExpLeaf
i]

    String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"If this is the last block, reset the dynamicId" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      TPrimExp Bool ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
globalId [TPrimExp Int64 ExpLeaf
0] (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0 :: Int32)) []