{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE LambdaCase #-}
module Futhark.Optimise.MemoryBlockMerging.Liveness.Interference
( findInterferences
) where
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.List as L
import Data.Maybe (mapMaybe, fromMaybe, catMaybes)
import Control.Monad
import Control.Monad.RWS
import Control.Monad.Writer
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (
ExplicitMemorish, ExplicitMemory, InKernel)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
data Context = Context { ctxVarToMem :: VarMemMappings MemorySrc
, ctxMemAliases :: MemAliases
, ctxFirstUses :: FirstUses
, ctxLastUses :: LastUses
, ctxExistentials :: Names
, ctxLoopCorrespondingVar :: M.Map VName (VName, SubExp)
}
deriving (Show)
type InterferencesList = [(MName, MNames)]
getInterferencesMap :: InterferencesList -> Interferences
getInterferencesMap = M.unionsWith S.union . map (uncurry M.singleton)
data Current = Current { curAlive :: MNames
, curResPotentialKernelInterferences
:: PotentialKernelDataRaceInterferences
}
deriving (Show)
newtype FindM lore a = FindM
{ unFindM :: RWS Context InterferencesList Current a }
deriving (Monad, Functor, Applicative,
MonadReader Context,
MonadWriter InterferencesList,
MonadState Current)
type LoreConstraints lore = (ExplicitMemorish lore,
KernelInterferences lore,
SpecialBodyExceptions lore,
FullWalk lore)
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
awaken :: MName -> FindM lore ()
awaken mem = modifyCurAlive $ S.insert mem
kill :: MName -> FindM lore ()
kill mem = modifyCurAlive $ S.delete mem
modifyCurAlive :: (MNames -> MNames) -> FindM lore ()
modifyCurAlive f = modify $ \c -> c { curAlive = f $ curAlive c }
addPotentialKernelInterferenceGroup ::
PotentialKernelDataRaceInterferenceGroup -> FindM lore ()
addPotentialKernelInterferenceGroup set =
modify $ \c -> c { curResPotentialKernelInterferences =
curResPotentialKernelInterferences c ++ [set] }
recordCurrentInterferences :: FindM lore ()
recordCurrentInterferences = do
current <- gets curAlive
forM_ (S.toList current) $ \mem ->
tell [(mem, current)]
recordNewInterferences :: MNames -> FindM lore ()
recordNewInterferences mems_in_stm = do
current <- gets curAlive
forM_ (S.toList current) $ \mem ->
tell [(mem, mems_in_stm)]
forM_ (S.toList mems_in_stm) $ \mem ->
tell [(mem, current)]
findInterferences :: VarMemMappings MemorySrc -> MemAliases ->
FirstUses -> LastUses -> Names -> FunDef ExplicitMemory
-> (Interferences, PotentialKernelDataRaceInterferences)
findInterferences var_to_mem mem_aliases first_uses last_uses existentials fundef =
let context = Context { ctxVarToMem = var_to_mem
, ctxMemAliases = mem_aliases
, ctxFirstUses = first_uses
, ctxLastUses = last_uses
, ctxExistentials = existentials
, ctxLoopCorrespondingVar = M.empty
}
m = unFindM $ do
forM_ (funDefParams fundef) lookInFunDefFParam
lookInBody $ funDefBody fundef
(cur, interferences_list) = execRWS m context (Current S.empty [])
interferences = removeEmptyMaps $ removeKeyFromMapElems $ makeCommutativeMap
$ getInterferencesMap interferences_list
potential_kernel_interferences = curResPotentialKernelInterferences cur
in (interferences, potential_kernel_interferences)
lookInFunDefFParam :: FParam lore -> FindM lore ()
lookInFunDefFParam (Param var _) = do
first_uses_var <- lookupEmptyable var <$> asks ctxFirstUses
mapM_ awaken $ S.toList first_uses_var
recordCurrentInterferences
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds res) = do
mapM_ lookInStm bnds
lookInRes res
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds res) = do
mapM_ lookInStm bnds
lookInRes $ map kernelResultSubExp res
awakenFirstUses :: [PatElem lore] -> FindM lore ()
awakenFirstUses patvalelems =
forM_ patvalelems $ \(PatElem var _) -> do
first_uses_var <- lookupEmptyable var <$> asks ctxFirstUses
mapM_ awaken $ S.toList first_uses_var
isNoOp :: Exp lore -> Bool
isNoOp (BasicOp bop) = case bop of
Scratch{} -> True
_ -> False
isNoOp _ = False
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm stm@(Let (Pattern _patctxelems patvalelems) _ e)
| isNoOp e =
awakenFirstUses patvalelems
| otherwise = do
awakenFirstUses patvalelems
ctx <- ask
let ctx' = ctx { ctxLoopCorrespondingVar =
M.union (ctxLoopCorrespondingVar ctx)
(findLoopCorrespondingVar ctx stm)
}
let stm_exceptions = fromMaybe [] $ do
indices <- specialBodyIndices e
let walker_exc =
identityWalker
{ walkOnBody = \body -> let (body', lcv) = innermostLoopNestBody ctx body
ctx'' = ctx' { ctxLoopCorrespondingVar =
M.union (ctxLoopCorrespondingVar ctx') lcv }
in tell $ interferenceExceptions ctx''
(bodyStms body') (bodyResult body')
indices Nothing }
walker_kernel_exc =
identityKernelWalker
{ walkOnKernelBody = \body -> let (body', lcv) = innermostLoopNestBody ctx body
ctx'' = ctx' { ctxLoopCorrespondingVar =
M.union (ctxLoopCorrespondingVar ctx') lcv }
in tell $ interferenceExceptions ctx''
(bodyStms body') (bodyResult body')
indices Nothing
, walkOnKernelKernelBody = \kbody -> tell $ interferenceExceptions ctx'
(kernelBodyStms kbody)
(mapMaybe (\case
ThreadsReturn _ se -> Just se
_ -> Nothing)
$ kernelBodyResult kbody)
indices
(specialBodyWriteMems stm)
}
return $ execWriter $ fullWalkExpM walker_exc walker_kernel_exc e
first_uses <- asks ctxFirstUses
last_uses <- asks ctxLastUses
let stm_mems =
S.unions $ map (\pelem ->
let v = patElemName pelem
in S.union
(lookupEmptyable v first_uses)
(lookupEmptyable (FromStm v) last_uses)) patvalelems
((), stm_interferences) <- censor (const []) $ listen $ do
recordNewInterferences stm_mems
local (const ctx') $ fullWalkExpM walker walker_kernel e
let stm_interferences' =
map (\(k, vs) ->
(k, S.fromList
$ filter (\v -> not ((k, v) `L.elem` stm_exceptions
|| (v, k) `L.elem` stm_exceptions))
$ S.toList vs))
stm_interferences
tell stm_interferences'
potential_kernel_interferences <- findKernelDataRaceInterferences e
forM_ potential_kernel_interferences addPotentialKernelInterferenceGroup
forM_ patvalelems $ \(PatElem var _) -> do
last_uses_var <- lookupEmptyable (FromStm var) <$> asks ctxLastUses
mapM_ kill last_uses_var
where walker = identityWalker
{ walkOnBody = lookInBody }
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . lookInBody
, walkOnKernelKernelBody = coerce . lookInKernelBody
, walkOnKernelLambda = coerce . lookInBody . lambdaBody
}
findLoopCorrespondingVar :: LoreConstraints lore =>
Context -> Stm lore -> M.Map VName (VName, SubExp)
findLoopCorrespondingVar ctx (Let (Pattern _patctxelems patvalelems) _
(DoLoop _ _ _ (Body _ stms res))) =
M.fromList $ catMaybes $ zipWith findIt patvalelems res
where findIt (PatElem pat_v (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn pat_mem _))) (Var res_v)
| not (null stms) = case L.last $ stmsToList stms of
Let (Pattern _ [PatElem _last_v
(ExpMem.MemArray _ _ _ (ExpMem.ArrayIn last_stm_mem _))]) _
(BasicOp (Update _ (DimFix slice_part : _) (Var copy_v))) ->
if pat_mem == last_stm_mem
then let res_v' =
if (memSrcName <$> M.lookup copy_v (ctxVarToMem ctx))
== Just last_stm_mem
then Just copy_v
else Just res_v
in res_v' >>= \t -> Just (t, (pat_v, slice_part))
else Nothing
_ -> Nothing
| otherwise = Nothing
findIt _ _ = Nothing
findLoopCorrespondingVar _ _ = M.empty
innermostLoopNestBody :: LoreConstraints lore =>
Context -> Body lore -> (Body lore, M.Map VName (VName, SubExp))
innermostLoopNestBody ctx body = case stmsToList $ bodyStms body of
Let _ _ (BasicOp Scratch{}) : loopstm@(Let _ _ (DoLoop _ _ _ body')) : _ ->
let (body'', loop_corresponding_var) = innermostLoopNestBody ctx body'
in (body'', M.union
(findLoopCorrespondingVar ctx loopstm)
loop_corresponding_var)
_ -> (body, M.empty)
lookInRes :: [SubExp] -> FindM lore ()
lookInRes ses = do
let vs = subExpVars ses
last_uses <- asks ctxLastUses
let last_uses_v =
S.unions $ map (\v -> lookupEmptyable (FromRes v) last_uses) vs
recordNewInterferences last_uses_v
mapM_ kill $ S.toList last_uses_v
firstUsesInStm :: LoreConstraints lore => FirstUses ->
Stm lore -> [KernelFirstUse]
firstUsesInStm first_uses stm =
let m = lookFUInStm stm
in snd $ evalRWS m first_uses ()
firstUsesInExp :: LoreConstraints lore =>
Exp lore -> FindM lore [KernelFirstUse]
firstUsesInExp e = do
let m = lookFUInExp e
first_uses <- asks ctxFirstUses
return $ snd $ evalRWS m first_uses ()
lookFUInStm :: LoreConstraints lore =>
Stm lore -> RWS FirstUses [KernelFirstUse] () ()
lookFUInStm (Let (Pattern _patctxelems patvalelems) _ e_stm) = do
forM_ patvalelems $ \(PatElem patname membound) ->
case membound of
ExpMem.MemArray pt _ _ (ExpMem.ArrayIn _ ixfun) -> do
fus <- lookupEmptyable patname <$> ask
forM_ fus $ \fu -> tell [(fu, patname, pt, ixfun)]
_ -> return ()
lookFUInExp e_stm
lookFUInExp :: LoreConstraints lore =>
Exp lore -> RWS FirstUses [KernelFirstUse] () ()
lookFUInExp = fullWalkExpM fu_walker fu_walker_kernel
where fu_walker = identityWalker
{ walkOnBody = mapM_ lookFUInStm . bodyStms }
fu_walker_kernel = identityKernelWalker
{ walkOnKernelBody = mapM_ lookFUInStm . bodyStms
, walkOnKernelKernelBody = mapM_ lookFUInStm . kernelBodyStms
, walkOnKernelLambda = mapM_ lookFUInStm . bodyStms . lambdaBody
}
class KernelInterferences lore where
findKernelDataRaceInterferences ::
Exp lore -> FindM lore (Maybe PotentialKernelDataRaceInterferenceGroup)
instance KernelInterferences ExplicitMemory where
findKernelDataRaceInterferences e = case e of
Op (ExpMem.Inner Kernel{}) -> Just <$> firstUsesInExp e
_ -> return Nothing
instance KernelInterferences InKernel where
findKernelDataRaceInterferences _ = return Nothing
class SpecialBodyExceptions lore where
specialBodyIndices :: Exp lore -> Maybe [MName]
specialBodyWriteMems :: Stm lore -> Maybe [(MName, ExpMem.IxFun, PrimType)]
instance SpecialBodyExceptions ExplicitMemory where
specialBodyIndices (Op (ExpMem.Inner (Kernel _ kernelspace _ _))) =
Just $ map fst $ spaceDimensions kernelspace
specialBodyIndices e = specialBodyIndicesBase e
specialBodyWriteMems (Let (Pattern _patctxelems patvalelems) _
(Op (ExpMem.Inner Kernel{}))) =
Just $ mapMaybe (\p -> case patElemAttr p of
ExpMem.MemArray t _ _ (ExpMem.ArrayIn mem ixfun) -> Just (mem, ixfun, t)
_ -> Nothing) patvalelems
specialBodyWriteMems _ = Nothing
instance SpecialBodyExceptions InKernel where
specialBodyIndices = specialBodyIndicesBase
specialBodyWriteMems = const Nothing
specialBodyIndicesBase :: Exp lore -> Maybe [MName]
specialBodyIndicesBase (DoLoop _ _ (ForLoop i _ _ _) _) = Just [i]
specialBodyIndicesBase _ = Nothing
interferenceExceptions :: LoreConstraints lore =>
Context -> Stms lore -> [SubExp] -> [MName] ->
Maybe [(MName, ExpMem.IxFun, PrimType)] -> [(MName, MName)]
interferenceExceptions ctx stms res indices output_mems_may =
let output_vars = subExpVars res
indices_slice = map (DimFix . Var) indices
stms_first_uses = map (\(mem, _, _, _) -> mem)
$ concatMap (firstUsesInStm (ctxFirstUses ctx)) stms
results =
concat $ flip map (stmsToList stms) $ \(Let (Pattern _patctxelems patvalelems) _ e) ->
flip map patvalelems $ \(PatElem v membound) ->
let fromread = case e of
BasicOp (Index orig slice) -> do
orig_mem <- M.lookup orig $ ctxVarToMem ctx
if
memSrcName orig_mem `L.notElem` stms_first_uses &&
not (memSrcName orig_mem `S.member` ctxExistentials ctx)
then return (v, typeOf membound, orig_mem, slice)
else Nothing
_ -> Nothing
fromwrite = case e of
BasicOp Update{}
| ExpMem.MemArray pt _ _ _ <- membound -> do
let (orig', slice') =
fixpointIterateMay
(\(v0, ss0) -> do
(v1, s1) <- M.lookup v0 (ctxLoopCorrespondingVar ctx)
return (v1, DimFix s1 : ss0))
(v, [])
orig_mem <- M.lookup orig' $ ctxVarToMem ctx
if
memSrcName orig_mem `L.notElem` stms_first_uses &&
not (memSrcName orig_mem `S.member` ctxExistentials ctx)
then return (v, Prim pt, orig_mem, slice')
else Nothing
_ -> Nothing
in (fromread, fromwrite)
fromreads = mapMaybe fst results
fromwrites = mapMaybe snd results
fromwrites' = filter (\(v, _, _, _) -> v `L.elem` output_vars) fromwrites
fus_input_vars = M.fromList $ map (\(v, _, mem, _) ->
(v, S.singleton $ memSrcName mem)) fromreads
lus_input_vars = mapFromListSetUnion $ mapMaybe
(\(v, typ, mem, _) ->
let check e_pat =
let frees = freeInExp e_pat
b = case typ of
Prim _ ->
v `S.member` frees
_ ->
memSrcName mem `L.elem`
mapMaybe ((memSrcName <$>) . (`M.lookup` ctxVarToMem ctx))
(S.toList frees)
in b
check' (Let _ _ e) = check e
in (\stm -> (FromStm $ patElemName $ head $ patternValueElements $ stmPattern stm,
S.singleton $ memSrcName mem)) <$>
L.find check' (reverse $ stmsToList stms)) fromreads
fus_output_vars = mapFromListSetUnion $ case output_mems_may of
Just _ -> []
_ -> map (\(v, _, mem, _) -> (v, S.singleton $ memSrcName mem)) fromwrites'
fus_result = mapFromListSetUnion $ case output_mems_may of
Just mems -> zip output_vars $ map (S.singleton . (\(mem, _, _) -> mem)) mems
_ -> []
fus = M.unionsWith S.union [ctxFirstUses ctx, fus_input_vars, fus_output_vars]
lus = M.unionsWith S.union [ctxLastUses ctx, lus_input_vars]
input_mem_slices = M.fromList $ map (\(_, _, mem, slice) ->
(memSrcName mem, slice)) fromreads
output_mem_slices = M.fromList $ case output_mems_may of
Just mems ->
map (\(mem, _, _) -> (mem, indices_slice)) mems
_ ->
map (\(_, _, mem, slice) -> (memSrcName mem, slice)) fromwrites'
mem_slices = M.union input_mem_slices output_mem_slices
input_mem_ixfuns = M.fromList $ map (\(_, _, mem, _) ->
(memSrcName mem, memSrcIxFun mem)) fromreads
output_mem_ixfuns = M.fromList $ case output_mems_may of
Just mems -> map (\(mem, ixfun, _) -> (mem, ixfun)) mems
_ -> map (\(_, _, mem, _) -> (memSrcName mem, memSrcIxFun mem)) fromwrites'
mem_ixfuns = M.union input_mem_ixfuns output_mem_ixfuns
input_mem_primtypes = M.fromList
$ map (\(_, t, mem, _) -> (memSrcName mem, elemType t)) fromreads
output_mem_primtypes = M.fromList $ case output_mems_may of
Just mems -> map (\(mem, _, pt) -> (mem, pt)) mems
_ -> map (\(_, t, mem, _) -> (memSrcName mem, elemType t)) fromwrites'
mem_primtypes = M.union input_mem_primtypes output_mem_primtypes
mem_ins0 = S.fromList $ map (\(_, _, mem, _) -> memSrcName mem) fromreads
mem_outs0 = S.fromList $ case output_mems_may of
Just mems -> map (\(mem, _, _) -> mem) mems
_ -> map (\(_, _, mem, _) -> memSrcName mem) fromwrites'
mem_ins = S.difference mem_ins0 mem_outs0
mem_outs = S.difference mem_outs0 mem_ins0
exceptions = snd $ evalRWS (findExceptions fus fus_result lus
mem_ins mem_outs mem_slices mem_ixfuns
mem_primtypes output_vars) () S.empty
in exceptions
where findExceptions :: FirstUses -> FirstUses -> LastUses -> Names -> Names ->
M.Map VName (Slice SubExp) -> M.Map VName ExpMem.IxFun ->
M.Map VName PrimType -> [VName] ->
RWS () [(VName, VName)] LocalDeaths ()
findExceptions fus fus_result lus mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes output_vars = do
forM_ stms $ \(Let (Pattern _patctxelems patvalelems) _ _) -> do
let vs = map patElemName patvalelems
fus_stm = S.unions $ map (`lookupEmptyable` fus) vs
lus_stm = S.unions $ map ((`lookupEmptyable` lus) . FromStm) vs
recordNewExceptions mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes fus_stm
modify $ S.union lus_stm
forM_ output_vars $ \ov -> do
let fus_ov = lookupEmptyable ov fus_result
recordNewExceptions mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes fus_ov
recordNewExceptions :: Names -> Names ->
M.Map VName (Slice SubExp) -> M.Map VName ExpMem.IxFun ->
M.Map VName PrimType -> Names ->
RWS () [(VName, VName)] LocalDeaths ()
recordNewExceptions mem_ins mem_outs mem_slices mem_ixfuns mem_primtypes fus_cur = do
deaths <- get
forM_ (S.toList fus_cur) $ \mem_fu -> forM_ deaths $ \mem_killed ->
fromMaybe (return ()) $ do
slice_fu <- M.lookup mem_fu mem_slices
slice_killed <- M.lookup mem_killed mem_slices
ixfun_fu <- M.lookup mem_fu mem_ixfuns
ixfun_killed <- M.lookup mem_killed mem_ixfuns
pt_fu <- M.lookup mem_fu mem_primtypes
pt_killed <- M.lookup mem_killed mem_primtypes
return $ when
(
mem_fu `S.member` mem_outs && mem_killed `S.member` mem_ins &&
ixfun_fu == ixfun_killed &&
slice_fu == slice_killed &&
(primByteSize pt_fu :: Int) == primByteSize pt_killed
) $ tell [(mem_fu, mem_killed)]
type LocalDeaths = Names