{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Optimise.MemoryBlockMerging.Liveness.FirstUse
( findFirstUses
, createsNewArrayBase
) where
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Control.Monad
import Control.Monad.RWS
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
(ExplicitMemory, InKernel, ExplicitMemorish)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
data Context = Context
{ ctxVarToMem :: VarMemMappings MemorySrc
, ctxMemAliases :: MemAliases
, ctxCurOuterFirstUses :: Names
}
deriving (Show)
newtype FindM lore a = FindM { unFindM :: RWS Context () FirstUses a }
deriving (Monad, Functor, Applicative,
MonadReader Context,
MonadWriter (),
MonadState FirstUses)
type LoreConstraints lore = (ExplicitMemorish lore,
ArrayUtils lore,
FullWalk lore)
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
varMems :: VName -> FindM lore MNames
varMems var = do
var_to_mem <- asks ctxVarToMem
mem_aliases <- asks ctxMemAliases
return $ fromMaybe S.empty $ do
mem <- memSrcName <$> M.lookup var var_to_mem
return $ S.union (S.singleton mem) $ lookupEmptyable mem mem_aliases
recordMapping :: VName -> MName -> FindM lore ()
recordMapping stmt_var mem =
modify $ M.unionWith S.union (M.singleton stmt_var $ S.singleton mem)
findFirstUses :: VarMemMappings MemorySrc -> MemAliases
-> FunDef ExplicitMemory -> FirstUses
findFirstUses var_to_mem mem_aliases fundef =
let context = Context { ctxVarToMem = var_to_mem
, ctxMemAliases = mem_aliases
, ctxCurOuterFirstUses = S.empty
}
m = unFindM $ do
forM_ (funDefParams fundef) lookInFunDefFParam
lookInBody $ funDefBody fundef
first_uses = removeEmptyMaps $ expandWithAliases mem_aliases
$ fst $ execRWS m context M.empty
in first_uses
lookInFunDefFParam :: LoreConstraints lore =>
FParam lore -> FindM lore ()
lookInFunDefFParam (Param x (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn xmem _))) =
recordMapping x xmem
lookInFunDefFParam _ = return ()
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
mapM_ lookInStm bnds
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
mapM_ lookInStm bnds
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm (Let (Pattern patctxelems patvalelems) _ e) = do
outer_first_uses <- asks ctxCurOuterFirstUses
when (createsNewArray e) $ do
let e_free_vars = freeInExp e
e_mems <- S.unions <$> mapM varMems (S.toList e_free_vars)
forM_ patvalelems $ \(PatElem x membound) ->
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn xmem _) -> do
x_mems <- varMems xmem
when (S.null $ S.intersection x_mems e_mems)
$ unless (xmem `S.member` outer_first_uses)
$ recordMapping x xmem
_ -> return ()
forM_ patctxelems
$ \p -> forM_ patvalelems
$ \el -> lookInPatCtxElem (patElemName el) p
case e of
DoLoop mergectxparams _mergevalparams _loopform _body ->
forM_ mergectxparams
$ \p -> forM_ patvalelems
$ \el -> lookInMergeCtxParam (patElemName el) p
_ -> return ()
cur_first_uses <- get
local (\ctx -> ctx { ctxCurOuterFirstUses = S.unions $ M.elems cur_first_uses })
$ fullWalkExpM walker walker_kernel e
where walker = identityWalker
{ walkOnBody = lookInBody }
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . lookInBody
, walkOnKernelKernelBody = coerce . lookInKernelBody
, walkOnKernelLambda = coerce . lookInBody . lambdaBody
}
lookInPatCtxElem :: LoreConstraints lore =>
VName -> PatElem lore -> FindM lore ()
lookInPatCtxElem x (PatElem xmem ExpMem.MemMem{}) =
recordMapping x xmem
lookInPatCtxElem _ _ = return ()
lookInMergeCtxParam :: LoreConstraints lore =>
VName -> (FParam lore, SubExp) -> FindM lore ()
lookInMergeCtxParam x (Param xmem ExpMem.MemMem{}, _) =
recordMapping x xmem
lookInMergeCtxParam _ _ = return ()
class ArrayUtils lore where
createsNewArray :: Exp lore -> Bool
createsNewArrayBase :: Exp lore -> Bool
createsNewArrayBase e = case e of
BasicOp Replicate{} -> True
BasicOp Iota{} -> True
BasicOp Manifest{} -> True
BasicOp Copy{} -> True
BasicOp Concat{} -> True
BasicOp ArrayLit{} -> True
BasicOp Scratch{} -> True
_ -> False
instance ArrayUtils ExplicitMemory where
createsNewArray e = case e of
Op (ExpMem.Inner ExpMem.Kernel{}) -> True
_ -> createsNewArrayBase e
instance ArrayUtils InKernel where
createsNewArray e = case e of
Op (ExpMem.Inner ExpMem.GroupReduce{}) -> True
Op (ExpMem.Inner ExpMem.GroupScan{}) -> True
Op (ExpMem.Inner ExpMem.GroupStream{}) -> True
Op (ExpMem.Inner ExpMem.Combine{}) -> True
_ -> createsNewArrayBase e