{-# LANGUAGE FlexibleContexts #-}
module Futhark.Analysis.DataDependencies
( Dependencies,
dataDependencies,
findNecessaryForReturned,
)
where
import qualified Data.Map.Strict as M
import Futhark.IR
type Dependencies = M.Map VName Names
dataDependencies :: ASTLore lore => Body lore -> Dependencies
dataDependencies :: Body lore -> Dependencies
dataDependencies = Dependencies -> Body lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
forall k a. Map k a
M.empty
dataDependencies' ::
ASTLore lore =>
Dependencies ->
Body lore ->
Dependencies
dataDependencies' :: Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
startdeps = (Dependencies -> Stm lore -> Dependencies)
-> Dependencies -> Seq (Stm lore) -> Dependencies
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Dependencies -> Stm lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Stm lore -> Dependencies
grow Dependencies
startdeps (Seq (Stm lore) -> Dependencies)
-> (Body lore -> Seq (Stm lore)) -> Body lore -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms
where
grow :: Dependencies -> Stm lore -> Dependencies
grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpDec lore)
_ (If SubExp
c BodyT lore
tb BodyT lore
fb IfDec (BranchType lore)
_)) =
let tdeps :: Dependencies
tdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
tb
fdeps :: Dependencies
fdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
fb
cdeps :: Names
cdeps = Dependencies -> SubExp -> Names
depsOf Dependencies
deps SubExp
c
comb :: (PatElemT dec, SubExp, SubExp) -> (VName, Names)
comb (PatElemT dec
pe, SubExp
tres, SubExp
fres) =
( PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe,
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
[PatElemT dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT dec
pe, Names
cdeps, Dependencies -> SubExp -> Names
depsOf Dependencies
tdeps SubExp
tres, Dependencies -> SubExp -> Names
depsOf Dependencies
fdeps SubExp
fres]
[Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT dec
pe)
)
branchdeps :: Dependencies
branchdeps =
[(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Dependencies)
-> [(VName, Names)] -> Dependencies
forall a b. (a -> b) -> a -> b
$
((PatElemT (LetDec lore), SubExp, SubExp) -> (VName, Names))
-> [(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), SubExp, SubExp) -> (VName, Names)
forall dec.
FreeIn dec =>
(PatElemT dec, SubExp, SubExp) -> (VName, Names)
comb ([(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)])
-> [(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [SubExp]
-> [SubExp]
-> [(PatElemT (LetDec lore), SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
(BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb)
(BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb)
in [Dependencies] -> Dependencies
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions [Dependencies
branchdeps, Dependencies
deps, Dependencies
tdeps, Dependencies
fdeps]
grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpDec lore)
_ ExpT lore
e) =
let free :: Names
free = Pattern lore -> Names
forall a. FreeIn a => a -> Names
freeIn Pattern lore
pat Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> ExpT lore -> Names
forall a. FreeIn a => a -> Names
freeIn ExpT lore
e
freeDeps :: Names
freeDeps = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free
in [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Names
freeDeps) | VName
name <- Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat] Dependencies -> Dependencies -> Dependencies
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps
depsOf :: Dependencies -> SubExp -> Names
depsOf :: Dependencies -> SubExp -> Names
depsOf Dependencies
_ (Constant PrimValue
_) = Names
forall a. Monoid a => a
mempty
depsOf Dependencies
deps (Var VName
v) = Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
v
depsOfVar :: Dependencies -> VName -> Names
depsOfVar :: Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
name = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
name Dependencies
deps
findNecessaryForReturned ::
(Param dec -> Bool) ->
[(Param dec, SubExp)] ->
M.Map VName Names ->
Names
findNecessaryForReturned :: (Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned Param dec -> Bool
usedAfterLoop [(Param dec, SubExp)]
merge_and_res Dependencies
allDependencies =
Names -> Names
iterateNecessary Names
forall a. Monoid a => a
mempty
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName ([Param dec] -> [VName]) -> [Param dec] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Bool) -> [Param dec] -> [Param dec]
forall a. (a -> Bool) -> [a] -> [a]
filter Param dec -> Bool
usedAfterLoop ([Param dec] -> [Param dec]) -> [Param dec] -> [Param dec]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Param dec)
-> [(Param dec, SubExp)] -> [Param dec]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst [(Param dec, SubExp)]
merge_and_res)
where
iterateNecessary :: Names -> Names
iterateNecessary Names
prev_necessary
| Names
necessary Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
prev_necessary = Names
necessary
| Bool
otherwise = Names -> Names
iterateNecessary Names
necessary
where
necessary :: Names
necessary = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
dependencies [SubExp]
returnedResultSubExps
usedAfterLoopOrNecessary :: Param dec -> Bool
usedAfterLoopOrNecessary Param dec
param =
Param dec -> Bool
usedAfterLoop Param dec
param Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` Names
prev_necessary
returnedResultSubExps :: [SubExp]
returnedResultSubExps =
((Param dec, SubExp) -> SubExp)
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Param dec, SubExp)] -> [SubExp])
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Bool)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param dec -> Bool
usedAfterLoopOrNecessary (Param dec -> Bool)
-> ((Param dec, SubExp) -> Param dec)
-> (Param dec, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst) [(Param dec, SubExp)]
merge_and_res
dependencies :: SubExp -> Names
dependencies (Constant PrimValue
_) =
Names
forall a. Monoid a => a
mempty
dependencies (Var VName
v) =
Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v Dependencies
allDependencies