{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Optimise.MemoryBlockMerging.Existentials
( findExistentials
) where
import qualified Data.Set as S
import qualified Data.List as L
import Control.Monad
import Control.Monad.Writer
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (ExplicitMemorish)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
newtype FindM lore a = FindM { unFindM :: Writer Names a }
deriving (Monad, Functor, Applicative,
MonadWriter Names)
type LoreConstraints lore = (ExplicitMemorish lore,
FullWalk lore)
record :: VName -> FindM lore ()
record = tell . S.singleton
coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM
findExistentials :: LoreConstraints lore =>
FunDef lore -> Names
findExistentials fundef =
let m = unFindM $ lookInBody $ funDefBody fundef
existentials = execWriter m
in existentials
lookInBody :: LoreConstraints lore =>
Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
mapM_ lookInStm bnds
lookInKernelBody :: LoreConstraints lore =>
KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
mapM_ lookInStm bnds
lookInStm :: LoreConstraints lore =>
Stm lore -> FindM lore ()
lookInStm (Let (Pattern patctxelems patvalelems) _ e) = do
forM_ patvalelems $ \(PatElem var membound) ->
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
when (mem `L.elem` map patElemName patctxelems)
$ record var
_ -> return ()
case e of
DoLoop mergectxparams mergevalparams _loopform _body ->
forM_ mergevalparams $ \(Param var membound, _) ->
case membound of
ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
when (mem `L.elem` map (paramName . fst) mergectxparams)
$ record var
_ -> return ()
_ -> return ()
fullWalkExpM walker walker_kernel e
where walker = identityWalker
{ walkOnBody = lookInBody }
walker_kernel = identityKernelWalker
{ walkOnKernelBody = coerce . lookInBody
, walkOnKernelKernelBody = coerce . lookInKernelBody
, walkOnKernelLambda = coerce . lookInBody . lambdaBody
}