module Futhark.Optimise.MemoryBlockMerging.CrudeMovingUp
( moveUpInFunDef
) where
import qualified Data.Set as S
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe (mapMaybe)
import Control.Monad
import Control.Monad.RWS
import Control.Monad.Writer
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (ExplicitMemory)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Control.Monad.State
import Control.Monad.Identity
type Line = Int
data Origin = FromFParam
| FromLine Line (Exp ExplicitMemory)
deriving (Eq, Ord, Show)
data PrimBinding = PrimBinding { pbFrees :: Names
, _pbConsumed :: Names
, pbOrigin :: Origin
}
deriving (Show)
type BindingMap = [(Names, PrimBinding)]
moveUpInFunDef :: FunDef ExplicitMemory
-> (Body ExplicitMemory -> Maybe [FParam ExplicitMemory] -> [VName])
-> FunDef ExplicitMemory
moveUpInFunDef fundef findHoistees =
let scope_new = scopeOf fundef
bindingmap_cur = []
body' = hoistInBody scope_new bindingmap_cur
(Just (funDefParams fundef)) findHoistees (funDefBody fundef)
fundef' = fundef { funDefBody = body' }
in fundef'
lookupPrimBinding :: VName -> State BindingMap PrimBinding
lookupPrimBinding vname =
gets $ snd . fromJust (pretty vname ++ " was not found in BindingMap."
++ " This should not happen!")
. L.find ((vname `S.member`) . fst)
namesDependingOn :: VName -> State BindingMap Names
namesDependingOn v =
gets $ S.unions . map fst . filter (\(_, pb) -> v `S.member` pbFrees pb)
scopeBindingMap :: (VName, NameInfo ExplicitMemory)
-> BindingMap
scopeBindingMap (x, _) = [(S.singleton x, PrimBinding S.empty S.empty FromFParam)]
boundInKernelSpace :: ExpMem.KernelSpace -> Names
boundInKernelSpace space =
S.fromList ([ ExpMem.spaceGlobalId space
, ExpMem.spaceLocalId space
, ExpMem.spaceGroupId space]
++ (case ExpMem.spaceStructure space of
ExpMem.FlatThreadSpace ts ->
map fst ts ++ mapMaybe (subExpVar . snd) ts
ExpMem.NestedThreadSpace ts ->
map (\(x, _, _, _) -> x) ts
++ mapMaybe (subExpVar . (\(_, x, _, _) -> x)) ts
++ map (\(_, _, x, _) -> x) ts
++ mapMaybe (subExpVar . (\(_, _, _, x) -> x)) ts
))
boundInExpExtra :: Exp ExplicitMemory -> Names
boundInExpExtra = execWriter . inExp
where inExp :: Exp ExplicitMemory -> Writer Names ()
inExp e = case e of
Op (ExpMem.Inner (ExpMem.Kernel _ space _ _)) ->
tell $ boundInKernelSpace space
_ -> walkExpM walker e
walker = identityWalker {
walkOnBody = mapM_ (inExp . stmExp) . bodyStms
}
bodyBindingMap :: [Stm ExplicitMemory] -> BindingMap
bodyBindingMap stms =
concatMap createBindingStmt $ zip [0..] stms
where createBindingStmt :: (Line, Stm ExplicitMemory)
-> BindingMap
createBindingStmt (line, stmt@(Let (Pattern patctxelems patvalelems) _ e)) =
let stmt_vars = S.fromList (map patElemName (patctxelems ++ patvalelems))
frees = freeInStm stmt
consumed = case e of BasicOp (Update src _ _) -> S.singleton src
_ -> mempty
bound_extra = boundInExpExtra e
frees' = frees `S.difference` bound_extra
vars_binding = (stmt_vars, PrimBinding frees' consumed (FromLine line e))
shape_sizes = S.fromList $ concatMap shapeSizes (patctxelems ++ patvalelems)
sizes_binding = (shape_sizes, PrimBinding frees' consumed (FromLine line e))
param_vars = case e of
Op (ExpMem.Inner (ExpMem.Kernel _ space _ _)) ->
boundInKernelSpace space
_ -> S.empty
params_binding = (param_vars, PrimBinding S.empty S.empty FromFParam)
bmap = [vars_binding, sizes_binding, params_binding]
in bmap
shapeSizes (PatElem _ (ExpMem.MemArray _ shape _ _)) =
mapMaybe subExpVar $ shapeDims shape
shapeSizes _ = []
hoistInBody :: Scope ExplicitMemory
-> BindingMap
-> Maybe [FParam ExplicitMemory]
-> (Body ExplicitMemory -> Maybe [FParam ExplicitMemory] -> [VName])
-> Body ExplicitMemory
-> Body ExplicitMemory
hoistInBody scope_new bindingmap_old params findHoistees body =
let hoistees = findHoistees body params
bindingmap_fromscope = concatMap scopeBindingMap $ M.toList scope_new
bindingmap_body = bodyBindingMap $ stmsToList $ bodyStms body
bindingmap = bindingmap_old ++ bindingmap_fromscope ++ bindingmap_body
(Body () bnds res, bindingmap') =
foldl (\(body0, lbindingmap) -> hoist lbindingmap body0)
(body, bindingmap) hoistees
bnds' = fmap (hoistRecursivelyStm bindingmap' findHoistees) bnds
body' = Body () bnds' res
in body'
hoistRecursivelyStm :: BindingMap
-> (Body ExplicitMemory -> Maybe [FParam ExplicitMemory] -> [VName])
-> Stm ExplicitMemory
-> Stm ExplicitMemory
hoistRecursivelyStm bindingmap findHoistees (Let pat aux e) =
runIdentity (Let pat aux <$> mapExpM transform e)
where transform = identityMapper { mapOnBody = mapper }
mapper scope_new = return . hoistInBody scope_new bindingmap' Nothing findHoistees
bindingmap' = map (\(ns, PrimBinding frees consumed _) ->
(ns, PrimBinding frees consumed FromFParam))
bindingmap
hoist :: BindingMap
-> Body ExplicitMemory
-> VName
-> (Body ExplicitMemory, BindingMap)
hoist bindingmap_cur body hoistee =
let bindingmap = bindingmap_cur <> bodyBindingMap (stmsToList $ bodyStms body)
body' = runState (moveLetUpwards hoistee body) bindingmap
in body'
moveLetUpwards :: VName -> Body ExplicitMemory
-> State BindingMap (Body ExplicitMemory)
moveLetUpwards letname body = do
PrimBinding deps consumed letorig <- lookupPrimBinding letname
deps' <- S.delete letname
<$> (S.union deps
<$> (S.unions <$> mapM namesDependingOn (S.toList consumed)))
case letorig of
FromFParam -> return body
FromLine line_cur exp_cur ->
case exp_cur of
DoLoop{} -> return body
Op ExpMem.Inner{} -> return body
_ -> do
deps'' <- sortByKeyM (fmap pbOrigin . lookupPrimBinding)
$ S.toList deps'
body' <- foldM (flip moveLetUpwards) body deps''
origins <- mapM (fmap pbOrigin . lookupPrimBinding) deps''
let line_dest = case foldl max FromFParam origins of
FromFParam -> 0
FromLine n _e -> n + 1
PrimBinding _ _ letorig' <- lookupPrimBinding letname
when (letorig' /= letorig) $ error "Assertion: This should not happen."
stms' <- moveLetToLine letname line_cur line_dest $ stmsToList $ bodyStms body'
return body' { bodyStms = stmsFromList stms' }
moveLetToLine :: VName -> Line -> Line -> [Stm ExplicitMemory]
-> State BindingMap [Stm ExplicitMemory]
moveLetToLine stm_cur_name line_cur line_dest stms
| line_cur == line_dest = return stms
| otherwise = do
let stm_cur = stms !! line_cur
stms1 = take line_cur stms ++ drop (line_cur + 1) stms
stms2 = take line_dest stms1 ++ [stm_cur] ++ drop line_dest stms1
modify $ map (\t@(ns, PrimBinding frees consumed orig) ->
case orig of
FromFParam -> t
FromLine l e -> if l >= line_dest && l < line_cur
then (ns, PrimBinding frees consumed
(FromLine (l + 1) e))
else t)
r <- lookupPrimBinding stm_cur_name
case r of
PrimBinding frees consumed (FromLine _ exp_cur) ->
modify $ replaceWhere stm_cur_name (PrimBinding frees consumed
(FromLine line_dest exp_cur))
_ -> error "moveLetToLine: unhandled case"
return stms2
replaceWhere :: VName -> PrimBinding -> BindingMap -> BindingMap
replaceWhere n pb1 =
map (\(ns, pb) -> (ns, if n `S.member` ns
then pb1
else pb))