{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.Optimise.MemoryBlockMerging.Miscellaneous
( makeCommutativeMap
, insertOrUpdate
, insertOrUpdateMany
, insertOrNew
, removeEmptyMaps
, removeKeyFromMapElems
, newDeclarationsStm
, lookupEmptyable
, fromJust
, maybeFromBoolM
, sortByKeyM
, mapMaybeM
, anyM
, whenM
, expandPrimExp
, expandIxFun
, mapFromListSetUnion
, fixpointIterateMay
, filterSetM
, (<&&>), (<||>)
, expandWithAliases
, FullWalk(..)
, fullWalkAliasesExpM
, FullWalkAliases
, FullMap
, fullMapExpM
) where
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.List as L
import Control.Monad
import Data.Maybe (fromMaybe, catMaybes)
import Data.Function (on)
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
(ExplicitMemory, InKernel)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Representation.Kernels.KernelExp
import Futhark.Representation.Aliases
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Optimise.MemoryBlockMerging.Types
makeCommutativeMap :: Ord v => M.Map v (S.Set v) -> M.Map v (S.Set v)
makeCommutativeMap m =
let names = S.toList (S.union (M.keysSet m) (S.unions (M.elems m)))
assocs = map (\n ->
let existing = lookupEmptyable n m
newly_found = S.unions $ map (\(k, v) ->
if S.member n v
then S.singleton k
else S.empty) $ M.assocs m
ns = S.union existing newly_found
in (n, ns)) names
in M.fromList assocs
insertOrUpdate :: (Ord k, Ord v) => k -> v ->
M.Map k (S.Set v) -> M.Map k (S.Set v)
insertOrUpdate k v = M.alter (insertOrNew (S.singleton v)) k
insertOrUpdateMany :: (Ord k, Ord v) => k -> S.Set v ->
M.Map k (S.Set v) -> M.Map k (S.Set v)
insertOrUpdateMany k vs = M.alter (insertOrNew vs) k
insertOrNew :: Ord a => S.Set a -> Maybe (S.Set a) -> Maybe (S.Set a)
insertOrNew xs m = Just $ case m of
Just s -> S.union xs s
Nothing -> xs
removeEmptyMaps :: M.Map k (S.Set v) -> M.Map k (S.Set v)
removeEmptyMaps = M.filter (not . S.null)
removeKeyFromMapElems :: (Ord k) => M.Map k (S.Set k) -> M.Map k (S.Set k)
removeKeyFromMapElems = M.mapWithKey S.delete
newDeclarationsStm :: Stm lore -> [VName]
newDeclarationsStm (Let (Pattern patctxelems patvalelems) _ e) =
let new_decls0 = map patElemName (patctxelems ++ patvalelems)
new_decls1 = case e of
DoLoop mergectxparams mergevalparams _loopform _body ->
map (paramName . fst) (mergectxparams ++ mergevalparams)
_ -> []
new_decls = new_decls0 ++ new_decls1
in new_decls
lookupEmptyable :: (Ord a, Monoid b) => a -> M.Map a b -> b
lookupEmptyable x m = fromMaybe mempty $ M.lookup x m
fromJust :: String -> Maybe a -> a
fromJust _ (Just x) = x
fromJust mistake Nothing = error ("error: " ++ mistake)
maybeFromBoolM :: Monad m => (a -> m Bool) -> (a -> m (Maybe a))
maybeFromBoolM f a = do
res <- f a
return $ if res
then Just a
else Nothing
expandWithAliases :: forall v. Ord v => MemAliases -> M.Map v Names -> M.Map v Names
expandWithAliases mem_aliases = fixpointIterate expand
where expand :: M.Map v Names -> M.Map v Names
expand mems_map =
M.fromList (map (\(v, mems) ->
(v, S.unions (mems : map (`lookupEmptyable` mem_aliases)
(S.toList mems))))
(M.assocs mems_map))
fixpointIterate :: Eq a => (a -> a) -> a -> a
fixpointIterate f x
| f x == x = x
| otherwise = fixpointIterate f (f x)
fixpointIterateMay :: (a -> Maybe a) -> a -> a
fixpointIterateMay f x = maybe x (fixpointIterateMay f) (f x)
mapFromListSetUnion :: (Ord k, Ord v) => [(k, S.Set v)] -> M.Map k (S.Set v)
mapFromListSetUnion = M.unionsWith S.union . map (uncurry M.singleton)
expandPrimExp :: M.Map VName (ExpMem.PrimExp VName) -> ExpMem.PrimExp VName
-> ExpMem.PrimExp VName
expandPrimExp var_to_pe = fixpointIterate (substituteInPrimExp var_to_pe)
expandIxFun :: M.Map VName (ExpMem.PrimExp VName) -> ExpMem.IxFun -> ExpMem.IxFun
expandIxFun var_to_pe = fixpointIterate (IxFun.substituteInIxFun var_to_pe)
(<&&>) :: Monad m => m Bool -> m Bool -> m Bool
m <&&> n = (&&) <$> m <*> n
(<||>) :: Monad m => m Bool -> m Bool -> m Bool
m <||> n = (||) <$> m <*> n
anyM :: Monad m => (a -> m Bool) -> [a] -> m Bool
anyM f xs = or <$> mapM f xs
whenM :: Monad m => m Bool -> m () -> m ()
whenM b m = do
b' <- b
when b' m
mapMaybeM :: Monad m => (a -> m (Maybe b)) -> [a] -> m [b]
mapMaybeM f xs = catMaybes <$> mapM f xs
sortByKeyM :: (Ord t, Monad m) => (a -> m t) -> [a] -> m [a]
sortByKeyM f xs =
map fst . L.sortBy (compare `on` snd) . zip xs <$> mapM f xs
filterSetM :: (Ord a, Monad m) => (a -> m Bool) -> S.Set a -> m (S.Set a)
filterSetM f xs = S.fromList <$> filterM f (S.toList xs)
class FullMap lore where
fullMapExpM :: Monad m => Mapper lore lore m -> KernelMapper InKernel InKernel m
-> Exp lore -> m (Exp lore)
instance FullMap ExplicitMemory where
fullMapExpM mapper mapper_kernel e =
case e of
Op (ExpMem.Inner kernel) ->
Op . ExpMem.Inner <$> mapKernelM mapper_kernel kernel
_ -> mapExpM mapper e
instance FullMap InKernel where
fullMapExpM mapper mapper_kernel e = case e of
Op (ExpMem.Inner ke) -> Op . ExpMem.Inner <$> case ke of
ExpMem.Combine a b c body ->
ExpMem.Combine a b c <$> mapOnKernelBody mapper_kernel body
ExpMem.GroupReduce a lambda b ->
ExpMem.GroupReduce a
<$> mapOnKernelLambda mapper_kernel lambda
<*> pure b
ExpMem.GroupScan a lambda b ->
ExpMem.GroupScan a
<$> mapOnKernelLambda mapper_kernel lambda
<*> pure b
ExpMem.GroupStream a b (ExpMem.GroupStreamLambda a1 b1 params0 params1 gsbody) c d ->
ExpMem.GroupStream a b
<$> (ExpMem.GroupStreamLambda a1 b1
<$> mapM (mapOnKernelLParam mapper_kernel) params0
<*> mapM (mapOnKernelLParam mapper_kernel) params1
<*> mapOnKernelBody mapper_kernel gsbody
)
<*> pure c <*> pure d
_ -> return ke
_ -> mapExpM mapper e
class FullWalk lore where
fullWalkExpM :: Monad m => Walker lore m -> KernelWalker InKernel m
-> Exp lore -> m ()
class FullWalkAliases lore where
fullWalkAliasesExpM :: Monad m => Walker (Aliases lore) m
-> KernelWalker (Aliases InKernel) m
-> Exp (Aliases lore) -> m ()
instance FullWalk ExplicitMemory where
fullWalkExpM walker walker_kernel e = do
walkExpM walker e
case e of
Op (ExpMem.Inner kernel) ->
walkKernelM walker_kernel kernel
_ -> return ()
instance FullWalkAliases ExplicitMemory where
fullWalkAliasesExpM walker walker_kernel e = do
walkExpM walker e
case e of
Op (ExpMem.Inner kernel) ->
walkKernelM walker_kernel kernel
_ -> return ()
instance FullWalk InKernel where
fullWalkExpM walker walker_kernel e = case e of
Op (ExpMem.Inner ke) -> walkOnKernelExpM walker_kernel ke
_ -> walkExpM walker e
instance FullWalkAliases InKernel where
fullWalkAliasesExpM walker walker_kernel e = case e of
Op (ExpMem.Inner ke) -> walkOnKernelExpM walker_kernel ke
_ -> walkExpM walker e
walkOnKernelExpM :: Monad m => KernelWalker lore m ->
KernelExp lore -> m ()
walkOnKernelExpM walker_kernel ke = case ke of
ExpMem.Combine _ _ _ body ->
walkOnKernelBody walker_kernel body
ExpMem.GroupReduce _ lambda _ ->
walkOnKernelLambda walker_kernel lambda
ExpMem.GroupScan _ lambda _ ->
walkOnKernelLambda walker_kernel lambda
ExpMem.GroupStream _ _ gslambda _ _ ->
walkOnGroupStreamLambdaM walker_kernel gslambda
_ -> return ()
walkOnGroupStreamLambdaM :: Monad m => KernelWalker lore m ->
GroupStreamLambda lore -> m ()
walkOnGroupStreamLambdaM walker_kernel (GroupStreamLambda _ _
params0 params1 gsbody) = do
mapM_ (walkOnKernelLParam walker_kernel) params0
mapM_ (walkOnKernelLParam walker_kernel) params1
walkOnKernelBody walker_kernel gsbody