{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.StreamKernel
( segThreadCapped,
streamRed,
streamMap,
)
where
import Control.Monad
import Control.Monad.Writer
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.Kernels hiding
( BasicOp,
Body,
Exp,
FParam,
FunDef,
LParam,
Lambda,
PatElem,
Pattern,
Prog,
RetType,
Stm,
)
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToKernels
import Futhark.Tools
import Prelude hiding (quot)
data KernelSize = KernelSize
{
KernelSize -> SubExp
kernelElementsPerThread :: SubExp,
KernelSize -> SubExp
kernelNumThreads :: SubExp
}
deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
$cp1Ord :: Eq KernelSize
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
(Int -> KernelSize -> ShowS)
-> (KernelSize -> String)
-> ([KernelSize] -> ShowS)
-> Show KernelSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSize] -> ShowS
$cshowList :: [KernelSize] -> ShowS
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> ShowS
$cshowsPrec :: Int -> KernelSize -> ShowS
Show)
numberOfGroups ::
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String ->
SubExp ->
SubExp ->
m (SubExp, SubExp)
numberOfGroups :: String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size = do
Name
max_num_groups_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
SubExp
num_groups <-
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w Name
max_num_groups_key SubExp
group_size
SubExp
num_threads <-
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
(SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_groups, SubExp
num_threads)
blockedKernelSize ::
(MonadBinder m, Lore m ~ Kernels) =>
String ->
SubExp ->
m KernelSize
blockedKernelSize :: String -> SubExp -> m KernelSize
blockedKernelSize String
desc SubExp
w = do
SubExp
group_size <- String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup
(SubExp
_, SubExp
num_threads) <- String -> SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size
SubExp
per_thread_elements <-
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_thread_elements"
(ExpT Kernels -> m SubExp) -> m (ExpT Kernels) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
num_threads)
KernelSize -> m KernelSize
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelSize -> m KernelSize) -> KernelSize -> m KernelSize
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> KernelSize
KernelSize SubExp
per_thread_elements SubExp
num_threads
splitArrays ::
(MonadBinder m, Lore m ~ Kernels) =>
VName ->
[VName] ->
SplitOrdering ->
SubExp ->
SubExp ->
SubExp ->
[VName] ->
m ()
splitArrays :: VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays VName
chunk_size [VName]
split_bound SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i [VName]
arrs = do
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
chunk_size] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i
case SplitOrdering
ordering of
SplitOrdering
SplitContiguous -> do
SubExp
offset <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
i SubExp
elems_per_i
(VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset) [VName]
split_bound [VName]
arrs
SplitStrided SubExp
stride -> (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride) [VName]
split_bound [VName]
arrs
where
contiguousSlice :: SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset VName
slice_name VName
arr = do
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
offset (VName -> SubExp
Var VName
chunk_size) (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))]
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice
stridedSlice :: SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride VName
slice_name VName
arr = do
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (VName -> SubExp
Var VName
chunk_size) SubExp
stride]
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice
partitionChunkedKernelFoldParameters ::
Int ->
[Param dec] ->
(VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters :: Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
num_accs (Param dec
i_param : Param dec
chunk_param : [Param dec]
params) =
let ([Param dec]
acc_params, [Param dec]
arr_params) = Int -> [Param dec] -> ([Param dec], [Param dec])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
in (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
i_param, Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)
partitionChunkedKernelFoldParameters Int
_ [Param dec]
_ =
String -> (VName, Param dec, [Param dec], [Param dec])
forall a. HasCallStack => String -> a
error String
"partitionChunkedKernelFoldParameters: lambda takes too few parameters"
blockedPerThread ::
(MonadBinder m, Lore m ~ Kernels) =>
VName ->
SubExp ->
KernelSize ->
StreamOrd ->
Lambda (Lore m) ->
Int ->
[VName] ->
m ([PatElemT Type], [PatElemT Type])
blockedPerThread :: VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
thread_gtid SubExp
w KernelSize
kernel_size StreamOrd
ordering Lambda (Lore m)
lam Int
num_nonconcat [VName]
arrs = do
let (VName
_, Param Type
chunk_size, [], [Param Type]
arr_params) =
Int
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
0 ([Param Type] -> (VName, Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
LambdaT Kernels
lam
ordering' :: SplitOrdering
ordering' =
case StreamOrd
ordering of
StreamOrd
InOrder -> SplitOrdering
SplitContiguous
StreamOrd
Disorder -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelNumThreads KernelSize
kernel_size
red_ts :: [Type]
red_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam
map_ts :: [Type]
map_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam
SubExp
per_thread <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelElementsPerThread KernelSize
kernel_size
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays
(Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)
((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_params)
SplitOrdering
ordering'
SubExp
w
(VName -> SubExp
Var VName
thread_gtid)
SubExp
per_thread
[VName]
arrs
[PatElemT Type]
chunk_red_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
red_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
red_t -> do
VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_red"
PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name Type
red_t
[PatElemT Type]
chunk_map_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
map_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
map_t -> do
VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_map"
PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ Type
map_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)
let ([SubExp]
chunk_red_ses, [SubExp]
chunk_map_ses) =
Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonconcat ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT Kernels -> [SubExp]) -> BodyT Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam
Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> m ()) -> Stms (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam)
Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
[ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT Type
pe, SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_red_pes [SubExp]
chunk_red_ses
]
Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
[ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT Type
pe, SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_map_pes [SubExp]
chunk_map_ses
]
([PatElemT Type], [PatElemT Type])
-> m ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes)
kerneliseLambda ::
MonadFreshNames m =>
[SubExp] ->
Lambda Kernels ->
m (Lambda Kernels)
kerneliseLambda :: [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
lam = do
VName
thread_index <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"thread_index"
let thread_index_param :: Param Type
thread_index_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
thread_index (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
(Param Type
fold_chunk_param, [Param Type]
fold_acc_params, [Param Type]
fold_inp_params) =
Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam
mkAccInit :: Param dec -> SubExp -> Stm lore
mkAccInit Param dec
p (Var VName
v)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param dec
p =
[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
mkAccInit Param dec
p SubExp
x = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
acc_init_bnds :: Stms Kernels
acc_init_bnds = [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels) -> [Stm Kernels] -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp -> Stm Kernels)
-> [Param Type] -> [SubExp] -> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param Type -> SubExp -> Stm Kernels
forall lore dec.
(Bindable lore, Typed dec) =>
Param dec -> SubExp -> Stm lore
mkAccInit [Param Type]
fold_acc_params [SubExp]
nes
LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return
LambdaT Kernels
lam
{ lambdaBody :: BodyT Kernels
lambdaBody =
Stms Kernels -> BodyT Kernels -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms Stms Kernels
acc_init_bnds (BodyT Kernels -> BodyT Kernels) -> BodyT Kernels -> BodyT Kernels
forall a b. (a -> b) -> a -> b
$
LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam,
lambdaParams :: [LParam Kernels]
lambdaParams =
Param Type
thread_index_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
Param Type
fold_chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
[Param Type]
fold_inp_params
}
prepareStream ::
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize ->
[(VName, SubExp)] ->
SubExp ->
Commutativity ->
Lambda Kernels ->
[SubExp] ->
[VName] ->
m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream :: KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = do
let (KernelSize SubExp
elems_per_thread SubExp
num_threads) = KernelSize
size
let (StreamOrd
ordering, SplitOrdering
split_ordering) =
case Commutativity
comm of
Commutativity
Commutative -> (StreamOrd
Disorder, SubExp -> SplitOrdering
SplitStrided SubExp
num_threads)
Commutativity
Noncommutative -> (StreamOrd
InOrder, SplitOrdering
SplitContiguous)
LambdaT Kernels
fold_lam' <- [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
[SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
fold_lam
VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
num_threads)]
KernelBody Kernels
kbody <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels))
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> m ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes) <-
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore (BinderT Kernels (State VNameSource)))
-> Int
-> [VName]
-> BinderT
Kernels (State VNameSource) ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
gtid SubExp
w KernelSize
size StreamOrd
ordering Lambda (Lore (BinderT Kernels (State VNameSource)))
LambdaT Kernels
fold_lam' ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
arrs
let concatReturns :: PatElemT Type -> KernelResult
concatReturns PatElemT Type
pe =
SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns SplitOrdering
split_ordering SubExp
w SubExp
elems_per_thread (VName -> KernelResult) -> VName -> KernelResult
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
pe
[KernelResult] -> Binder Kernels [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return
( (PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify (SubExp -> KernelResult)
-> (PatElemT Type -> SubExp) -> PatElemT Type -> KernelResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
chunk_red_pes
[KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> KernelResult
concatReturns [PatElemT Type]
chunk_map_pes
)
let ([Type]
redout_ts, [Type]
mapout_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT Kernels
fold_lam
ts :: [Type]
ts = [Type]
redout_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
mapout_ts
(SubExp, SegSpace, [Type], KernelBody Kernels)
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_threads, SegSpace
space, [Type]
ts, KernelBody Kernels
kbody)
streamRed ::
(MonadFreshNames m, HasScope Kernels m) =>
MkSegLevel Kernels m ->
Pattern Kernels ->
SubExp ->
Commutativity ->
Lambda Kernels ->
Lambda Kernels ->
[SubExp] ->
[VName] ->
m (Stms Kernels)
streamRed :: MkSegLevel Kernels m
-> Pattern Kernels
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (Stms Kernels)
streamRed MkSegLevel Kernels m
mk_lvl Pattern Kernels
pat SubExp
w Commutativity
comm LambdaT Kernels
red_lam LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT Kernels m () -> m (Stms Kernels))
-> BinderT Kernels m () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_red" SubExp
w
let ([PatElemT Type]
redout_pes, [PatElemT Type]
mapout_pes) = Int -> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElemT Type] -> ([PatElemT Type], [PatElemT Type]))
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat
(PatternT Type
redout_pat, [(VName, SubExp)]
ispace, BinderT Kernels m ()
read_dummy) <- Pattern (Lore (BinderT Kernels m))
-> BinderT
Kernels
m
(Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
BinderT Kernels m ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m, DistLore (Lore m)) =>
Pattern (Lore m) -> m (Pattern (Lore m), [(VName, SubExp)], m ())
dummyDim (Pattern (Lore (BinderT Kernels m))
-> BinderT
Kernels
m
(Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
BinderT Kernels m ()))
-> Pattern (Lore (BinderT Kernels m))
-> BinderT
Kernels
m
(Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
BinderT Kernels m ())
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type]
redout_pes
let pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
redout_pat [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
mapout_pes
(SubExp
_, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs
SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_red" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat' (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$
Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
SegLevel
-> SegSpace
-> [SegBinOp Kernels]
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed
SegLevel
lvl
SegSpace
kspace
[Commutativity
-> LambdaT Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm LambdaT Kernels
red_lam [SubExp]
nes Shape
forall a. Monoid a => a
mempty]
[Type]
ts
KernelBody Kernels
kbody
BinderT Kernels m ()
read_dummy
streamMap ::
(MonadFreshNames m, HasScope Kernels m) =>
MkSegLevel Kernels m ->
[String] ->
[PatElem Kernels] ->
SubExp ->
Commutativity ->
Lambda Kernels ->
[SubExp] ->
[VName] ->
m ((SubExp, [VName]), Stms Kernels)
streamMap :: MkSegLevel Kernels m
-> [String]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m ((SubExp, [VName]), Stms Kernels)
streamMap MkSegLevel Kernels m
mk_lvl [String]
out_desc [PatElem Kernels]
mapout_pes SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels))
-> BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_map" SubExp
w
(SubExp
threads, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [] SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs
let redout_ts :: [Type]
redout_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
ts
[PatElemT Type]
redout_pes <- [(String, Type)]
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([String] -> [Type] -> [(String, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
out_desc [Type]
redout_ts) (((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type])
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \(String
desc, Type
t) ->
VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem (VName -> Type -> PatElemT Type)
-> BinderT Kernels m VName
-> BinderT Kernels m (Type -> PatElemT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc BinderT Kernels m (Type -> PatElemT Type)
-> BinderT Kernels m Type -> BinderT Kernels m (PatElemT Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> BinderT Kernels m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
threads)
let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [PatElemT Type]
redout_pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
[PatElem Kernels]
mapout_pes
SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_map" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
kspace [Type]
ts KernelBody Kernels
kbody
(SubExp, [VName]) -> BinderT Kernels m (SubExp, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
threads, (PatElemT Type -> VName) -> [PatElemT Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT Type]
redout_pes)
segThreadCapped :: MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped :: MkSegLevel Kernels m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
SubExp
w <-
String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"nest_size"
(ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ws
SubExp
group_size <- String -> SizeClass -> BinderT Kernels m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup
case ThreadRecommendation
r of
ThreadRecommendation
ManyThreads -> do
SubExp
usable_groups <-
String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups"
(ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
(IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe)
(SubExp -> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w)
(SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
NoRecommendation SegVirt
v -> do
(SubExp
num_groups, SubExp
_) <- String -> SubExp -> SubExp -> BinderT Kernels m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size
SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
v