{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE LambdaCase #-}
module Futhark.Optimise.MemoryBlockMerging.MemoryUpdater
( transformFromVarMemMappings
) where
import qualified Data.Map.Strict as M
import qualified Data.List as L
import Data.Maybe (mapMaybe, fromMaybe)
import Control.Applicative ((<|>))
import Control.Arrow (second)
import Control.Monad
import Control.Monad.RWS
import Futhark.MonadFreshNames
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
(ExplicitMemorish, ExplicitMemory)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.Types
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
data Context = Context { ctxVarToMem :: VarMemMappings MemoryLoc
, ctxVarToMemOrig :: VarMemMappings MName
, ctxAllocSizes :: M.Map MName SubExp
, ctxAllocSizesOrig :: M.Map MName SubExp
, ctxHasMaxedSize :: Bool
}
deriving (Show)
newtype FindM lore a = FindM { unFindM :: RWS Context () (VNameSource, [(MName, VName)]) a }
deriving (Monad, Functor, Applicative,
MonadReader Context, MonadState (VNameSource, [(MName, VName)]))
instance MonadFreshNames (FindM lore) where
getNameSource = gets fst
putNameSource s = modify $ \(_, m) -> (s, m)
modifyMemSizeMapping :: ([(MName, VName)] -> [(MName, VName)]) -> FindM lore ()
modifyMemSizeMapping f = modify $ second f
type LoreConstraints lore = (ExplicitMemorish lore,
FullMap lore,
BodyAttr lore ~ (),
ExpAttr lore ~ ())
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
transformFromVarMemMappings :: MonadFreshNames m =>
VarMemMappings MemoryLoc ->
VarMemMappings MName ->
M.Map MName SubExp -> M.Map MName SubExp -> Bool ->
FunDef ExplicitMemory ->
m (FunDef ExplicitMemory)
transformFromVarMemMappings var_to_mem var_to_mem_orig alloc_sizes alloc_sizes_orig has_maxed_size fundef =
let m = unFindM $ transformFunDefBody $ funDefBody fundef
ctx = Context { ctxVarToMem = var_to_mem
, ctxVarToMemOrig = var_to_mem_orig
, ctxAllocSizes = alloc_sizes
, ctxAllocSizesOrig = alloc_sizes_orig
, ctxHasMaxedSize = has_maxed_size
}
in modifyNameSource (\src ->
let (body', (src', _), ()) = runRWS m ctx (src, [])
in (fundef { funDefBody = body' }, src')
)
transformFunDefBody :: LoreConstraints lore =>
Body lore -> FindM lore (Body lore)
transformFunDefBody (Body () bnds res) = do
bnds' <- mapM transformStm $ stmsToList bnds
res' <- transformFunDefBodyResult res
return $ Body () (stmsFromList bnds') res'
transformFunDefBodyResult :: [SubExp] -> FindM lore [SubExp]
transformFunDefBodyResult ses = do
var_to_mem_orig <- asks ctxVarToMemOrig
var_to_mem <- asks ctxVarToMem
mem_to_size_orig <- asks ctxAllocSizesOrig
mem_to_size <- asks ctxAllocSizes
mem_to_new_size <- gets snd
let check se
| Var v <- se
, Just orig <- M.lookup v var_to_mem_orig
, Just new <- memLocName <$> M.lookup v var_to_mem
= ((Var orig, Nothing), Var new) : case (M.lookup orig mem_to_size_orig,
(Var <$> L.lookup new mem_to_new_size) <|> M.lookup new mem_to_size) of
(Just size_orig, Just size_new) ->
[((size_orig, Just (Var orig)), size_new)]
_ -> []
| otherwise = []
check_size_only se
| Var v <- se
, Just orig <- M.lookup v mem_to_size_orig
, Just new <- (Var <$> L.lookup v mem_to_new_size) <|> M.lookup v mem_to_size
, orig /= new
= [((orig, Just (Var v)), new)]
| otherwise = []
mem_orig_to_new1 = concatMap check ses
mem_orig_to_new2 = concatMap check_size_only ses
mem_orig_to_new = mem_orig_to_new1 ++ mem_orig_to_new2
return $ zipWith (
\se ts -> fromMaybe se (
(se, Nothing) `L.lookup` mem_orig_to_new
<|> case ts of
(ts0 : _) ->
(se, Just ts0) `L.lookup` mem_orig_to_new
_ -> Nothing
)
) ses (L.tail $ L.tails ses)
transformBody :: LoreConstraints lore =>
Body lore -> FindM lore (Body lore)
transformBody (Body () bnds res) = do
bnds' <- mapM transformStm $ stmsToList bnds
return $ Body () (stmsFromList bnds') res
transformKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore (KernelBody lore)
transformKernelBody (KernelBody () bnds res) = do
bnds' <- mapM transformStm $ stmsToList bnds
return $ KernelBody () (stmsFromList bnds') res
transformMemInfo :: ExpMem.MemInfo d u ExpMem.MemReturn -> MemoryLoc ->
ExpMem.MemInfo d u ExpMem.MemReturn
transformMemInfo meminfo memloc = case meminfo of
ExpMem.MemArray pt shape u _memreturn ->
let extixfun = ExpMem.existentialiseIxFun [] $ memLocIxFun memloc
in ExpMem.MemArray pt shape u
(ExpMem.ReturnsInBlock (memLocName memloc) extixfun)
_ -> meminfo
data BranchReturn = ExistingBranchReturn ExpMem.BodyReturns
| NewBranchReturn (Int -> ExpMem.BodyReturns)
VName VName VName
transformStm :: LoreConstraints lore =>
Stm lore -> FindM lore (Stm lore)
transformStm (Let (Pattern patctxelems patvalelems) aux e) = do
patvalelems' <- mapM transformPatValElem patvalelems
e' <- fullMapExpM mapper mapper_kernel e
var_to_mem <- asks ctxVarToMem
var_to_mem_orig <- asks ctxVarToMemOrig
mem_to_size <- asks ctxAllocSizes
mem_to_new_size <- gets snd
(e'', patctxelems') <- case e' of
If cond body_then body_else (IfAttr rets sort) -> do
let bodyVarMemLocs body =
map (flip M.lookup var_to_mem <=< subExpVar)
$ drop (length patctxelems) $ bodyResult body
findBodyResMem i body_results =
let imem = patElemName (patctxelems L.!! i)
matching_var = mapMaybe (
\(p, p_i) ->
case patElemAttr p of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn vmem _) ->
if imem == vmem
then Just p_i
else Nothing
_ ->
Nothing
) (zip patvalelems [0..])
in do
j <- case matching_var of
[t] -> Just t
_ -> Nothing
body_res_var <- subExpVar (body_results L.!! (length patctxelems + j))
MemoryLoc mem _ixfun <- M.lookup body_res_var var_to_mem
return mem
fixBodyExistentials body =
body { bodyResult =
zipWith (\res i -> if i < length patctxelems
then maybe res Var $ findBodyResMem i (bodyResult body)
else res)
(bodyResult body) [0..] }
let ms_then = bodyVarMemLocs body_then
ms_else = bodyVarMemLocs body_else
let rets' =
if ms_then == ms_else
then zipWith (\r m -> case m of
Nothing -> r
Just m' ->
transformMemInfo r m'
) rets ms_then
else rets
let body_then' = fixBodyExistentials body_then
body_else' = fixBodyExistentials body_else
let mem_size mem = L.lookup mem mem_to_new_size <|> (subExpVar =<< M.lookup mem mem_to_size)
v_size v = do
mem <- M.lookup v (M.map memLocName var_to_mem) <|> M.lookup v var_to_mem_orig
mem_size mem
has_maxed_size <- asks ctxHasMaxedSize
let rets_branch_returns =
L.zipWith4 (\r pat th el -> case (r, pat, th, el) of
(ExpMem.MemArray pt shape u
(ExpMem.ReturnsNewBlock space n
(Free (Var _size)) extixfun),
PatElem _
(ExpMem.MemArray _ _ _
(ExpMem.ArrayIn patmem _)),
Var v_th, Var v_el) ->
case (v_size v_th, v_size v_el) of
(Just s_th, Just s_el) ->
if not has_maxed_size
then ExistingBranchReturn r
else NewBranchReturn
(\nth_ctxelem ->
ExpMem.MemArray pt shape u
(ExpMem.ReturnsNewBlock space n
(Ext nth_ctxelem) extixfun))
s_th s_el patmem
_ -> error ("both branch return arrays should use a memory block with a size: " ++ show v_th ++ " and " ++ show v_el)
_ -> ExistingBranchReturn r
)
rets'
patvalelems
(drop (length patctxelems) (bodyResult body_then'))
(drop (length patctxelems) (bodyResult body_else'))
patctxelems_new <-
replicateM
(length (filter (\case
NewBranchReturn{} -> True
ExistingBranchReturn{} -> False
) rets_branch_returns))
(newVName "new_memory_size")
let (rets'', _, body_ext_new, _, patmem_to_new_size) =
foldl (\(prev, i, ext, patctxelems_new', mapping) rb -> case rb of
ExistingBranchReturn r ->
(prev ++ [r], i, ext, patctxelems_new', mapping)
NewBranchReturn rf s_th s_el patmem ->
(prev ++ [rf i], i + 1, ext ++ [(s_th, s_el)],
tail patctxelems_new',
mapping ++ [(patmem, head patctxelems_new')])
) ([], length patctxelems, [], patctxelems_new, []) rets_branch_returns
modifyMemSizeMapping (++ patmem_to_new_size)
let (th_ext_new, el_ext_new) = unzip body_ext_new
body_then'' = body_then' { bodyResult =
take (length patctxelems) (bodyResult body_then') ++
map Var th_ext_new ++
drop (length patctxelems) (bodyResult body_then')
}
body_else'' = body_else' { bodyResult =
take (length patctxelems) (bodyResult body_else') ++
map Var el_ext_new ++
drop (length patctxelems) (bodyResult body_else')
}
patctxelems_replaced = map (\pe -> case pe of
PatElem name (ExpMem.MemMem _size space) ->
case L.lookup name patmem_to_new_size of
Just size_new ->
PatElem name (ExpMem.MemMem (Var size_new) space)
Nothing -> pe
_ -> pe
) patctxelems
patctxelems' = patctxelems_replaced ++ map (\v -> PatElem v (ExpMem.MemPrim (IntType Int64))) patctxelems_new
return (If cond body_then'' body_else'' (IfAttr rets'' sort),
patctxelems')
DoLoop mergectxparams mergevalparams loopform body -> do
mergectxparams' <- mapM (transformMergeCtxParam mergevalparams) mergectxparams
mergevalparams' <- mapM transformMergeValParam mergevalparams
let zipped = zip [(0::Int)..] (patctxelems ++ patvalelems)
findMemLinks (i, PatElem _x (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn xmem _))) =
case L.find (\(_, PatElem ymem _) -> ymem == xmem) zipped of
Just (j, _) -> Just (j, i)
Nothing -> Nothing
findMemLinks _ = Nothing
mem_links = mapMaybe findMemLinks zipped
res = bodyResult body
fixResRecord i se
| Var _mem <- se
, Just j <- L.lookup i mem_links
, Var related_var <- res L.!! j
, Just mem_new <- M.lookup related_var var_to_mem =
Var $ memLocName mem_new
| otherwise = se
res' = zipWith fixResRecord [(0::Int)..] res
body' = body { bodyResult = res' }
loopform' <- case loopform of
ForLoop i it bound loop_vars ->
ForLoop i it bound <$> mapM transformForLoopVar loop_vars
WhileLoop _ -> return loopform
return (DoLoop mergectxparams' mergevalparams' loopform' body',
patctxelems)
_ -> return (e', patctxelems)
return (Let (Pattern patctxelems' patvalelems') aux e'')
where mapper = identityMapper
{ mapOnBody = const transformBody
, mapOnFParam = transformFParam
, mapOnLParam = transformLParam
}
mapper_kernel = identityKernelMapper
{ mapOnKernelBody = coerce . transformBody
, mapOnKernelKernelBody = coerce . transformKernelBody
, mapOnKernelLambda = coerce . transformLambda
, mapOnKernelLParam = transformLParam
}
transformMergeCtxParam :: [(FParam ExplicitMemory, SubExp)] ->
(FParam ExplicitMemory, SubExp)
-> FindM lore (FParam ExplicitMemory, SubExp)
transformMergeCtxParam mergevalparams (param@(Param ctxmem ExpMem.MemMem{}), mem) = do
var_to_mem <- asks ctxVarToMem
let usesCtxMem (Param _ (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn pmem _))) = ctxmem == pmem
usesCtxMem _ = False
mem' = fromMaybe mem $ do
(_, Var orig_var) <- L.find (usesCtxMem . fst) mergevalparams
orig_mem <- M.lookup orig_var var_to_mem
return $ Var $ memLocName orig_mem
return (param, mem')
transformMergeCtxParam _ t = return t
transformMergeValParam :: (FParam ExplicitMemory, SubExp)
-> FindM lore (FParam ExplicitMemory, SubExp)
transformMergeValParam (Param x membound, se) = do
membound' <- newMemBound membound x
return (Param x membound', se)
transformPatValElem :: PatElem ExplicitMemory -> FindM lore (PatElem ExplicitMemory)
transformPatValElem (PatElem x membound) =
PatElem x <$> newMemBound membound x
transformFParam :: LoreConstraints lore =>
FParam lore -> FindM lore (FParam lore)
transformFParam (Param x membound) =
Param x <$> newMemBound membound x
transformLParam :: LoreConstraints lore =>
LParam lore -> FindM lore (LParam lore)
transformLParam (Param x membound) =
Param x <$> newMemBound membound x
transformLambda :: LoreConstraints lore =>
Lambda lore -> FindM lore (Lambda lore)
transformLambda (Lambda params body types) = do
params' <- mapM transformLParam params
body' <- transformBody body
return $ Lambda params' body' types
transformForLoopVar :: LoreConstraints lore =>
(LParam lore, VName) ->
FindM lore (LParam lore, VName)
transformForLoopVar (Param x membound, array) = do
membound' <- newMemBound membound x
return (Param x membound', array)
newMemBound :: ExpMem.MemBound u -> VName -> FindM lore (ExpMem.MemBound u)
newMemBound membound var = do
var_to_mem <- asks ctxVarToMem
let membound'
| ExpMem.MemArray pt shape u _ <- membound
, Just (MemoryLoc mem ixfun) <- M.lookup var var_to_mem =
Just $ ExpMem.MemArray pt shape u $ ExpMem.ArrayIn mem ixfun
| otherwise = Nothing
return $ fromMaybe membound membound'