{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.ImpGen.Kernels.SegScan
( compileSegScan )
where
import Control.Monad.Except
import Data.Maybe
import Data.List ()
import Prelude hiding (quot, rem)
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
makeLocalArrays :: Count GroupSize SubExp -> SubExp -> [SubExp] -> Lambda ExplicitMemory
-> InKernelGen [VName]
makeLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [SubExp]
-> Lambda ExplicitMemory
-> InKernelGen [VName]
makeLocalArrays (Count SubExp
group_size) SubExp
num_threads [SubExp]
nes Lambda ExplicitMemory
scan_op = do
let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
_scan_y_params) =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
[Param (MemInfo SubExp NoUniqueness MemBind)]
-> (Param (MemInfo SubExp NoUniqueness MemBind)
-> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params ((Param (MemInfo SubExp NoUniqueness MemBind)
-> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> InKernelGen [VName])
-> (Param (MemInfo SubExp NoUniqueness MemBind)
-> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param (MemInfo SubExp NoUniqueness MemBind)
p ->
case Param (MemInfo SubExp NoUniqueness MemBind)
-> MemInfo SubExp NoUniqueness MemBind
forall attr. Param attr -> attr
paramAttr Param (MemInfo SubExp NoUniqueness MemBind)
p of
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
let shape' :: ShapeBase SubExp
shape' = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray String
"scan_arr" PrimType
pt ShapeBase SubExp
shape' (MemBind -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> MemBind -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape'
MemInfo SubExp NoUniqueness MemBind
_ -> do
let pt :: PrimType
pt = TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray String
"scan_arr" PrimType
pt ShapeBase SubExp
shape (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
type CrossesSegment = Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
localArrayIndex :: KernelConstants -> Type -> Imp.Exp
localArrayIndex :: KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t =
if TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase (ShapeBase SubExp) NoUniqueness
t
then KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
else KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
barrierFor :: Lambda ExplicitMemory -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda ExplicitMemory -> (Bool, Fence, InKernelGen ())
barrierFor Lambda ExplicitMemory
scan_op = (Bool
array_scan, Fence
fence, KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
where array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda ExplicitMemory
scan_op
fence :: Fence
fence | Bool
array_scan = Fence
Imp.FenceGlobal
| Bool
otherwise = Fence
Imp.FenceLocal
scanStage1 :: Pattern ExplicitMemory
-> Count NumGroups SubExp -> Count GroupSize SubExp -> SegSpace
-> Lambda ExplicitMemory -> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen (VName, Imp.Exp, CrossesSegment)
scanStage1 :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen (VName, Exp, CrossesSegment)
scanStage1 (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes KernelBody ExplicitMemory
kbody = do
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
VName
num_threads <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
rets :: [TypeBase (ShapeBase SubExp) NoUniqueness]
rets = Lambda ExplicitMemory -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda ExplicitMemory
scan_op
[Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
let num_elements :: Exp
num_elements = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims'
elems_per_thread :: Exp
elems_per_thread = Exp
num_elements Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` VName -> Exp
Imp.vi32 VName
num_threads
elems_per_group :: Exp
elems_per_group = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_thread
Lambda ExplicitMemory
scan_op_renamed <- Lambda ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
scan_op
let crossesSegment :: CrossesSegment
crossesSegment =
case [Exp] -> [Exp]
forall a. [a] -> [a]
reverse [Exp]
dims' of
Exp
segment_size : Exp
_ : [Exp]
_ -> (Exp -> Exp -> Exp) -> CrossesSegment
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> CrossesSegment)
-> (Exp -> Exp -> Exp) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
(Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)
[Exp]
_ -> CrossesSegment
forall a. Maybe a
Nothing
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"scan_stage1" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[VName]
local_arrs <- Count GroupSize SubExp
-> SubExp
-> [SubExp]
-> Lambda ExplicitMemory
-> InKernelGen [VName]
makeLocalArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads) [SubExp]
nes Lambda ExplicitMemory
scan_op
Maybe (Exp ExplicitMemory)
-> Scope ExplicitMemory -> InKernelGen ()
forall lore r op.
Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore r op ()
dScope Maybe (Exp ExplicitMemory)
forall a. Maybe a
Nothing (Scope ExplicitMemory -> InKernelGen ())
-> Scope ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params) =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
[(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [SubExp]
nes) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"j" Exp
elems_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
j -> do
VName
chunk_offset <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"chunk_offset" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
j Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
KernelConstants -> Exp
kernelGroupId KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group
VName
flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
flat_idx PrimType
int32
let in_bounds :: Exp
in_bounds =
(Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) [Exp]
dims'
when_in_bounds :: InKernelGen ()
when_in_bounds = Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
kbody
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write to-scan values to parameters" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [KernelResult]
scan_res) (((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, KernelResult
se) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write mapped values results to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) [KernelResult]
map_res) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, KernelResult
se) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []
when_out_of_bounds :: InKernelGen ()
when_out_of_bounds = [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [SubExp]
nes) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bounds read input; others get neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
in_bounds InKernelGen ()
when_in_bounds InKernelGen ()
when_out_of_bounds
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"combine with carry and write to local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT ExplicitMemory -> Stms ExplicitMemory
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT ExplicitMemory -> Stms ExplicitMemory)
-> BodyT ExplicitMemory -> Stms ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> BodyT ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_op) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [VName]
local_arrs (BodyT ExplicitMemory -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT ExplicitMemory -> [SubExp])
-> BodyT ExplicitMemory -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> BodyT ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_op)) (((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
-> InKernelGen ())
-> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, SubExp
se) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t] SubExp
se []
let crossesSegment' :: CrossesSegment
crossesSegment' = do
Exp -> Exp -> Exp
f <- CrossesSegment
crossesSegment
(Exp -> Exp -> Exp) -> CrossesSegment
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> CrossesSegment)
-> (Exp -> Exp -> Exp) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
let from' :: Exp
from' = Exp
from Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32
to' :: Exp
to' = Exp
to Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32
in Exp -> Exp -> Exp
f Exp
from' Exp
to'
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence
CrossesSegment
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan CrossesSegment
crossesSegment'
(VName -> Exp
Imp.vi32 VName
num_threads)
(KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Lambda ExplicitMemory
scan_op_renamed [VName]
local_arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bounds write partial scan result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids)
(VName -> SubExp
Var VName
arr) [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t]
InKernelGen ()
barrier
let load_carry :: InKernelGen ()
load_carry =
[(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((VName, Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr)
[if TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
then KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
else (KernelConstants -> Exp
kernelGroupId KernelConstants
constantsExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]
load_neutral :: InKernelGen ()
load_neutral =
[(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread reads last element as carry-in for next iteration" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Exp
crosses_segment <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"crosses_segment" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
case CrossesSegment
crossesSegment of
CrossesSegment
Nothing -> Exp
forall v. PrimExp v
false
Just Exp -> Exp -> Exp
f -> Exp -> Exp -> Exp
f (VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
KernelConstants -> Exp
kernelGroupSize KernelConstants
constantsExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1)
(VName -> PrimType -> Exp
Imp.var VName
chunk_offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
KernelConstants -> Exp
kernelGroupSize KernelConstants
constants)
Exp
should_load_carry <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"should_load_carry" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
Not Exp
crosses_segment
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
should_load_carry InKernelGen ()
load_carry
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
should_load_carry InKernelGen ()
load_neutral
InKernelGen ()
barrier
(VName, Exp, CrossesSegment)
-> CallKernelGen (VName, Exp, CrossesSegment)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
num_threads, Exp
elems_per_group, CrossesSegment
crossesSegment)
where (Bool
array_scan, Fence
fence, InKernelGen ()
barrier) = Lambda ExplicitMemory -> (Bool, Fence, InKernelGen ())
barrierFor Lambda ExplicitMemory
scan_op
scanStage2 :: Pattern ExplicitMemory
-> VName -> Imp.Exp -> Count NumGroups SubExp -> CrossesSegment -> SegSpace
-> Lambda ExplicitMemory -> [SubExp]
-> CallKernelGen ()
scanStage2 :: Pattern ExplicitMemory
-> VName
-> Exp
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage2 (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)]
pes) VName
stage1_num_threads Exp
elems_per_group Count NumGroups SubExp
num_groups CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes = do
let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
rets :: [TypeBase (ShapeBase SubExp) NoUniqueness]
rets = Lambda ExplicitMemory -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda ExplicitMemory
scan_op
[Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
let crossesSegment' :: CrossesSegment
crossesSegment' = do
Exp -> Exp -> Exp
f <- CrossesSegment
crossesSegment
(Exp -> Exp -> Exp) -> CrossesSegment
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> CrossesSegment)
-> (Exp -> Exp -> Exp) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
Exp -> Exp -> Exp
f ((Exp
from Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) ((Exp
to Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"scan_stage2" Count NumGroups Exp
1 Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[VName]
local_arrs <- Count GroupSize SubExp
-> SubExp
-> [SubExp]
-> Lambda ExplicitMemory
-> InKernelGen [VName]
makeLocalArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
stage1_num_threads) [SubExp]
nes Lambda ExplicitMemory
scan_op
VName
flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
(KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
flat_idx PrimType
int32
let in_bounds :: Exp
in_bounds =
(Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) [Exp]
dims'
when_in_bounds :: InKernelGen ()
when_in_bounds = [(TypeBase (ShapeBase SubExp) NoUniqueness, VName,
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName,
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [VName]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness, VName,
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [VName]
local_arrs [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((TypeBase (ShapeBase SubExp) NoUniqueness, VName,
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName,
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids
when_out_of_bounds :: InKernelGen ()
when_out_of_bounds = [(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [VName]
local_arrs [SubExp]
nes) (((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
-> InKernelGen ())
-> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness, VName, SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, VName
arr, SubExp
ne) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t] SubExp
ne []
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bound read carries; others get neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
in_bounds InKernelGen ()
when_in_bounds InKernelGen ()
when_out_of_bounds
InKernelGen ()
barrier
CrossesSegment
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan CrossesSegment
crossesSegment'
(VName -> Exp
Imp.vi32 VName
stage1_num_threads) (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Lambda ExplicitMemory
scan_op [VName]
local_arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"threads in bounds write scanned carries" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase (ShapeBase SubExp) NoUniqueness]
rets [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((TypeBase (ShapeBase SubExp) NoUniqueness,
PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase (ShapeBase SubExp) NoUniqueness
t, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids)
(VName -> SubExp
Var VName
arr) [KernelConstants -> TypeBase (ShapeBase SubExp) NoUniqueness -> Exp
localArrayIndex KernelConstants
constants TypeBase (ShapeBase SubExp) NoUniqueness
t]
where (Bool
_, Fence
_, InKernelGen ()
barrier) = Lambda ExplicitMemory -> (Bool, Fence, InKernelGen ())
barrierFor Lambda ExplicitMemory
scan_op
scanStage3 :: Pattern ExplicitMemory
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> Imp.Exp -> CrossesSegment -> SegSpace
-> Lambda ExplicitMemory -> [SubExp]
-> CallKernelGen ()
scanStage3 :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Exp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage3 (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size Exp
elems_per_group CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes = do
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
[Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
Exp
required_groups <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"required_groups" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
[Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims' Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"scan_stage3" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt Exp
required_groups ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
virt_group_id -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Exp
flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
VName -> Exp
Imp.vi32 VName
virt_group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' Exp
flat_idx
VName
orig_group <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"orig_group" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ Exp
flat_idx Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
elems_per_group
VName
carry_in_flat_idx <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"carry_in_flat_idx" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
VName -> PrimType -> Exp
Imp.var VName
orig_group PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
let carry_in_idx :: [Exp]
carry_in_idx = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
carry_in_flat_idx PrimType
int32
let in_bounds :: Exp
in_bounds =
(Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) [Exp]
dims'
crosses_segment :: Exp
crosses_segment = Exp -> Maybe Exp -> Exp
forall a. a -> Maybe a -> a
fromMaybe Exp
forall v. PrimExp v
false (Maybe Exp -> Exp) -> Maybe Exp -> Exp
forall a b. (a -> b) -> a -> b
$
CrossesSegment
crossesSegment CrossesSegment -> Maybe Exp -> Maybe (Exp -> Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
Exp -> Maybe Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> PrimType -> Exp
Imp.var VName
carry_in_flat_idx PrimType
int32) Maybe (Exp -> Exp) -> Maybe Exp -> Maybe Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
Exp -> Maybe Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
flat_idx
is_a_carry :: Exp
is_a_carry = Exp
flat_idx Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==.
(VName -> PrimType -> Exp
Imp.var VName
orig_group PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
elems_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
no_carry_in :: Exp
no_carry_in = VName -> PrimType -> Exp
Imp.var VName
orig_group PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
is_a_carry Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
crosses_segment
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Maybe (Exp ExplicitMemory)
-> Scope ExplicitMemory -> InKernelGen ()
forall lore r op.
Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore r op ()
dScope Maybe (Exp ExplicitMemory)
forall a. Maybe a
Nothing (Scope ExplicitMemory -> InKernelGen ())
-> Scope ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params) =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_op
[(Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
carry_in_idx
[(Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids
[Param (MemInfo SubExp NoUniqueness MemBind)]
-> BodyT ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params (BodyT ExplicitMemory -> InKernelGen ())
-> BodyT ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> BodyT ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_op
[(Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
gtids) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
compileSegScan :: Pattern ExplicitMemory
-> SegLevel -> SegSpace
-> Lambda ExplicitMemory -> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
compileSegScan :: Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
compileSegScan Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes KernelBody ExplicitMemory
kbody = do
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
(VName
stage1_num_threads, Exp
elems_per_group, CrossesSegment
crossesSegment) <-
Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> KernelBody ExplicitMemory
-> CallKernelGen (VName, Exp, CrossesSegment)
scanStage1 Pattern ExplicitMemory
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegSpace
space Lambda ExplicitMemory
scan_op [SubExp]
nes KernelBody ExplicitMemory
kbody
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
elems_per_group
Lambda ExplicitMemory
scan_op' <- Lambda ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
scan_op
Lambda ExplicitMemory
scan_op'' <- Lambda ExplicitMemory
-> ImpM ExplicitMemory HostEnv HostOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
scan_op
Pattern ExplicitMemory
-> VName
-> Exp
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage2 Pattern ExplicitMemory
pat VName
stage1_num_threads Exp
elems_per_group (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op' [SubExp]
nes
Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Exp
-> CrossesSegment
-> SegSpace
-> Lambda ExplicitMemory
-> [SubExp]
-> CallKernelGen ()
scanStage3 Pattern ExplicitMemory
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Exp
elems_per_group CrossesSegment
crossesSegment SegSpace
space Lambda ExplicitMemory
scan_op'' [SubExp]
nes