{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Optimise.MemoryBlockMerging.Coalescing.Core
( coreCoalesceFunDef
) where
import qualified Data.Set as S
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe (maybe, fromMaybe, mapMaybe, isJust)
import Control.Monad
import Control.Monad.RWS
import Futhark.MonadFreshNames
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (
ExplicitMemory, ExplicitMemorish)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Tools
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
import Futhark.Optimise.MemoryBlockMerging.MemoryUpdater
import Futhark.Optimise.MemoryBlockMerging.PrimExps (findPrimExpsFunDef)
import Futhark.Optimise.MemoryBlockMerging.Coalescing.Exps
import Futhark.Optimise.MemoryBlockMerging.Coalescing.SafetyCondition2
import Futhark.Optimise.MemoryBlockMerging.Coalescing.SafetyCondition3
import Futhark.Optimise.MemoryBlockMerging.Coalescing.SafetyCondition5
import Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizes
data Current = Current
{
curCoalescedIntos :: CoalescedIntos
, curMemsCoalesced :: MemsCoalesced
}
deriving (Show)
type CoalescedIntos = M.Map VName (S.Set (VName, PrimExp VName,
[Slice (PrimExp VName)]))
type MemsCoalesced = M.Map VName MemoryLoc
emptyCurrent :: Current
emptyCurrent = Current
{ curCoalescedIntos = M.empty
, curMemsCoalesced = M.empty
}
data Context = Context
{ ctxFunDef :: FunDef ExplicitMemory
, ctxVarToMem :: VarMemMappings MemorySrc
, ctxMemAliases :: MemAliases
, ctxVarAliases :: VarAliases
, ctxFirstUses :: FirstUses
, ctxLastUses :: LastUses
, ctxActualVars :: M.Map VName Names
, ctxExistentials :: Names
, ctxVarPrimExps :: M.Map VName (PrimExp VName)
, ctxVarExps :: M.Map VName Exp'
, ctxAllocatedBlocksBeforeCreation :: M.Map VName MNames
, ctxVarsInUseBeforeMem :: M.Map MName Names
, ctxCurSnapshot :: Current
}
deriving (Show)
newtype FindM lore a = FindM { unFindM :: RWS Context () Current a }
deriving (Monad, Functor, Applicative,
MonadReader Context,
MonadState Current)
type LoreConstraints lore = (ExplicitMemorish lore,
FullWalk lore)
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
modifyCurCoalescedIntos :: (CoalescedIntos -> CoalescedIntos) -> FindM lore ()
modifyCurCoalescedIntos f =
modify $ \c -> c { curCoalescedIntos = f $ curCoalescedIntos c }
modifyCurMemsCoalesced :: (MemsCoalesced -> MemsCoalesced) -> FindM lore ()
modifyCurMemsCoalesced f =
modify $ \c -> c { curMemsCoalesced = f $ curMemsCoalesced c }
ifExp :: MonadReader Context m =>
VName -> m (Maybe Exp')
ifExp var = do
var_exp <- M.lookup var <$> asks ctxVarExps
return $ case var_exp of
Just e@(Exp _ _ If{}) -> Just e
_ -> Nothing
isIfExp :: MonadReader Context m =>
VName -> m Bool
isIfExp var = isJust <$> ifExp var
isLoopExp :: MonadReader Context m =>
VName -> m Bool
isLoopExp var = do
var_exp <- M.lookup var <$> asks ctxVarExps
return $ case var_exp of
Just (Exp _ _ DoLoop{}) -> True
_ -> False
isReshapeExp :: MonadReader Context m =>
VName -> m Bool
isReshapeExp var = do
var_exp <- M.lookup var <$> asks ctxVarExps
return $ case var_exp of
Just (Exp _ _ (BasicOp Reshape{})) -> True
_ -> False
lookupVarMem :: MonadReader Context m =>
VName -> m MemorySrc
lookupVarMem var =
fromJust ("lookup memory block from " ++ pretty var) . M.lookup var
<$> asks ctxVarToMem
lookupActualVars :: MonadReader Context m =>
VName -> m Names
lookupActualVars var = do
actual_vars <- asks ctxActualVars
let actual_vars' = expandWithAliases actual_vars actual_vars
return $ fromMaybe (S.singleton var) $ M.lookup var actual_vars'
lookupCurrentVarMem :: VName -> FindM lore (Maybe VName)
lookupCurrentVarMem var = do
mem_cur <- M.lookup var . curMemsCoalesced <$> asks ctxCurSnapshot
mem_orig <- M.lookup var <$> asks ctxVarToMem
return $ case (mem_cur, mem_orig) of
(Just m, _) -> Just (memLocName m)
(_, Just m) -> Just (memSrcName m)
_ -> Nothing
withMemAliases :: MonadReader Context m =>
VName -> m Names
withMemAliases mem =
S.union (S.singleton mem) . lookupEmptyable mem
<$> asks ctxMemAliases
data Bindage = BindInPlace VName (Slice SubExp)
| BindVar
recordOptimisticCoalescing :: VName -> PrimExp VName
-> [Slice (PrimExp VName)]
-> VName -> MemoryLoc -> Bindage -> FindM lore ()
recordOptimisticCoalescing src offset ixfun_slices dst dst_memloc bindage = do
modifyCurCoalescedIntos $ insertOrUpdate dst (src, offset, ixfun_slices)
case bindage of
BindVar -> return ()
BindInPlace orig _ ->
modifyCurCoalescedIntos $ insertOrUpdate dst (orig, zeroOffset, [])
modifyCurMemsCoalesced $ M.insert src dst_memloc
coreCoalesceFunDef :: MonadFreshNames m =>
FunDef ExplicitMemory -> VarMemMappings MemorySrc
-> MemAliases -> VarAliases -> FirstUses -> LastUses
-> ActualVariables -> Names -> m (FunDef ExplicitMemory)
coreCoalesceFunDef fundef var_to_mem mem_aliases var_aliases first_uses
last_uses actual_vars existentials = do
let primexps = findPrimExpsFunDef fundef
exps = findExpsFunDef fundef
cond2 = findSafetyCondition2FunDef fundef
cond5 = findSafetyCondition5FunDef fundef first_uses
context = Context { ctxFunDef = fundef
, ctxVarToMem = var_to_mem
, ctxMemAliases = mem_aliases
, ctxVarAliases = var_aliases
, ctxFirstUses = first_uses
, ctxLastUses = last_uses
, ctxActualVars = actual_vars
, ctxExistentials = existentials
, ctxVarPrimExps = primexps
, ctxVarExps = exps
, ctxAllocatedBlocksBeforeCreation = cond2
, ctxVarsInUseBeforeMem = cond5
, ctxCurSnapshot = emptyCurrent
}
m = unFindM $ lookInBody $ funDefBody fundef
var_to_mem_res = curMemsCoalesced $ fst $ execRWS m context emptyCurrent
sizes = memBlockSizesFunDef fundef
transformFromVarMemMappings var_to_mem_res (M.map memSrcName var_to_mem) (M.map fst sizes) (M.map fst sizes) False fundef
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
mapM_ lookInStm bnds
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
mapM_ lookInStm bnds
zeroOffset :: PrimExp VName
zeroOffset = primExpFromSubExp (IntType Int32) (constant (0 :: Int32))
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm (Let (Pattern _patctxelems patvalelems) _ e) = do
case patvalelems of
[PatElem dst ExpMem.MemArray{}] -> do
cur_snapshot <- get
var_to_mem <- asks ctxVarToMem
local (\ctx -> ctx { ctxCurSnapshot = cur_snapshot })
$ case e of
BasicOp (Update orig slice (Var src)) ->
case M.lookup src var_to_mem of
Just _ ->
let ixfun_slices =
let slice' = map (primExpFromSubExp (IntType Int32) <$>) slice
in [slice']
bindage = BindInPlace orig slice
in tryCoalesce dst ixfun_slices bindage src zeroOffset
Nothing ->
return ()
BasicOp (Copy src) ->
tryCoalesce dst [] BindVar src zeroOffset
BasicOp (Concat 0 src0 src0s _) -> do
let srcs = src0 : src0s
shapes <- mapM ((memSrcShape <$>) . lookupVarMem) srcs
let getOffsets offset_prev shape =
let se = head (shapeDims shape)
len = primExpFromSubExp (IntType Int32) se
offset_new = offset_prev + len
in offset_new
offsets = init (scanl getOffsets zeroOffset shapes)
zipWithM_ (tryCoalesce dst [] BindVar) srcs offsets
_ -> return ()
_ -> return ()
fullWalkExpM walker walker_kernel e
where walker = identityWalker
{ walkOnBody = lookInBody }
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . lookInBody
, walkOnKernelKernelBody = coerce . lookInKernelBody
, walkOnKernelLambda = coerce . lookInBody . lambdaBody
}
tryCoalesce :: VName -> [Slice (PrimExp VName)] -> Bindage ->
VName -> PrimExp VName -> FindM lore ()
tryCoalesce dst ixfun_slices bindage src offset = do
mem_dst <- lookupVarMem dst
src's <- S.toList <$> lookupActualVars src
coalesced_intos <- curCoalescedIntos <$> asks ctxCurSnapshot
let (src0s, offset0s, ixfun_slice0ss) =
unzip3 $ S.toList $ S.unions
$ map (`lookupEmptyable` coalesced_intos) (src : src's)
var_to_pe <- asks ctxVarPrimExps
let srcs = src's ++ src0s
offsets = replicate (length src's) offset
++ map (\o0 -> if o0 == zeroOffset && offset == zeroOffset
then zeroOffset
else offset + o0) offset0s
ixfun_slicess = replicate (length src's) ixfun_slices
++ map (\slices0 -> ixfun_slices ++ slices0) ixfun_slice0ss
let ixfuns' = zipWith (\offset_local islices ->
let ixfun0 = memSrcIxFun mem_dst
ixfun1 = foldl IxFun.slice ixfun0 islices
initial_dimfixes = L.takeWhile (isJust . dimFix) (concat ixfun_slices)
ixfun2 = if offset_local == zeroOffset
then ixfun1
else IxFun.offsetIndexDWIM (length initial_dimfixes) ixfun1 offset_local
ixfun3 = expandIxFun var_to_pe ixfun2
in ixfun3
) offsets ixfun_slicess
existentials <- asks ctxExistentials
let currentlyDisabled src_local = do
src_local_is_loop <- isLoopExp src_local
let res = src_local_is_loop
&& src_local `L.elem` existentials
return res
safe0 <- not . or <$> mapM currentlyDisabled srcs
mem_src_base <- lookupVarMem src
safe1 <- safetyCond1 dst mem_src_base
when (safe0 && safe1) $ do
safes <- zipWithM (canBeCoalesced dst) srcs ixfuns'
when (and safes) $ do
modifyCurCoalescedIntos $ M.delete src
forM_ (L.zip4 srcs offsets ixfun_slicess ixfuns')
$ \(src_local, offset_local, ixfun_slices_local, ixfun_local) -> do
denotes_existential <- S.member src_local <$> asks ctxExistentials
is_if <- isIfExp src_local
dst_memloc <-
if denotes_existential && not is_if
then do
mem_src <- lookupVarMem src_local
return $ MemoryLoc (memSrcName mem_src) ixfun_local
else
return $ MemoryLoc (memSrcName mem_dst) ixfun_local
recordOptimisticCoalescing
src_local offset_local ixfun_slices_local
dst dst_memloc bindage
canBeCoalesced :: VName -> VName -> ExpMem.IxFun -> FindM lore Bool
canBeCoalesced dst src ixfun = do
mem_dst <- lookupVarMem dst
mem_src <- lookupVarMem src
safe2 <- safetyCond2 src mem_dst
safe3 <- safetyCond3 src dst mem_dst
safe4 <- safetyCond4 src
safe5 <- safetyCond5 mem_src ixfun
safe_if <- safetyIf src dst
let safe_all = safe2 && safe3 && safe4 && safe5 && safe_if
return safe_all
safetyCond1 :: MonadReader Context m =>
VName -> MemorySrc -> m Bool
safetyCond1 dst mem_src = do
last_uses <- lookupEmptyable (FromStm dst) <$> asks ctxLastUses
let res = S.member (memSrcName mem_src) last_uses
return res
safetyCond2 :: MonadReader Context m =>
VName -> MemorySrc -> m Bool
safetyCond2 src mem_dst = do
allocs_before_src <- lookupEmptyable src
<$> asks ctxAllocatedBlocksBeforeCreation
let res = S.member (memSrcName mem_dst) allocs_before_src
return res
safetyCond3 :: VName -> VName -> MemorySrc -> FindM lore Bool
safetyCond3 src dst mem_dst = do
fundef <- asks ctxFunDef
let uses_after_src_vars = S.toList $ getVarUsesBetween fundef src dst
uses_after_src <- mapM (maybe (return S.empty) withMemAliases
<=< lookupCurrentVarMem) uses_after_src_vars
return $ not $ S.member (memSrcName mem_dst) (S.unions uses_after_src)
safetyCond4 :: MonadReader Context m =>
VName -> m Bool
safetyCond4 src = do
if_handling <- isIfExp src
src_actuals <- lookupEmptyable src <$> asks ctxActualVars
reshape_handling <- isReshapeExp src <&&> pure (not (S.null src_actuals))
src_aliases <- lookupEmptyable src <$> asks ctxVarAliases
let res = if_handling || reshape_handling || S.null src_aliases
return res
safetyCond5 :: MonadReader Context m =>
MemorySrc -> ExpMem.IxFun -> m Bool
safetyCond5 mem_src ixfun = do
in_use_before_mem_src <- lookupEmptyable (memSrcName mem_src)
<$> asks ctxVarsInUseBeforeMem
let used_vars = freeIn ixfun
res = all (`S.member` in_use_before_mem_src) $ S.toList used_vars
return res
safetyIf :: VName -> VName -> FindM lore Bool
safetyIf src dst = do
mem_src <- lookupVarMem src
actual_srcs <- S.toList <$> lookupActualVars src
existentials <- asks ctxExistentials
var_to_mem <- asks ctxVarToMem
first_uses_all <- asks ctxFirstUses
reverse_actual_srcs <-
S.toList . S.unions . M.elems . M.filter (src `S.member`)
<$> asks ctxActualVars
outer <- mapMaybeM ifExp reverse_actual_srcs
let (is_in_if,
if_branch_results_from_outer,
at_least_one_creation_inside) = case outer of
[Exp nctx nthpat (If _ body0 body1 _)] ->
let results_from_outer = S.fromList $ mapMaybe subExpVar
$ concatMap (drop nctx . bodyResult)
$ filter (null . bodyStms) [body0, body1]
resultCreatedInside body se = fromMaybe False $ do
res <- subExpVar se
res_mem <- memSrcName <$> M.lookup res var_to_mem
let body_vars = concatMap (map patElemName . patternValueElements
. stmPattern) $ bodyStms body
body_first_uses = S.unions $ map (`lookupEmptyable` first_uses_all)
body_vars
return $ S.member res_mem body_first_uses
at_least = resultCreatedInside body0 (bodyResult body0 !! (nctx + nthpat))
|| resultCreatedInside body1 (bodyResult body1 !! (nctx + nthpat))
in (True, results_from_outer, at_least)
_ -> (False, S.empty, False)
let res_general = not is_in_if || (not (any (`S.member` existentials) actual_srcs)
|| at_least_one_creation_inside)
let if_handling =
is_in_if
&& not (any (`S.member` if_branch_results_from_outer) actual_srcs)
&& not (src `S.member` existentials)
res_current <-
if if_handling
then do
mem_actual_srcs <- L.nub <$> mapM lookupVarMem reverse_actual_srcs
let mem_actual_srcs_cur = L.delete mem_src mem_actual_srcs
and <$> mapM (safetyCond3 src dst) mem_actual_srcs_cur
else return True
let res = res_general && res_current
return res