{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where
import Control.Monad.Except
import Data.List (zip4)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, quot)
import Prelude hiding (mod, quot, rem)
xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
alignTo :: IntegralExp a => a -> a -> a
alignTo :: forall a. IntegralExp a => a -> a -> a
alignTo a
x a
a = (a
x a -> a -> a
forall a. IntegralExp a => a -> a -> a
`divUp` a
a) a -> a -> a
forall a. Num a => a -> a -> a
* a
a
createLocalArrays ::
Count GroupSize SubExp ->
SubExp ->
[PrimType] ->
InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (Count SubExp
groupSize) SubExp
m [PrimType]
types = do
let groupSizeE :: TPrimExp Int64 VName
groupSizeE = SubExp -> TPrimExp Int64 VName
pe64 SubExp
groupSize
workSize :: TPrimExp Int64 VName
workSize = SubExp -> TPrimExp Int64 VName
pe64 SubExp
m TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
groupSizeE
prefixArraysSize :: TPrimExp Int64 VName
prefixArraysSize =
(TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
acc TPrimExp Int64 VName
tySize -> TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 VName
acc TPrimExp Int64 VName
tySize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
tySize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
groupSizeE) TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
(PrimType -> TPrimExp Int64 VName)
-> [PrimType] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxTransposedArraySize :: TPrimExp Int64 VName
maxTransposedArraySize =
(TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (PrimType -> TPrimExp Int64 VName)
-> [PrimType] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TPrimExp Int64 VName
workSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types
warpSize :: Num a => a
warpSize :: forall a. Num a => a
warpSize = a
32
maxWarpExchangeSize :: TPrimExp Int64 VName
maxWarpExchangeSize =
(TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
acc TPrimExp Int64 VName
tySize -> TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 VName
acc TPrimExp Int64 VName
tySize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
tySize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Integer -> TPrimExp Int64 VName
forall a. Num a => Integer -> a
fromInteger Integer
forall a. Num a => a
warpSize) TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
(PrimType -> TPrimExp Int64 VName)
-> [PrimType] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxLookbackSize :: TPrimExp Int64 VName
maxLookbackSize = TPrimExp Int64 VName
maxWarpExchangeSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
forall a. Num a => a
warpSize
size :: Count Bytes (TPrimExp Int64 VName)
size = TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
maxLookbackSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TPrimExp Int64 VName
prefixArraysSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TPrimExp Int64 VName
maxTransposedArraySize
varTE :: TV Int64 -> TPrimExp Int64 VName
varTE :: TV Int64 -> TPrimExp Int64 VName
varTE = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (TV Int64 -> VName) -> TV Int64 -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> VName
forall t. TV t -> VName
tvVar
[TPrimExp Int64 VName]
byteOffsets <-
(TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"byte_offsets") ([TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TPrimExp Int64 VName
off TPrimExp Int64 VName
tySize -> TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 VName
off TPrimExp Int64 VName
tySize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
groupSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tySize) TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(PrimType -> TPrimExp Int64 VName)
-> [PrimType] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
[TPrimExp Int64 VName]
warpByteOffsets <-
(TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"warp_byte_offset") ([TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TPrimExp Int64 VName
off TPrimExp Int64 VName
tySize -> TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
alignTo TPrimExp Int64 VName
off TPrimExp Int64 VName
tySize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
forall a. Num a => a
warpSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tySize) TPrimExp Int64 VName
forall a. Num a => a
warpSize ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(PrimType -> TPrimExp Int64 VName)
-> [PrimType] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Allocate reused shared memeory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
VName
localMem <- String
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc String
"local_mem" Count Bytes (TPrimExp Int64 VName)
size (String -> Space
Space String
"local")
TV Int64
transposeArrayLength <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"trans_arr_len" TPrimExp Int64 VName
workSize
VName
sharedId <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"shared_id" PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem
[VName]
transposedArrays <-
[PrimType]
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
types ((PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem
String
"local_transpose_arr"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
VName
localMem
[VName]
prefixArrays <-
[(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
byteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray
String
"local_prefix_arr"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
groupSize])
VName
localMem
(IxFun -> ImpM GPUMem KernelEnv KernelOp VName)
-> IxFun -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [SubExp -> TPrimExp Int64 VName
pe64 SubExp
groupSize]
VName
warpscan <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"warpscan" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)]) VName
localMem
[VName]
warpExchanges <-
[(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
warpByteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
-> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray
String
"warp_exchange"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)])
VName
localMem
(IxFun -> ImpM GPUMem KernelEnv KernelOp VName)
-> IxFun -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [TPrimExp Int64 VName
forall a. Num a => a
warpSize]
(VName, [VName], [VName], VName, [VName])
-> InKernelGen (VName, [VName], [VName], VName, [VName])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
warpscan, [VName]
warpExchanges)
inBlockScanLookback ::
KernelConstants ->
Imp.TExp Int64 ->
VName ->
[VName] ->
Lambda GPUMem ->
InKernelGen ()
inBlockScanLookback :: KernelConstants
-> TPrimExp Int64 VName
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback KernelConstants
constants TPrimExp Int64 VName
arrs_full_size VName
flag_arr [VName]
arrs Lambda GPUMem
scan_lam = ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TV Any
flg_x <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flg_x" PrimType
p_int8
TV Int8
flg_y <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flg_y" PrimType
p_int8
let flg_param_x :: Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x = Attrs
-> VName
-> MemInfo SubExp NoUniqueness MemBind
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
flg_x) (PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
flg_param_y :: Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y = Attrs
-> VName
-> MemInfo SubExp NoUniqueness MemBind
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flg_y) (PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
flg_y_exp :: TPrimExp Int8 VName
flg_y_exp = TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
flg_y
statusP :: TPrimExp Int8 VName
statusP = (TPrimExp Int8 VName
2 :: Imp.TExp Int8)
statusX :: TPrimExp Int8 VName
statusX = (TPrimExp Int8 VName
0 :: Imp.TExp Int8)
[LParam GPUMem] -> ImpM GPUMem KernelEnv KernelOp ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam)
TV Int32
skip_threads <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"skip_threads" PrimType
int32
let in_block_thread_active :: TPrimExp Bool VName
in_block_thread_active =
TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id
actual_params :: [LParam GPUMem]
actual_params = Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
([Param (MemInfo SubExp NoUniqueness MemBind)]
x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
actual_params Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
actual_params
y_to_x :: ImpM GPUMem KernelEnv KernelOp ()
y_to_x =
[(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []
y_to_x_flg :: ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg =
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
flg_x) [] (VName -> SubExp
Var (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flg_y)) []
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"read input for in-block scan" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial (Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y Param (MemInfo SubExp NoUniqueness MemBind)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (VName
flag_arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs)
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
in_block_id TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
ImpM GPUMem KernelEnv KernelOp ()
y_to_x
ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg
Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier
let op_to_x :: ImpM GPUMem KernelEnv KernelOp ()
op_to_x = do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Int8 VName
flg_y_exp TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusP TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TPrimExp Int8 VName
flg_y_exp TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusX)
( do
ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg
ImpM GPUMem KernelEnv KernelOp ()
y_to_x
)
([Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam)
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"in-block scan (hopefully no barriers needed)" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TV Int32
skip_threads TV Int32 -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
1
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_size) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
in_block_thread_active (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"read operands" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
(TPrimExp Int64 VName
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads))
(Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x Param (MemInfo SubExp NoUniqueness MemBind)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params)
(VName
flag_arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs)
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform operation" ImpM GPUMem KernelEnv KernelOp ()
op_to_x
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"write result" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[ImpM GPUMem KernelEnv KernelOp ()]
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([ImpM GPUMem KernelEnv KernelOp ()]
-> ImpM GPUMem KernelEnv KernelOp ())
-> [ImpM GPUMem KernelEnv KernelOp ()]
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
(Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [ImpM GPUMem KernelEnv KernelOp ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult
(Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x Param (MemInfo SubExp NoUniqueness MemBind)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params)
(Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y Param (MemInfo SubExp NoUniqueness MemBind)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params)
(VName
flag_arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs)
TV Int32
skip_threads TV Int32 -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2
where
p_int8 :: PrimType
p_int8 = IntType -> PrimType
IntType IntType
Int8
block_size :: TExp Int32
block_size = TExp Int32
32
block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. IntegralExp a => a -> a -> a
`quot` TExp Int32
block_size
in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
gtid :: TPrimExp Int64 VName
gtid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_lam
barrier :: ImpM GPUMem KernelEnv KernelOp ()
barrier
| Bool
array_scan =
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
| Bool
otherwise =
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
readInitial :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid]
| Bool
otherwise =
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid]
readParam :: TPrimExp Int64 VName
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam TPrimExp Int64 VName
behind Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind]
| Bool
otherwise =
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
gtid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
arrs_full_size]
writeResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult Param (MemInfo SubExp NoUniqueness MemBind)
x Param (MemInfo SubExp NoUniqueness MemBind)
y VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x = do
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
| Bool
otherwise =
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
compileSegScan ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
SegBinOp GPUMem ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat (MemInfo SubExp NoUniqueness MemBind)
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scanOp KernelBody GPUMem
kbody = do
let Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes = Pat (MemInfo SubExp NoUniqueness MemBind)
pat
scanOpNe :: [SubExp]
scanOpNe = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scanOp
tys :: [PrimType]
tys = (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map (\(Prim PrimType
pt) -> PrimType
pt) ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp
n :: TPrimExp Int64 VName
n = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
sumT :: Integer
maxT :: Integer
sumT :: Integer
sumT = (Integer -> PrimType -> Integer)
-> Integer -> [PrimType] -> Integer
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize PrimType
typ) Integer
0 [PrimType]
tys
primByteSize' :: PrimType -> Integer
primByteSize' = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
4 (Integer -> Integer)
-> (PrimType -> Integer) -> PrimType -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize
sumT' :: Integer
sumT' = (Integer -> PrimType -> Integer)
-> Integer -> [PrimType] -> Integer
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ PrimType -> Integer
primByteSize' PrimType
typ) Integer
0 [PrimType]
tys Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
4
maxT :: Integer
maxT = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((PrimType -> Integer) -> [PrimType] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Integer
forall a. Num a => PrimType -> a
primByteSize [PrimType]
tys)
m :: Num a => a
m :: forall a. Num a => a
m = Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> a) -> Integer -> a
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
1 (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min Integer
mem_constraint Integer
reg_constraint
k_reg :: Integer
k_reg = Integer
64
k_mem :: Integer
k_mem = Integer
95
mem_constraint :: Integer
mem_constraint = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
k_mem Integer
sumT Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
maxT
reg_constraint :: Integer
reg_constraint = (Integer
k_reg Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
sumT') Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
sumT')
group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
Count NumGroups SubExp
num_groups <-
SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize (TV Int64 -> Count NumGroups SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_groups" (TPrimExp Int64 VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
`divUp` (TPrimExp Int64 VName
group_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall a. Num a => a
m))
let num_groups' :: TPrimExp Int64 VName
num_groups' = SubExp -> TPrimExp Int64 VName
pe64 (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups)
TPrimExp Int64 VName
num_threads <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_threads" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size'
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
segmented :: Bool
segmented = [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
not_segmented_e :: TPrimExp Bool VName
not_segmented_e = if Bool
segmented then TPrimExp Bool VName
forall v. TPrimExp Bool v
false else TPrimExp Bool VName
forall v. TPrimExp Bool v
true
segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
statusX, statusA, statusP :: Num a => a
statusX :: forall a. Num a => a
statusX = a
0
statusA :: forall a. Num a => a
statusA = a
1
statusP :: forall a. Num a => a
statusP = a
2
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Sequential elements per thread (m):" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
forall a. Num a => a
m :: Imp.TExp Int32)
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory constraint " (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (Integer -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
mem_constraint :: Imp.TExp Int32)
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Register constraint" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (Integer -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
reg_constraint :: Imp.TExp Int32)
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"sumT'" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (Integer -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
sumT' :: Imp.TExp Int32)
VName
globalId <- String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"id_counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Int -> ArrayContents
Imp.ArrayZeros Int
1
VName
statusFlags <- String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"status_flags" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups]) (String -> Space
Space String
"device")
([VName]
aggregateArrays, [VName]
incprefixArrays) <-
([(VName, VName)] -> ([VName], [VName]))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip (ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName]))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
-> ImpM GPUMem HostEnv HostOp ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
[PrimType]
-> (PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)])
-> (PrimType -> ImpM GPUMem HostEnv HostOp (VName, VName))
-> ImpM GPUMem HostEnv HostOp [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
(,)
(VName -> VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"aggregates" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups]) (String -> Space
Space String
"device")
ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"incprefixes" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups]) (String -> Space
Space String
"device")
VName -> SubExp -> CallKernelGen ()
sReplicate VName
statusFlags (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusX
String
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread String
"segscan" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
(VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
warpscan, [VName]
exchanges) <-
Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m) [PrimType]
tys
TV Int64
dynamicId <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"dynamic_id" PrimType
int32
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
(VName
globalIdMem, Space
_, Count Elements (TPrimExp Int64 VName)
globalIdOff) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
GPUMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
globalId [TPrimExp Int64 VName
0]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd
IntType
Int32
(TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId)
VName
globalIdMem
(TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall u e. e -> Count u e
Count (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count Elements (TPrimExp Int64 VName)
globalIdOff)
(TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32))
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
sharedId [TPrimExp Int64 VName
0] (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
dynamicId) []
let localBarrier :: KernelOp
localBarrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
localFence :: KernelOp
localFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
globalFence :: KernelOp
globalFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId) [] (VName -> SubExp
Var VName
sharedId) [TPrimExp Int64 VName
0]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
TV Int64
blockOff <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"blockOff" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall a. Num a => a
m TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
TPrimExp Int64 VName
sgmIdx <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sgm_idx" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. IntegralExp a => a -> a -> a
`mod` TPrimExp Int64 VName
segment_size
TExp Int32
boundary <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"boundary" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Int64 VName
forall a. Num a => a
m TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size') (TPrimExp Int64 VName
segment_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
sgmIdx)
TExp Int32
segsize_compact <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segsize_compact" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TPrimExp Int64 VName
forall a. Num a => a
m TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size') TPrimExp Int64 VName
segment_size
[VName]
privateArrays <-
[PrimType]
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray
String
"private"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m])
([SubExp] -> PrimType -> Space
ScalarSpace [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m] PrimType
ty)
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Load and map" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
forall a. Num a => a
m ((TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
phys_tid <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"phys_tid" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
blockOff
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
phys_tid
let in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
in_bounds =
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
all_scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem
scanOp]) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
[(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) [KernelResult]
map_res) (((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (MemInfo SubExp NoUniqueness MemBind)
dest, KernelResult
src) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
dest) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 VName
i] SubExp
src []
out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds =
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [SubExp]
scanOpNe) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 VName
i] SubExp
ne []
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TPrimExp Int64 VName
phys_tid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) ImpM GPUMem KernelEnv KernelOp ()
in_bounds ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Transpose scan inputs" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
forall a. Num a => a
m ((TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
sharedIdx <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sharedIdx" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
trans [TPrimExp Int64 VName
sharedIdx] (VName -> SubExp
Var VName
priv) [TPrimExp Int64 VName
i]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> TExp Int32
-> (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int32
forall a. Num a => a
m ((TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TV Int32
sharedIdx <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
priv [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i] (VName -> SubExp
Var VName
trans) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
sharedIdx]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Per thread scan" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TExp Int32
globalIdx <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gidx" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
(KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (TPrimExp Int64 VName
forall a. Num a => a
m TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) ((TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
let xs :: [VName]
xs = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scanOp
ys :: [VName]
ys = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scanOp
TPrimExp Bool VName
new_sgm <-
if Bool
segmented
then String
-> TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"new_sgm" (TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName))
-> TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName)
forall a b. (a -> b) -> a -> b
$ (TExp Int32
globalIdx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. IntegralExp a => a -> a -> a
`mod` TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
else TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Bool VName
forall v. TPrimExp Bool v
false
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless TPrimExp Bool VName
new_sgm (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, (VName, VName, PrimType))]
-> ((VName, (VName, VName, PrimType))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([(VName, VName, PrimType)] -> [(VName, (VName, VName, PrimType))])
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [PrimType] -> [(VName, VName, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [VName]
ys [PrimType]
tys) (((VName, (VName, VName, PrimType))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, (VName, VName, PrimType))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, (VName
x, VName
y, PrimType
ty)) -> do
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TPrimExp Int64 VName
i]
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1]
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1] SubExp
res []
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Publish results in shared memory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prefixArrays [VName]
privateArrays) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (VName -> SubExp
Var VName
src) [TPrimExp Int64 VName
forall a. Num a => a
m TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
let crossesSegment :: Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
crossesSegment = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
segmented
(TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
forall a. a -> Maybe a
Just ((TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName))
-> (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
let from' :: TExp Int32
from' = (TExp Int32
from TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
to' :: TExp Int32
to' = (TExp Int32
to TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
in (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
from') TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. IntegralExp a => a -> a -> a
`mod` TExp Int32
segsize_compact
Lambda GPUMem
scanOp' <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem))
-> Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp
[TV Any]
accs <- (PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> [PrimType] -> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"acc") [PrimType]
tys
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Scan results (with warp scan)" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
crossesSegment
TPrimExp Int64 VName
num_threads
(KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
Lambda GPUMem
scanOp'
[VName]
prefixArrays
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
let firstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread TV Any
acc VName
prefixes =
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
notFirstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread TV Any
acc VName
prefixes =
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
((TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread [TV Any]
accs [VName]
prefixArrays)
((TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread [TV Any]
accs [VName]
prefixArrays)
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
[TV Any]
prefixes <- [(SubExp, PrimType)]
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
String -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"prefix" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
TPrimExp Bool VName
blockNewSgm <- String
-> TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_new_sgm" (TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName))
-> TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
sgmIdx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Perform lookback" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool VName
blockNewSgm TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
accs [VName]
incprefixArrays) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, VName
incprefixArray) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
[(SubExp, TV Any)]
-> ((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp] -> [TV Any] -> [(SubExp, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [TV Any]
accs) (((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((SubExp, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, TV Any
acc) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] SubExp
ne []
let warpSize :: TExp Int32
warpSize = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool VName
blockNewSgm TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
warpSize) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Bool VName
not_segmented_e TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Int32
boundary TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
group_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall a. Num a => a
m))
( do
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
aggregateArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
aggregateArray, TV Any
acc) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
aggregateArray [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusA) []
)
( do
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
)
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
warpscan [TPrimExp Int64 VName
0] (VName -> SubExp
Var VName
statusFlags) [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence
TV Int8
status <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"status" PrimType
int8 :: InKernelGen (TV Int8)
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
status) [] (VName -> SubExp
Var VName
warpscan) [TPrimExp Int64 VName
0]
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
status TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
forall a. Num a => a
statusP)
( TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [VName]
incprefixArrays) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
prefix, VName
incprefixArray) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
)
( do
TV Int32
readOffset <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"readOffset" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants)
let loopStop :: TExp Int32
loopStop = TExp Int32
warpSize TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (-TExp Int32
1)
sameSegment :: TV Int32 -> TPrimExp Bool VName
sameSegment TV Int32
readIdx
| Bool
segmented =
let startIdx :: TPrimExp Int64 VName
startIdx = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readIdx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall a. Num a => a
m TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
in TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
startIdx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
sgmIdx
| Bool
otherwise = TPrimExp Bool VName
forall v. TPrimExp Bool v
true
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
loopStop) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TV Int32
readI <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"read_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
[TV Any]
aggrs <- [(SubExp, PrimType)]
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType) -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> ImpM GPUMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
String -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"aggr" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
TV Int8
flag <- String
-> TPrimExp Int8 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"flag" (TPrimExp Int8 VName
forall a. Num a => a
statusX :: Imp.TExp Int8)
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int32 -> TPrimExp Bool VName
sameSegment TV Int32
readI)
( do
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
statusFlags) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
forall a. Num a => a
statusP)
( [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefix, TV Any
aggr) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
incprefix) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
)
( TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
forall a. Num a => a
statusA) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
aggregateArrays) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
aggregate) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
aggregate) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
)
)
(VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) [])
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
aggr) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
aggr) []
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
warpscan [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (TV Int8 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int8
flag) []
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
warpscan) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int8 VName
2 :: Imp.TExp Int8)) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
lam' <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
KernelConstants
-> TPrimExp Int64 VName
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback
KernelConstants
constants
TPrimExp Int64 VName
num_threads
VName
warpscan
[VName]
exchanges
Lambda GPUMem
lam'
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
warpscan) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
[(TV Any, VName)]
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
exchanges) (((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
exchange) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
exchange) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
forall a. Num a => a
statusP)
(TV Int32
readOffset TV Int32 -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
loopStop)
( TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
forall a. Num a => a
statusA) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TV Int32
readOffset TV Int32 -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
zExt32 TExp Int32
warpSize
)
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int8 -> TPrimExp Int8 VName
forall t. TV t -> TExp t
tvExp TV Int8
flag TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TPrimExp Int8 VName
forall a. Num a => a
statusX :: Imp.TExp Int8)) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
lam <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
aggrs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
aggr) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
aggr)
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
prefix) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix)
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
prefixes [PrimType]
tys ([SubExp] -> [(TV Any, PrimType, SubExp)])
-> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam) (((TV Any, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(TV Any
prefix, PrimType
ty, SubExp
res) -> TV Any
prefix TV Any -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res)
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localFence
)
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
scanOp'''' <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
let xs :: [VName]
xs = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp''''
ys :: [VName]
ys = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp''''
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
boundary TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
group_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall a. Num a => a
m)) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
prefix) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
acc) -> VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''') (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
incprefixArray, SubExp
res) -> VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] SubExp
res []
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
exchange [TPrimExp Int64 VName
0] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
prefix) []
[(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
accs [PrimType]
tys [SubExp]
scanOpNe) (((TV Any, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, PrimType
ty, SubExp
ne) ->
TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
[(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
exchange) [TPrimExp Int64 VName
0]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
Lambda GPUMem
scanOp''''' <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
Lambda GPUMem
scanOp'''''' <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Distribute results" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp'''''
([VName]
xs', [VName]
ys') = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp''''''
[((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(TV Any, TV Any)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [PrimType]
-> [((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ([TV Any] -> [TV Any] -> [(TV Any, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [TV Any]
accs) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [VName]
xs') ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [VName]
ys') [PrimType]
tys) ((((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\((TV Any
prefix, TV Any
acc), (VName
x, VName
x'), (VName
y, VName
y'), PrimType
ty) -> do
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x' (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
VName -> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y' (TExp Any -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
boundary TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool VName
blockNewSgm)
( Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''''') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, PrimType, SubExp)]
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [PrimType]
tys ([SubExp] -> [(VName, PrimType, SubExp)])
-> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''''') (((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, PrimType, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
x, PrimType
ty, SubExp
res) -> VName
x VName -> Exp -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res
)
([(VName, TV Any)]
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
accs) (((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
acc) -> VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [])
TExp Int32
stop <-
String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"stopping_point" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. IntegralExp a => a -> a -> a
`rem` TExp Int32
segsize_compact
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
forall a. Num a => a
m ((TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
i TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
stop TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [VName]
ys) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TPrimExp Int64 VName
i]
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp''''') (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp''''') (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
dest, SubExp
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 VName
i] SubExp
res []
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Transpose scan output and Write it to global memory in coalesced fashion" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName, VName)]
-> ((VName, VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [VName] -> [(VName, VName, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
transposedArrays [VName]
privateArrays ([VName] -> [(VName, VName, VName)])
-> [VName] -> [(VName, VName, VName)]
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) (((VName, VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
locmem, VName
priv, VName
dest) -> do
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
forall a. Num a => a
m ((TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TV Int64
sharedIdx <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locmem [TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TPrimExp Int64 VName
i]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
forall a. Num a => a
m ((TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
flat_idx <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
blockOff
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
flat_idx
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
flat_idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
VName
dest
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
(VName -> SubExp
Var VName
locmem)
[TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
flat_idx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
blockOff]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
String
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"If this is the last block, reset the dynamicId" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
globalId [TPrimExp Int64 VName
0] (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0 :: Int32)) []