module Futhark.Analysis.DataDependencies
( Dependencies,
dataDependencies,
depsOf,
depsOf',
depsOfArrays,
depsOfShape,
lambdaDependencies,
reductionDependencies,
findNecessaryForReturned,
)
where
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Futhark.IR
type Dependencies = M.Map VName Names
dataDependencies :: (ASTRep rep) => Body rep -> Dependencies
dataDependencies :: forall rep. ASTRep rep => Body rep -> Dependencies
dataDependencies = forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' forall k a. Map k a
M.empty
dataDependencies' ::
(ASTRep rep) =>
Dependencies ->
Body rep ->
Dependencies
dataDependencies' :: forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' Dependencies
startdeps = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {rep}. ASTRep rep => Dependencies -> Stm rep -> Dependencies
grow Dependencies
startdeps forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms
where
grow :: Dependencies -> Stm rep -> Dependencies
grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (WithAcc [WithAccInput rep]
inputs Lambda rep
lam)) =
let input_deps :: [Names]
input_deps = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {rep}.
ASTRep rep =>
(Shape, [VName], Maybe (Lambda rep, [SubExp])) -> [Names]
depsOfWithAccInput [WithAccInput rep]
inputs
lam_deps :: [Names]
lam_deps = forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [Names] -> [Names]
lambdaDependencies Dependencies
deps Lambda rep
lam ([Names]
input_deps forall a. Semigroup a => a -> a -> a
<> [Names]
input_deps)
transitive :: [Names]
transitive = forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> Names -> Names
depsOfNames Dependencies
deps) [Names]
lam_deps
in forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [Names]
transitive) forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps
where
depsOfArrays' :: Shape -> [VName] -> [Names]
depsOfArrays' Shape
shape =
forall a b. (a -> b) -> [a] -> [b]
map (\VName
arr -> VName -> Names
oneName VName
arr forall a. Semigroup a => a -> a -> a
<> Shape -> Names
depsOfShape Shape
shape)
depsOfWithAccInput :: (Shape, [VName], Maybe (Lambda rep, [SubExp])) -> [Names]
depsOfWithAccInput (Shape
shape, [VName]
arrs, Maybe (Lambda rep, [SubExp])
Nothing) =
Shape -> [VName] -> [Names]
depsOfArrays' Shape
shape [VName]
arrs
depsOfWithAccInput (Shape
shape, [VName]
arrs, Just (Lambda rep
lam', [SubExp]
nes)) =
forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies Dependencies
deps Lambda rep
lam' [SubExp]
nes (Shape -> [VName] -> [Names]
depsOfArrays' Shape
shape [VName]
arrs)
grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op OpC rep rep
op)) =
let op_deps :: [Names]
op_deps = forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> Names -> Names
depsOfNames Dependencies
deps) (forall op. IsOp op => op -> [Names]
opDependencies OpC rep rep
op)
pat_deps :: [Names]
pat_deps = forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> Names -> Names
depsOfNames Dependencies
deps forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
in if forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
op_deps forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
pat_deps
then
forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines forall a b. (a -> b) -> a -> b
$
[ String
"dataDependencies':",
String
"Pattern size: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
pat_deps),
String
"Op deps size: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
op_deps),
String
"Expression:",
forall a. Pretty a => a -> String
prettyString OpC rep rep
op
]
else
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Semigroup a => a -> a -> a
(<>) [Names]
pat_deps [Names]
op_deps)
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps
grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Match [SubExp]
c [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_)) =
let cases_deps :: [Dependencies]
cases_deps = forall a b. (a -> b) -> [a] -> [b]
map (forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' Dependencies
deps forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
defbody_deps :: Dependencies
defbody_deps = forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' Dependencies
deps Body rep
defbody
cdeps :: Names
cdeps = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Dependencies -> SubExp -> Names
depsOf Dependencies
deps) [SubExp]
c
comb :: (PatElem dec, [Names], Names) -> (VName, Names)
comb (PatElem dec
pe, [Names]
se_cases_deps, Names
se_defbody_deps) =
( forall dec. PatElem dec -> VName
patElemName PatElem dec
pe,
forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$
[Names]
se_cases_deps
forall a. [a] -> [a] -> [a]
++ [forall a. FreeIn a => a -> Names
freeIn PatElem dec
pe, Names
cdeps, Names
se_defbody_deps]
forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn PatElem dec
pe)
)
branchdeps :: Dependencies
branchdeps =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall {dec}.
FreeIn dec =>
(PatElem dec, [Names], Names) -> (VName, Names)
comb forall a b. (a -> b) -> a -> b
$
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
( forall a. [[a]] -> [[a]]
L.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall a b. (a -> b) -> [a] -> [b]
map forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependencies -> SubExp -> Names
depsOf) [Dependencies]
cases_deps forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> [SubExpRes]
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
)
(forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> SubExp -> Names
depsOf Dependencies
defbody_deps forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
defbody))
in forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions forall a b. (a -> b) -> a -> b
$ [Dependencies
branchdeps, Dependencies
deps, Dependencies
defbody_deps] forall a. [a] -> [a] -> [a]
++ [Dependencies]
cases_deps
grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) =
let free :: Names
free = forall a. FreeIn a => a -> Names
freeIn Pat (LetDec rep)
pat forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Exp rep
e
free_deps :: Names
free_deps = Dependencies -> Names -> Names
depsOfNames Dependencies
deps Names
free
in forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Names
free_deps) | VName
name <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat] forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps
depsOf :: Dependencies -> SubExp -> Names
depsOf :: Dependencies -> SubExp -> Names
depsOf Dependencies
_ (Constant PrimValue
_) = forall a. Monoid a => a
mempty
depsOf Dependencies
deps (Var VName
v) = Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
v
depsOf' :: SubExp -> Names
depsOf' :: SubExp -> Names
depsOf' (Constant PrimValue
_) = forall a. Monoid a => a
mempty
depsOf' (Var VName
v) = Dependencies -> VName -> Names
depsOfVar forall a. Monoid a => a
mempty VName
v
depsOfVar :: Dependencies -> VName -> Names
depsOfVar :: Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
name = VName -> Names
oneName VName
name forall a. Semigroup a => a -> a -> a
<> forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
name Dependencies
deps
depsOfRes :: Dependencies -> SubExpRes -> Names
depsOfRes :: Dependencies -> SubExpRes -> Names
depsOfRes Dependencies
deps (SubExpRes Certs
_ SubExp
se) = Dependencies -> SubExp -> Names
depsOf Dependencies
deps SubExp
se
depsOfNames :: Dependencies -> Names -> Names
depsOfNames :: Dependencies -> Names -> Names
depsOfNames Dependencies
deps Names
names = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
names
depsOfArrays :: SubExp -> [VName] -> [Names]
depsOfArrays :: SubExp -> [VName] -> [Names]
depsOfArrays SubExp
size = forall a b. (a -> b) -> [a] -> [b]
map (\VName
arr -> VName -> Names
oneName VName
arr forall a. Semigroup a => a -> a -> a
<> Dependencies -> SubExp -> Names
depsOf forall a. Monoid a => a
mempty SubExp
size)
depsOfShape :: Shape -> Names
depsOfShape :: Shape -> Names
depsOfShape Shape
shape = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> SubExp -> Names
depsOf forall a. Monoid a => a
mempty) (forall d. ShapeBase d -> [d]
shapeDims Shape
shape)
lambdaDependencies ::
(ASTRep rep) =>
Dependencies ->
Lambda rep ->
[Names] ->
[Names]
lambdaDependencies :: forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [Names] -> [Names]
lambdaDependencies Dependencies
deps Lambda rep
lam [Names]
inputs =
let names_in_scope :: Names
names_in_scope = forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Names]
inputs
deps_in :: Dependencies
deps_in = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [VName]
boundByLambda Lambda rep
lam) [Names]
inputs
deps' :: Dependencies
deps' = forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' (Dependencies
deps_in forall a. Semigroup a => a -> a -> a
<> Dependencies
deps) (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
in forall a b. (a -> b) -> [a] -> [b]
map
(Names -> Names -> Names
namesIntersection Names
names_in_scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependencies -> SubExpRes -> Names
depsOfRes Dependencies
deps')
(forall rep. Body rep -> [SubExpRes]
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
reductionDependencies ::
(ASTRep rep) =>
Dependencies ->
Lambda rep ->
[SubExp] ->
[Names] ->
[Names]
reductionDependencies :: forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies Dependencies
deps Lambda rep
lam [SubExp]
nes [Names]
inputs =
let nes' :: [Names]
nes' = forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> SubExp -> Names
depsOf Dependencies
deps) [SubExp]
nes
in forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [Names] -> [Names]
lambdaDependencies Dependencies
deps Lambda rep
lam (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Semigroup a => a -> a -> a
(<>) [Names]
nes' [Names]
inputs)
findNecessaryForReturned ::
(Param dec -> Bool) ->
[(Param dec, SubExp)] ->
M.Map VName Names ->
Names
findNecessaryForReturned :: forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned Param dec -> Bool
usedAfterLoop [(Param dec, SubExp)]
merge_and_res Dependencies
allDependencies =
Names -> Names
iterateNecessary forall a. Monoid a => a
mempty
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter Param dec -> Bool
usedAfterLoop forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param dec, SubExp)]
merge_and_res)
where
iterateNecessary :: Names -> Names
iterateNecessary Names
prev_necessary
| Names
necessary forall a. Eq a => a -> a -> Bool
== Names
prev_necessary = Names
necessary
| Bool
otherwise = Names -> Names
iterateNecessary Names
necessary
where
necessary :: Names
necessary = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
dependencies [SubExp]
returnedResultSubExps
usedAfterLoopOrNecessary :: Param dec -> Bool
usedAfterLoopOrNecessary Param dec
param =
Param dec -> Bool
usedAfterLoop Param dec
param Bool -> Bool -> Bool
|| forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` Names
prev_necessary
returnedResultSubExps :: [SubExp]
returnedResultSubExps =
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (Param dec -> Bool
usedAfterLoopOrNecessary forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param dec, SubExp)]
merge_and_res
dependencies :: SubExp -> Names
dependencies (Constant PrimValue
_) =
forall a. Monoid a => a
mempty
dependencies (Var VName
v) =
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v Dependencies
allDependencies