{-# LANGUAGE FlexibleContexts #-}

-- | Facilities for inspecting the data dependencies of a program.
module Futhark.Analysis.DataDependencies
  ( Dependencies,
    dataDependencies,
    findNecessaryForReturned,
  )
where

import qualified Data.Map.Strict as M
import Futhark.IR

-- | A mapping from a variable name @v@, to those variables on which
-- the value of @v@ is dependent.  The intuition is that we could
-- remove all other variables, and @v@ would still be computable.
-- This also includes names bound in loops or by lambdas.
type Dependencies = M.Map VName Names

-- | Compute the data dependencies for an entire body.
dataDependencies :: ASTLore lore => Body lore -> Dependencies
dataDependencies :: forall lore. ASTLore lore => Body lore -> Dependencies
dataDependencies = Dependencies -> Body lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
forall k a. Map k a
M.empty

dataDependencies' ::
  ASTLore lore =>
  Dependencies ->
  Body lore ->
  Dependencies
dataDependencies' :: forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
startdeps = (Dependencies -> Stm lore -> Dependencies)
-> Dependencies -> Seq (Stm lore) -> Dependencies
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Dependencies -> Stm lore -> Dependencies
forall {lore}.
ASTLore lore =>
Dependencies -> Stm lore -> Dependencies
grow Dependencies
startdeps (Seq (Stm lore) -> Dependencies)
-> (BodyT lore -> Seq (Stm lore)) -> BodyT lore -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms
  where
    grow :: Dependencies -> Stm lore -> Dependencies
grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpDec lore)
_ (If SubExp
c BodyT lore
tb BodyT lore
fb IfDec (BranchType lore)
_)) =
      let tdeps :: Dependencies
tdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
tb
          fdeps :: Dependencies
fdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
fb
          cdeps :: Names
cdeps = Dependencies -> SubExp -> Names
depsOf Dependencies
deps SubExp
c
          comb :: (PatElemT dec, SubExp, SubExp) -> (VName, Names)
comb (PatElemT dec
pe, SubExp
tres, SubExp
fres) =
            ( PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe,
              [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
                [PatElemT dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT dec
pe, Names
cdeps, Dependencies -> SubExp -> Names
depsOf Dependencies
tdeps SubExp
tres, Dependencies -> SubExp -> Names
depsOf Dependencies
fdeps SubExp
fres]
                  [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT dec
pe)
            )
          branchdeps :: Dependencies
branchdeps =
            [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Dependencies)
-> [(VName, Names)] -> Dependencies
forall a b. (a -> b) -> a -> b
$
              ((PatElemT (LetDec lore), SubExp, SubExp) -> (VName, Names))
-> [(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), SubExp, SubExp) -> (VName, Names)
forall {dec}.
FreeIn dec =>
(PatElemT dec, SubExp, SubExp) -> (VName, Names)
comb ([(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)])
-> [(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$
                [PatElemT (LetDec lore)]
-> [SubExp]
-> [SubExp]
-> [(PatElemT (LetDec lore), SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
                  (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
                  (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb)
                  (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb)
       in [Dependencies] -> Dependencies
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions [Dependencies
branchdeps, Dependencies
deps, Dependencies
tdeps, Dependencies
fdeps]
    grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpDec lore)
_ ExpT lore
e) =
      let free :: Names
free = Pattern lore -> Names
forall a. FreeIn a => a -> Names
freeIn Pattern lore
pat Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> ExpT lore -> Names
forall a. FreeIn a => a -> Names
freeIn ExpT lore
e
          freeDeps :: Names
freeDeps = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free
       in [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Names
freeDeps) | VName
name <- Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat] Dependencies -> Dependencies -> Dependencies
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps

depsOf :: Dependencies -> SubExp -> Names
depsOf :: Dependencies -> SubExp -> Names
depsOf Dependencies
_ (Constant PrimValue
_) = Names
forall a. Monoid a => a
mempty
depsOf Dependencies
deps (Var VName
v) = Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
v

depsOfVar :: Dependencies -> VName -> Names
depsOfVar :: Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
name = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
name Dependencies
deps

-- | @findNecessaryForReturned p merge deps@ computes which of the
-- loop parameters (@merge@) are necessary for the result of the loop,
-- where @p@ given a loop parameter indicates whether the final value
-- of that parameter is live after the loop.  @deps@ is the data
-- dependencies of the loop body.  This is computed by straightforward
-- fixpoint iteration.
findNecessaryForReturned ::
  (Param dec -> Bool) ->
  [(Param dec, SubExp)] ->
  M.Map VName Names ->
  Names
findNecessaryForReturned :: forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned Param dec -> Bool
usedAfterLoop [(Param dec, SubExp)]
merge_and_res Dependencies
allDependencies =
  Names -> Names
iterateNecessary Names
forall a. Monoid a => a
mempty
    Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName ([Param dec] -> [VName]) -> [Param dec] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Bool) -> [Param dec] -> [Param dec]
forall a. (a -> Bool) -> [a] -> [a]
filter Param dec -> Bool
usedAfterLoop ([Param dec] -> [Param dec]) -> [Param dec] -> [Param dec]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Param dec)
-> [(Param dec, SubExp)] -> [Param dec]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst [(Param dec, SubExp)]
merge_and_res)
  where
    iterateNecessary :: Names -> Names
iterateNecessary Names
prev_necessary
      | Names
necessary Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
prev_necessary = Names
necessary
      | Bool
otherwise = Names -> Names
iterateNecessary Names
necessary
      where
        necessary :: Names
necessary = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
dependencies [SubExp]
returnedResultSubExps
        usedAfterLoopOrNecessary :: Param dec -> Bool
usedAfterLoopOrNecessary Param dec
param =
          Param dec -> Bool
usedAfterLoop Param dec
param Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` Names
prev_necessary
        returnedResultSubExps :: [SubExp]
returnedResultSubExps =
          ((Param dec, SubExp) -> SubExp)
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Param dec, SubExp)] -> [SubExp])
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Bool)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param dec -> Bool
usedAfterLoopOrNecessary (Param dec -> Bool)
-> ((Param dec, SubExp) -> Param dec)
-> (Param dec, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst) [(Param dec, SubExp)]
merge_and_res
        dependencies :: SubExp -> Names
dependencies (Constant PrimValue
_) =
          Names
forall a. Monoid a => a
mempty
        dependencies (Var VName
v) =
          Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v Dependencies
allDependencies