{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Optimise.MemoryBlockMerging.Liveness.LastUse
( findLastUses
) where
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Control.Monad
import Control.Monad.RWS
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
(ExplicitMemorish, ExplicitMemory)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
type LastUsesList = [LastUses]
getLastUsesMap :: LastUsesList -> LastUses
getLastUsesMap = M.unionsWith S.union
type OptimisticLastUses = M.Map VName (StmOrRes, Bool)
data Context = Context
{ ctxVarToMem :: VarMemMappings MemorySrc
, ctxMemAliases :: MemAliases
, ctxFirstUses :: FirstUses
, ctxExistentials :: Names
, ctxCurFirstUsesOuter :: Names
}
deriving (Show)
data Current = Current
{ curOptimisticLastUses :: OptimisticLastUses
, curFirstUses :: Names
}
deriving (Show)
newtype FindM lore a = FindM { unFindM :: RWS Context LastUsesList Current a }
deriving (Monad, Functor, Applicative,
MonadReader Context,
MonadWriter LastUsesList,
MonadState Current)
type LoreConstraints lore = (ExplicitMemorish lore,
FullWalk lore)
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
varMems :: VName -> FindM lore MNames
varMems var =
maybe S.empty (S.singleton . memSrcName) <$> asks (M.lookup var . ctxVarToMem)
modifyCurOptimisticLastUses :: (OptimisticLastUses -> OptimisticLastUses) -> FindM lore ()
modifyCurOptimisticLastUses f =
modify $ \c -> c { curOptimisticLastUses = f $ curOptimisticLastUses c }
modifyCurFirstUses :: (Names -> Names) -> FindM lore ()
modifyCurFirstUses f = modify $ \c -> c { curFirstUses = f $ curFirstUses c }
withLocalCurFirstUses :: FindM lore a -> FindM lore a
withLocalCurFirstUses m = do
cur_first_uses <- gets curFirstUses
res <- m
modifyCurFirstUses $ const cur_first_uses
return res
recordMapping :: StmOrRes -> MName -> FindM lore ()
recordMapping var mem = tell [M.singleton var (S.singleton mem)]
findLastUses :: VarMemMappings MemorySrc -> MemAliases -> FirstUses -> Names
-> FunDef ExplicitMemory -> LastUses
findLastUses var_to_mem mem_aliases first_uses existentials fundef =
let context = Context
{ ctxVarToMem = var_to_mem
, ctxMemAliases = mem_aliases
, ctxFirstUses = first_uses
, ctxExistentials = existentials
, ctxCurFirstUsesOuter = S.empty
}
m = unFindM $ do
forM_ (funDefParams fundef) lookInFunDefFParam
lookInBody $ funDefBody fundef
mapM_ lookInRes $ bodyResult $ funDefBody fundef
optimistics <- gets curOptimisticLastUses
forM_ (M.keys optimistics) $ \mem ->
commitOptimistic mem
last_uses = removeEmptyMaps $ getLastUsesMap
$ snd $ evalRWS m context (Current M.empty S.empty)
in last_uses
setOptimistic :: MName -> StmOrRes -> MNames -> FindM lore ()
setOptimistic mem x_lu exclude = do
mem_aliases <- asks ctxMemAliases
let mems = S.difference (S.union (S.singleton mem)
$ lookupEmptyable mem mem_aliases) exclude
forM_ mems $ \mem' -> do
let is_indirect = mem' /= mem
modifyCurOptimisticLastUses $ M.insert mem' (x_lu, is_indirect)
removeIndirectOptimistic :: MName -> FindM lore ()
removeIndirectOptimistic mem = do
res <- M.lookup mem <$> gets curOptimisticLastUses
case res of
Just (_, True) ->
modifyCurOptimisticLastUses $ M.delete mem
_ -> return ()
commitOptimistic :: MName -> FindM lore ()
commitOptimistic mem = do
res <- M.lookup mem <$> gets curOptimisticLastUses
case res of
Just (x_lu, _) -> recordMapping x_lu mem
Nothing -> return ()
lookInFunDefFParam :: FParam lore -> FindM lore ()
lookInFunDefFParam (Param x _) = do
first_uses_x <- lookupEmptyable x <$> asks ctxFirstUses
modifyCurFirstUses $ S.union first_uses_x
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
mapM_ lookInStm bnds
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
mapM_ lookInStm bnds
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm (Let (Pattern _patctxelems patvalelems) _ e) = do
cur_first_uses <- gets curFirstUses
let mMod = case e of
If{} -> id
_ -> local $ \ctx -> ctx { ctxCurFirstUsesOuter = cur_first_uses }
forM_ patvalelems $ \(PatElem x membound) ->
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn xmem _) -> do
first_uses_x <- lookupEmptyable x <$> asks ctxFirstUses
modifyCurFirstUses $ S.union first_uses_x
when (S.member xmem first_uses_x) $ commitOptimistic xmem
_ -> return ()
let e_free_vars = freeInExp e `S.difference` S.fromList (freeExcludes e)
e_mems <- S.unions <$> mapM varMems (S.toList e_free_vars)
mem_aliases <- asks ctxMemAliases
first_uses_outer <- asks ctxCurFirstUsesOuter
forM_ patvalelems $ \(PatElem x _) ->
forM_ (S.toList e_mems) $ \mem -> do
let from_outer = any (`S.member` first_uses_outer)
(mem : S.toList (lookupEmptyable mem mem_aliases))
if from_outer
then removeIndirectOptimistic mem
else setOptimistic mem (FromStm x) S.empty
if S.null (lookupEmptyable mem mem_aliases)
then
when from_outer $ do
let reverse_mem_aliases = M.keys $ M.filter (mem `S.member`) mem_aliases
exclude = S.singleton mem
forM_ reverse_mem_aliases $ \mem' ->
setOptimistic mem' (FromStm x) exclude
else
unless from_outer $ setOptimistic mem (FromStm x) S.empty
withLocalCurFirstUses $ mMod $ fullWalkExpM walker walker_kernel e
where walker = identityWalker
{ walkOnBody = lookInBody }
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . lookInBody
, walkOnKernelKernelBody = coerce . lookInKernelBody
, walkOnKernelLambda = coerce . lookInBody . lambdaBody
}
lookInRes :: SubExp -> FindM lore ()
lookInRes (Var v) = do
exis <- asks ctxExistentials
unless (v `S.member` exis) $ do
mem_v <- M.lookup v <$> asks ctxVarToMem
case mem_v of
Just mem ->
setOptimistic (memSrcName mem) (FromRes v) S.empty
Nothing ->
return ()
lookInRes _ = return ()
freeExcludes :: Exp lore -> [VName]
freeExcludes e = case e of
DoLoop _ _mergevalparams _ _ ->
[]
BasicOp (Update orig _ _) ->
[orig]
_ -> []