{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.Segmented
( regularSegmentedRedomap
, regularSegmentedScan
)
where
import Control.Monad
import qualified Data.Map.Strict as M
import Data.Semigroup ((<>))
import Futhark.Transform.Rename
import Futhark.Representation.Kernels
import Futhark.Representation.SOACS.SOAC (nilFn)
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass.ExtractKernels.BlockedKernel
data SegmentedVersion = OneGroupOneSegment
| ManyGroupsOneSegment
deriving (Eq, Ord, Show)
regularSegmentedRedomap :: (HasScope Kernels m, MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> SubExp
-> [SubExp]
-> Pattern Kernels
-> Pattern Kernels
-> SubExp
-> Commutativity
-> Lambda InKernel
-> Lambda InKernel
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> m ()
regularSegmentedRedomap segment_size num_segments nest_sizes flat_pat
pat w comm reduce_lam fold_lam ispace inps nes arrs_flat = do
unless (null $ patternContextElements pat) $ fail "regularSegmentedRedomap result pattern contains context elements, and Rasmus did not think this would ever happen."
map_out_arrs <- forM (drop num_redres $ patternIdents pat) $ \(Ident name t) -> do
tmp <- letExp (baseString name <> "_out_in") $
BasicOp $ Scratch (elemType t) (arrayDims t)
letExp (baseString name ++ "_out_in") $
BasicOp $ Reshape (reshapeOuter [DimNew w] (length nest_sizes+1) $ arrayShape t) tmp
forM_ arrs_flat $ \arr -> do
tp <- lookupType arr
case tp of
Array _primtp (Shape (flatsize:_)) _uniqness ->
when (flatsize /= w) $
fail$ "regularSegmentedRedomap: first dimension of array has incorrect size " ++ pretty arr ++ ":" ++ pretty tp
_ ->
fail $ "regularSegmentedRedomap: non array encountered " ++ pretty arr ++ ":" ++ pretty tp
chunk_pat <- fmap (Pattern []) $ forM (patternValueElements pat) $ \pat_e ->
case patElemType pat_e of
Array ty (Shape (dim0:_)) u -> do
vn' <- newName $ patElemName pat_e
return $ PatElem vn' $ Array ty (Shape [dim0]) u
_ -> fail $ "segmentedRedomap: result pattern is not array " ++ pretty pat_e
chunk_fold_lam <- chunkLambda chunk_pat nes fold_lam
kern_chunk_fold_lam <- kerneliseLambda nes chunk_fold_lam
let chunk_red_pat = Pattern [] $ take num_redres $ patternValueElements chunk_pat
kern_chunk_reduce_lam <- kerneliseLambda nes =<< chunkLambda chunk_red_pat nes reduce_lam
my_index <- newVName "my_index"
other_offset <- newVName "other_offset"
let my_index_param = Param my_index (Prim int32)
let other_offset_param = Param other_offset (Prim int32)
let reduce_lam' = reduce_lam { lambdaParams = my_index_param :
other_offset_param :
lambdaParams reduce_lam
}
flag_reduce_lam <- addFlagToLambda nes reduce_lam
let flag_reduce_lam' = flag_reduce_lam { lambdaParams = my_index_param :
other_offset_param :
lambdaParams flag_reduce_lam
}
group_size <- getSize "group_size" SizeGroup
num_groups_hint <- getSize "num_groups_hint" SizeNumGroups
(num_groups_per_segment, _) <-
calcGroupsPerSegmentAndElementsPerThread
segment_size num_segments num_groups_hint group_size ManyGroupsOneSegment
let all_arrs = arrs_flat ++ map_out_arrs
(large_1_ses, large_1_stms) <- runBinder $
useLargeOnePerSeg group_size all_arrs reduce_lam' kern_chunk_fold_lam
(large_m_ses, large_m_stms) <- runBinder $
useLargeMultiRecursiveReduce group_size all_arrs reduce_lam' kern_chunk_fold_lam
kern_chunk_reduce_lam flag_reduce_lam'
let e_large_seg = eIf (eCmpOp (CmpEq $ IntType Int32) (eSubExp num_groups_per_segment)
(eSubExp one))
(mkBodyM large_1_stms large_1_ses)
(mkBodyM large_m_stms large_m_ses)
(small_ses, small_stms) <- runBinder $ useSmallKernel group_size map_out_arrs flag_reduce_lam'
e <- eIf (eCmpOp (CmpSlt Int32) (eBinOp (SQuot Int32) (eSubExp group_size) (eSubExp two))
(eSubExp segment_size))
(eBody [e_large_seg])
(mkBodyM small_stms small_ses)
redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do
vn' <- newName $ patElemName pe
return $ PatElem vn' $ replaceSegmentDims num_segments $ patElemType pe
let mapres_pes = drop num_redres $ patternValueElements flat_pat
let unreshaped_pat = Pattern [] $ redres_pes ++ mapres_pes
letBind_ unreshaped_pat e
forM_ (zip (patternValueElements unreshaped_pat)
(patternValueElements pat)) $ \(kpe, pe) ->
letBind_ (Pattern [] [pe]) $
BasicOp $ Reshape [DimNew se | se <- arrayDims $ patElemAttr pe]
(patElemName kpe)
where
replaceSegmentDims d t =
t `setArrayDims` (d : drop (length nest_sizes) (arrayDims t))
one = constant (1 :: Int32)
two = constant (2 :: Int32)
num_redres = length nes
useLargeOnePerSeg group_size all_arrs reduce_lam' kern_chunk_fold_lam = do
mapres_pes <- forM (drop num_redres $ patternValueElements flat_pat) $ \pe -> do
vn' <- newName $ patElemName pe
return $ PatElem vn' $ patElemType pe
(kernel, _, _) <-
largeKernel group_size segment_size num_segments nest_sizes
all_arrs comm reduce_lam' kern_chunk_fold_lam
nes w OneGroupOneSegment
ispace inps
kernel_redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do
vn' <- newName $ patElemName pe
return $ PatElem vn' $ replaceSegmentDims num_segments $ patElemType pe
let kernel_pat = Pattern [] $ kernel_redres_pes ++ mapres_pes
addStm =<< renameStm (Let kernel_pat (defAux ()) $ Op kernel)
return $ map (Var . patElemName) $ patternValueElements kernel_pat
useLargeMultiRecursiveReduce group_size all_arrs reduce_lam' kern_chunk_fold_lam kern_chunk_reduce_lam flag_reduce_lam' = do
mapres_pes <- forM (drop num_redres $ patternValueElements flat_pat) $ \pe -> do
vn' <- newName $ patElemName pe
return $ PatElem vn' $ patElemType pe
(firstkernel, num_groups_used, num_groups_per_segment) <-
largeKernel group_size segment_size num_segments nest_sizes
all_arrs comm reduce_lam' kern_chunk_fold_lam
nes w ManyGroupsOneSegment
ispace inps
firstkernel_redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do
vn' <- newName $ patElemName pe
return $ PatElem vn' $ replaceSegmentDims num_groups_used $ patElemType pe
let first_pat = Pattern [] $ firstkernel_redres_pes ++ mapres_pes
addStm =<< renameStm (Let first_pat (defAux ()) $ Op firstkernel)
let new_segment_size = num_groups_per_segment
let new_total_elems = num_groups_used
let tmp_redres = map patElemName firstkernel_redres_pes
(finalredres, part_two_stms) <- runBinder $ performFinalReduction
new_segment_size new_total_elems tmp_redres
reduce_lam' kern_chunk_reduce_lam flag_reduce_lam'
mapM_ (addStm <=< renameStm) part_two_stms
return $ finalredres ++ map (Var . patElemName) mapres_pes
performFinalReduction new_segment_size new_total_elems tmp_redres
reduce_lam' kern_chunk_reduce_lam flag_reduce_lam' = do
group_size <- getSize "group_size" SizeGroup
(large_ses, large_stms) <- runBinder $ do
(large_kernel, _, _) <- largeKernel group_size new_segment_size num_segments nest_sizes
tmp_redres comm reduce_lam' kern_chunk_reduce_lam
nes new_total_elems OneGroupOneSegment
ispace inps
letTupExp' "kernel_result" $ Op large_kernel
(small_ses, small_stms) <- runBinder $ do
red_scratch_arrs <- forM (take num_redres $ patternIdents pat) $ \(Ident name t) -> do
tmp <- letExp (baseString name <> "_redres_scratch") $
BasicOp $ Scratch (elemType t) (arrayDims t)
let reshape = reshapeOuter [DimNew num_segments] (length nest_sizes) $ arrayShape t
letExp (baseString name ++ "_redres_scratch") $
BasicOp $ Reshape reshape tmp
kernel <- smallKernel group_size new_segment_size num_segments
tmp_redres red_scratch_arrs
comm flag_reduce_lam' reduce_lam
nes new_total_elems ispace inps
letTupExp' "kernel_result" $ Op kernel
e <- eIf (eCmpOp (CmpSlt Int32)
(eBinOp (SQuot Int32) (eSubExp group_size) (eSubExp two))
(eSubExp new_segment_size))
(mkBodyM large_stms large_ses)
(mkBodyM small_stms small_ses)
letTupExp' "step_two_kernel_result" e
useSmallKernel group_size map_out_arrs flag_reduce_lam' = do
red_scratch_arrs <-
forM (take num_redres $ patternIdents pat) $ \(Ident name t) -> do
tmp <- letExp (baseString name <> "_redres_scratch") $
BasicOp $ Scratch (elemType t) (arrayDims t)
let shape_change = reshapeOuter [DimNew num_segments]
(length nest_sizes) (arrayShape t)
letExp (baseString name ++ "_redres_scratch") $
BasicOp $ Reshape shape_change tmp
let scratch_arrays = red_scratch_arrs ++ map_out_arrs
kernel <- smallKernel group_size segment_size num_segments
arrs_flat scratch_arrays
comm flag_reduce_lam' fold_lam
nes w ispace inps
letTupExp' "kernel_result" $ Op kernel
largeKernel :: (MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> SubExp
-> SubExp
-> [SubExp]
-> [VName]
-> Commutativity
-> Lambda InKernel
-> Lambda InKernel
-> [SubExp]
-> SubExp
-> SegmentedVersion
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Kernel InKernel, SubExp, SubExp)
largeKernel group_size segment_size num_segments nest_sizes all_arrs comm
reduce_lam' kern_chunk_fold_lam
nes w segver ispace inps = do
let num_redres = length nes
num_groups_hint <- getSize "num_groups_hint" SizeNumGroups
(num_groups_per_segment, elements_per_thread) <-
calcGroupsPerSegmentAndElementsPerThread segment_size num_segments num_groups_hint group_size segver
num_groups <- letSubExp "num_groups" $
case segver of
OneGroupOneSegment -> BasicOp $ SubExp num_segments
ManyGroupsOneSegment -> BasicOp $ BinOp (Mul Int32) num_segments num_groups_per_segment
num_threads <- letSubExp "num_threads" $
BasicOp $ BinOp (Mul Int32) num_groups group_size
threads_within_segment <- letSubExp "threads_within_segment" $
BasicOp $ BinOp (Mul Int32) group_size num_groups_per_segment
gtid_vn <- newVName "gtid"
gtid_ln <- newVName "gtid"
space <- newKernelSpace (num_groups, group_size, num_threads) $
FlatThreadSpace $ ispace ++ [(gtid_vn, num_groups_per_segment),(gtid_ln,group_size)]
let red_ts = take num_redres $ lambdaReturnType kern_chunk_fold_lam
let map_ts = map rowType $ drop num_redres $ lambdaReturnType kern_chunk_fold_lam
let kernel_return_types = red_ts ++ map_ts
let ordering = case comm of Commutative -> SplitStrided threads_within_segment
Noncommutative -> SplitContiguous
let stride = case ordering of SplitStrided s -> s
SplitContiguous -> one
let each_thread = do
segment_index <- letSubExp "segment_index" $
BasicOp $ BinOp (SQuot Int32) (Var $ spaceGroupId space) num_groups_per_segment
index_within_segment <- letSubExp "index_within_segment" =<<
eBinOp (Add Int32)
(eSubExp $ Var gtid_ln)
(eBinOp (Mul Int32)
(eSubExp group_size)
(eBinOp (SRem Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_groups_per_segment))
)
(in_segment_offset,offset) <-
makeOffsetExp ordering index_within_segment elements_per_thread segment_index
let (_, chunksize, [], arr_params) =
partitionChunkedKernelFoldParameters 0 $ lambdaParams kern_chunk_fold_lam
let chunksize_se = Var $ paramName chunksize
patelems_res_of_split <- forM arr_params $ \arr_param -> do
let chunk_t = paramType arr_param `setOuterSize` Var (paramName chunksize)
return $ PatElem (paramName arr_param) chunk_t
letBind_ (Pattern [] [PatElem (paramName chunksize) $ paramType chunksize]) $
Op $ SplitSpace ordering segment_size index_within_segment elements_per_thread
addKernelInputStms inps
forM_ (zip all_arrs patelems_res_of_split) $ \(arr, pe) -> do
let pe_t = patElemType pe
segment_dims = nest_sizes ++ arrayDims (pe_t `setOuterSize` segment_size)
arr_nested <- letExp (baseString arr ++ "_nested") $
BasicOp $ Reshape (map DimNew segment_dims) arr
arr_nested_t <- lookupType arr_nested
let slice = fullSlice arr_nested_t $ map (DimFix . Var . fst) ispace ++
[DimSlice in_segment_offset chunksize_se stride]
letBind_ (Pattern [] [pe]) $ BasicOp $ Index arr_nested slice
red_pes <- forM red_ts $ \red_t -> do
pe_name <- newVName "chunk_fold_red"
return $ PatElem pe_name red_t
map_pes <- forM map_ts $ \map_t -> do
pe_name <- newVName "chunk_fold_map"
return $ PatElem pe_name $ map_t `arrayOfRow` chunksize_se
addStms $ bodyStms (lambdaBody kern_chunk_fold_lam)
addStms $ stmsFromList
[ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se
| (pe,se) <- zip (red_pes ++ map_pes)
(bodyResult $ lambdaBody kern_chunk_fold_lam) ]
combine_red_pes <- forM red_ts $ \red_t -> do
pe_name <- newVName "chunk_fold_red"
return $ PatElem pe_name $ red_t `arrayOfRow` group_size
cids <- replicateM (length red_pes) $ newVName "cid"
addStms $ stmsFromList
[ Let (Pattern [] [pe']) (defAux ()) $
Op $ Combine (combineSpace [(cid, group_size)]) [patElemType pe] [] $
Body () mempty [Var $ patElemName pe]
| (cid, pe', pe) <- zip3 cids combine_red_pes red_pes ]
final_red_pes <- forM (lambdaReturnType reduce_lam') $ \t -> do
pe_name <- newVName "final_result"
return $ PatElem pe_name t
letBind_ (Pattern [] final_red_pes) $
Op $ GroupReduce group_size reduce_lam' $
zip nes (map patElemName combine_red_pes)
return (final_red_pes, map_pes, offset)
((final_red_pes, map_pes, offset), stms) <- runBinder each_thread
red_returns <- forM final_red_pes $ \pe ->
return $ ThreadsReturn OneResultPerGroup $ Var $ patElemName pe
map_returns <- forM map_pes $ \pe ->
return $ ConcatReturns ordering w elements_per_thread
(Just offset) $
patElemName pe
let kernel_returns = red_returns ++ map_returns
let kerneldebughints = KernelDebugHints kernelname
[ ("num_segment", num_segments)
, ("segment_size", segment_size)
, ("num_groups", num_groups)
, ("group_size", group_size)
, ("elements_per_thread", elements_per_thread)
, ("num_groups_per_segment", num_groups_per_segment)
]
let kernel = Kernel kerneldebughints space kernel_return_types $
KernelBody () stms kernel_returns
return (kernel, num_groups, num_groups_per_segment)
where
one = constant (1 :: Int32)
commname = case comm of Commutative -> "comm"
Noncommutative -> "nocomm"
kernelname = case segver of
OneGroupOneSegment -> "segmented_redomap__large_" ++ commname ++ "_one"
ManyGroupsOneSegment -> "segmented_redomap__large_" ++ commname ++ "_many"
makeOffsetExp SplitContiguous index_within_segment elements_per_thread segment_index = do
in_segment_offset <- letSubExp "in_segment_offset" $
BasicOp $ BinOp (Mul Int32) elements_per_thread index_within_segment
offset <- letSubExp "offset" =<< eBinOp (Add Int32) (eSubExp in_segment_offset)
(eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index))
return (in_segment_offset, offset)
makeOffsetExp (SplitStrided _) index_within_segment _elements_per_thread segment_index = do
offset <- letSubExp "offset" =<< eBinOp (Add Int32) (eSubExp index_within_segment)
(eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index))
return (index_within_segment, offset)
calcGroupsPerSegmentAndElementsPerThread :: (MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> SubExp
-> SubExp
-> SubExp
-> SegmentedVersion
-> m (SubExp, SubExp)
calcGroupsPerSegmentAndElementsPerThread segment_size num_segments
num_groups_hint group_size segver = do
num_groups_per_segment_hint <-
letSubExp "num_groups_per_segment_hint" =<<
case segver of
OneGroupOneSegment -> eSubExp one
ManyGroupsOneSegment -> eDivRoundingUp Int32 (eSubExp num_groups_hint)
(eSubExp num_segments)
elements_per_thread <-
letSubExp "elements_per_thread" =<<
eDivRoundingUp Int32 (eSubExp segment_size)
(eBinOp (Mul Int32) (eSubExp group_size)
(eSubExp num_groups_per_segment_hint))
num_groups_per_segment <-
letSubExp "num_groups_per_segment" =<<
case segver of
OneGroupOneSegment -> eSubExp one
ManyGroupsOneSegment ->
eIf (eCmpOp (CmpEq $ IntType Int32) (eSubExp elements_per_thread) (eSubExp one))
(eBody [eDivRoundingUp Int32 (eSubExp segment_size) (eSubExp group_size)])
(mkBodyM mempty [num_groups_per_segment_hint])
return (num_groups_per_segment, elements_per_thread)
where
one = constant (1 :: Int32)
smallKernel :: (MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> SubExp
-> SubExp
-> [VName]
-> [VName]
-> Commutativity
-> Lambda InKernel
-> Lambda InKernel
-> [SubExp]
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Kernel InKernel)
smallKernel group_size segment_size num_segments in_arrs scratch_arrs
comm flag_reduce_lam' fold_lam_unrenamed
nes w ispace inps = do
let num_redres = length nes
fold_lam <- renameLambda fold_lam_unrenamed
num_segments_per_group <- letSubExp "num_segments_per_group" $
BasicOp $ BinOp (SQuot Int32) group_size segment_size
num_groups <- letSubExp "num_groups" =<<
eDivRoundingUp Int32 (eSubExp num_segments) (eSubExp num_segments_per_group)
num_threads <- letSubExp "num_threads" $
BasicOp $ BinOp (Mul Int32) num_groups group_size
active_threads_per_group <- letSubExp "active_threads_per_group" $
BasicOp $ BinOp (Mul Int32) segment_size num_segments_per_group
let remainder_last_group = eBinOp (SRem Int32) (eSubExp num_segments) (eSubExp num_segments_per_group)
segments_in_last_group <- letSubExp "seg_in_last_group" =<<
eIf (eCmpOp (CmpEq $ IntType Int32) remainder_last_group
(eSubExp zero))
(eBody [eSubExp num_segments_per_group])
(eBody [remainder_last_group])
active_threads_in_last_group <- letSubExp "active_threads_last_group" $
BasicOp $ BinOp (Mul Int32) segment_size segments_in_last_group
space <- newKernelSpace (num_groups, group_size, num_threads) $
FlatThreadSpace []
let lid = Var $ spaceLocalId space
let (red_ts, map_ts) = splitAt num_redres $ lambdaReturnType fold_lam
let kernel_return_types = red_ts ++ map_ts
let wasted_thread_part1 = do
let create_dummy_val (Prim ty) = return $ Constant $ blankPrimValue ty
create_dummy_val (Array ty sh _) = letSubExp "dummy" $ BasicOp $ Scratch ty (shapeDims sh)
create_dummy_val Mem{} = fail "segredomap, 'Mem' used as result type"
dummy_vals <- mapM create_dummy_val kernel_return_types
return (negone : dummy_vals)
let normal_thread_part1 = do
segment_index <- letSubExp "segment_index" =<<
eBinOp (Add Int32)
(eBinOp (SQuot Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size))
(eBinOp (Mul Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_segments_per_group))
index_within_segment <- letSubExp "index_within_segment" =<<
eBinOp (SRem Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size)
offset <- makeOffsetExp index_within_segment segment_index
red_pes <- forM red_ts $ \red_t -> do
pe_name <- newVName "fold_red"
return $ PatElem pe_name red_t
map_pes <- forM map_ts $ \map_t -> do
pe_name <- newVName "fold_map"
return $ PatElem pe_name map_t
addManualIspaceCalcStms segment_index ispace
addKernelInputStms inps
let arr_params = drop num_redres $ lambdaParams fold_lam
let nonred_lamparam_pes = map
(\p -> PatElem (paramName p) (paramType p)) arr_params
forM_ (zip in_arrs nonred_lamparam_pes) $ \(arr, pe) -> do
tp <- lookupType arr
let slice = fullSlice tp [DimFix offset]
letBind_ (Pattern [] [pe]) $ BasicOp $ Index arr slice
forM_ (zip nes (take num_redres $ lambdaParams fold_lam)) $ \(ne,param) -> do
let pe = PatElem (paramName param) (paramType param)
letBind_ (Pattern [] [pe]) $ BasicOp $ SubExp ne
addStms $ bodyStms $ lambdaBody fold_lam
addStms $ stmsFromList
[ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se
| (pe,se) <- zip (red_pes ++ map_pes) (bodyResult $ lambdaBody fold_lam) ]
let mapoffset = offset
let mapret_elems = map (Var . patElemName) map_pes
let redres_elems = map (Var . patElemName) red_pes
return (mapoffset : redres_elems ++ mapret_elems)
let all_threads red_pes = do
isfirstinsegment <- letExp "isfirstinsegment" =<<
eCmpOp (CmpEq $ IntType Int32)
(eBinOp (SRem Int32) (eSubExp lid) (eSubExp segment_size))
(eSubExp zero)
let red_pes_wflag = PatElem isfirstinsegment (Prim Bool) : red_pes
let red_ts_wflag = Prim Bool : red_ts
combine_red_pes' <- forM red_ts_wflag $ \red_t -> do
pe_name <- newVName "chunk_fold_red"
return $ PatElem pe_name $ red_t `arrayOfRow` group_size
cids <- replicateM (length red_pes_wflag) $ newVName "cid"
addStms $ stmsFromList [ Let (Pattern [] [pe']) (defAux ()) $ Op $
Combine (combineSpace [(cid, group_size)]) [patElemType pe] [] $
Body () mempty [Var $ patElemName pe]
| (cid, pe', pe) <- zip3 cids combine_red_pes' red_pes_wflag ]
scan_red_pes_wflag <- forM red_ts_wflag $ \red_t -> do
pe_name <- newVName "scanned"
return $ PatElem pe_name $ red_t `arrayOfRow` group_size
let scan_red_pes = drop 1 scan_red_pes_wflag
letBind_ (Pattern [] scan_red_pes_wflag) $ Op $
GroupScan group_size flag_reduce_lam' $
zip (false:nes) (map patElemName combine_red_pes')
return scan_red_pes
let normal_thread_part2 scan_red_pes = do
segment_index <- letSubExp "segment_index" =<<
eBinOp (Add Int32)
(eBinOp (SQuot Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size))
(eBinOp (Mul Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_segments_per_group))
islastinsegment <- letExp "islastinseg" =<< eCmpOp (CmpEq $ IntType Int32)
(eBinOp (SRem Int32) (eSubExp lid) (eSubExp segment_size))
(eBinOp (Sub Int32) (eSubExp segment_size) (eSubExp one))
redoffset <- letSubExp "redoffset" =<<
eIf (eSubExp $ Var islastinsegment)
(eBody [eSubExp segment_index])
(mkBodyM mempty [negone])
redret_elems <- fmap (map Var) $ letTupExp "red_return_elem" =<<
eIf (eSubExp $ Var islastinsegment)
(eBody [return $ BasicOp $ Index (patElemName pe) (fullSlice (patElemType pe) [DimFix lid])
| pe <- scan_red_pes])
(mkBodyM mempty nes)
return (redoffset : redret_elems)
let picknchoose = do
is_last_group <- letSubExp "islastgroup" =<<
eCmpOp (CmpEq $ IntType Int32)
(eSubExp $ Var $ spaceGroupId space)
(eBinOp (Sub Int32) (eSubExp num_groups) (eSubExp one))
active_threads_this_group <- letSubExp "active_thread_this_group" =<<
eIf (eSubExp is_last_group)
(eBody [eSubExp active_threads_in_last_group])
(eBody [eSubExp active_threads_per_group])
isactive <- letSubExp "isactive" =<<
eCmpOp (CmpSlt Int32) (eSubExp lid) (eSubExp active_threads_this_group)
(normal_res1, normal_stms1) <- runBinder normal_thread_part1
(wasted_res1, wasted_stms1) <- runBinder wasted_thread_part1
mapoffset_pe <- (`PatElem` i32) <$> newVName "mapoffset"
redtmp_pes <- forM red_ts $ \red_t -> do
pe_name <- newVName "redtmp_res"
return $ PatElem pe_name red_t
map_pes <- forM map_ts $ \map_t -> do
pe_name <- newVName "map_res"
return $ PatElem pe_name map_t
e1 <- eIf (eSubExp isactive)
(mkBodyM normal_stms1 normal_res1)
(mkBodyM wasted_stms1 wasted_res1)
letBind_ (Pattern [] (mapoffset_pe:redtmp_pes++map_pes)) e1
scan_red_pes <- all_threads redtmp_pes
(normal_res2, normal_stms2) <- runBinder $ normal_thread_part2 scan_red_pes
redoffset_pe <- (`PatElem` i32) <$> newVName "redoffset"
red_pes <- forM red_ts $ \red_t -> do
pe_name <- newVName "red_res"
return $ PatElem pe_name red_t
e2 <- eIf (eSubExp isactive)
(mkBodyM normal_stms2 normal_res2)
(mkBodyM mempty (negone : nes))
letBind_ (Pattern [] (redoffset_pe:red_pes)) e2
return $ map (Var . patElemName) $ redoffset_pe:mapoffset_pe:red_pes++map_pes
(redoffset:mapoffset:redmapres, stms) <- runBinder picknchoose
let (finalredvals, finalmapvals) = splitAt num_redres redmapres
red_returns <- forM (zip finalredvals $ take num_redres scratch_arrs) $ \(se, scarr) ->
return $ WriteReturn [num_segments] scarr [([redoffset], se)]
map_returns <- forM (zip finalmapvals $ drop num_redres scratch_arrs) $ \(se, scarr) ->
return $ WriteReturn [w] scarr [([mapoffset], se)]
let kernel_returns = red_returns ++ map_returns
let kerneldebughints = KernelDebugHints kernelname
[ ("num_segment", num_segments)
, ("segment_size", segment_size)
, ("num_groups", num_groups)
, ("group_size", group_size)
, ("num_segments_per_group", num_segments_per_group)
, ("active_threads_per_group", active_threads_per_group)
]
let kernel = Kernel kerneldebughints space kernel_return_types $
KernelBody () stms kernel_returns
return kernel
where
i32 = Prim $ IntType Int32
zero = constant (0 :: Int32)
one = constant (1 :: Int32)
negone = constant (-1 :: Int32)
false = constant False
commname = case comm of Commutative -> "comm"
Noncommutative -> "nocomm"
kernelname = "segmented_redomap__small_" ++ commname
makeOffsetExp index_within_segment segment_index = do
e <- eBinOp (Add Int32)
(eSubExp index_within_segment)
(eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index))
letSubExp "offset" e
addKernelInputStms :: (MonadBinder m, Lore m ~ InKernel) =>
[KernelInput]
-> m ()
addKernelInputStms = mapM_ $ \kin -> do
let pe = PatElem (kernelInputName kin) (kernelInputType kin)
let arr = kernelInputArray kin
arrtp <- lookupType arr
let slice = fullSlice arrtp [DimFix se | se <- kernelInputIndices kin]
letBind (Pattern [] [pe]) $ BasicOp $ Index arr slice
addManualIspaceCalcStms :: (MonadBinder m, Lore m ~ InKernel) =>
SubExp
-> [(VName, SubExp)]
-> m ()
addManualIspaceCalcStms outer_index ispace = do
let calc_ispace_index prev_val (vn,size) = do
let pe = PatElem vn (Prim $ IntType Int32)
letBind_ (Pattern [] [pe]) $ BasicOp $ BinOp (SRem Int32) prev_val size
letSubExp "tmp_val" $ BasicOp $ BinOp (SQuot Int32) prev_val size
foldM_ calc_ispace_index outer_index (reverse ispace)
addFlagToLambda :: (MonadBinder m, Lore m ~ Kernels) =>
[SubExp] -> Lambda InKernel -> m (Lambda InKernel)
addFlagToLambda nes lam = do
let num_accs = length nes
x_flag <- newVName "x_flag"
y_flag <- newVName "y_flag"
let x_flag_param = Param x_flag $ Prim Bool
y_flag_param = Param y_flag $ Prim Bool
(x_params, y_params) = splitAt num_accs $ lambdaParams lam
params = [x_flag_param] ++ x_params ++ [y_flag_param] ++ y_params
body <- runBodyBinder $ localScope (scopeOfLParams params) $ do
new_flag <- letSubExp "new_flag" $
BasicOp $ BinOp LogOr (Var x_flag) (Var y_flag)
lhs <- fmap (map Var) $ letTupExp "seg_lhs" $ If (Var y_flag)
(resultBody nes)
(resultBody $ map (Var . paramName) x_params) $
ifCommon $ map paramType x_params
let rhs = map (Var . paramName) y_params
lam' <- renameLambda lam
res <- eLambda lam' $ map eSubExp $ lhs ++ rhs
return $ resultBody $ new_flag : res
return Lambda { lambdaParams = params
, lambdaBody = body
, lambdaReturnType = Prim Bool : lambdaReturnType lam
}
regularSegmentedScan :: (MonadBinder m, Lore m ~ Kernels) =>
SubExp
-> Pattern Kernels
-> SubExp
-> Lambda InKernel
-> Lambda InKernel
-> [(VName, SubExp)] -> [KernelInput]
-> [SubExp] -> [VName]
-> m ()
regularSegmentedScan segment_size pat w lam map_lam ispace inps nes arrs = do
flags_i <- newVName "flags_i"
unused_flag_array <- newVName "unused_flag_array"
flags_body <-
runBodyBinder $ localScope (M.singleton flags_i $ IndexInfo Int32) $ do
segment_index <- letSubExp "segment_index" $
BasicOp $ BinOp (SRem Int32) (Var flags_i) segment_size
start_of_segment <- letSubExp "start_of_segment" $
BasicOp $ CmpOp (CmpEq int32) segment_index zero
let flag = start_of_segment
return $ resultBody [flag]
(mapk_bnds, mapk) <- mapKernelFromBody w (FlatThreadSpace [(flags_i, w)]) [] [Prim Bool] flags_body
addStms mapk_bnds
flags <- letExp "flags" $ Op mapk
lam' <- addFlagToLambda nes lam
flag_p <- newParam "flag" $ Prim Bool
let map_lam' = map_lam { lambdaParams = flag_p : lambdaParams map_lam
, lambdaBody = (lambdaBody map_lam)
{ bodyResult = Var (paramName flag_p) : bodyResult (lambdaBody map_lam) }
, lambdaReturnType = Prim Bool : lambdaReturnType map_lam
}
let pat' = pat { patternValueElements = PatElem unused_flag_array
(arrayOf (Prim Bool) (Shape [w]) NoUniqueness) :
patternValueElements pat
}
void $ blockedScan pat' w (lam', false:nes) (Commutative, nilFn, mempty) map_lam' segment_size ispace inps (flags:arrs)
where zero = constant (0 :: Int32)
false = constant False