{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.TileLoops.RegTiling3D
( doRegTiling3D )
where
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.List
import Data.Maybe
import Futhark.MonadFreshNames
import Futhark.Representation.Kernels
import Futhark.Tools
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
type TileM = ReaderT (Scope Kernels) (State VNameSource)
type VarianceTable = M.Map VName Names
maxRegTile :: Int32
maxRegTile = 30
mkRegTileSe :: Int32 -> SubExp
mkRegTileSe = constant
doRegTiling3D :: Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
doRegTiling3D (Let pat aux (Op old_kernel))
| Kernel kerhint space kertp (KernelBody () kstms kres) <- old_kernel,
FlatThreadSpace gspace <- spaceStructure space,
initial_variance <- M.map mempty $ scopeOfKernelSpace space,
variance <- varianceInStms initial_variance kstms,
local_tid <- spaceLocalId space,
(_,_) : (_,_) : (gidz,m_M) : _ <- reverse $ spaceDimensions space,
(code1, Just stream_stmt, code2) <- matchCodeStreamCode kstms,
Let pat_strm aux_strm (Op (GroupStream w w0 lam accs arrs)) <- stream_stmt,
not (null accs),
reg_tile <- maxRegTile `quot` fromIntegral (length accs),
reg_tile_se <- mkRegTileSe reg_tile,
w == w0,
arr_chunk_params <- groupStreamArrParams lam,
Just _ <- is3dTileable mempty space variance
arrs arr_chunk_params,
Just arr_tab0 <- foldl (processIndirections $ S.fromList arrs)
(Just M.empty) code1,
ker_res_nms <- mapMaybe retThreadInSpace kres,
length ker_res_nms == length kres,
Pattern [] ker_patels <- pat,
all primType kertp,
all (variantToOuterDim variance gidz) ker_res_nms = do
mm <- newVName "mm"
mask <- newVName "mask"
let mm_stmt = mkInKerIntMulStmt mm (Var gidz) reg_tile_se
let mask_stm= mkLet [] [Ident mask $ Prim int32] $ BasicOp $
BinOp (Shl Int32)
(Constant $ IntValue $ Int32Value 1 )
(Constant $ IntValue $ Int32Value 31)
(arr_tab,trnsp_tab) <- foldM (insertTranspose variance gidz)
(M.empty, M.empty) $ M.toList arr_tab0
let manif_stms = map(\ (a_t, (a,i,tp)) ->
let perm = [i+1..arrayRank tp-1] ++ [0..i]
in mkLet [] [Ident a_t tp] $
BasicOp $ Manifest perm a
) $ M.toList trnsp_tab
(space_stms, space_struct, tiled_group_size, num_threads, num_groups) <-
mkKerSpaceExtraStms reg_tile gspace
let kspace' = space { spaceStructure = space_struct
, spaceGroupSize = tiled_group_size
, spaceNumThreads = num_threads
, spaceNumGroups = num_groups
}
mb_myloop <- translateStreamsToLoop (reg_tile,mask,gidz,m_M,mm,local_tid,tiled_group_size)
variance arr_tab w lam accs arrs $
patternValueElements pat_strm
case mb_myloop of
Nothing -> return Nothing
Just (myloop, strm_res_inv, strm_res_var) -> do
loop_var_res <- forM strm_res_var $ \(PatElem nm attr) -> do
clone_patel_nms <- replicateM (fromIntegral reg_tile) $ newVName $ baseString nm
return $ map (`PatElem` attr) clone_patel_nms
let pat_loop = Pattern [] $ strm_res_inv ++ concat loop_var_res
let stm_loop = Let pat_loop aux_strm myloop
let ker_var_res_patels =
filter (\(r,_) -> variantToOuterDim variance gidz r) $
zip ker_res_nms ker_patels
(ker_var_res, ker_var_patels) = unzip ker_var_res_patels
(code2_var, code2_inv) =
partition (variantToOuterDim variance gidz . patElemName .
head . patternValueElements . stmPattern) code2
scratch_nms_stms <- mapM mkScratchStm ker_var_patels
let (scratch_nms, scratch_stms) = unzip scratch_nms_stms
loop_var_nms_tr = transpose $ map (map patElemName) loop_var_res
strm_var_nms = map patElemName strm_res_var
(ip_out_nms, unrolled_code) <-
foldM (cloneVarCode2 mm space strm_var_nms ker_var_res_patels code2_var)
(scratch_nms, []) $ zip [0..reg_tile-1] loop_var_nms_tr
let ker_res_ip_tp_tab = M.fromList $ zip ker_var_res $ zip ip_out_nms $
map patElemType ker_var_patels
(kres', kertp') = unzip $
zipWith (\ r tp -> case M.lookup r ker_res_ip_tp_tab of
Nothing -> (ThreadsReturn ThreadsInSpace (Var r), tp)
Just (ip_nm, ip_tp) -> (KernelInPlaceReturn ip_nm, ip_tp)
) ker_res_nms kertp
kstms' = stmsFromList $ mask_stm : mm_stmt : stm_loop : (code2_inv ++ unrolled_code)
ker_body = KernelBody () kstms' kres'
new_ker = Op $ Kernel kerhint kspace' kertp' ker_body
extra_stms = space_stms <> stmsFromList (scratch_stms ++ manif_stms)
return $ Just (extra_stms, Let pat aux new_ker)
where
processIndirections :: S.Set VName
-> Maybe (M.Map VName (VName, Slice SubExp, Type))
-> Stm InKernel
-> Maybe (M.Map VName (VName, Slice SubExp, Type))
processIndirections arrs acc (Let patt _ (BasicOp (Index arr_nm slc))) =
case (acc, patternValueElements patt) of
(Nothing, _) -> Nothing
(Just tab, [p]) -> do
let (p_nm, p_tp) = (patElemName p, patElemType p)
case (S.member p_nm arrs, p_tp) of
(True, Array _ (Shape [_]) _) ->
Just $ M.insert p_nm (arr_nm,slc,p_tp) tab
_ -> Nothing
(_, _) -> Nothing
processIndirections _ _ _ = Nothing
insertTranspose :: VarianceTable -> VName
-> (M.Map VName (VName, Slice SubExp, Type), M.Map VName (VName,Int,Type))
-> (VName, (VName, Slice SubExp, Type))
-> TileM (M.Map VName (VName, Slice SubExp, Type), M.Map VName (VName,Int,Type))
insertTranspose variance gidz (tab, trnsp) (p_nm, (arr_nm,slc,p_tp)) =
case findIndex (variantSliceDim variance gidz) slc of
Nothing -> return (M.insert p_nm (arr_nm,slc,p_tp) tab, trnsp)
Just i -> do
arr_tp <- lookupType arr_nm
arr_tr_nm <- newVName $ baseString arr_nm ++ "_transp"
let tab' = M.insert p_nm (arr_tr_nm,slc,p_tp) tab
let trnsp' = M.insert arr_tr_nm (arr_nm, i, arr_tp) trnsp
return (tab', trnsp')
variantSliceDim :: VarianceTable -> VName -> DimIndex SubExp -> Bool
variantSliceDim variance gidz (DimFix (Var vnm)) = variantToOuterDim variance gidz vnm
variantSliceDim _ _ _ = False
mkInKerIntMulStmt :: VName -> SubExp -> SubExp -> Stm InKernel
mkInKerIntMulStmt res_nm0 op1_se op2_se =
mkLet [] [Ident res_nm0 $ Prim int32] $
BasicOp $ BinOp (Mul Int32) op1_se op2_se
retThreadInSpace (ThreadsReturn ThreadsInSpace (Var r)) = Just r
retThreadInSpace _ = Nothing
doRegTiling3D _ = return Nothing
translateStreamsToLoop :: (Int32,VName,VName,SubExp,VName,VName,SubExp) ->
VarianceTable ->
M.Map VName (VName, Slice SubExp, Type) ->
SubExp -> GroupStreamLambda InKernel ->
[SubExp] -> [VName] -> [PatElem InKernel]
-> TileM (Maybe (Exp InKernel, [PatElem InKernel], [PatElem InKernel]))
translateStreamsToLoop (reg_tile, mask,gidz,m_M,mm,local_tid, group_size) variance
arr_tab w_o lam_o accs_o_p arrs_o_p strm_ress
|
accs_o_f <- groupStreamAccParams lam_o,
arrs_o_f <- groupStreamArrParams lam_o,
[Let _ _ (Op (GroupStream _ ct1i32 lam_i accs_i_p arrs_i_p))] <-
stmsToList $ bodyStms $ groupStreamLambdaBody lam_o,
ct1i32 == (Constant $ IntValue $ Int32Value 1),
accs_i_f <- groupStreamAccParams lam_i,
arrs_i_f <- groupStreamArrParams lam_i,
and $ zipWith (==) (map subExpVar accs_i_p) (map (Just . paramName) accs_o_f),
and $ zipWith (==) arrs_i_p $ map paramName arrs_o_f,
all (primType . paramType) accs_o_f,
loop_ind_nm <- groupStreamChunkOffset lam_i,
body_i <- groupStreamLambdaBody lam_i,
arr_tab' <- foldl (\ tab (a_o_p, a_o_f, a_i_p, a_i_f) ->
case (paramName a_o_f == a_i_p, M.lookup a_o_p tab) of
(True, Just info) -> M.insert (paramName a_i_f) info tab
_ -> tab
) arr_tab $ zip4 arrs_o_p arrs_o_f arrs_i_p arrs_i_f,
accs_i_f' <- map translParamToFParam accs_i_f,
(invar_out_stms, var_ind_stms, var_out_stms) <-
foldl (\ (acc_inv, acc_inds, acc_var) stmt ->
let nm = patElemName $ head $ patternValueElements $ stmPattern stmt
in if not $ variantToOuterDim variance gidz nm
then (stmt : acc_inv,acc_inds,acc_var)
else case stmt of
Let _ _ (BasicOp (Index arr_nm [DimFix _])) ->
case M.lookup arr_nm arr_tab' of
Just _ -> (acc_inv,stmt:acc_inds,acc_var)
Nothing -> (acc_inv,acc_inds,stmt:acc_var)
_ -> (acc_inv,acc_inds,stmt:acc_var)
) ([],[],[]) $ reverse $ stmsToList $ bodyStms body_i,
var_nms <- concatMap (patternNames . stmPattern) var_out_stms,
null $ S.intersection (S.fromList var_nms) $
S.unions (map freeInStm var_ind_stms),
loop_ini_vs <- subExpVars accs_o_p,
all (not . variantToOuterDim variance gidz) loop_ini_vs,
loop_res0 <- bodyResult body_i,
loop_res <- subExpVars loop_res0,
length loop_res == length loop_res0 = do
let (loop_var_p_i_r, loop_inv_p_i_r) =
partition (\(_,_,r,_) -> variantToOuterDim variance gidz r) $
zip4 accs_i_f' accs_o_p loop_res strm_ress
inv_stms0 <- mapM (transfInvIndStm arr_tab' loop_ind_nm) invar_out_stms
let inv_stms = concat inv_stms0
m <- newVName "m"
ind_stms0 <- foldM (transfVarIndStm arr_tab' (reg_tile,loop_ind_nm,local_tid,group_size,m,m_M))
(Just ([],M.empty)) $ reverse var_ind_stms
case ind_stms0 of
Nothing -> return Nothing
Just (ind_stms, subst_tab) -> do
let m_stmt = mkLet [] [Ident m $ Prim int32] $
BasicOp $ BinOp (Add Int32) (Var mm) (Var local_tid)
tab_z_m_comb = M.insert gidz m M.empty
ind_stms' = m_stmt : map (substituteNames tab_z_m_comb) ind_stms
let loop_var_p_i_r' = map (\(x,y,z,_)->(x,y,z)) loop_var_p_i_r
if_ress <- mapM (cloneVarStms subst_tab (mask,loop_ind_nm,mm,m_M,gidz)
loop_var_p_i_r' var_out_stms) [0..reg_tile-1]
let (if_stmt_clones0, var_ress_pars) = unzip if_ress
if_stmt_clones = concat if_stmt_clones0
(_, var_ini, _, strm_var_res) = unzip4 loop_var_p_i_r
var_inis = concat $ replicate (fromIntegral reg_tile) var_ini
(var_ress, var_pars) = unzip $ concat var_ress_pars
(inv_pars, inv_inis, inv_ress, strm_inv_res) = unzip4 loop_inv_p_i_r
loop_form_acc = inv_pars ++ var_pars
loop_inis_acc = inv_inis ++ var_inis
loop_ress = inv_ress ++ var_ress
ind_bar <- newVName "loop_ind"
let bar_stmt = mkLet [] [Ident loop_ind_nm $ Prim int32] $ Op (Barrier [Var ind_bar])
stms_body_i' = bar_stmt : inv_stms ++ ind_stms' ++ if_stmt_clones
form = ForLoop ind_bar Int32 w_o []
body_i' = Body (bodyAttr body_i)
(stmsFromList stms_body_i') $
map Var loop_ress
myloop = DoLoop [] (zip loop_form_acc loop_inis_acc) form body_i'
free_in_body = freeInBody body_i'
elim_vars = S.fromList $ arrs_i_p ++ arrs_o_p ++
map paramName arrs_i_f ++
map paramName accs_o_f
if null $ S.intersection free_in_body elim_vars
then return $ Just (myloop, strm_inv_res, strm_var_res)
else return Nothing
translateStreamsToLoop _ _ _ _ _ _ _ _ = return Nothing
cloneVarStms :: M.Map VName (VName,Type) -> (VName, VName, VName, SubExp, VName)
-> [(FParam InKernel, SubExp, VName)] -> [Stm InKernel]
-> Int32 -> TileM ([Stm InKernel], [(VName,FParam InKernel)])
cloneVarStms subst_tab (mask,loop_ind,mm,m_M,gidz) loop_info var_out_stms i = do
let (loop_par_origs, loop_inis, body_res_origs) = unzip3 loop_info
body_res_clones <- mapM (\x -> newVName $ baseString x ++ "_clone") body_res_origs
loop_par_nm_clones <- mapM (\x -> newVName $ baseString (paramName x) ++ "_clone") loop_par_origs
m <- newVName "m"
z <- newVName "zero"
ii<- newVName "unroll_ct"
let loop_par_clones = zipWith (\ p nm -> p { paramName = nm })
loop_par_origs loop_par_nm_clones
res_types = map paramType loop_par_origs
i_se = Constant $ IntValue $ Int32Value i
stmt_zero = mkLet [] [Ident z $ Prim int32] $
BasicOp $ BinOp (And Int32) (Var mask) (Var loop_ind)
stmt_ii = mkLet [] [Ident ii $ Prim int32] $
BasicOp $ BinOp (Add Int32) (Var z) i_se
m_stmt_other =
mkLet [] [Ident m $ Prim int32] $
BasicOp $ BinOp (Add Int32) (Var mm) (Var ii)
read_sh_stms =
map (\ (scal,(sh_arr, el_tp)) ->
mkLet [] [Ident scal el_tp] $
BasicOp $ Index sh_arr [DimFix i_se]
) $ M.toList subst_tab
tab_z_m_other = foldl (\tab (old,new) -> M.insert (paramName old) new tab)
(M.insert gidz m M.empty) $
zip loop_par_origs loop_par_nm_clones
var_out_stms' = map (substituteNames tab_z_m_other) $
read_sh_stms ++ var_out_stms
cond_nm <- newVName "out3_inbounds"
let simple = all simpleStm var_out_stms
let cond_stm = if simple
then mkLet [] [Ident cond_nm $ Prim Bool] $
BasicOp $ SubExp (Constant $ BoolValue True)
else mkCondStmt m_M m cond_nm
then_body <- renameBody $ Body () (stmsFromList var_out_stms') (map Var body_res_origs)
let else_body = Body () mempty loop_inis
if_stmt = mkLet [] (zipWith Ident body_res_clones res_types) $
If (Var cond_nm) then_body else_body $
IfAttr (staticShapes res_types) IfFallback
return ( [stmt_zero, stmt_ii, m_stmt_other, cond_stm, if_stmt]
, zip body_res_clones loop_par_clones )
mkCondStmt :: SubExp -> VName -> VName -> Stm InKernel
mkCondStmt m_M m cond_nm =
mkLet [] [Ident cond_nm $ Prim Bool] $
BasicOp $ CmpOp (CmpSlt Int32) (Var m) m_M
simpleStm :: Stm InKernel -> Bool
simpleStm (Let _ _ e) = safeExp e
mkScratchStm :: PatElem Kernels -> TileM (VName, Stm Kernels)
mkScratchStm ker_patel = do
let (unique_arr_tp, res_arr_nm0) = (patElemType ker_patel, patElemName ker_patel)
ptp = elemType unique_arr_tp
scrtch_arr_nm <- newVName $ baseString res_arr_nm0 ++ "_0"
let scratch_stm = mkLet [] [Ident scrtch_arr_nm unique_arr_tp] $
BasicOp $ Scratch ptp $ arrayDims unique_arr_tp
return (scrtch_arr_nm, scratch_stm)
cloneVarCode2 :: VName -> KernelSpace -> [VName]
-> [(VName, PatElem InKernel)] -> [Stm InKernel]
-> ([VName], [Stm InKernel]) -> (Int32, [VName])
-> TileM ([VName], [Stm InKernel])
cloneVarCode2 mm space strm_res_nms keres_patels code2_var
(ip_arr_nms, unroll_code) (k, loop_res_nms) = do
let (ker_nms, pat_els) = unzip keres_patels
arr_tps = map patElemType pat_els
root_strs = map (baseString . patElemName) pat_els
ip_inn_nms <- mapM (\s -> newVName $ s ++ "_inn_" ++ pretty (k+1)) root_strs
ip_out_nms <- mapM (\s -> newVName $ s ++ "_out_" ++ pretty (k+1)) root_strs
m <- newVName "m"
let (gidx,_) : (gidy,_) : (gidz,m_M) : rev_outer_dims = reverse $ spaceDimensions space
(outer_dims, _) = unzip $ reverse rev_outer_dims
ip_stmts = map (mkInPlaceStmt (outer_dims++[m,gidy,gidx])) $
zip4 ip_arr_nms ip_inn_nms ker_nms arr_tps
cond_nm <- newVName "m_cond"
let i_se = Constant $ IntValue $ Int32Value k
m_stm = mkLet [] [Ident m $ Prim int32] $
BasicOp $ BinOp (Add Int32) (Var mm) i_se
c_stm = mkCondStmt m_M m cond_nm
else_body = Body () mempty (map Var ip_arr_nms)
strm_loop_tab = M.fromList $ (gidz, m) : zip strm_res_nms loop_res_nms
then_stms = stmsFromList $ map (substituteNames strm_loop_tab) $
code2_var ++ ip_stmts
then_body <- renameBody $ Body () then_stms $ map Var ip_inn_nms
let if_stm = mkLet [] (zipWith Ident ip_out_nms arr_tps) $
If (Var cond_nm) then_body else_body $
IfAttr (staticShapes arr_tps) IfFallback
return (ip_out_nms, unroll_code ++ [m_stm, c_stm, if_stm])
where mkInPlaceStmt :: [VName] -> (VName, VName, VName, Type)
-> Stm InKernel
mkInPlaceStmt inds (cur_nm, new_nm, ker_nm, arr_tp) =
let upd_slc = map (DimFix . Var) inds
ipupd_exp = BasicOp $ Update cur_nm upd_slc (Var ker_nm)
in mkLet [] [Ident new_nm arr_tp] ipupd_exp
helper3Stms :: VName -> SubExp -> SubExp -> Slice SubExp
-> VName -> Stm InKernel -> TileM [Stm InKernel]
helper3Stms loop_ind strd beg par_slc par_arr (Let ptt att _) = do
tmp1 <- newVName "tmp"
tmp2 <- newVName "ind"
let stmt1 = mkLet [] [Ident tmp1 $ Prim int32] $
BasicOp $ BinOp (Mul Int32) (Var loop_ind) strd
stmt2 = mkLet [] [Ident tmp2 $ Prim int32] $
BasicOp $ BinOp (Add Int32) beg (Var tmp1)
ndims = length par_slc
ind_exp = BasicOp (Index par_arr (take (ndims-1) par_slc ++ [DimFix $ Var tmp2]))
stmt3 = Let ptt att ind_exp
return [stmt1,stmt2,stmt3]
transfInvIndStm :: M.Map VName (VName, Slice SubExp, Type)
-> VName -> Stm InKernel
-> TileM [Stm InKernel]
transfInvIndStm tab loop_ind stm@(Let _ _ (BasicOp (Index arr_nm [DimFix _])))
| Just (par_arr, par_slc@(_:_), _) <- M.lookup arr_nm tab,
DimSlice beg _ strd <- last par_slc =
helper3Stms loop_ind strd beg par_slc par_arr stm
transfInvIndStm _ _ stm = return [stm]
transfVarIndStm :: M.Map VName (VName, Slice SubExp, Type)
-> (Int32,VName,VName,SubExp,VName,SubExp)
-> Maybe ([Stm InKernel],M.Map VName (VName,Type))
-> Stm InKernel
-> TileM (Maybe ([Stm InKernel],M.Map VName (VName,Type)))
transfVarIndStm tab (reg_tile,loop_ind,local_tid,group_size,m,m_M) acc
stm@(Let ptt _ (BasicOp (Index arr_nm [DimFix _])))
| Just (tstms,stab) <- acc,
Just (par_arr, par_slc@(_:_), _) <- M.lookup arr_nm tab,
DimSlice beg _ strd <- last par_slc,
[pat_el] <- patternValueElements ptt,
el_tp <- patElemType pat_el,
pat_el_nm <- patElemName pat_el,
Prim _ <- el_tp = do
stms3 <- helper3Stms loop_ind strd beg par_slc par_arr stm
let glb_ind_stms = stmsFromList stms3
sh_arr_1d <- newVName $ baseString par_arr ++ "_sh_1d"
cid <- newVName "cid"
let block_cspace = combineSpace [(cid,group_size)]
comb_exp = Op $ Combine block_cspace [el_tp]
[(local_tid, mkRegTileSe reg_tile), (m,m_M)] $
Body () glb_ind_stms [Var pat_el_nm]
sh_arr_pe = PatElem sh_arr_1d $
arrayOfShape el_tp $ Shape [group_size]
write_sh_arr_stmt =
Let (Pattern [] [sh_arr_pe]) (defAux ()) comb_exp
return $ Just (write_sh_arr_stmt:tstms, M.insert pat_el_nm (sh_arr_1d,el_tp) stab)
transfVarIndStm _ _ _ _ = return Nothing
translParamToFParam :: LParam InKernel -> FParam InKernel
translParamToFParam = fmap (`toDecl` Nonunique)
matchCodeStreamCode :: Stms InKernel ->
([Stm InKernel], Maybe (Stm InKernel), [Stm InKernel])
matchCodeStreamCode kstms =
foldl (\acc stmt ->
case (acc,stmt) of
( (cd1,Nothing,cd2), Let _ _ (Op GroupStream{})) ->
(cd1, Just stmt, cd2)
( (cd1, Nothing, cd2), _) -> (cd1++[stmt], Nothing, cd2)
( (cd1,Just strm,cd2), _) -> (cd1,Just strm,cd2++[stmt])
) ([],Nothing,[]) (stmsToList kstms)
is3dTileable :: Names -> KernelSpace -> VarianceTable -> [VName]
-> [LParam InKernel] -> Maybe [Int]
is3dTileable branch_variant kspace variance arrs block_params =
let ok1 = all (primType . rowType . paramType) block_params
inner_perm0 = map variantOnlyToOneOfThreeInnerDims arrs
inner_perm = catMaybes inner_perm0
ok2 = elem 0 inner_perm && elem 1 inner_perm && elem 2 inner_perm
ok3 = length inner_perm0 == length inner_perm
ok = ok1 && ok2 && ok3
in if ok then Just inner_perm else Nothing
where variantOnlyToOneOfThreeInnerDims :: VName -> Maybe Int
variantOnlyToOneOfThreeInnerDims arr = do
(k,_) : (j,_) : (i,_) : _ <- Just $ reverse $ spaceDimensions kspace
let variant_to = M.findWithDefault mempty arr variance
branch_invariant = not $ S.member k branch_variant ||
S.member j branch_variant ||
S.member i branch_variant
if not branch_invariant
then Nothing
else if i `S.member` variant_to && not (j `S.member` variant_to) && not (k `S.member` variant_to) then Just 0
else if not (i `S.member` variant_to) && j `S.member` variant_to && not (k `S.member` variant_to) then Just 1
else if not (i `S.member` variant_to) && not (j `S.member` variant_to) && k `S.member` variant_to then Just 2
else Nothing
mkKerSpaceExtraStms :: Int32 -> [(VName, SubExp)]
-> TileM (Stms Kernels, SpaceStructure, SubExp, SubExp, SubExp)
mkKerSpaceExtraStms reg_tile gspace = do
dim_z_nm <- newVName "gidz_range"
tmp <- newVName "tmp"
let tmp_stm = mkLet [] [Ident tmp $ Prim int32] $
BasicOp $ BinOp (Add Int32) m_M $
Constant $ IntValue $ Int32Value (reg_tile-1)
rgz_stm = mkLet [] [Ident dim_z_nm $ Prim int32] $
BasicOp $ BinOp (SQuot Int32) (Var tmp) $
Constant $ IntValue $ Int32Value reg_tile
(gidx,sz_x) : (gidy,sz_y) : (gidz,m_M) : untiled_gspace = reverse gspace
((tile_size_x, tile_size_y, tiled_group_size), tile_size_bnds) <- runBinder $ do
tile_size_key <- nameFromString . pretty <$> newVName "tile_size"
tile_ct_size <- letSubExp "tile_size" $ Op $ GetSize tile_size_key SizeTile
tile_size_x <- letSubExp "tile_size_x" $ BasicOp $
BinOp (SMin Int32) tile_ct_size sz_x
tile_size_y <- letSubExp "tile_size_y" $ BasicOp $
BinOp (SMin Int32) tile_ct_size sz_y
tiled_group_size <- letSubExp "tiled_group_size" $
BasicOp $ BinOp (Mul Int32) tile_size_x tile_size_y
return (tile_size_x, tile_size_y, tiled_group_size)
untiled_gspace' <- fmap reverse $ forM (reverse untiled_gspace) $ \(gtid,gdim) -> do
ltid <- newVName "ltid"
return (gtid, gdim, ltid, constant (1::Int32))
ltidz <- newVName "ltid"
let dim_z = (gidz, Var dim_z_nm, ltidz, constant (1::Int32))
ltidy <- newVName "ltid"
let dim_y = (gidy, sz_y, ltidy, tile_size_y)
ltidx <- newVName "ltid"
let dim_x = (gidx, sz_x, ltidx, tile_size_x)
gspace' = reverse $ dim_x : dim_y : dim_z : untiled_gspace'
((num_threads, num_groups), num_bnds) <-
runBinder $ sufficientGroups gspace' tiled_group_size
let extra_stms = oneStm tmp_stm <> oneStm rgz_stm <> tile_size_bnds <> num_bnds
return ( extra_stms, NestedThreadSpace gspace'
, tiled_group_size, num_threads, num_groups )
variantToOuterDim :: VarianceTable -> VName -> VName -> Bool
variantToOuterDim variance gid_outer nm =
gid_outer == nm || gid_outer `S.member` M.findWithDefault mempty nm variance
varianceInStms :: VarianceTable -> Stms InKernel -> VarianceTable
varianceInStms = foldl varianceInStm
varianceInStm :: VarianceTable -> Stm InKernel -> VarianceTable
varianceInStm v0 bnd@(Let _ _ (Op (GroupStream _ _ lam accs arrs))) =
let v = defVarianceInStm v0 bnd
acc_lam_f = groupStreamAccParams lam
arr_lam_f = groupStreamArrParams lam
bdy_lam = groupStreamLambdaBody lam
stm_lam = bodyStms bdy_lam
v' = foldl' (\vacc (v_a, v_f) ->
let vrc = S.insert v_a $ M.findWithDefault mempty v_a vacc
in M.insert v_f vrc vacc
) v $ zip arrs $ map paramName arr_lam_f
v''= foldl' (\vacc (v_se, v_f) ->
case v_se of
Var v_a ->
let vrc = S.insert v_a $ M.findWithDefault mempty v_a vacc
in M.insert v_f vrc vacc
Constant _ -> vacc
) v' $ zip accs $ map paramName acc_lam_f
in varianceInStms v'' stm_lam
varianceInStm variance bnd = defVarianceInStm variance bnd
defVarianceInStm :: VarianceTable -> Stm InKernel -> VarianceTable
defVarianceInStm variance bnd =
foldl' add variance $ patternNames $ stmPattern bnd
where add variance' v = M.insert v binding_variance variance'
look variance' v = S.insert v $ M.findWithDefault mempty v variance'
binding_variance = mconcat $ map (look variance) $ S.toList (freeInStm bnd)
sufficientGroups :: MonadBinder m =>
[(VName, SubExp, VName, SubExp)] -> SubExp
-> m (SubExp, SubExp)
sufficientGroups gspace group_size = do
groups_in_dims <- forM gspace $ \(_, gd, _, ld) ->
letSubExp "groups_in_dim" =<< eDivRoundingUp Int32 (eSubExp gd) (eSubExp ld)
num_groups <- letSubExp "num_groups" =<<
foldBinOp (Mul Int32) (constant (1::Int32)) groups_in_dims
num_threads <- letSubExp "num_threads" $
BasicOp $ BinOp (Mul Int32) num_groups group_size
return (num_threads, num_groups)