{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Optimise.MemoryBlockMerging.ActualVariables
( findActualVariables
) where
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.List as L
import Data.Maybe (fromMaybe, mapMaybe, catMaybes)
import Control.Monad
import Control.Monad.RWS
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (
ExplicitMemorish, ExplicitMemory, InKernel)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
import Futhark.Optimise.MemoryBlockMerging.AllExpVars
data Context = Context
{ ctxVarToMem :: VarMemMappings MemorySrc
, ctxFirstUses :: FirstUses
}
deriving (Show)
newtype FindM lore a = FindM { unFindM :: RWS Context () ActualVariables a }
deriving (Monad, Functor, Applicative,
MonadReader Context,
MonadState ActualVariables)
type LoreConstraints lore = (ExplicitMemorish lore,
FullWalk lore,
LookInKernelExp lore)
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
recordActuals :: VName -> Names -> FindM lore ()
recordActuals stmt_var more_actuals = do
current_actuals <- M.lookup stmt_var <$> get
case S.null <$> current_actuals of
Just True -> return ()
_ -> modify (insertOrUpdateMany stmt_var more_actuals)
findActualVariables :: VarMemMappings MemorySrc -> FirstUses ->
FunDef ExplicitMemory -> ActualVariables
findActualVariables var_mem_mappings first_uses fundef =
let context = Context var_mem_mappings first_uses
m = unFindM $ lookInBody $ funDefBody fundef
actual_variables = fst $ execRWS m context M.empty
in actual_variables
lookInFParam :: FParam lore -> FindM lore ()
lookInFParam (Param v _) =
recordActuals v $ S.singleton v
lookInLParam :: LParam lore -> FindM lore ()
lookInLParam (Param v _) =
recordActuals v $ S.singleton v
lookInLambda :: LoreConstraints lore => Lambda lore -> FindM lore ()
lookInLambda (Lambda params body _) = do
forM_ params lookInLParam
lookInBody body
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
mapM_ lookInStm bnds
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
mapM_ lookInStm bnds
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm stm@(Let (Pattern patctxelems patvalelems) _ e) = do
case (patvalelems, e) of
([PatElem var _], BasicOp (Update orig _ _)) -> do
let actuals = S.fromList [var, orig]
recordActuals var actuals
recordActuals orig actuals
_ -> return ()
let bodyResult' = drop (length patctxelems) . bodyResult
case e of
DoLoop _mergectxparams mergevalparams loopform body -> do
let body_vars0 = mapMaybe (subExpVar . snd) mergevalparams
body_vars1 = map (paramName . fst) mergevalparams
body_vars2 = S.toList $ findAllExpVars e
body_vars = body_vars0 ++ body_vars1 ++ body_vars2
forM_ patvalelems $ \(PatElem var membound) -> do
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) -> do
let zipped = zip patctxelems (bodyResult body)
mem_search = case L.find ((== mem) . patElemName . fst) zipped of
Just (_, Var res_mem) -> res_mem
_ -> mem
body_vars' <- filterM (lookupGivesMem mem_search) body_vars
let actuals = var : body_vars'
forM_ actuals $ \a -> recordActuals a (S.fromList actuals)
_ -> return ()
recordActuals var S.empty
case loopform of
ForLoop _ _ _ loop_vars ->
forM_ loop_vars $ \(Param lvar _, array) ->
aliasOpHandleVar array lvar
WhileLoop _ -> return ()
If _se body_then body_else _types ->
forM_ (zip3 patvalelems (bodyResult' body_then) (bodyResult' body_else))
$ \(PatElem var membound, res_then, res_else) -> do
let body_vars = S.toList $ findAllExpVars e
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
if mem `L.elem` map patElemName patctxelems
then
recordActuals var
$ S.fromList (var : catMaybes [subExpVar res_then, subExpVar res_else])
else do
body_vars' <- filterM (lookupGivesMem mem) body_vars
first_uses <- asks ctxFirstUses
case filter ((mem `S.member`) . (`lookupEmptyable` first_uses)) body_vars' of
[] ->
recordActuals var $ S.fromList (var : body_vars')
_ ->
forM_ (var : body_vars') $ \v -> recordActuals v S.empty
_ -> return ()
BasicOp (Index orig _) -> do
let ielem = head patvalelems
var = patElemName ielem
case patElemAttr ielem of
ExpMem.MemArray{} ->
aliasOpHandleVar orig var
_ -> return ()
BasicOp (Reshape shapechange_var orig) ->
forM_ (map patElemName patvalelems) $ \var -> do
orig' <- aliasOpRoot' orig
mem_orig <- M.lookup orig' <$> asks ctxVarToMem
case (shapechange_var, mem_orig) of
([_], Just (MemorySrc _ _ (Shape [_]))) ->
recordActuals var $ S.fromList [var, orig]
_ ->
recordActuals var S.empty
recordActuals orig' $ S.fromList [orig', var]
BasicOp (Rearrange _ orig) ->
aliasOpHandle orig patvalelems
BasicOp (Rotate _ orig) ->
aliasOpHandle orig patvalelems
BasicOp (Opaque (Var orig)) ->
aliasOpHandle orig patvalelems
_ -> forM_ patvalelems $ \(PatElem var membound) -> do
let body_vars = S.toList $ findAllExpVars e
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) -> do
body_vars' <- filterM (lookupGivesMem mem) body_vars
recordActuals var $ S.fromList (var : body_vars')
_ -> return ()
lookInKernelExp stm
fullWalkExpM walker walker_kernel e
where walker = identityWalker
{ walkOnBody = lookInBody
, walkOnFParam = lookInFParam
, walkOnLParam = lookInLParam
}
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . lookInBody
, walkOnKernelKernelBody = coerce . lookInKernelBody
, walkOnKernelLambda = coerce . lookInLambda
, walkOnKernelLParam = lookInLParam
}
aliasOpHandle :: VName -> [PatElem lore] -> FindM lore ()
aliasOpHandle orig patvalelems =
forM_ (map patElemName patvalelems) $ aliasOpHandleVar orig
aliasOpHandleVar :: VName -> VName -> FindM lore ()
aliasOpHandleVar orig var = do
recordActuals var S.empty
orig' <- aliasOpRoot' orig
recordActuals orig' $ S.fromList [orig', var]
aliasOpRoot :: VName -> FindM lore (Maybe VName)
aliasOpRoot orig = do
current_actuals <- get
return $ case S.null <$> M.lookup orig current_actuals of
Just True -> case M.keys (M.filter (orig `S.member`) current_actuals) of
orig' : _ -> Just orig'
_ -> Nothing
_ -> Just orig
aliasOpRoot' :: VName -> FindM lore VName
aliasOpRoot' orig =
fromJust ("at some point there will have been a proper statement: "
++ pretty orig) <$> aliasOpRoot orig
lookupGivesMem :: MName -> VName -> FindM lore Bool
lookupGivesMem mem v = do
m <- M.lookup v <$> asks ctxVarToMem
return (Just mem == (memSrcName <$> m))
class LookInKernelExp lore where
lookInKernelExp :: Stm lore -> FindM lore ()
instance LookInKernelExp ExplicitMemory where
lookInKernelExp (Let (Pattern _ patvalelems) _ e) = case e of
Op (ExpMem.Inner (Kernel _ _ _ (KernelBody _ _ ress))) ->
zipWithM_ (\(PatElem var _) res -> case res of
WriteReturn _ arr _ ->
recordActuals arr $ S.singleton var
_ -> return ()
) patvalelems ress
_ -> return ()
instance LookInKernelExp InKernel where
lookInKernelExp (Let _ _ e) = case e of
Op (ExpMem.Inner ke) -> case ke of
ExpMem.GroupReduce _ _ input -> do
let arrs = map snd input
extendActualVarsInKernel e arrs
ExpMem.GroupScan _ _ input -> do
let arrs = map snd input
extendActualVarsInKernel e arrs
ExpMem.GroupStream _ _ _ _ arrs ->
extendActualVarsInKernel e arrs
_ -> return ()
_ -> return ()
extendActualVarsInKernel :: Exp InKernel -> [VName] -> FindM InKernel ()
extendActualVarsInKernel e arrs = forM_ arrs $ \var -> do
var' <- fromMaybe var <$> aliasOpRoot var
varmem <- M.lookup var <$> asks ctxVarToMem
case varmem of
Just mem -> do
let body_vars = findAllExpVars e
body_vars' <- filterSetM (lookupGivesMem $ memSrcName mem) body_vars
let actuals = S.insert var' body_vars'
recordActuals var' actuals
Nothing -> return ()