{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TupleSections #-}
module Futhark.Optimise.MemoryBlockMerging.Reuse.Core
( coreReuseFunDef
) where
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.List as L
import Data.Maybe (catMaybes, fromMaybe, isJust)
import Control.Monad
import Control.Monad.RWS
import Control.Monad.State
import Control.Monad.Identity
import Futhark.MonadFreshNames
import Futhark.Binder
import Futhark.Construct
import Futhark.Representation.AST
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.Representation.ExplicitMemory
(ExplicitMemory, ExplicitMemorish)
import Futhark.Pass.ExplicitAllocations()
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.PrimExps (findPrimExpsFunDef)
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
import Futhark.Optimise.MemoryBlockMerging.MemoryUpdater
import Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizes
import Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizeUses
data Context = Context { ctxFirstUses :: FirstUses
, ctxInterferences :: Interferences
, ctxPotentialKernelInterferences
:: PotentialKernelDataRaceInterferences
, ctxSizes :: Sizes
, ctxVarToMem :: VarMemMappings MemorySrc
, ctxActualVars :: M.Map VName Names
, ctxExistentials :: Names
, ctxVarPrimExps :: M.Map VName (PrimExp VName)
, ctxSizeVarsUsesBefore :: M.Map VName Names
}
deriving (Show)
data Current = Current { curUses :: M.Map MName MNames
, curEqAsserts :: M.Map VName Names
, curVarToMemRes :: VarMemMappings MemoryLoc
, curVarToMaxExpRes :: M.Map MName Names
, curKernelMaxSizedRes :: M.Map MName (VName,
((VName, VName),
(VName, VName)))
}
deriving (Show)
emptyCurrent :: Current
emptyCurrent = Current { curUses = M.empty
, curEqAsserts = M.empty
, curVarToMemRes = M.empty
, curVarToMaxExpRes = M.empty
, curKernelMaxSizedRes = M.empty
}
newtype FindM lore a = FindM { unFindM :: RWS Context () Current a }
deriving (Monad, Functor, Applicative,
MonadReader Context,
MonadState Current)
type LoreConstraints lore = (ExplicitMemorish lore,
FullWalk lore)
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
lookupVarMem :: MonadReader Context m =>
VName -> m MemorySrc
lookupVarMem var =
fromJust ("lookup memory block from " ++ pretty var) . M.lookup var
<$> asks ctxVarToMem
lookupActualVars' :: ActualVariables -> VName -> Names
lookupActualVars' actual_vars var =
let actual_vars' = expandWithAliases actual_vars actual_vars
in fromMaybe (S.singleton var) $ M.lookup var actual_vars'
lookupActualVars :: MonadReader Context m =>
VName -> m Names
lookupActualVars var = asks $ flip lookupActualVars' var . ctxActualVars
lookupSize :: MonadReader Context m =>
VName -> m SubExp
lookupSize var =
fst . fromJust ("lookup size from " ++ pretty var) . M.lookup var
<$> asks ctxSizes
lookupSpace :: MonadReader Context m =>
MName -> m Space
lookupSpace mem =
snd . fromJust ("lookup space from " ++ pretty mem) . M.lookup mem
<$> asks ctxSizes
insertUse :: VName -> VName -> FindM lore ()
insertUse old_mem new_mem =
modify $ \cur -> cur { curUses = insertOrUpdate old_mem new_mem $ curUses cur }
recordMemMapping :: VName -> MemoryLoc -> FindM lore ()
recordMemMapping x mem =
modify $ \cur -> cur { curVarToMemRes = M.insert x mem $ curVarToMemRes cur }
recordMaxMapping :: MName -> VName -> FindM lore ()
recordMaxMapping mem y =
modify $ \cur -> cur { curVarToMaxExpRes = insertOrUpdate mem y
$ curVarToMaxExpRes cur }
recordKernelMaxMapping :: MName -> (VName, ((VName, VName), (VName, VName)))
-> FindM lore ()
recordKernelMaxMapping mem info =
modify $ \cur -> cur { curKernelMaxSizedRes =
M.insert mem info $ curKernelMaxSizedRes cur
}
modifyCurEqAsserts :: (M.Map VName Names -> M.Map VName Names) -> FindM lore ()
modifyCurEqAsserts f = modify $ \c -> c { curEqAsserts = f $ curEqAsserts c }
withLocalUses :: FindM lore a -> FindM lore a
withLocalUses m = do
uses_before <- gets curUses
res <- m
uses_after <- gets curUses
let uses_before_updated = M.filterWithKey
(\mem _ -> mem `S.member` M.keysSet uses_before)
uses_after
modify $ \cur -> cur { curUses = uses_before_updated }
return res
coreReuseFunDef :: MonadFreshNames m =>
FunDef ExplicitMemory -> FirstUses ->
Interferences -> PotentialKernelDataRaceInterferences ->
VarMemMappings MemorySrc -> ActualVariables -> Names ->
m (FunDef ExplicitMemory)
coreReuseFunDef fundef first_uses interferences potential_kernel_interferences var_to_mem actual_vars existentials = do
let sizes = memBlockSizesFunDef fundef
size_uses = findSizeUsesFunDef fundef
var_to_pe = findPrimExpsFunDef fundef
context = Context
{ ctxFirstUses = first_uses
, ctxInterferences = interferences
, ctxPotentialKernelInterferences = potential_kernel_interferences
, ctxSizes = sizes
, ctxVarToMem = var_to_mem
, ctxActualVars = actual_vars
, ctxExistentials = existentials
, ctxVarPrimExps = var_to_pe
, ctxSizeVarsUsesBefore = size_uses
}
m = unFindM $ do
forM_ (funDefParams fundef) lookInFParam
lookInBody $ funDefBody fundef
(res, ()) = execRWS m context emptyCurrent
var_to_mem_res = curVarToMemRes res
fundef' <- transformFromVarMemMappings var_to_mem_res (M.map memSrcName var_to_mem) (M.map fst sizes) (M.map fst sizes) False fundef
let sizes' = memBlockSizesFunDef fundef'
fundef'' <- transformFromVarMaxExpMappings (curVarToMaxExpRes res) fundef'
transformFromKernelMaxSizedMappings var_to_pe var_to_mem (M.map memLocName var_to_mem_res) sizes' actual_vars (curKernelMaxSizedRes res) fundef''
lookInFParam :: LoreConstraints lore =>
FParam lore -> FindM lore ()
lookInFParam (Param _ membound) =
case membound of
ExpMem.MemArray _ _ Unique (ExpMem.ArrayIn mem _) ->
insertUse mem mem
_ -> return ()
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
mapM_ lookInStm bnds
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
mapM_ lookInStm bnds
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm (Let (Pattern _patctxelems patvalelems) _ e) = do
var_to_pe <- asks ctxVarPrimExps
let eqs | BasicOp (Assert (Var v) _ _) <- e
, Just (CmpOpExp (CmpEq _) (LeafExp v0 _) (LeafExp v1 _)) <- M.lookup v var_to_pe = do
modifyCurEqAsserts $ insertOrUpdate v0 v1
modifyCurEqAsserts $ insertOrUpdate v1 v0
| otherwise = return ()
eqs
forM_ patvalelems $ \(PatElem var membound) -> do
first_uses_var <- lookupEmptyable var <$> asks ctxFirstUses
actual_vars_var <- lookupActualVars var
existentials <- asks ctxExistentials
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
when (
mem `S.member` first_uses_var
&& not (var `S.member` existentials)
&& not (any (`S.member` existentials) actual_vars_var))
$ handleNewArray var mem
_ -> return ()
fullWalkExpM walker walker_kernel e
where walker = identityWalker
{ walkOnBody = withLocalUses . lookInBody }
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . withLocalUses . lookInBody
, walkOnKernelKernelBody = coerce . withLocalUses . lookInKernelBody
, walkOnKernelLambda = coerce . withLocalUses . lookInBody . lambdaBody
}
handleNewArray :: VName -> MName -> FindM lore ()
handleNewArray x xmem = do
interferences <- asks ctxInterferences
actual_vars <- lookupActualVars x
let notTheSame :: Monad m => MName -> MNames -> m Bool
notTheSame kmem _used_mems = return (kmem /= xmem)
let noneInterfere :: Monad m => MName -> MNames -> m Bool
noneInterfere _kmem used_mems =
return $ all (\used_mem -> not $ S.member xmem
$ lookupEmptyable used_mem interferences)
$ S.toList used_mems
let noneInterfereKernelArray :: MonadReader Context m => MNames -> m Bool
noneInterfereKernelArray used_mems =
not <$> anyM (interferesInKernel xmem) (S.toList used_mems)
let sameSpace :: MonadReader Context m =>
MName -> MNames -> m Bool
sameSpace kmem _used_mems = do
kspace <- lookupSpace kmem
xspace <- lookupSpace xmem
return (kspace == xspace)
let sizesMatch :: MNames -> FindM lore Bool
sizesMatch used_mems = do
ok_sizes <- mapM lookupSize $ S.toList used_mems
new_size <- lookupSize xmem
let eq_simple = new_size `L.elem` ok_sizes
var_to_pe <- asks ctxVarPrimExps
eq_asserts <- gets curEqAsserts
let sePrimExp se = do
v <- subExpVar se
pe <- M.lookup v var_to_pe
let pe_expanded = expandPrimExp var_to_pe pe
traverse (\v_inner ->
pure $ VarWithLooseEquality v_inner
$ lookupEmptyable v_inner eq_asserts
) pe_expanded
let ok_sizes_pe = map sePrimExp ok_sizes
let new_size_pe = sePrimExp new_size
let eq_advanced = isJust new_size_pe && new_size_pe `L.elem` ok_sizes_pe
return (eq_simple || eq_advanced)
let sizesCanBeMaxed :: MName -> FindM lore Bool
sizesCanBeMaxed kmem = do
ksize <- lookupSize kmem
xsize <- lookupSize xmem
uses_before <- asks ctxSizeVarsUsesBefore
let ok = fromMaybe False $ do
ksize' <- subExpVar ksize
xsize' <- subExpVar xsize
return (xsize' `S.member` fromJust ("is recorded for all size variables "
++ pretty ksize')
(M.lookup ksize' uses_before))
return ok
let sizesCanBeMaxedKernelArray :: MName -> MNames ->
FindM lore (Maybe (VName, ((VName, VName),
(VName, VName))))
sizesCanBeMaxedKernelArray kmem used_mems = do
potentials <- asks ctxPotentialKernelInterferences
uses_before <- asks ctxSizeVarsUsesBefore
let first_usess = filter (\p ->
let pot_mems = map (\(m, _, _, _) -> m) p
in kmem `elem` pot_mems && xmem `elem` pot_mems)
potentials
kmem_size <- fromJust "should be a var" . subExpVar <$> lookupSize kmem
return $ case (S.toList used_mems, first_usess) of
([_], [first_uses]) -> do
(_, kmem_array, kmem_pt, kmem_ixfun) <-
L.find (\(mname, _, _, _) -> mname == kmem) first_uses
(_, xmem_array, xmem_pt, xmem_ixfun) <-
L.find (\(mname, _, _, _) -> mname == xmem) first_uses
if (kmem, kmem_ixfun) `ixFunsCompatible` (xmem, xmem_ixfun)
then Nothing
else do
(kmem_ixfun_start, kmem_indices_start, kmem_final_dim) <-
IxFun.getInfoMaxUnification kmem_ixfun
(xmem_ixfun_start, xmem_indices_start, xmem_final_dim) <-
IxFun.getInfoMaxUnification xmem_ixfun
let xmem_final_dim_before_kmem_final_dim =
maybe False (xmem_final_dim `S.member`) $
M.lookup kmem_final_dim uses_before
kmem_ixfun_start' = getIxFun' kmem_ixfun_start
(M.singleton kmem_final_dim xmem_final_dim)
xmem_ixfun_start' = getIxFun' xmem_ixfun_start
(M.singleton xmem_final_dim kmem_final_dim)
res = if kmem_indices_start == xmem_indices_start &&
(kmem, kmem_ixfun_start') `ixFunsCompatible`
(xmem, xmem_ixfun_start') &&
(primByteSize kmem_pt :: Int) == primByteSize xmem_pt &&
xmem_final_dim_before_kmem_final_dim
then return (kmem_size,
((kmem_array, kmem_final_dim),
(xmem_array, xmem_final_dim)))
else Nothing
in res
_ -> Nothing
where getIxFun' :: ExpMem.IxFun -> M.Map VName VName ->
IxFun.IxFun (PrimExp VarWithLooseEquality)
getIxFun' ixfun others =
let loose_eq_map name_inner =
pure $ VarWithLooseEquality name_inner
$ maybe S.empty S.singleton $ M.lookup name_inner others
in runIdentity $ traverse (traverse loose_eq_map) ixfun
let sizesCanBeMaxedKernelArray' :: MName -> MNames -> FindM lore Bool
sizesCanBeMaxedKernelArray' kmem used_mems =
isJust <$> sizesCanBeMaxedKernelArray kmem used_mems
let noOtherUsesOfMemory :: MName -> MNames -> FindM lore Bool
noOtherUsesOfMemory _kmem _used_mems =
and . M.elems . M.mapWithKey (
\v m -> (memSrcName m /= xmem)
|| (v `L.elem` actual_vars)
) <$> asks ctxVarToMem
let notCurrentlyDisabled :: FindM lore Bool
notCurrentlyDisabled =
isJust . subExpVar <$> lookupSize xmem
let sizesWorkOut :: MName -> MNames -> FindM lore Bool
sizesWorkOut kmem used_mems =
(notCurrentlyDisabled <&&> noneInterfereKernelArray used_mems <&&>
(sizesMatch used_mems <||> sizesCanBeMaxed kmem))
<||> sizesCanBeMaxedKernelArray' kmem used_mems
let canBeUsed t = and <$> mapM (($ t) . uncurry)
[notTheSame, noneInterfere, sameSpace, noOtherUsesOfMemory,
sizesWorkOut]
cur_uses <- gets curUses
found_use <- catMaybes <$> mapM (maybeFromBoolM canBeUsed) (M.assocs cur_uses)
case found_use of
(kmem, used_mems) : _ -> do
insertUse kmem xmem
forM_ actual_vars $ \var -> do
ixfun <- memSrcIxFun <$> lookupVarMem var
recordMemMapping var $ MemoryLoc kmem ixfun
whenM (sizesCanBeMaxed kmem) $ do
ksize <- lookupSize kmem
xsize <- lookupSize xmem
fromMaybe (return ()) $ do
ksize' <- subExpVar ksize
xsize' <- subExpVar xsize
return $ do
recordMaxMapping kmem ksize'
recordMaxMapping kmem xsize'
kernel_maxing <- sizesCanBeMaxedKernelArray kmem used_mems
forM_ kernel_maxing $ \info ->
recordKernelMaxMapping kmem info
_ ->
insertUse xmem xmem
data VarWithLooseEquality = VarWithLooseEquality VName Names
deriving (Show)
instance Eq VarWithLooseEquality where
VarWithLooseEquality v0 vs0 == VarWithLooseEquality v1 vs1 =
not $ S.null $ S.intersection (S.insert v0 vs0) (S.insert v1 vs1)
interferesInKernel :: MonadReader Context m => MName -> MName -> m Bool
interferesInKernel mem0 mem1 = do
potentials <- asks ctxPotentialKernelInterferences
let interferesInGroup :: PotentialKernelDataRaceInterferenceGroup -> Bool
interferesInGroup first_uses = fromMaybe False $ do
(_, _, pt0, ixfun0) <- L.find (\(mname, _, _, _) -> mname == mem0) first_uses
(_, _, pt1, ixfun1) <- L.find (\(mname, _, _, _) -> mname == mem1) first_uses
return $ interferes (pt0, ixfun0) (pt1, ixfun1)
interferes :: (PrimType, ExpMem.IxFun) -> (PrimType, ExpMem.IxFun) -> Bool
interferes (pt0, ixfun0) (pt1, ixfun1) =
mem0 /= mem1 &&
(
((ixFunHasIndex ixfun0 || ixFunHasIndex ixfun1) &&
not (ixFunsCompatible (mem0, ixfun0) (mem1, ixfun1)))
||
((primByteSize pt0 :: Int) /= primByteSize pt1)
)
return $ any interferesInGroup potentials
ixFunHasIndex :: IxFun.IxFun num -> Bool
ixFunHasIndex = IxFun.ixFunHasIndex
ixFunsCompatible :: Eq v =>
(MName, IxFun.IxFun (PrimExp v)) -> (MName, IxFun.IxFun (PrimExp v)) ->
Bool
ixFunsCompatible (_mem0, ixfun0) (_mem1, ixfun1) =
IxFun.ixFunsCompatibleRaw ixfun0 ixfun1
transformFromVarMaxExpMappings :: MonadFreshNames m =>
M.Map VName Names
-> FunDef ExplicitMemory -> m (FunDef ExplicitMemory)
transformFromVarMaxExpMappings var_to_max fundef = do
var_to_new_var <-
M.fromList <$> mapM (\(k, v) -> (k,) <$> maxsToReplacement (S.toList v))
(M.assocs var_to_max)
return $ insertAndReplace var_to_new_var fundef
data Replacement = Replacement
{ replName :: VName
, replStms :: [Stm ExplicitMemory]
}
deriving (Show)
maxsToReplacement :: MonadFreshNames m =>
[VName] -> m Replacement
maxsToReplacement [] = error "maxsToReplacements: Cannot take max of zero variables"
maxsToReplacement [v] = return $ Replacement v []
maxsToReplacement vs = do
let (vs0, vs1) = splitAt (length vs `div` 2) vs
Replacement m0 es0 <- maxsToReplacement vs0
Replacement m1 es1 <- maxsToReplacement vs1
vmax <- newVName "max"
let emax = BasicOp $ BinOp (SMax Int64) (Var m0) (Var m1)
new_stm = Let (Pattern [] [PatElem vmax
(ExpMem.MemPrim (IntType Int64))]) (defAux ()) emax
prev_stms = es0 ++ es1 ++ [new_stm]
return $ Replacement vmax prev_stms
insertAndReplace :: M.Map MName Replacement -> FunDef ExplicitMemory ->
FunDef ExplicitMemory
insertAndReplace replaces0 fundef =
let body' = evalState (transformBody $ funDefBody fundef) replaces0
in fundef { funDefBody = body' }
where transformBody :: Body ExplicitMemory ->
State (M.Map VName Replacement) (Body ExplicitMemory)
transformBody body = do
stms' <- concat <$> mapM transformStm (stmsToList $ bodyStms body)
return $ body { bodyStms = stmsFromList stms' }
transformStm :: Stm ExplicitMemory ->
State (M.Map VName Replacement) [Stm ExplicitMemory]
transformStm stm@(Let (Pattern [] [PatElem mem_name
(ExpMem.MemMem _ pat_space)]) _
(Op (ExpMem.Alloc _ space))) = do
replaces <- get
case M.lookup mem_name replaces of
Just repl -> do
let prev = replStms repl
new = Let (Pattern [] [PatElem mem_name
(ExpMem.MemMem (Var (replName repl))
pat_space)]) (defAux ())
(Op (ExpMem.Alloc (Var (replName repl)) space))
modify $ M.adjust (\repl0 -> repl0 { replStms = [] }) mem_name
return (prev ++ [new])
Nothing -> return [stm]
transformStm (Let pat attr e) = do
let mapper = identityMapper { mapOnBody = const transformBody }
e' <- mapExpM mapper e
return [Let pat attr e']
transformFromKernelMaxSizedMappings :: MonadFreshNames m =>
M.Map VName (PrimExp VName) -> VarMemMappings MemorySrc -> VarMemMappings MName ->
Sizes -> ActualVariables -> M.Map MName (VName, ((VName, VName),
(VName, VName))) ->
FunDef ExplicitMemory -> m (FunDef ExplicitMemory)
transformFromKernelMaxSizedMappings
var_to_pe var_to_mem var_to_mem_res sizes_orig actual_vars mem_to_info fundef = do
(mem_to_size_var, arr_to_mem_ixfun) <-
unzip <$> mapM (uncurry withNewMaxVar) (M.assocs mem_to_info)
let mem_to_size_var' = M.fromList mem_to_size_var
arr_to_memloc = M.fromList $ map (\(arr, destmem, ixfun) ->
(arr, MemoryLoc destmem ixfun))
$ concat arr_to_mem_ixfun
fundef' = insertAndReplace mem_to_size_var' fundef
sizes = memBlockSizesFunDef fundef'
transformFromVarMemMappings arr_to_memloc (M.union var_to_mem_res (M.map memSrcName var_to_mem)) (M.map fst sizes) (M.map fst sizes_orig) True fundef'
where withNewMaxVar :: MonadFreshNames m =>
MName -> (VName,
((VName, VName),
(VName, VName))) ->
m ((MName, Replacement),
[(VName, MName, ExpMem.IxFun)])
withNewMaxVar mem (kmem_size,
((kmem_array, kmem_final_dim),
(xmem_array, xmem_final_dim))) = do
final_dim_max_v <- newVName "max_final_dim"
let final_dim_max_e =
BasicOp (BinOp (SMax Int32)
(Var kmem_final_dim) (Var xmem_final_dim))
var_to_pe_extension =
M.singleton kmem_final_dim (LeafExp final_dim_max_v (IntType Int32))
var_to_pe' = M.union var_to_pe_extension var_to_pe
full_size_pe = fromJust "should exist" $ M.lookup kmem_size var_to_pe
full_size_pe_expanded = expandPrimExp var_to_pe' full_size_pe
new_full_size_m =
letExp "max" =<< primExpToExp (return . BasicOp . SubExp . Var)
full_size_pe_expanded
(alloc_size_var, alloc_size_stms) <-
modifyNameSource $ runState $ runBinderT new_full_size_m mempty
let alloc_size_fd_stm =
Let (Pattern [] [PatElem final_dim_max_v
(ExpMem.MemPrim (IntType Int32))]) (defAux ()) final_dim_max_e
alloc_size_stms' = oneStm alloc_size_fd_stm <> alloc_size_stms
vars_kmem =
S.insert kmem_array $ lookupActualVars' actual_vars kmem_array
vars_xmem =
S.insert xmem_array $ lookupActualVars' actual_vars xmem_array
arrayToMapping final_dim v =
let ixfun = memSrcIxFun $ fromJust "should exist"
$ M.lookup v var_to_mem
ixfun_new = IxFun.subsInIndexIxFun ixfun final_dim final_dim_max_v
in (v, mem, ixfun_new)
arr_to_mem_ixfun_kmem = map (arrayToMapping kmem_final_dim)
$ S.toList vars_kmem
arr_to_mem_ixfun_xmem = map (arrayToMapping xmem_final_dim)
$ S.toList vars_xmem
arr_to_mem_ixfun = arr_to_mem_ixfun_kmem ++ arr_to_mem_ixfun_xmem
return ((mem, Replacement alloc_size_var $ stmsToList alloc_size_stms'),
arr_to_mem_ixfun)