{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where
import Control.Monad
import Data.List (zip4, zip7)
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, nextMul, quot)
import Prelude hiding (mod, quot, rem)
xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (forall rep. Lambda rep -> [LParam rep]
lambdaParams (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (forall rep. Lambda rep -> [LParam rep]
lambdaParams (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
createLocalArrays ::
Count GroupSize SubExp ->
SubExp ->
[PrimType] ->
InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (Count SubExp
groupSize) SubExp
chunk [PrimType]
types = do
let groupSizeE :: TExp Int64
groupSizeE = SubExp -> TExp Int64
pe64 SubExp
groupSize
workSize :: TExp Int64
workSize = SubExp -> TExp Int64
pe64 SubExp
chunk forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE
prefixArraysSize :: TExp Int64
prefixArraysSize =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
acc TExp Int64
tySize forall a. Num a => a -> a -> a
+ TExp Int64
tySize forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE) TExp Int64
0 forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxTransposedArraySize :: TExp Int64
maxTransposedArraySize =
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TExp Int64
workSize forall a. Num a => a -> a -> a
* forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types
warpSize :: (Num a) => a
warpSize :: forall a. Num a => a
warpSize = a
32
maxWarpExchangeSize :: TExp Int64
maxWarpExchangeSize =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
acc TExp Int64
tySize forall a. Num a => a -> a -> a
+ TExp Int64
tySize forall a. Num a => a -> a -> a
* forall a. Num a => Integer -> a
fromInteger forall a. Num a => a
warpSize) TExp Int64
0 forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxLookbackSize :: TExp Int64
maxLookbackSize = TExp Int64
maxWarpExchangeSize forall a. Num a => a -> a -> a
+ forall a. Num a => a
warpSize
size :: Count Bytes (TExp Int64)
size = forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ TExp Int64
maxLookbackSize forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
prefixArraysSize forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
maxTransposedArraySize
(TExp Int64
_, [TExp Int64]
byteOffsets) <-
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM
( \TExp Int64
off TExp Int64
tySize -> do
TExp Int64
off' <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"byte_offsets" forall a b. (a -> b) -> a -> b
$ forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
off TExp Int64
tySize forall a. Num a => a -> a -> a
+ SubExp -> TExp Int64
pe64 SubExp
groupSize forall a. Num a => a -> a -> a
* TExp Int64
tySize
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64
off', TExp Int64
off)
)
TExp Int64
0
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
(TExp Int64
_, [TExp Int64]
warpByteOffsets) <-
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM
( \TExp Int64
off TExp Int64
tySize -> do
TExp Int64
off' <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"warp_byte_offset" forall a b. (a -> b) -> a -> b
$ forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
off TExp Int64
tySize forall a. Num a => a -> a -> a
+ forall a. Num a => a
warpSize forall a. Num a => a -> a -> a
* TExp Int64
tySize
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64
off', TExp Int64
off)
)
forall a. Num a => a
warpSize
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Allocate reusable shared memory" forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
VName
localMem <- forall rep r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc String
"local_mem" Count Bytes (TExp Int64)
size (String -> Space
Space String
"local")
TV Int64
transposeArrayLength <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"trans_arr_len" TExp Int64
workSize
VName
sharedId <- forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"shared_id" PrimType
int32 (forall d. [d] -> ShapeBase d
Shape [forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem
[VName]
transposedArrays <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
types forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem
String
"local_transpose_arr"
PrimType
ty
(forall d. [d] -> ShapeBase d
Shape [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
VName
localMem
[VName]
prefixArrays <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [TExp Int64]
byteOffsets [PrimType]
types) forall a b. (a -> b) -> a -> b
$ \(TExp Int64
off, PrimType
ty) -> do
let off' :: TExp Int64
off' = TExp Int64
off forall e. IntegralExp e => e -> e -> e
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
ty
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
String
"local_prefix_arr"
PrimType
ty
(forall d. [d] -> ShapeBase d
Shape [SubExp
groupSize])
VName
localMem
forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
off' [SubExp -> TExp Int64
pe64 SubExp
groupSize]
VName
warpscan <- forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"warpscan" PrimType
int8 (forall d. [d] -> ShapeBase d
Shape [forall v. IsValue v => v -> SubExp
constant (forall a. Num a => a
warpSize :: Int64)]) VName
localMem
[VName]
warpExchanges <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [TExp Int64]
warpByteOffsets [PrimType]
types) forall a b. (a -> b) -> a -> b
$ \(TExp Int64
off, PrimType
ty) -> do
let off' :: TExp Int64
off' = TExp Int64
off forall e. IntegralExp e => e -> e -> e
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
ty
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
String
"warp_exchange"
PrimType
ty
(forall d. [d] -> ShapeBase d
Shape [forall v. IsValue v => v -> SubExp
constant (forall a. Num a => a
warpSize :: Int64)])
VName
localMem
forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
off' [forall a. Num a => a
warpSize]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
warpscan, [VName]
warpExchanges)
statusX, statusA, statusP :: (Num a) => a
statusX :: forall a. Num a => a
statusX = a
0
statusA :: forall a. Num a => a
statusA = a
1
statusP :: forall a. Num a => a
statusP = a
2
inBlockScanLookback ::
KernelConstants ->
Imp.TExp Int64 ->
VName ->
[VName] ->
Lambda GPUMem ->
InKernelGen ()
inBlockScanLookback :: KernelConstants
-> TExp Int64
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback KernelConstants
constants TExp Int64
arrs_full_size VName
flag_arr [VName]
arrs Lambda GPUMem
scan_lam = forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$ do
TV Any
flg_x <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flg_x" PrimType
p_int8
TV Int8
flg_y <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flg_y" PrimType
p_int8
let flg_param_x :: Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x = forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty (forall {k} (t :: k). TV t -> VName
tvVar TV Any
flg_x) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
flg_param_y :: Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y = forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
flg_y_exp :: TPrimExp Int8 VName
flg_y_exp = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flg_y
statusP_e :: TPrimExp Int8 VName
statusP_e = forall a. Num a => a
statusP :: Imp.TExp Int8
statusX_e :: TPrimExp Int8 VName
statusX_e = forall a. Num a => a
statusX :: Imp.TExp Int8
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam)
TV Int32
skip_threads <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"skip_threads" PrimType
int32
let in_block_thread_active :: TPrimExp Bool VName
in_block_thread_active =
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id
actual_params :: [LParam GPUMem]
actual_params = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
([Param (MemInfo SubExp NoUniqueness MemBind)]
x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam GPUMem]
actual_params forall a. Integral a => a -> a -> a
`div` Int
2) [LParam GPUMem]
actual_params
y_to_x :: ImpM GPUMem KernelEnv KernelOp ()
y_to_x =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType (forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []
y_to_x_flg :: ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg =
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall {k} (t :: k). TV t -> VName
tvVar TV Any
flg_x) [] (VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y)) []
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read input for in-block scan" forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial (Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (VName
flag_arr forall a. a -> [a] -> [a]
: [VName]
arrs)
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
in_block_id forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
ImpM GPUMem KernelEnv KernelOp ()
y_to_x
ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier
let op_to_x :: ImpM GPUMem KernelEnv KernelOp ()
op_to_x = do
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Int8 VName
flg_y_exp forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusP_e forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TPrimExp Int8 VName
flg_y_exp forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusX_e)
( do
ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg
ImpM GPUMem KernelEnv KernelOp ()
y_to_x
)
(forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam)
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"in-block scan (hopefully no barriers needed)" forall a b. (a -> b) -> a -> b
$ do
TV Int32
skip_threads forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
1
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_size) forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
in_block_thread_active forall a b. (a -> b) -> a -> b
$ do
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read operands" forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
(TExp Int64
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads))
(Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params)
(VName
flag_arr forall a. a -> [a] -> [a]
: [VName]
arrs)
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform operation" ImpM GPUMem KernelEnv KernelOp ()
op_to_x
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write result" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult
(Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params)
(Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params)
(VName
flag_arr forall a. a -> [a] -> [a]
: [VName]
arrs)
TV Int32
skip_threads forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall a. Num a => a -> a -> a
* TExp Int32
2
where
p_int8 :: PrimType
p_int8 = IntType -> PrimType
IntType IntType
Int8
block_size :: TExp Int32
block_size = TExp Int32
32
block_id :: TExp Int32
block_id = TExp Int32
ltid32 forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 forall a. Num a => a -> a -> a
- TExp Int32
block_id forall a. Num a => a -> a -> a
* TExp Int32
block_size
ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
ltid :: TExp Int64
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
gtid :: TExp Int64
gtid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
array_scan :: Bool
array_scan = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPUMem
scan_lam
barrier :: ImpM GPUMem KernelEnv KernelOp ()
barrier
| Bool
array_scan =
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
| Bool
otherwise =
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
readInitial :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
ltid]
| Bool
otherwise =
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
gtid]
readParam :: TExp Int64
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam TExp Int64
behind Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
ltid forall a. Num a => a -> a -> a
- TExp Int64
behind]
| Bool
otherwise =
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
gtid forall a. Num a => a -> a -> a
- TExp Int64
behind forall a. Num a => a -> a -> a
+ TExp Int64
arrs_full_size]
writeResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult Param (MemInfo SubExp NoUniqueness MemBind)
x Param (MemInfo SubExp NoUniqueness MemBind)
y VName
arr = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall p. Typed p => Param p -> Bool
isPrimParam Param (MemInfo SubExp NoUniqueness MemBind)
x) forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64
ltid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
compileSegScan ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
SegBinOp GPUMem ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat (MemInfo SubExp NoUniqueness MemBind)
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan_op KernelBody GPUMem
map_kbody = do
KernelAttrs
attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
let Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes = Pat (MemInfo SubExp NoUniqueness MemBind)
pat
scanop_nes :: [SubExp]
scanop_nes = forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan_op
n :: TExp Int64
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
tys' :: [Type]
tys' = forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan_op
tys :: [PrimType]
tys = forall a b. (a -> b) -> [a] -> [b]
map forall shape u. TypeBase shape u -> PrimType
elemType [Type]
tys'
group_size_e :: TExp Int64
group_size_e = SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs
num_physgroups_e :: TExp Int64
num_physgroups_e = SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups KernelAttrs
attrs
let chunk_const :: KernelConstExp
chunk_const = [Type] -> KernelConstExp
getChunkSize [Type]
tys'
TV Int64
chunk_v <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"chunk_size" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelConstExp -> CallKernelGen (PrimExp VName)
kernelConstToExp KernelConstExp
chunk_const
let chunk :: TExp Int64
chunk = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
SubExp
num_virtgroups <-
forall {k} (t :: k). TV t -> SubExp
tvSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_virtgroups" (TExp Int64
n forall e. IntegralExp e => e -> e -> e
`divUp` (TExp Int64
group_size_e forall a. Num a => a -> a -> a
* TExp Int64
chunk))
let num_virtgroups_e :: TExp Int64
num_virtgroups_e = SubExp -> TExp Int64
pe64 SubExp
num_virtgroups
TExp Int64
num_virt_threads <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_virt_threads" forall a b. (a -> b) -> a -> b
$ TExp Int64
num_virtgroups_e forall a. Num a => a -> a -> a
* TExp Int64
group_size_e
let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
dims
segmented :: Bool
segmented = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dims' forall a. Ord a => a -> a -> Bool
> Int
1
not_segmented_e :: TPrimExp Bool VName
not_segmented_e = forall v. Bool -> TPrimExp Bool v
fromBool forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not Bool
segmented
segment_size :: TExp Int64
segment_size = forall a. [a] -> a
last [TExp Int64]
dims'
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe (PrimExp VName) -> Code a
Imp.DebugPrint String
"Sequential elements per thread (chunk)" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
chunk
VName
statusFlags <- forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"status_flags" PrimType
int8 (forall d. [d] -> ShapeBase d
Shape [SubExp
num_virtgroups]) (String -> Space
Space String
"device")
VName -> SubExp -> CallKernelGen ()
sReplicate VName
statusFlags forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusX
([VName]
aggregateArrays, [VName]
incprefixArrays) <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
(,)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"aggregates" PrimType
ty (forall d. [d] -> ShapeBase d
Shape [SubExp
num_virtgroups]) (String -> Space
Space String
"device")
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"incprefixes" PrimType
ty (forall d. [d] -> ShapeBase d
Shape [SubExp
num_virtgroups]) (String -> Space
Space String
"device")
VName
global_id <- String -> Int -> CallKernelGen VName
genZeroes String
"global_dynid" Int
1
let attrs' :: KernelAttrs
attrs' = KernelAttrs
attrs {kAttrConstExps :: Map VName KernelConstExp
kAttrConstExps = forall k a. k -> a -> Map k a
M.singleton (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
chunk_v) KernelConstExp
chunk_const}
String
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread String
"segscan" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs' forall a b. (a -> b) -> a -> b
$ do
TExp Int32
chunk32 <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"chunk_size_32b" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
ltid :: TExp Int64
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
(VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
warpscan, [VName]
exchanges) <-
Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs) (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v) [PrimType]
tys
TV Int64
physgroup_id <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"physgroup_id" PrimType
int32
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetGroupId (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
physgroup_id) Int
0
TExp Int64
iters <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virtloop_bound" forall a b. (a -> b) -> a -> b
$
(TExp Int64
num_virtgroups_e forall a. Num a => a -> a -> a
- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
physgroup_id)
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
num_physgroups_e
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"virtloop_i" TExp Int64
iters forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ do
TV Int64
dyn_id <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"dynamic_id" PrimType
int32
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"First thread in block fetches this block's dynamic_id" forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
(VName
globalIdMem, Space
_, Count Elements (TExp Int64)
globalIdOff) <- forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
global_id [TExp Int64
0]
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace forall a b. (a -> b) -> a -> b
$
IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp VName
-> AtomicOp
Imp.AtomicAdd
IntType
Int32
(forall {k} (t :: k). TV t -> VName
tvVar TV Int64
dyn_id)
VName
globalIdMem
(forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count Elements (TExp Int64)
globalIdOff)
(forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32))
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Set dynamic id for this block" forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
sharedId [TExp Int64
0] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
dyn_id) []
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"First thread in last (virtual) block resets global dynamic_id" forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
num_virtgroups_e forall a. Num a => a -> a -> a
- TExp Int64
1) forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global_id [TExp Int64
0] (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
let local_barrier :: KernelOp
local_barrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
local_fence :: KernelOp
local_fence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
global_fence :: KernelOp
global_fence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
dyn_id) [] (VName -> SubExp
Var VName
sharedId) [TExp Int64
0]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
TExp Int64
block_offset <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_offset" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id) forall a. Num a => a -> a -> a
* TExp Int64
chunk forall a. Num a => a -> a -> a
* TExp Int64
group_size_e
TExp Int64
sgm_idx <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sgm_idx" forall a b. (a -> b) -> a -> b
$ TExp Int64
block_offset forall e. IntegralExp e => e -> e -> e
`mod` TExp Int64
segment_size
TExp Int32
boundary <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"boundary" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Int64
chunk forall a. Num a => a -> a -> a
* TExp Int64
group_size_e) (TExp Int64
segment_size forall a. Num a => a -> a -> a
- TExp Int64
sgm_idx)
TExp Int32
segsize_compact <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segsize_compact" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Int64
chunk forall a. Num a => a -> a -> a
* TExp Int64
group_size_e) TExp Int64
segment_size
[VName]
private_chunks <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray
String
"private"
PrimType
ty
(forall d. [d] -> ShapeBase d
Shape [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v])
([SubExp] -> PrimType -> Space
ScalarSpace [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v] PrimType
ty)
TExp Int64
thd_offset <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"thd_offset" forall a b. (a -> b) -> a -> b
$ TExp Int64
block_offset forall a. Num a => a -> a -> a
+ TExp Int64
ltid
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Load and map" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
virt_tid <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_tid" forall a b. (a -> b) -> a -> b
$ TExp Int64
thd_offset forall a. Num a => a -> a -> a
+ TExp Int64
i forall a. Num a => a -> a -> a
* TExp Int64
group_size_e
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims') TExp Int64
virt_tid
let in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
in_bounds =
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
map_kbody) forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem
scan_op]) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
map_kbody
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) [KernelResult]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem (MemInfo SubExp NoUniqueness MemBind)
dest, KernelResult
src) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
dest) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
src []
out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks [SubExp]
scanop_nes) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
ne []
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int64
virt_tid forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) ImpM GPUMem KernelEnv KernelOp ()
in_bounds ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Transpose scan inputs" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
private_chunks) forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
sharedIdx <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sharedIdx" forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid forall a. Num a => a -> a -> a
+ TExp Int64
i forall a. Num a => a -> a -> a
* TExp Int64
group_size_e
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
trans [TExp Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TV Int64
sharedIdx <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid forall a. Num a => a -> a -> a
* TExp Int64
chunk forall a. Num a => a -> a -> a
+ TExp Int64
i
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
priv [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
i] (VName -> SubExp
Var VName
trans) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
sharedIdx]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Per thread scan" forall a b. (a -> b) -> a -> b
$ do
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (TExp Int64
chunk forall a. Num a => a -> a -> a
- TExp Int64
1) forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
let xs :: [VName]
xs = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan_op
ys :: [VName]
ys = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan_op
TPrimExp Bool VName
new_sgm <-
if Bool
segmented
then do
TExp Int32
gidx <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gidx" forall a b. (a -> b) -> a -> b
$ (TExp Int32
ltid32 forall a. Num a => a -> a -> a
* TExp Int32
chunk32) forall a. Num a => a -> a -> a
+ TExp Int32
1
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"new_sgm" forall a b. (a -> b) -> a -> b
$ (TExp Int32
gidx forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i forall a. Num a => a -> a -> a
- TExp Int32
boundary) forall e. IntegralExp e => e -> e -> e
`mod` TExp Int32
segsize_compact forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall v. TPrimExp Bool v
false
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless TPrimExp Bool VName
new_sgm forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [VName]
private_chunks [VName]
xs [VName]
ys [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
x, VName
y, PrimType
ty) -> do
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i forall a. Num a => a -> a -> a
+ TExp Int64
1]
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan_op) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
res) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i forall a. Num a => a -> a -> a
+ TExp Int64
1] SubExp
res []
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Publish results in shared memory" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prefixArrays [VName]
private_chunks) forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
ltid] (VName -> SubExp
Var VName
src) [TExp Int64
chunk forall a. Num a => a -> a -> a
- TExp Int64
1]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
let crossesSegment :: Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
crossesSegment = do
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
segmented
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
let from' :: TExp Int32
from' = (TExp Int32
from forall a. Num a => a -> a -> a
+ TExp Int32
1) forall a. Num a => a -> a -> a
* TExp Int32
chunk32 forall a. Num a => a -> a -> a
- TExp Int32
1
to' :: TExp Int32
to' = (TExp Int32
to forall a. Num a => a -> a -> a
+ TExp Int32
1) forall a. Num a => a -> a -> a
* TExp Int32
chunk32 forall a. Num a => a -> a -> a
- TExp Int32
1
in (TExp Int32
to' forall a. Num a => a -> a -> a
- TExp Int32
from') forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32
to' forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact forall a. Num a => a -> a -> a
- TExp Int32
boundary) forall e. IntegralExp e => e -> e -> e
`mod` TExp Int32
segsize_compact
Lambda GPUMem
scan_op1 <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan_op
[TV Any]
accs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"acc") [PrimType]
tys
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Scan results (with warp scan)" forall a b. (a -> b) -> a -> b
$ do
Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> TExp Int64
-> TExp Int64
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
crossesSegment
TExp Int64
group_size_e
TExp Int64
num_virt_threads
Lambda GPUMem
scan_op1
[VName]
prefixArrays
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let firstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread TV Any
acc VName
prefixes =
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
group_size_e forall a. Num a => a -> a -> a
- TExp Int64
1]
notFirstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread TV Any
acc VName
prefixes =
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int64
ltid forall a. Num a => a -> a -> a
- TExp Int64
1]
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int32
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
(forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread [TV Any]
accs [VName]
prefixArrays)
(forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread [TV Any]
accs [VName]
prefixArrays)
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
[TV Any]
prefixes <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanop_nes [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"prefix" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
TPrimExp Bool VName
blockNewSgm <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_new_sgm" forall a b. (a -> b) -> a -> b
$ TExp Int64
sgm_idx forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Perform lookback" forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool VName
blockNewSgm forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
accs [VName]
incprefixArrays) forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, VName
incprefixArray) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
acc) []
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
global_fence
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) []
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanop_nes [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, TV Any
acc) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [] SubExp
ne []
let warpSize :: TExp Int32
warpSize = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool VName
blockNewSgm forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
warpSize) forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Bool VName
not_segmented_e forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Int32
boundary forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64
group_size_e forall a. Num a => a -> a -> a
* TExp Int64
chunk))
( do
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
aggregateArrays [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
aggregateArray, TV Any
acc) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
aggregateArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
acc) []
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
global_fence
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusA) []
)
( do
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
acc) []
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
global_fence
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) []
)
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
warpscan [TExp Int64
0] (VName -> SubExp
Var VName
statusFlags) [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id forall a. Num a => a -> a -> a
- TExp Int64
1]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_fence
TV Int8
status <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"status" PrimType
int8 :: InKernelGen (TV Int8)
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
status) [] (VName -> SubExp
Var VName
warpscan) [TExp Int64
0]
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
status forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusP)
( forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [VName]
incprefixArrays) forall a b. (a -> b) -> a -> b
$ \(TV Any
prefix, VName
incprefixArray) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id forall a. Num a => a -> a -> a
- TExp Int64
1]
)
( do
TV Int32
readOffset <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"readOffset" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants)
let loopStop :: TExp Int32
loopStop = TExp Int32
warpSize forall a. Num a => a -> a -> a
* (-TExp Int32
1)
sameSegment :: TV Int32 -> TPrimExp Bool VName
sameSegment TV Int32
readIdx
| Bool
segmented =
let startIdx :: TExp Int64
startIdx = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readIdx forall a. Num a => a -> a -> a
+ TExp Int32
1) forall a. Num a => a -> a -> a
* TExp Int64
group_size_e forall a. Num a => a -> a -> a
* TExp Int64
chunk forall a. Num a => a -> a -> a
- TExp Int64
1
in TExp Int64
block_offset forall a. Num a => a -> a -> a
- TExp Int64
startIdx forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
sgm_idx
| Bool
otherwise = forall v. TPrimExp Bool v
true
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readOffset forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
loopStop) forall a b. (a -> b) -> a -> b
$ do
TV Int32
readI <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"read_i" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readOffset forall a. Num a => a -> a -> a
+ TExp Int32
ltid32
[TV Any]
aggrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanop_nes [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"aggr" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
TV Int8
flag <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"flag" (forall a. Num a => a
statusX :: Imp.TExp Int8)
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int32 -> TPrimExp Bool VName
sameSegment TV Int32
readI)
( do
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
statusFlags) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusP)
( forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
aggrs) forall a b. (a -> b) -> a -> b
$ \(VName
incprefix, TV Any
aggr) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
incprefix) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
)
( forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusA) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
aggregateArrays) forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
aggregate) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
aggregate) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
)
)
(forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) [])
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
aggrs) forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
aggr) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
ltid] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
aggr) []
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
warpscan [TExp Int64
ltid] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int8
flag) []
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
warpscan) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize forall a. Num a => a -> a -> a
- TExp Int64
1]
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall a. Num a => a
statusP) forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op1
KernelConstants
-> TExp Int64
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback
KernelConstants
constants
TExp Int64
num_virt_threads
VName
warpscan
[VName]
exchanges
Lambda GPUMem
lam'
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
warpscan) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize forall a. Num a => a -> a -> a
- TExp Int64
1]
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
exchanges) forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
exchange) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
exchange) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize forall a. Num a => a -> a -> a
- TExp Int64
1]
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusP)
(TV Int32
readOffset forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
loopStop)
( forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusA) forall a b. (a -> b) -> a -> b
$ do
TV Int32
readOffset forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readOffset forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
zExt32 TExp Int32
warpSize
)
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. forall a. Num a => a
statusX) forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
lam <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op1
let ([VName]
xs, [VName]
ys) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
aggrs) forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
aggr) -> forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
aggr)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
prefix) -> forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
prefix)
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
prefixes [PrimType]
tys forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam) forall a b. (a -> b) -> a -> b
$
\(TV Any
prefix, PrimType
ty, SubExp
res) -> TV Any
prefix forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
res)
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_fence
)
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
scan_op2 <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op1
let xs :: [VName]
xs = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op2
ys :: [VName]
ys = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op2
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
boundary forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64
group_size_e forall a. Num a => a -> a -> a
* TExp Int64
chunk)) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
prefix) -> forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
prefix
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
acc) -> forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
acc
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op2) forall a b. (a -> b) -> a -> b
$
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op2) forall a b. (a -> b) -> a -> b
$
\(VName
incprefixArray, SubExp
res) -> forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] SubExp
res []
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
global_fence
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$ forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) []
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
0] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
prefix) []
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
accs [PrimType]
tys [SubExp]
scanop_nes) forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, PrimType
ty, SubExp
ne) ->
forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) forall a b. (a -> b) -> a -> b
$ do
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
exchange) [TExp Int64
0]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
Lambda GPUMem
scan_op3 <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op1
Lambda GPUMem
scan_op4 <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op1
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Distribute results" forall a b. (a -> b) -> a -> b
$ do
let ([VName]
xs, [VName]
ys) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op3
([VName]
xs', [VName]
ys') = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op4
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [TV Any]
prefixes [TV Any]
accs [VName]
xs [VName]
xs' [VName]
ys [VName]
ys' [PrimType]
tys) forall a b. (a -> b) -> a -> b
$
\(TV Any
prefix, TV Any
acc, VName
x, VName
x', VName
y, VName
y', PrimType
ty) -> do
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
prefix
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
acc
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int32
ltid32 forall a. Num a => a -> a -> a
* TExp Int32
chunk32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
boundary forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool VName
blockNewSgm)
( forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op4) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [PrimType]
tys forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op4) forall a b. (a -> b) -> a -> b
$
\(VName
x, PrimType
ty, SubExp
res) -> VName
x forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
res
)
(forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
acc) -> forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [])
TExp Int32
stop <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"stopping_point" forall a b. (a -> b) -> a -> b
$
TExp Int32
segsize_compact forall a. Num a => a -> a -> a
- (TExp Int32
ltid32 forall a. Num a => a -> a -> a
* TExp Int32
chunk32 forall a. Num a => a -> a -> a
- TExp Int32
1 forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact forall a. Num a => a -> a -> a
- TExp Int32
boundary) forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
segsize_compact
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
stop forall a. Num a => a -> a -> a
- TExp Int32
1) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks [VName]
ys) forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op3) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op3) forall a b. (a -> b) -> a -> b
$
\(VName
dest, SubExp
res) ->
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
res []
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Transpose scan output and Write it to global memory in coalesced fashion" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
transposedArrays [VName]
private_chunks forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) forall a b. (a -> b) -> a -> b
$ \(VName
locmem, VName
priv, VName
dest) -> do
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TV Int64
sharedIdx <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64
ltid forall a. Num a => a -> a -> a
* TExp Int64
chunk) forall a. Num a => a -> a -> a
+ TExp Int64
i
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
locmem [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
flat_idx <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" forall a b. (a -> b) -> a -> b
$ TExp Int64
thd_offset forall a. Num a => a -> a -> a
+ TExp Int64
i forall a. Num a => a -> a -> a
* TExp Int64
group_size_e
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims') TExp Int64
flat_idx
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64
flat_idx forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
VName
dest
(forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
(VName -> SubExp
Var VName
locmem)
[forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ TExp Int64
flat_idx forall a. Num a => a -> a -> a
- TExp Int64
block_offset]
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
{-# NOINLINE compileSegScan #-}