{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.TileLoops
( tileLoops )
where
import Control.Applicative
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.List
import Data.Maybe
import Futhark.MonadFreshNames
import Futhark.Representation.Kernels
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (mapAccumLM)
import Futhark.Optimise.TileLoops.RegTiling3D
tileLoops :: Pass Kernels Kernels
tileLoops = Pass "tile loops" "Tile stream loops inside kernels" $
fmap Prog . mapM optimiseFunDef . progFunctions
optimiseFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
optimiseFunDef fundec = do
body' <- modifyNameSource $ runState $
runReaderT m (scopeOfFParams (funDefParams fundec))
return fundec { funDefBody = body' }
where m = optimiseBody $ funDefBody fundec
type TileM = ReaderT (Scope Kernels) (State VNameSource)
optimiseBody :: Body Kernels -> TileM (Body Kernels)
optimiseBody (Body () bnds res) = localScope (scopeOf bnds) $
Body () <$> (mconcat <$> mapM optimiseStm (stmsToList bnds)) <*> pure res
optimiseStm :: Stm Kernels -> TileM (Stms Kernels)
optimiseStm stmt@(Let pat aux (Op old_kernel@(Kernel desc space ts body))) = do
res3dtiling <- doRegTiling3D stmt
case res3dtiling of
Just (extra_bnds, stmt') -> return $ extra_bnds <> oneStm stmt'
Nothing -> do
(extra_bnds, space', body') <- tileInKernelBody mempty initial_variance space body
let new_kernel = Kernel desc space' ts body'
if kernelType old_kernel == kernelType new_kernel
then return $ extra_bnds <> oneStm (Let pat aux $ Op new_kernel)
else return $ oneStm $ Let pat aux $ Op old_kernel
where initial_variance = M.map mempty $ scopeOfKernelSpace space
optimiseStm (Let pat aux e) =
pure <$> (Let pat aux <$> mapExpM optimise e)
where optimise = identityMapper { mapOnBody = const optimiseBody }
tileInKernelBody :: Names -> VarianceTable
-> KernelSpace -> KernelBody InKernel
-> TileM (Stms Kernels, KernelSpace, KernelBody InKernel)
tileInKernelBody branch_variant initial_variance initial_kspace (KernelBody () kstms kres) = do
(extra_bnds, kspace', kstms') <-
tileInStms branch_variant initial_variance initial_kspace kstms
return (extra_bnds, kspace', KernelBody () kstms' kres)
tileInBody :: Names -> VarianceTable
-> KernelSpace -> Body InKernel
-> TileM (Stms Kernels, KernelSpace, Body InKernel)
tileInBody branch_variant initial_variance initial_kspace (Body () stms res) = do
(extra_bnds, kspace', stms') <-
tileInStms branch_variant initial_variance initial_kspace stms
return (extra_bnds, kspace', Body () stms' res)
tileInStms :: Names -> VarianceTable
-> KernelSpace -> Stms InKernel
-> TileM (Stms Kernels, KernelSpace, Stms InKernel)
tileInStms branch_variant initial_variance initial_kspace kstms = do
((kspace, extra_bndss), kstms') <-
mapAccumLM tileInKernelStatement (initial_kspace,mempty) $ stmsToList kstms
return (extra_bndss, kspace, stmsFromList kstms')
where variance = varianceInStms initial_variance kstms
tileInKernelStatement (kspace, extra_bnds)
(Let pat attr (Op (GroupStream w max_chunk lam accs arrs)))
| max_chunk == w,
not $ null arrs,
chunk_size <- Var $ groupStreamChunkSize lam,
arr_chunk_params <- groupStreamArrParams lam,
maybe_1d_tiles <-
zipWith (is1dTileable branch_variant kspace variance chunk_size) arrs arr_chunk_params,
maybe_1_5d_tiles <-
zipWith (is1_5dTileable branch_variant kspace variance chunk_size) arrs arr_chunk_params,
Just mk_tilings <-
zipWithM (<|>) maybe_1d_tiles maybe_1_5d_tiles = do
(kspaces, arr_chunk_params', tile_kstms) <- unzip3 <$> sequence mk_tilings
let (kspace', kspace_bnds) =
case kspaces of
[] -> (kspace, mempty)
new_kspace : _ -> new_kspace
Body () lam_kstms lam_res <- syncAtEnd $ groupStreamLambdaBody lam
let lam_kstms' = mconcat tile_kstms <> lam_kstms
group_size = spaceGroupSize kspace
lam' = lam { groupStreamLambdaBody = Body () lam_kstms' lam_res
, groupStreamArrParams = arr_chunk_params'
}
return ((kspace', extra_bnds <> kspace_bnds),
Let pat attr $ Op $ GroupStream w group_size lam' accs arrs)
tileInKernelStatement (kspace, extra_bnds)
(Let pat attr (Op (GroupStream w max_chunk lam accs arrs)))
| w == max_chunk,
not $ null arrs,
FlatThreadSpace gspace <- spaceStructure kspace,
chunk_size <- Var $ groupStreamChunkSize lam,
arr_chunk_params <- groupStreamArrParams lam,
Just mk_tilings <-
zipWithM (is2dTileable branch_variant kspace variance chunk_size)
arrs arr_chunk_params = do
((tile_size, tiled_group_size), tile_size_bnds) <- runBinder $ do
tile_size_key <- nameFromString . pretty <$> newVName "tile_size"
tile_size <- letSubExp "tile_size" $ Op $ GetSize tile_size_key SizeTile
tiled_group_size <- letSubExp "tiled_group_size" $
BasicOp $ BinOp (Mul Int32) tile_size tile_size
return (tile_size, tiled_group_size)
let (tiled_gspace,untiled_gspace) = splitAt 2 $ reverse gspace
untiled_gspace' <- fmap reverse $ forM (reverse untiled_gspace) $ \(gtid,gdim) -> do
ltid <- newVName "ltid"
return (gtid,gdim,
ltid, constant (1::Int32))
tiled_gspace' <- fmap reverse $ forM (reverse tiled_gspace) $ \(gtid,gdim) -> do
ltid <- newVName "ltid"
return (gtid,gdim,
ltid, tile_size)
let gspace' = reverse $ tiled_gspace' ++ untiled_gspace'
((num_threads, num_groups), num_bnds) <-
runBinder $ sufficientGroups gspace' tiled_group_size
let kspace' = kspace { spaceStructure = NestedThreadSpace gspace'
, spaceGroupSize = tiled_group_size
, spaceNumThreads = num_threads
, spaceNumGroups = num_groups
}
local_ids = map (\(_, _, ltid, _) -> ltid) gspace'
(arr_chunk_params', tile_kstms) <-
fmap unzip $ forM mk_tilings $ \mk_tiling ->
mk_tiling tile_size local_ids
Body () lam_kstms lam_res <- syncAtEnd $ groupStreamLambdaBody lam
let lam_kstms' = mconcat tile_kstms <> lam_kstms
lam' = lam { groupStreamLambdaBody = Body () lam_kstms' lam_res
, groupStreamArrParams = arr_chunk_params'
}
return ((kspace', extra_bnds <> tile_size_bnds <> num_bnds),
Let pat attr $ Op $ GroupStream w tile_size lam' accs arrs)
tileInKernelStatement (kspace, extra_bnds)
(Let pat attr (Op (GroupStream w maxchunk lam accs arrs))) = do
let branch_variant' = branch_variant <>
fromMaybe mempty (flip M.lookup variance =<< subExpVar w)
(bnds, kspace', lam') <- tileInStreamLambda branch_variant' variance kspace lam
return ((kspace', extra_bnds <> bnds),
Let pat attr $ Op $ GroupStream w maxchunk lam' accs arrs)
tileInKernelStatement acc stm =
return (acc, stm)
tileInStreamLambda :: Names -> VarianceTable -> KernelSpace -> GroupStreamLambda InKernel
-> TileM (Stms Kernels, KernelSpace, GroupStreamLambda InKernel)
tileInStreamLambda branch_variant variance kspace lam = do
(bnds, kspace', kbody') <-
tileInBody branch_variant variance' kspace $ groupStreamLambdaBody lam
return (bnds, kspace', lam { groupStreamLambdaBody = kbody' })
where variance' = varianceInStms variance $
bodyStms $ groupStreamLambdaBody lam
is1dTileable :: MonadFreshNames m =>
Names -> KernelSpace -> VarianceTable -> SubExp -> VName -> LParam InKernel
-> Maybe (m ((KernelSpace, Stms Kernels),
LParam InKernel,
Stms InKernel))
is1dTileable branch_variant kspace variance block_size arr block_param = do
guard $ S.null $ M.findWithDefault mempty arr variance
guard $ S.null branch_variant
guard $ primType $ rowType $ paramType block_param
return $ do
(outer_block_param, kstms) <- tile1d kspace block_size block_param
return ((kspace, mempty), outer_block_param, kstms)
is1_5dTileable :: (MonadFreshNames m, HasScope Kernels m) =>
Names -> KernelSpace -> VarianceTable
-> SubExp -> VName -> LParam InKernel
-> Maybe (m ((KernelSpace, Stms Kernels),
LParam InKernel,
Stms InKernel))
is1_5dTileable branch_variant kspace variance block_size arr block_param = do
guard $ primType $ rowType $ paramType block_param
(inner_gtid, inner_gdim) <- invariantToInnermostDimension
mk_structure <-
case spaceStructure kspace of
NestedThreadSpace{} -> Nothing
FlatThreadSpace gtids_and_gdims ->
return $ do
let n_dims = length gtids_and_gdims
outer <- forM (take (n_dims-1) gtids_and_gdims) $ \(gtid, gdim) -> do
ltid <- newVName "ltid"
return (gtid, gdim, ltid, gdim)
inner_ltid <- newVName "inner_ltid"
inner_ldim <- newVName "inner_ldim"
let compute_tiled_group_size =
mkLet [] [Ident inner_ldim $ Prim int32] $
BasicOp $ BinOp (SMin Int32) (spaceGroupSize kspace) inner_gdim
structure = NestedThreadSpace $ outer ++ [(inner_gtid, inner_gdim,
inner_ltid, Var inner_ldim)]
((num_threads, num_groups), num_bnds) <- runBinder $ do
threads_necessary <-
letSubExp "threads_necessary" =<<
foldBinOp (Mul Int32)
(constant (1::Int32)) (map snd gtids_and_gdims)
groups_necessary <-
letSubExp "groups_necessary" =<<
eDivRoundingUp Int32 (eSubExp threads_necessary) (eSubExp $ Var inner_ldim)
num_threads <-
letSubExp "num_threads" $
BasicOp $ BinOp (Mul Int32) groups_necessary (Var inner_ldim)
return (num_threads, groups_necessary)
let kspace' = kspace { spaceGroupSize = Var inner_ldim
, spaceNumGroups = num_groups
, spaceNumThreads = num_threads
, spaceStructure = structure
}
return (oneStm compute_tiled_group_size <> num_bnds,
kspace')
return $ do
(outer_block_param, kstms) <- tile1d kspace block_size block_param
(structure_bnds, kspace') <- mk_structure
return ((kspace', structure_bnds), outer_block_param, kstms)
where invariantToInnermostDimension :: Maybe (VName, SubExp)
invariantToInnermostDimension =
case reverse $ spaceDimensions kspace of
(i,d) : _
| not $ i `S.member` M.findWithDefault mempty arr variance,
not $ i `S.member` branch_variant -> Just (i,d)
_ -> Nothing
tile1d :: MonadFreshNames m =>
KernelSpace
-> SubExp
-> LParam InKernel
-> m (LParam InKernel, Stms InKernel)
tile1d kspace block_size block_param = do
outer_block_param <- do
name <- newVName $ baseString (paramName block_param) ++ "_outer"
return block_param { paramName = name }
let ltid = spaceLocalId kspace
read_elem_bnd <- do
name <- newVName $ baseString (paramName outer_block_param) ++ "_elem"
return $
mkLet [] [Ident name $ rowType $ paramType outer_block_param] $
BasicOp $ Index (paramName outer_block_param) [DimFix $ Var ltid]
cid <- newVName "cid"
let block_cspace = combineSpace [(cid, block_size)]
block_pe =
PatElem (paramName block_param) $ paramType outer_block_param
write_block_stms =
[ Let (Pattern [] [block_pe]) (defAux ()) $ Op $
Combine block_cspace [patElemType pe] [] $
Body () (oneStm read_elem_bnd) [Var $ patElemName pe]
| pe <- patternElements $ stmPattern read_elem_bnd ]
return (outer_block_param, stmsFromList write_block_stms)
is2dTileable :: MonadFreshNames m =>
Names -> KernelSpace -> VarianceTable -> SubExp -> VName -> LParam InKernel
-> Maybe (SubExp -> [VName] -> m (LParam InKernel, Stms InKernel))
is2dTileable branch_variant kspace variance block_size arr block_param = do
guard $ primType $ rowType $ paramType block_param
pt <- case rowType $ paramType block_param of
Prim pt -> return pt
_ -> Nothing
inner_perm <- invariantToOneOfTwoInnerDims
Just $ \tile_size local_is -> do
let num_outer = length local_is - 2
perm = [0..num_outer-1] ++ map (+num_outer) inner_perm
invariant_i : variant_i : _ = reverse $ rearrangeShape perm local_is
(global_i,global_d):_ = rearrangeShape inner_perm $ drop num_outer $ spaceDimensions kspace
outer_block_param <- do
name <- newVName $ baseString (paramName block_param) ++ "_outer"
return block_param { paramName = name }
elem_name <- newVName $ baseString (paramName outer_block_param) ++ "_elem"
let read_elem_bnd = mkLet [] [Ident elem_name $ Prim pt] $
BasicOp $ Index (paramName outer_block_param) $
fullSlice (paramType outer_block_param) [DimFix $ Var invariant_i]
cids <- replicateM (length local_is - num_outer) $ newVName "cid"
let block_size_2d = Shape $ rearrangeShape inner_perm [tile_size, block_size]
block_cspace = combineSpace $ zip cids $
rearrangeShape inner_perm [tile_size,block_size]
block_name_2d <- newVName $ baseString (paramName block_param) ++ "_2d"
let block_pe =
PatElem block_name_2d $
rowType (paramType outer_block_param) `arrayOfShape` block_size_2d
write_block_stm =
Let (Pattern [] [block_pe]) (defAux ()) $
Op $ Combine block_cspace [Prim pt] [(global_i, global_d)] $
Body () (oneStm read_elem_bnd) [Var elem_name]
let index_block_kstms =
[mkLet [] [paramIdent block_param] $
BasicOp $ Index block_name_2d $
rearrangeShape inner_perm $
fullSlice (rearrangeType inner_perm $ patElemType block_pe)
[DimFix $ Var variant_i]]
return (outer_block_param,
oneStm write_block_stm <> stmsFromList index_block_kstms)
where invariantToOneOfTwoInnerDims :: Maybe [Int]
invariantToOneOfTwoInnerDims = do
(j,_) : (i,_) : _ <- Just $ reverse $ spaceDimensions kspace
let variant_to = M.findWithDefault mempty arr variance
branch_invariant = not $ S.member j branch_variant || S.member i branch_variant
if branch_invariant && i `S.member` variant_to && not (j `S.member` variant_to) then
Just [0,1]
else if branch_invariant && j `S.member` variant_to && not (i `S.member` variant_to) then
Just [1,0]
else
Nothing
syncAtEnd :: MonadFreshNames m => Body InKernel -> m (Body InKernel)
syncAtEnd (Body () stms res) = do
(res', stms') <- (`runBinderT` mempty) $ do
mapM_ addStm stms
map Var <$> letTupExp "sync" (Op $ Barrier res)
return $ Body () stms' res'
type VarianceTable = M.Map VName Names
varianceInStms :: VarianceTable -> Stms InKernel -> VarianceTable
varianceInStms = foldl varianceInStm
varianceInStm :: VarianceTable -> Stm InKernel -> VarianceTable
varianceInStm variance bnd =
foldl' add variance $ patternNames $ stmPattern bnd
where add variance' v = M.insert v binding_variance variance'
look variance' v = S.insert v $ M.findWithDefault mempty v variance'
binding_variance = mconcat $ map (look variance) $ S.toList (freeInStm bnd)
sufficientGroups :: MonadBinder m =>
[(VName, SubExp, VName, SubExp)] -> SubExp
-> m (SubExp, SubExp)
sufficientGroups gspace group_size = do
groups_in_dims <- forM gspace $ \(_, gd, _, ld) ->
letSubExp "groups_in_dim" =<< eDivRoundingUp Int32 (eSubExp gd) (eSubExp ld)
num_groups <- letSubExp "num_groups" =<<
foldBinOp (Mul Int32) (constant (1::Int32)) groups_in_dims
num_threads <- letSubExp "num_threads" $
BasicOp $ BinOp (Mul Int32) num_groups group_size
return (num_threads, num_groups)