{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.TileLoops
( tileLoops )
where
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Sequence as Seq
import qualified Data.Map.Strict as M
import Data.List
import Prelude hiding (quot)
import Futhark.MonadFreshNames
import Futhark.Representation.Kernels
import Futhark.Transform.Rename
import Futhark.Pass
import Futhark.Tools
tileLoops :: Pass Kernels Kernels
tileLoops = Pass "tile loops" "Tile stream loops inside kernels" $
fmap Prog . mapM optimiseFunDef . progFunctions
optimiseFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
optimiseFunDef fundec = do
body' <- modifyNameSource $ runState $
runReaderT m (scopeOfFParams (funDefParams fundec))
return fundec { funDefBody = body' }
where m = optimiseBody $ funDefBody fundec
type TileM = ReaderT (Scope Kernels) (State VNameSource)
optimiseBody :: Body Kernels -> TileM (Body Kernels)
optimiseBody (Body () bnds res) = localScope (scopeOf bnds) $
Body () <$> (mconcat <$> mapM optimiseStm (stmsToList bnds)) <*> pure res
optimiseStm :: Stm Kernels -> TileM (Stms Kernels)
optimiseStm (Let pat aux (Op (SegOp (SegMap lvl@SegThread{} space ts kbody)))) = do
(host_stms, (lvl', space', kbody')) <- tileInKernelBody mempty initial_variance lvl space ts kbody
return $ host_stms <>
oneStm (Let pat aux $ Op $ SegOp $ SegMap lvl' space' ts kbody')
where initial_variance = M.map mempty $ scopeOfSegSpace space
optimiseStm (Let pat aux e) =
pure <$> (Let pat aux <$> mapExpM optimise e)
where optimise = identityMapper { mapOnBody = \scope -> localScope scope . optimiseBody }
tileInKernelBody :: Names -> VarianceTable
-> SegLevel -> SegSpace -> [Type] -> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody branch_variant initial_variance lvl initial_kspace ts kbody
| Just kbody_res <- mapM isSimpleResult $ kernelBodyResult kbody = do
maybe_tiled <-
tileInBody branch_variant initial_variance lvl initial_kspace ts $
Body () (kernelBodyStms kbody) kbody_res
case maybe_tiled of
Just (host_stms, tiling, tiledBody) -> do
(res', stms') <-
runBinder $ mapM (tilingTileReturns tiling) =<< tiledBody mempty
return (host_stms, (tilingLevel tiling,
tilingSpace tiling,
KernelBody () stms' res'))
Nothing ->
return (mempty, (lvl, initial_kspace, kbody))
| otherwise =
return (mempty, (lvl, initial_kspace, kbody))
where isSimpleResult (Returns se) = Just se
isSimpleResult _ = Nothing
tileInBody :: Names -> VarianceTable
-> SegLevel -> SegSpace -> [Type] -> Body Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody branch_variant initial_variance initial_lvl initial_space res_ts (Body () initial_kstms stms_res) =
descend mempty $ stmsToList initial_kstms
where
variance = varianceInStms initial_variance initial_kstms
descend _ [] =
return Nothing
descend prestms (stm_to_tile:poststms)
| (gtid, kdim) : top_space_rev <- reverse $ unSegSpace initial_space,
Just (w, arrs, form) <- tileable stm_to_tile,
all (not . nameIn gtid .
flip (M.findWithDefault mempty) variance) arrs,
not $ gtid `nameIn` branch_variant,
(prestms', poststms') <-
preludeToPostlude variance prestms stm_to_tile (stmsFromList poststms) =
Just . injectPrelude initial_space variance prestms' (freeIn stm_to_tile) <$>
tileGeneric (tiling1d $ reverse top_space_rev)
initial_lvl res_ts (stmPattern stm_to_tile)
gtid kdim
w form (zip arrs $ repeat [0]) poststms' stms_res
| (gtids, kdims) <- unzip $ unSegSpace initial_space,
Just (w, arrs, form) <- tileable stm_to_tile,
Just inner_perm <- mapM (invariantToOneOfTwoInnerDims branch_variant variance gtids) arrs,
gtid_y : gtid_x : top_gtids_rev <- reverse gtids,
kdim_y : kdim_x : top_kdims_rev <- reverse kdims,
(prestms', poststms') <-
preludeToPostlude variance prestms stm_to_tile (stmsFromList poststms) =
Just . injectPrelude initial_space variance prestms' (freeIn stm_to_tile) <$>
tileGeneric (tiling2d $ reverse $ zip top_gtids_rev top_kdims_rev)
initial_lvl res_ts (stmPattern stm_to_tile)
(gtid_x, gtid_y) (kdim_x, kdim_y)
w form (zip arrs inner_perm) poststms' stms_res
| DoLoop [] merge (ForLoop i it bound []) loopbody <- stmExp stm_to_tile,
(prestms', poststms') <-
preludeToPostlude variance prestms stm_to_tile (stmsFromList poststms)= do
let branch_variant' =
branch_variant <>
mconcat (map (flip (M.findWithDefault mempty) variance)
(namesToList (freeIn bound)))
merge_params = map fst merge
maybe_tiled <-
localScope (M.insert i (IndexInfo it) $ scopeOfFParams merge_params) $
tileInBody branch_variant' variance initial_lvl initial_space
(map paramType merge_params) $ mkBody (bodyStms loopbody) (bodyResult loopbody)
case maybe_tiled of
Nothing -> next
Just tiled ->
Just <$> tileDoLoop initial_space variance prestms'
(freeIn loopbody <> freeIn merge) tiled
res_ts (stmPattern stm_to_tile) (stmAux stm_to_tile)
merge i it bound poststms' stms_res
| otherwise = next
where next = localScope (scopeOf stm_to_tile) $
descend (prestms <> oneStm stm_to_tile) poststms
preludeToPostlude :: VarianceTable
-> Stms Kernels -> Stm Kernels -> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude variance prelude stm_to_tile postlude =
(prelude_used, prelude_not_used <> postlude)
where used_in_tiled = freeIn stm_to_tile
used_in_stm_variant =
(used_in_tiled<>) $ mconcat $
map (flip (M.findWithDefault mempty) variance) $
namesToList used_in_tiled
used stm = any (`nameIn` used_in_stm_variant) $
patternNames $ stmPattern stm
(prelude_used, prelude_not_used) =
Seq.partition used prelude
partitionPrelude :: VarianceTable -> Stms Kernels -> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude variance prestms tiled_kdims =
(invariant_prestms, precomputed_variant_prestms, recomputed_variant_prestms)
where
invariantTo names stm =
case patternNames (stmPattern stm) of
[] -> True
v:_ -> not $ any (`nameIn` names) $ namesToList $
M.findWithDefault mempty v variance
(invariant_prestms, variant_prestms) =
Seq.partition (invariantTo tiled_kdims) prestms
mustBeInlinedExp (BasicOp (Index _ slice)) = not $ null $ sliceDims slice
mustBeInlinedExp (BasicOp Rotate{}) = True
mustBeInlinedExp (BasicOp Rearrange{}) = True
mustBeInlinedExp (BasicOp Reshape{}) = True
mustBeInlinedExp _ = False
mustBeInlined = mustBeInlinedExp . stmExp
must_be_inlined = namesFromList $ concatMap (patternNames . stmPattern) $
stmsToList $ Seq.filter mustBeInlined variant_prestms
recompute stm =
any (`nameIn` must_be_inlined) (patternNames (stmPattern stm)) ||
not (invariantTo must_be_inlined stm)
(recomputed_variant_prestms, precomputed_variant_prestms) =
Seq.partition recompute variant_prestms
injectPrelude :: SegSpace -> VarianceTable
-> Stms Kernels -> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude initial_space variance prestms used (host_stms, tiling, tiledBody) =
(host_stms, tiling, tiledBody')
where tiled_kdims = namesFromList $ map fst $
filter (`notElem` unSegSpace (tilingSpace tiling)) $
unSegSpace initial_space
tiledBody' privstms = do
let (invariant_prestms,
precomputed_variant_prestms,
recomputed_variant_prestms) =
partitionPrelude variance prestms tiled_kdims
addStms invariant_prestms
let live_set = namesToList $ liveSet precomputed_variant_prestms $
used <> freeIn recomputed_variant_prestms
prelude_arrs <- inScopeOf precomputed_variant_prestms $
doPrelude tiling precomputed_variant_prestms live_set
let prelude_privstms =
PrivStms recomputed_variant_prestms $
mkReadPreludeValues prelude_arrs live_set
tiledBody (prelude_privstms <> privstms)
tileDoLoop :: SegSpace -> VarianceTable
-> Stms Kernels -> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type] -> Pattern Kernels -> StmAux (ExpAttr Kernels)
-> [(FParam Kernels, SubExp)] -> VName -> IntType -> SubExp
-> Stms Kernels -> Result
-> TileM (Stms Kernels, Tiling, TiledBody)
tileDoLoop initial_space variance prestms used_in_body (host_stms, tiling, tiledBody) res_ts pat aux merge i it bound poststms poststms_res = do
let (invariant_prestms,
precomputed_variant_prestms,
recomputed_variant_prestms) =
partitionPrelude variance prestms tiled_kdims
let (mergeparams, mergeinits) = unzip merge
tileDim t = arrayOf t (tilingTileShape tiling) $ uniqueness t
tiledBody' privstms = inScopeOf host_stms $ do
addStms invariant_prestms
let live_set = namesToList $ liveSet precomputed_variant_prestms used_in_body
prelude_arrs <- inScopeOf precomputed_variant_prestms $
doPrelude tiling precomputed_variant_prestms live_set
mergeparams' <- forM mergeparams $ \(Param pname pt) ->
Param <$> newVName (baseString pname ++ "_group") <*> pure (tileDim pt)
let merge_ts = map paramType mergeparams
let inloop_privstms =
PrivStms recomputed_variant_prestms $
mkReadPreludeValues prelude_arrs live_set
mergeinit' <-
fmap (map Var) $ certifying (stmAuxCerts aux) $
tilingSegMap tiling "tiled_loopinit" (scalarLevel tiling) $
\in_bounds slice ->
fmap (map Var) $ protectOutOfBounds "loopinit" in_bounds merge_ts $ do
addPrivStms slice inloop_privstms
addPrivStms slice privstms
return mergeinits
let merge' = zip mergeparams' mergeinit'
let indexMergeParams slice =
localScope (scopeOfFParams mergeparams') $
forM_ (zip mergeparams mergeparams') $ \(to, from) ->
letBindNames_ [paramName to] $ BasicOp $ Index (paramName from) $
fullSlice (paramType from) slice
loopbody' <- runBodyBinder $ resultBody . map Var <$>
tiledBody (privstms <> inloop_privstms <> PrivStms mempty indexMergeParams)
accs' <- letTupExp "tiled_inside_loop" $
DoLoop [] merge' (ForLoop i it bound []) loopbody'
postludeGeneric tiling privstms pat accs' poststms poststms_res res_ts
return (host_stms, tiling, tiledBody')
where tiled_kdims = namesFromList $ map fst $
filter (`notElem` unSegSpace (tilingSpace tiling)) $
unSegSpace initial_space
doPrelude :: Tiling -> Stms Kernels -> [VName] -> Binder Kernels [VName]
doPrelude tiling prestms prestms_live =
tilingSegMap tiling "prelude" (scalarLevel tiling) $
\in_bounds _slice -> do
ts <- mapM lookupType prestms_live
fmap (map Var) $ letTupExp "pre" =<<
eIf (toExp in_bounds)
(do addStms prestms
resultBodyM $ map Var prestms_live)
(eBody $ map eBlank ts)
liveSet :: FreeIn a => Stms Kernels -> a -> Names
liveSet stms after =
namesFromList (concatMap (patternNames . stmPattern) stms) `namesIntersection`
freeIn after
tileable :: Stm Kernels
-> Maybe (SubExp, [VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
tileable stm
| Op (OtherOp (Screma w form arrs)) <- stmExp stm,
Just (reds, map_lam) <- isRedomapSOAC form,
Reduce red_comm red_lam red_nes <- singleReduce reds,
lambdaReturnType map_lam == lambdaReturnType red_lam,
not $ null arrs,
all primType $ lambdaReturnType map_lam,
all (primType . paramType) $ lambdaParams map_lam =
Just (w, arrs, (red_comm, red_lam, red_nes, map_lam))
| otherwise =
Nothing
data PrivStms = PrivStms (Stms Kernels) ReadPrelude
privStms :: Stms Kernels -> PrivStms
privStms stms = PrivStms stms $ const $ return ()
addPrivStms :: Slice SubExp -> PrivStms -> Binder Kernels ()
addPrivStms local_slice (PrivStms stms readPrelude) = do
readPrelude local_slice
addStms stms
instance Semigroup PrivStms where
PrivStms stms_x readPrelude_x <> PrivStms stms_y readPrelude_y =
PrivStms stms_z readPrelude_z
where stms_z = stms_x <> stms_y
readPrelude_z slice = readPrelude_x slice >> readPrelude_y slice
instance Monoid PrivStms where
mempty = privStms mempty
type ReadPrelude = Slice SubExp -> Binder Kernels ()
data Tiling =
Tiling
{ tilingSegMap :: String -> SegLevel
-> (PrimExp VName -> Slice SubExp -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
, tilingReadTile :: TileKind -> PrivStms
-> SubExp -> [(VName, [Int])]
-> Binder Kernels [VName]
, tilingProcessTile :: PrivStms
-> Commutativity -> Lambda Kernels -> Lambda Kernels
-> [(VName, [Int])] -> [VName]
-> Binder Kernels [VName]
, tilingProcessResidualTile :: PrivStms
-> Commutativity -> Lambda Kernels -> Lambda Kernels
-> SubExp -> [VName] -> SubExp
-> [(VName, [Int])]
-> Binder Kernels [VName]
, tilingTileReturns :: VName -> Binder Kernels KernelResult
, tilingSpace :: SegSpace
, tilingTileShape :: Shape
, tilingLevel :: SegLevel
, tilingNumWholeTiles :: SubExp
}
type DoTiling gtids kdims =
SegLevel -> gtids -> kdims -> SubExp -> Binder Kernels Tiling
scalarLevel :: Tiling -> SegLevel
scalarLevel tiling =
SegThreadScalar (segNumGroups lvl) (segGroupSize lvl) SegNoVirt
where lvl = tilingLevel tiling
protectOutOfBounds :: String -> PrimExp VName -> [Type] -> Binder Kernels [SubExp]
-> Binder Kernels [VName]
protectOutOfBounds desc in_bounds ts m =
letTupExp desc =<< eIf (toExp in_bounds) (resultBody <$> m) (eBody $ map eBlank ts)
postludeGeneric :: Tiling -> PrivStms
-> Pattern Kernels -> [VName]
-> Stms Kernels -> Result -> [Type]
-> Binder Kernels [VName]
postludeGeneric tiling privstms pat accs' poststms poststms_res res_ts =
tilingSegMap tiling "thread_res" (scalarLevel tiling) $ \in_bounds slice -> do
forM_ (zip (patternNames pat) accs') $ \(us, everyone) ->
letBindNames_ [us] $ BasicOp $ Index everyone slice
if poststms == mempty
then return poststms_res
else fmap (map Var) $ protectOutOfBounds "postlude" in_bounds res_ts $ do
addPrivStms slice privstms
addStms poststms
return poststms_res
type TiledBody = PrivStms -> Binder Kernels [VName]
tileGeneric :: DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels)
-> [(VName, [Int])]
-> Stms Kernels -> Result
-> TileM (Stms Kernels, Tiling, TiledBody)
tileGeneric doTiling initial_lvl res_ts pat gtids kdims w form arrs_and_perms poststms poststms_res = do
(tiling, tiling_stms) <- runBinder $ doTiling initial_lvl gtids kdims w
return (tiling_stms, tiling, tiledBody tiling)
where
(red_comm, red_lam, red_nes, map_lam) = form
tiledBody :: Tiling -> PrivStms -> Binder Kernels [VName]
tiledBody tiling privstms = do
let num_whole_tiles = tilingNumWholeTiles tiling
tile_shape = tilingTileShape tiling
mergeinits <- tilingSegMap tiling "mergeinit" (scalarLevel tiling) $ \in_bounds slice ->
if freeIn red_nes == mempty
then return red_nes
else fmap (map Var) $ protectOutOfBounds "neutral" in_bounds (lambdaReturnType red_lam) $ do
addPrivStms slice privstms
return red_nes
merge <- forM (zip (lambdaParams red_lam) mergeinits) $ \(p, mergeinit) ->
(,) <$>
newParam (baseString (paramName p) ++ "_merge")
(paramType p `arrayOfShape` tile_shape `toDecl` Unique) <*>
pure (Var mergeinit)
tile_id <- newVName "tile_id"
let loopform = ForLoop tile_id Int32 num_whole_tiles []
loopbody <- renameBody <=< runBodyBinder $ inScopeOf loopform $
localScope (scopeOfFParams $ map fst merge) $ do
tile <- tilingReadTile tiling TileFull privstms (Var tile_id) arrs_and_perms
resultBody . map Var <$>
tilingProcessTile tiling privstms
red_comm red_lam map_lam
(zip tile (map snd arrs_and_perms)) (map (paramName . fst) merge)
accs <- letTupExp "accs" $ DoLoop [] merge loopform loopbody
red_lam' <- renameLambda red_lam
map_lam' <- renameLambda map_lam
accs' <- tilingProcessResidualTile tiling privstms
red_comm red_lam' map_lam'
num_whole_tiles accs w arrs_and_perms
postludeGeneric tiling privstms pat accs' poststms poststms_res res_ts
data TileKind = TilePartial | TileFull
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues prestms_live_arrs prestms_live slice =
fmap mconcat $ forM (zip prestms_live_arrs prestms_live) $ \(arr, v) -> do
arr_t <- lookupType arr
letBindNames_ [v] $ BasicOp $ Index arr $ fullSlice arr_t slice
tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Binder Kernels KernelResult
tileReturns dims_on_top dims arr = do
let unit_dims = replicate (length dims_on_top) (intConst Int32 1)
arr' <- if null dims_on_top then return arr
else do arr_t <- lookupType arr
let new_shape = unit_dims ++ arrayDims arr_t
letExp (baseString arr) $ BasicOp $ Reshape (map DimNew new_shape) arr
let tile_dims = zip (map snd dims_on_top) unit_dims ++ dims
return $ TileReturns tile_dims arr'
segMap1D :: String
-> SegLevel
-> (VName -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap1D desc lvl f = do
ltid <- newVName "ltid"
ltid_flat <- newVName "ltid_flat"
let space = SegSpace ltid_flat [(ltid, unCount $ segGroupSize lvl)]
((ts, res), stms) <- runBinder $ do
res <- f ltid
ts <- mapM subExpType res
return (ts, res)
Body _ stms' res' <- renameBody $ mkBody stms res
letTupExp desc $ Op $ SegOp $
SegMap lvl space ts $ KernelBody () stms' $ map Returns res'
reconstructGtids1D :: Count GroupSize SubExp -> VName -> VName -> VName
-> Binder Kernels ()
reconstructGtids1D group_size gtid gid ltid =
letBindNames_ [gtid] =<<
toExp (LeafExp gid int32 *
primExpFromSubExp int32 (unCount group_size) +
LeafExp ltid int32)
readTile1D :: SubExp -> VName -> VName
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> TileKind -> PrivStms
-> SubExp
-> [(VName, [Int])]
-> Binder Kernels [VName]
readTile1D
tile_size gid gtid num_groups group_size
kind privstms tile_id arrs_and_perms =
segMap1D "full_tile" (SegThread num_groups group_size SegNoVirt) $ \ltid -> do
j <- letSubExp "j" =<<
toExp (primExpFromSubExp int32 tile_id *
primExpFromSubExp int32 tile_size +
LeafExp ltid int32)
reconstructGtids1D group_size gtid gid ltid
addPrivStms [DimFix $ Var ltid] privstms
let arrs = map fst arrs_and_perms
arr_ts <- mapM lookupType arrs
let tile_ts = map rowType arr_ts
w = arraysSize 0 arr_ts
let readTileElem arr =
letExp "tile_elem" $ BasicOp $ Index arr [DimFix j]
fmap (map Var) $
case kind of
TilePartial ->
letTupExp "pre" =<< eIf (toExp $ primExpFromSubExp int32 j .<.
primExpFromSubExp int32 w)
(resultBody <$> mapM (fmap Var . readTileElem) arrs)
(eBody $ map eBlank tile_ts)
TileFull ->
mapM readTileElem arrs
processTile1D :: VName -> VName -> SubExp -> SubExp
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> PrivStms
-> Commutativity -> Lambda Kernels -> Lambda Kernels
-> [(VName, [Int])] -> [VName]
-> Binder Kernels [VName]
processTile1D
gid gtid kdim tile_size num_groups group_size
privstms
red_comm red_lam map_lam tiles_and_perm accs = do
let tile = map fst tiles_and_perm
segMap1D "acc" (SegThreadScalar num_groups group_size SegNoVirt) $ \ltid -> do
reconstructGtids1D group_size gtid gid ltid
addPrivStms [DimFix $ Var ltid] privstms
thread_accs <- forM accs $ \acc ->
letSubExp "acc" $ BasicOp $ Index acc [DimFix $ Var ltid]
let form' = redomapSOAC [Reduce red_comm red_lam thread_accs] map_lam
fmap (map Var) $
letTupExp "acc" =<< eIf (toExp $ LeafExp gtid int32 .<. primExpFromSubExp int32 kdim)
(eBody [pure $ Op $ OtherOp $ Screma tile_size form' tile])
(resultBodyM thread_accs)
processResidualTile1D :: VName -> VName -> SubExp -> SubExp
-> Count NumGroups SubExp -> Count GroupSize SubExp -> PrivStms
-> Commutativity -> Lambda Kernels -> Lambda Kernels
-> SubExp -> [VName] -> SubExp -> [(VName, [Int])]
-> Binder Kernels [VName]
processResidualTile1D
gid gtid kdim tile_size num_groups group_size privstms red_comm red_lam map_lam
num_whole_tiles accs w arrs_and_perms = do
residual_input <- letSubExp "residual_input" $
BasicOp $ BinOp (SRem Int32) w tile_size
letTupExp "acc_after_residual" =<<
eIf (toExp $ primExpFromSubExp int32 residual_input .==. 0)
(resultBodyM $ map Var accs)
(nonemptyTile residual_input)
where
nonemptyTile residual_input = runBodyBinder $ do
full_tile <- readTile1D tile_size gid gtid num_groups group_size
TilePartial privstms num_whole_tiles arrs_and_perms
tile <- forM full_tile $ \tile ->
letExp "partial_tile" $ BasicOp $ Index tile
[DimSlice (intConst Int32 0) residual_input (intConst Int32 1)]
resultBody . map Var <$> processTile1D
gid gtid kdim residual_input num_groups group_size privstms
red_comm red_lam map_lam (zip tile $ repeat [0]) accs
tiling1d :: [(VName,SubExp)] -> DoTiling VName SubExp
tiling1d dims_on_top initial_lvl gtid kdim w = do
gid <- newVName "gid"
gid_flat <- newVName "gid_flat"
(lvl, space) <-
if null dims_on_top
then return (SegGroup (segNumGroups initial_lvl) (segGroupSize initial_lvl) $ segVirt initial_lvl,
SegSpace gid_flat [(gid, unCount $ segNumGroups initial_lvl)])
else do
group_size <- letSubExp "computed_group_size" $
BasicOp $ BinOp (SMin Int32) (unCount (segGroupSize initial_lvl)) kdim
ldim <- letSubExp "ldim" =<<
eDivRoundingUp Int32 (eSubExp kdim) (eSubExp group_size)
num_groups <- letSubExp "computed_num_groups" =<<
foldBinOp (Mul Int32) ldim (map snd dims_on_top)
return (SegGroup (Count num_groups) (Count group_size) SegNoVirt,
SegSpace gid_flat $ dims_on_top ++ [(gid, ldim)])
let tile_size = unCount $ segGroupSize lvl
num_whole_tiles <- letSubExp "num_whole_tiles" $ BasicOp $ BinOp (SQuot Int32) w tile_size
return Tiling
{ tilingSegMap = \desc lvl' f -> segMap1D desc lvl' $ \ltid -> do
letBindNames_ [gtid] =<<
toExp (LeafExp gid int32 * primExpFromSubExp int32 tile_size +
LeafExp ltid int32)
f (LeafExp gtid int32 .<. primExpFromSubExp int32 kdim)
[DimFix $ Var ltid]
, tilingReadTile =
readTile1D tile_size gid gtid (segNumGroups lvl) (segGroupSize lvl)
, tilingProcessTile =
processTile1D gid gtid kdim tile_size (segNumGroups lvl) (segGroupSize lvl)
, tilingProcessResidualTile =
processResidualTile1D gid gtid kdim tile_size (segNumGroups lvl) (segGroupSize lvl)
, tilingTileReturns = tileReturns dims_on_top [(kdim, tile_size)]
, tilingTileShape = Shape [tile_size]
, tilingNumWholeTiles = num_whole_tiles
, tilingLevel = lvl
, tilingSpace = space
}
invariantToOneOfTwoInnerDims :: Names -> M.Map VName Names -> [VName] -> VName
-> Maybe [Int]
invariantToOneOfTwoInnerDims branch_variant variance dims arr = do
j : i : _ <- Just $ reverse dims
let variant_to = M.findWithDefault mempty arr variance
branch_invariant = not $ nameIn j branch_variant || nameIn i branch_variant
if branch_invariant && i `nameIn` variant_to && not (j `nameIn` variant_to) then
Just [0,1]
else if branch_invariant && j `nameIn` variant_to && not (i `nameIn` variant_to) then
Just [1,0]
else
Nothing
segMap2D :: String
-> SegLevel -> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D desc lvl (dim_x, dim_y) f = do
ltid_x <- newVName "ltid_x"
ltid_y <- newVName "ltid_y"
ltid_flat <- newVName "ltid_flat"
let space = SegSpace ltid_flat [(ltid_x, dim_x), (ltid_y, dim_y)]
((ts, res), stms) <- runBinder $ do
res <- f (ltid_x, ltid_y)
ts <- mapM subExpType res
return (ts, res)
Body _ stms' res' <- renameBody $ mkBody stms res
letTupExp desc $ Op $ SegOp $
SegMap lvl space ts $ KernelBody () stms' $ map Returns res'
reconstructGtids2D :: SubExp -> (VName, VName) -> (VName, VName) -> (VName, VName)
-> Binder Kernels ()
reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y) = do
letBindNames_ [gtid_x] =<<
toExp (LeafExp gid_x int32 * primExpFromSubExp int32 tile_size +
LeafExp ltid_x int32)
letBindNames_ [gtid_y] =<<
toExp (LeafExp gid_y int32 * primExpFromSubExp int32 tile_size +
LeafExp ltid_y int32)
readTile2D :: (SubExp, SubExp) -> (VName, VName) -> (VName, VName) -> SubExp
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> TileKind -> PrivStms -> SubExp
-> [(VName, [Int])]
-> Binder Kernels [VName]
readTile2D (kdim_x, kdim_y) (gtid_x, gtid_y) (gid_x, gid_y) tile_size num_groups group_size kind privstms tile_id arrs_and_perms =
segMap2D "full_tile" (SegThread num_groups group_size SegNoVirt) (tile_size, tile_size) $ \(ltid_x, ltid_y) -> do
i <- letSubExp "i" =<<
toExp (primExpFromSubExp int32 tile_id *
primExpFromSubExp int32 tile_size +
LeafExp ltid_x int32)
j <- letSubExp "j" =<<
toExp (primExpFromSubExp int32 tile_id *
primExpFromSubExp int32 tile_size +
LeafExp ltid_y int32)
reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y)
addPrivStms [DimFix $ Var ltid_x, DimFix $ Var ltid_y] privstms
let (arrs, perms) = unzip arrs_and_perms
arr_ts <- mapM lookupType arrs
let tile_ts = map rowType arr_ts
w = arraysSize 0 arr_ts
let readTileElem arr perm =
letExp "tile_elem" $ BasicOp $ Index arr
[DimFix $ last $ rearrangeShape perm [i,j]]
readTileElemIfInBounds (tile_t, arr, perm) = do
let idx = last $ rearrangeShape perm [i,j]
othercheck = last $ rearrangeShape perm
[ LeafExp gtid_y int32 .<. primExpFromSubExp int32 kdim_y
, LeafExp gtid_x int32 .<. primExpFromSubExp int32 kdim_x
]
eIf (toExp $
primExpFromSubExp int32 idx .<. primExpFromSubExp int32 w .&&. othercheck)
(eBody [return $ BasicOp $ Index arr [DimFix idx]])
(eBody [eBlank tile_t])
fmap (map Var) $
case kind of
TilePartial ->
mapM (letExp "pre" <=< readTileElemIfInBounds) (zip3 tile_ts arrs perms)
TileFull ->
zipWithM readTileElem arrs perms
processTile2D :: (VName, VName) -> (VName, VName) -> (SubExp, SubExp) -> SubExp
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> PrivStms
-> Commutativity -> Lambda Kernels -> Lambda Kernels
-> [(VName,[Int])] -> [VName]
-> Binder Kernels [VName]
processTile2D
(gid_x, gid_y) (gtid_x, gtid_y) (kdim_x, kdim_y) tile_size num_groups group_size
privstms red_comm red_lam map_lam tiles_and_perms accs = do
actual_tile_size <- arraysSize 0 <$> mapM (lookupType . fst) tiles_and_perms
segMap2D "acc" (SegThreadScalar num_groups group_size SegNoVirt) (tile_size, tile_size) $ \(ltid_x, ltid_y) -> do
reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y)
addPrivStms [DimFix $ Var ltid_x, DimFix $ Var ltid_y] privstms
thread_accs <- forM accs $ \acc ->
letSubExp "acc" $ BasicOp $ Index acc [DimFix $ Var ltid_x, DimFix $ Var ltid_y]
let form' = redomapSOAC [Reduce red_comm red_lam thread_accs] map_lam
tiles' <- forM tiles_and_perms $ \(tile, perm) -> do
tile_t <- lookupType tile
letExp "tile" $ BasicOp $ Index tile $ sliceAt tile_t (head perm)
[DimFix $ Var $ head $ rearrangeShape perm [ltid_x, ltid_y]]
fmap (map Var) $
letTupExp "acc" =<< eIf (toExp $
LeafExp gtid_x int32 .<. primExpFromSubExp int32 kdim_x .&&.
LeafExp gtid_y int32 .<. primExpFromSubExp int32 kdim_y)
(eBody [pure $ Op $ OtherOp $ Screma actual_tile_size form' tiles'])
(resultBodyM thread_accs)
processResidualTile2D :: (VName, VName) -> (VName, VName) -> (SubExp, SubExp) -> SubExp
-> Count NumGroups SubExp -> Count GroupSize SubExp -> PrivStms
-> Commutativity -> Lambda Kernels -> Lambda Kernels
-> SubExp -> [VName] -> SubExp -> [(VName, [Int])]
-> Binder Kernels [VName]
processResidualTile2D
gids gtids kdims tile_size num_groups group_size privstms red_comm red_lam map_lam
num_whole_tiles accs w arrs_and_perms = do
residual_input <- letSubExp "residual_input" $
BasicOp $ BinOp (SRem Int32) w tile_size
letTupExp "acc_after_residual" =<<
eIf (toExp $ primExpFromSubExp int32 residual_input .==. 0)
(resultBodyM $ map Var accs)
(nonemptyTile residual_input)
where
nonemptyTile residual_input = renameBody <=< runBodyBinder $ do
full_tile <- readTile2D kdims gtids gids tile_size num_groups group_size
TilePartial privstms num_whole_tiles arrs_and_perms
tile <- forM full_tile $ \tile ->
letExp "partial_tile" $ BasicOp $ Index tile
[DimSlice (intConst Int32 0) residual_input (intConst Int32 1),
DimSlice (intConst Int32 0) residual_input (intConst Int32 1)]
resultBody . map Var <$>
processTile2D gids gtids kdims tile_size num_groups group_size
privstms red_comm red_lam map_lam
(zip tile (map snd arrs_and_perms)) accs
tiling2d :: [(VName,SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d dims_on_top _initial_lvl (gtid_x, gtid_y) (kdim_x, kdim_y) w = do
gid_x <- newVName "gid_x"
gid_y <- newVName "gid_y"
tile_size_key <- nameFromString . pretty <$> newVName "tile_size"
tile_size <- letSubExp "tile_size" $ Op $ GetSize tile_size_key SizeTile
group_size <- letSubExp "group_size" $ BasicOp $ BinOp (Mul Int32) tile_size tile_size
num_groups_x <- letSubExp "num_groups_x" =<<
eDivRoundingUp Int32 (eSubExp kdim_x) (eSubExp tile_size)
num_groups_y <- letSubExp "num_groups_y" =<<
eDivRoundingUp Int32 (eSubExp kdim_y) (eSubExp tile_size)
num_groups <- letSubExp "num_groups_top" =<<
foldBinOp (Mul Int32) num_groups_x
(num_groups_y : map snd dims_on_top)
gid_flat <- newVName "gid_flat"
let lvl = SegGroup (Count num_groups) (Count group_size) SegNoVirt
space = SegSpace gid_flat $
dims_on_top ++ [(gid_x, num_groups_x), (gid_y, num_groups_y)]
num_whole_tiles <- letSubExp "num_whole_tiles" $
BasicOp $ BinOp (SQuot Int32) w tile_size
return Tiling
{ tilingSegMap = \desc lvl' f -> segMap2D desc lvl' (tile_size, tile_size) $ \(ltid_x, ltid_y) -> do
reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y)
f (LeafExp gtid_x int32 .<. primExpFromSubExp int32 kdim_x .&&.
LeafExp gtid_y int32 .<. primExpFromSubExp int32 kdim_y)
[DimFix $ Var ltid_x, DimFix $ Var ltid_y]
, tilingReadTile = readTile2D (kdim_x, kdim_y) (gtid_x, gtid_y) (gid_x, gid_y) tile_size (segNumGroups lvl) (segGroupSize lvl)
, tilingProcessTile = processTile2D (gid_x, gid_y) (gtid_x, gtid_y) (kdim_x, kdim_y) tile_size (segNumGroups lvl) (segGroupSize lvl)
, tilingProcessResidualTile = processResidualTile2D (gid_x, gid_y) (gtid_x, gtid_y) (kdim_x, kdim_y) tile_size (segNumGroups lvl) (segGroupSize lvl)
, tilingTileReturns = tileReturns dims_on_top [(kdim_x, tile_size), (kdim_y, tile_size)]
, tilingTileShape = Shape [tile_size, tile_size]
, tilingNumWholeTiles = num_whole_tiles
, tilingLevel = lvl
, tilingSpace = space
}
type VarianceTable = M.Map VName Names
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms = foldl varianceInStm
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm variance bnd =
foldl' add variance $ patternNames $ stmPattern bnd
where add variance' v = M.insert v binding_variance variance'
look variance' v = oneName v <> M.findWithDefault mempty v variance'
binding_variance = mconcat $ map (look variance) $ namesToList (freeIn bnd)