{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Optimise.MemoryBlockMerging.Coalescing.SafetyCondition3
( getVarUsesBetween
) where
import qualified Data.Set as S
import qualified Data.List as L
import Control.Monad
import Control.Monad.RWS
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (
ExplicitMemory, ExplicitMemorish)
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
data Context = Context
{ ctxSource :: VName
, ctxDestination :: VName
}
deriving (Show)
data Current = Current
{ curHasReachedSource :: Bool
, curHasReachedDestination :: Bool
, curVars :: Names
}
deriving (Show)
newtype FindM lore a = FindM { unFindM :: RWS Context () Current a }
deriving (Monad, Functor, Applicative,
MonadReader Context,
MonadState Current)
type LoreConstraints lore = (ExplicitMemorish lore,
FullWalk lore)
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
modifyCurVars :: (Names -> Names) -> FindM lore ()
modifyCurVars f = modify $ \c -> c { curVars = f $ curVars c }
getVarUsesBetween :: FunDef ExplicitMemory
-> VName -> VName
-> Names
getVarUsesBetween fundef src dst =
let context = Context src dst
m = unFindM $ lookInBody $ funDefBody fundef
res = curVars $ fst $ execRWS m context (Current False False S.empty)
in res
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
mapM_ lookInStm bnds
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
mapM_ lookInStm bnds
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm stm@(Let _ _ e) = do
let new_decls = newDeclarationsStm stm
dst <- asks ctxDestination
when (dst `L.elem` new_decls)
$ modify $ \c -> c { curHasReachedDestination = True }
is_after_source <- gets curHasReachedSource
is_before_destination <- gets curHasReachedDestination
unless is_before_destination $ do
let e_free_vars = freeInExp e
e_used_vars = S.union e_free_vars (S.fromList new_decls)
when is_after_source
$ modifyCurVars $ S.union e_used_vars
src <- asks ctxSource
when (src `L.elem` new_decls)
$ modify $ \c -> c { curHasReachedSource = True }
case e of
If _ body0 body1 _ -> do
before <- get
lookInBody body0
after0 <- get
put Current { curHasReachedSource = curHasReachedSource before
, curHasReachedDestination = curHasReachedDestination after0
, curVars = curVars before
}
lookInBody body1
after1 <- get
put Current { curHasReachedSource =
curHasReachedSource after0 || curHasReachedSource after1
, curHasReachedDestination =
curHasReachedDestination after0 || curHasReachedDestination after1
, curVars =
S.union (curVars after0) (curVars after1)
}
_ -> do
let walker = identityWalker { walkOnBody = lookInBody }
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . lookInBody
, walkOnKernelKernelBody = coerce . lookInKernelBody
, walkOnKernelLambda = coerce . lookInBody . lambdaBody
}
fullWalkExpM walker walker_kernel e