{-# LANGUAGE FlexibleContexts #-}

-- | A graph representation of a sequence of Futhark statements
-- (i.e. a 'Body'), built to handle fusion.  Could perhaps be made
-- more general.  An important property is that it does not handle
-- "nested bodies" (e.g. 'Match'); these are represented as single
-- nodes.
--
-- This is all implemented on top of the graph representation provided
-- by the @fgl@ package ("Data.Graph.Inductive").  The graph provided
-- by this package allows nodes and edges to have arbitrarily-typed
-- "labels".  It is these labels ('EdgeT', 'NodeT') that we use to
-- contain Futhark-specific information.  An edge goes *from* uses of
-- variables to the node that produces that variable.  There are also
-- edges that do not represent normal data dependencies, but other
-- things.  This means that a node can have multiple edges for the
-- same name, indicating different kinds of dependencies.
module Futhark.Optimise.Fusion.GraphRep
  ( -- * Data structure
    EdgeT (..),
    NodeT (..),
    DepContext,
    DepGraphAug,
    DepGraph (..),
    DepNode,

    -- * Queries
    getName,
    nodeFromLNode,
    mergedContext,
    mapAcross,
    edgesBetween,
    reachable,
    applyAugs,
    depsFromEdge,
    contractEdge,
    isRealNode,
    isCons,
    isDep,
    isInf,

    -- * Construction
    mkDepGraph,
    pprg,
  )
where

import Data.Bifunctor (bimap)
import Data.Foldable (foldlM)
import qualified Data.Graph.Inductive.Dot as G
import qualified Data.Graph.Inductive.Graph as G
import qualified Data.Graph.Inductive.Query.DFS as Q
import qualified Data.Graph.Inductive.Tree as G
import qualified Data.List as L
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.HORep.SOAC as H
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.Util (nubOrd)

-- | Information associated with an edge in the graph.
data EdgeT
  = Alias VName
  | InfDep VName
  | Dep VName
  | Cons VName
  | Fake VName
  | Res VName
  deriving (EdgeT -> EdgeT -> Bool
(EdgeT -> EdgeT -> Bool) -> (EdgeT -> EdgeT -> Bool) -> Eq EdgeT
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EdgeT -> EdgeT -> Bool
$c/= :: EdgeT -> EdgeT -> Bool
== :: EdgeT -> EdgeT -> Bool
$c== :: EdgeT -> EdgeT -> Bool
Eq, Eq EdgeT
Eq EdgeT
-> (EdgeT -> EdgeT -> Ordering)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> EdgeT)
-> (EdgeT -> EdgeT -> EdgeT)
-> Ord EdgeT
EdgeT -> EdgeT -> Bool
EdgeT -> EdgeT -> Ordering
EdgeT -> EdgeT -> EdgeT
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EdgeT -> EdgeT -> EdgeT
$cmin :: EdgeT -> EdgeT -> EdgeT
max :: EdgeT -> EdgeT -> EdgeT
$cmax :: EdgeT -> EdgeT -> EdgeT
>= :: EdgeT -> EdgeT -> Bool
$c>= :: EdgeT -> EdgeT -> Bool
> :: EdgeT -> EdgeT -> Bool
$c> :: EdgeT -> EdgeT -> Bool
<= :: EdgeT -> EdgeT -> Bool
$c<= :: EdgeT -> EdgeT -> Bool
< :: EdgeT -> EdgeT -> Bool
$c< :: EdgeT -> EdgeT -> Bool
compare :: EdgeT -> EdgeT -> Ordering
$ccompare :: EdgeT -> EdgeT -> Ordering
Ord)

-- | Information associated with a node in the graph.
data NodeT
  = StmNode (Stm SOACS)
  | SoacNode H.ArrayTransforms (Pat Type) (H.SOAC SOACS) (StmAux (ExpDec SOACS))
  | -- | Node corresponding to a result of the entire computation
    -- (i.e. the 'Result' of a body).  Any node that is not
    -- transitively reachable from one of these can be considered
    -- dead.
    ResNode VName
  | -- | Node corresponding to a free variable.
    -- Unclear whether we actually need these.
    FreeNode VName
  | FinalNode (Stms SOACS) NodeT (Stms SOACS)
  | MatchNode (Stm SOACS) [(NodeT, [EdgeT])]
  | DoNode (Stm SOACS) [(NodeT, [EdgeT])]
  deriving (NodeT -> NodeT -> Bool
(NodeT -> NodeT -> Bool) -> (NodeT -> NodeT -> Bool) -> Eq NodeT
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NodeT -> NodeT -> Bool
$c/= :: NodeT -> NodeT -> Bool
== :: NodeT -> NodeT -> Bool
$c== :: NodeT -> NodeT -> Bool
Eq)

instance Show EdgeT where
  show :: EdgeT -> String
show (Dep VName
vName) = String
"Dep " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Pretty a => a -> String
pretty VName
vName
  show (InfDep VName
vName) = String
"iDep " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Pretty a => a -> String
pretty VName
vName
  show (Cons VName
_) = String
"Cons"
  show (Fake VName
_) = String
"Fake"
  show (Res VName
_) = String
"Res"
  show (Alias VName
_) = String
"Alias"

instance Show NodeT where
  show :: NodeT -> String
show (StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ Exp SOACS
_)) = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
forall a. Pretty a => a -> String
pretty ([VName] -> [String]) -> [VName] -> [String]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat
  show (SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) = Pat Type -> String
forall a. Pretty a => a -> String
pretty Pat Type
pat
  show (FinalNode Stms SOACS
_ NodeT
nt Stms SOACS
_) = NodeT -> String
forall a. Show a => a -> String
show NodeT
nt
  show (ResNode VName
name) = ShowS
forall a. Pretty a => a -> String
pretty ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String
"Res: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
  show (FreeNode VName
name) = ShowS
forall a. Pretty a => a -> String
pretty ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String
"Input: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
  show (MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Match: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
forall a. Pretty a => a -> String
pretty ([VName] -> [String]) -> [VName] -> [String]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)
  show (DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Do: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
forall a. Pretty a => a -> String
pretty ([VName] -> [String]) -> [VName] -> [String]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)

-- | The name that this edge depends on.
getName :: EdgeT -> VName
getName :: EdgeT -> VName
getName EdgeT
edgeT = case EdgeT
edgeT of
  Alias VName
vn -> VName
vn
  InfDep VName
vn -> VName
vn
  Dep VName
vn -> VName
vn
  Cons VName
vn -> VName
vn
  Fake VName
vn -> VName
vn
  Res VName
vn -> VName
vn

-- | Does the node acutally represent something in the program?  A
-- "non-real" node represents things like fake nodes inserted to
-- express ordering due to consumption.
isRealNode :: NodeT -> Bool
isRealNode :: NodeT -> Bool
isRealNode ResNode {} = Bool
False
isRealNode FreeNode {} = Bool
False
isRealNode NodeT
_ = Bool
True

-- | Prettyprint dependency graph.
pprg :: DepGraph -> String
pprg :: DepGraph -> String
pprg = Dot () -> String
forall a. Dot a -> String
G.showDot (Dot () -> String) -> (DepGraph -> Dot ()) -> DepGraph -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gr String String -> Dot ()
forall (gr :: * -> * -> *). Graph gr => gr String String -> Dot ()
G.fglToDotString (Gr String String -> Dot ())
-> (DepGraph -> Gr String String) -> DepGraph -> Dot ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT -> String)
-> (EdgeT -> String) -> Gr NodeT EdgeT -> Gr String String
forall (gr :: * -> * -> *) a c b d.
DynGraph gr =>
(a -> c) -> (b -> d) -> gr a b -> gr c d
G.nemap NodeT -> String
forall a. Show a => a -> String
show EdgeT -> String
forall a. Show a => a -> String
show (Gr NodeT EdgeT -> Gr String String)
-> (DepGraph -> Gr NodeT EdgeT) -> DepGraph -> Gr String String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DepGraph -> Gr NodeT EdgeT
dgGraph

-- | A pair of a 'G.Node' and the node label.
type DepNode = G.LNode NodeT

type DepEdge = G.LEdge EdgeT

-- | A tuple with four parts: inbound links to the node, the node
-- itself, the 'NodeT' "label", and outbound links from the node.
-- This type is used to modify the graph in 'mapAcross'.
type DepContext = G.Context NodeT EdgeT

-- | A dependency graph.  Edges go from *consumers* to *producers*
-- (i.e. from usage to definition).  That means the incoming edges of
-- a node are the dependents of that node, and the outgoing edges are
-- the dependencies of that node.
data DepGraph = DepGraph
  { DepGraph -> Gr NodeT EdgeT
dgGraph :: G.Gr NodeT EdgeT,
    DepGraph -> ProducerMapping
dgProducerMapping :: ProducerMapping,
    -- | A table mapping VNames to VNames that are aliased to it.
    DepGraph -> AliasTable
dgAliasTable :: AliasTable
  }

-- | A "graph augmentation" is a monadic action that modifies the graph.
type DepGraphAug m = DepGraph -> m DepGraph

-- | For each node, what producer should the node depend on and what
-- type is it.
type EdgeGenerator = NodeT -> [(VName, EdgeT)]

-- | A mapping from variable name to the graph node that produces
-- it.
type ProducerMapping = M.Map VName G.Node

makeMapping :: Monad m => DepGraphAug m
makeMapping :: forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping dg :: DepGraph
dg@(DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}) =
  DepGraph -> m DepGraph
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg {dgProducerMapping :: ProducerMapping
dgProducerMapping = [(VName, Int)] -> ProducerMapping
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Int)] -> ProducerMapping)
-> [(VName, Int)] -> ProducerMapping
forall a b. (a -> b) -> a -> b
$ (DepNode -> [(VName, Int)]) -> [DepNode] -> [(VName, Int)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DepNode -> [(VName, Int)]
gen_dep_list (Gr NodeT EdgeT -> [DepNode]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes Gr NodeT EdgeT
g)}
  where
    gen_dep_list :: DepNode -> [(VName, G.Node)]
    gen_dep_list :: DepNode -> [(VName, Int)]
gen_dep_list (Int
i, NodeT
node) = [(VName
name, Int
i) | VName
name <- NodeT -> [VName]
getOutputs NodeT
node]

-- make a table to handle transitive aliases
makeAliasTable :: Monad m => Stms SOACS -> DepGraphAug m
makeAliasTable :: forall (m :: * -> *). Monad m => Stms SOACS -> DepGraphAug m
makeAliasTable Stms SOACS
stms DepGraph
dg = do
  let (Stms (Aliases SOACS)
_, (AliasTable
aliasTable', Names
_)) = AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), (AliasTable, Names))
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Stms rep -> (Stms (Aliases rep), (AliasTable, Names))
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms SOACS
stms
  DepGraph -> m DepGraph
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraph -> m DepGraph) -> DepGraph -> m DepGraph
forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgAliasTable :: AliasTable
dgAliasTable = AliasTable
aliasTable'}

-- | Apply several graph augmentations in sequence.
applyAugs :: Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs :: forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs DepGraph
g = (DepGraph -> DepGraphAug m -> m DepGraph)
-> DepGraph -> [DepGraphAug m] -> m DepGraph
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ((DepGraphAug m -> DepGraphAug m)
-> DepGraph -> DepGraphAug m -> m DepGraph
forall a b c. (a -> b -> c) -> b -> a -> c
flip DepGraphAug m -> DepGraphAug m
forall a b. (a -> b) -> a -> b
($)) DepGraph
g [DepGraphAug m]
augs

-- | Creates deps for the given nodes on the graph using the 'EdgeGenerator'.
genEdges :: Monad m => [DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges :: forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges [DepNode]
l_stms EdgeGenerator
edge_fun DepGraph
dg =
  [DepEdge] -> DepGraphAug m
forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges ((DepNode -> [DepEdge]) -> [DepNode] -> [DepEdge]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ProducerMapping -> DepNode -> [DepEdge]
genEdge (DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg)) [DepNode]
l_stms) DepGraph
dg
  where
    -- statements -> mapping from declared array names to soac index
    genEdge :: M.Map VName G.Node -> DepNode -> [G.LEdge EdgeT]
    genEdge :: ProducerMapping -> DepNode -> [DepEdge]
genEdge ProducerMapping
name_map (Int
from, NodeT
node) = do
      (VName
dep, EdgeT
edgeT) <- EdgeGenerator
edge_fun NodeT
node
      Just Int
to <- [VName -> ProducerMapping -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
dep ProducerMapping
name_map]
      DepEdge -> [DepEdge]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepEdge -> [DepEdge]) -> DepEdge -> [DepEdge]
forall a b. (a -> b) -> a -> b
$ Edge -> EdgeT -> DepEdge
forall b. Edge -> b -> LEdge b
G.toLEdge (Int
from, Int
to) EdgeT
edgeT

depGraphInsertEdges :: Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges :: forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges [DepEdge]
edgs DepGraph
dg = DepGraph -> m DepGraph
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraph -> m DepGraph) -> DepGraph -> m DepGraph
forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = [DepEdge] -> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall (gr :: * -> * -> *) b a.
DynGraph gr =>
[LEdge b] -> gr a b -> gr a b
G.insEdges [DepEdge]
edgs (Gr NodeT EdgeT -> Gr NodeT EdgeT)
-> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg}

-- | Monadically modify every node of the graph.
mapAcross :: Monad m => (DepContext -> m DepContext) -> DepGraphAug m
mapAcross :: forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross DepContext -> m DepContext
f DepGraph
dg = do
  Gr NodeT EdgeT
g' <- (Gr NodeT EdgeT -> Int -> m (Gr NodeT EdgeT))
-> Gr NodeT EdgeT -> [Int] -> m (Gr NodeT EdgeT)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ((Int -> Gr NodeT EdgeT -> m (Gr NodeT EdgeT))
-> Gr NodeT EdgeT -> Int -> m (Gr NodeT EdgeT)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Gr NodeT EdgeT -> m (Gr NodeT EdgeT)
forall {gr :: * -> * -> *}.
DynGraph gr =>
Int -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper) (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg) (Gr NodeT EdgeT -> [Int]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Int]
G.nodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg))
  DepGraph -> m DepGraph
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraph -> m DepGraph) -> DepGraph -> m DepGraph
forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g'}
  where
    helper :: Int -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper Int
n gr NodeT EdgeT
g' = case Int -> gr NodeT EdgeT -> Decomp gr NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
Int -> gr a b -> Decomp gr a b
G.match Int
n gr NodeT EdgeT
g' of
      (Just DepContext
c, gr NodeT EdgeT
g_new) -> do
        DepContext
c' <- DepContext -> m DepContext
f DepContext
c
        gr NodeT EdgeT -> m (gr NodeT EdgeT)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (gr NodeT EdgeT -> m (gr NodeT EdgeT))
-> gr NodeT EdgeT -> m (gr NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ DepContext
c' DepContext -> gr NodeT EdgeT -> gr NodeT EdgeT
forall (gr :: * -> * -> *) a b.
DynGraph gr =>
Context a b -> gr a b -> gr a b
G.& gr NodeT EdgeT
g_new
      (Maybe DepContext
Nothing, gr NodeT EdgeT
_) -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
forall (f :: * -> *) a. Applicative f => a -> f a
pure gr NodeT EdgeT
g'

stmFromNode :: NodeT -> Stms SOACS -- do not use outside of edge generation
stmFromNode :: NodeT -> Stms SOACS
stmFromNode (StmNode Stm SOACS
x) = Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
x
stmFromNode NodeT
_ = Stms SOACS
forall a. Monoid a => a
mempty

-- | Get the underlying @fgl@ node.
nodeFromLNode :: DepNode -> G.Node
nodeFromLNode :: DepNode -> Int
nodeFromLNode = DepNode -> Int
forall a b. (a, b) -> a
fst

-- | Get the variable name that this edge refers to.
depsFromEdge :: DepEdge -> VName
depsFromEdge :: DepEdge -> VName
depsFromEdge = EdgeT -> VName
getName (EdgeT -> VName) -> (DepEdge -> EdgeT) -> DepEdge -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DepEdge -> EdgeT
forall b. LEdge b -> b
G.edgeLabel

-- | Find all the edges connecting the two nodes.
edgesBetween :: DepGraph -> G.Node -> G.Node -> [DepEdge]
edgesBetween :: DepGraph -> Int -> Int -> [DepEdge]
edgesBetween DepGraph
dg Int
n1 Int
n2 = Gr NodeT EdgeT -> [DepEdge]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges (Gr NodeT EdgeT -> [DepEdge]) -> Gr NodeT EdgeT -> [DepEdge]
forall a b. (a -> b) -> a -> b
$ [Int] -> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall (gr :: * -> * -> *) a b.
DynGraph gr =>
[Int] -> gr a b -> gr a b
G.subgraph [Int
n1, Int
n2] (Gr NodeT EdgeT -> Gr NodeT EdgeT)
-> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg

-- | @reachable dg from to@ is true if @to@ is reachable from @from@.
reachable :: DepGraph -> G.Node -> G.Node -> Bool
reachable :: DepGraph -> Int -> Int -> Bool
reachable DepGraph
dg Int
source Int
target = Int
target Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Int -> Gr NodeT EdgeT -> [Int]
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> [Int]
Q.reachable Int
source (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)

-- Utility func for augs
augWithFun :: Monad m => EdgeGenerator -> DepGraphAug m
augWithFun :: forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
f DepGraph
dg = [DepNode] -> EdgeGenerator -> DepGraphAug m
forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges (Gr NodeT EdgeT -> [DepNode]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) EdgeGenerator
f DepGraph
dg

addDeps :: Monad m => DepGraphAug m
addDeps :: forall (m :: * -> *). Monad m => DepGraphAug m
addDeps = EdgeGenerator -> DepGraphAug m
forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
toDep
  where
    toDep :: EdgeGenerator
toDep NodeT
stmt =
      let ([VName]
fusible, [VName]
infusible) =
            ([(VName, Classification)] -> [VName])
-> ([(VName, Classification)] -> [VName])
-> ([(VName, Classification)], [(VName, Classification)])
-> ([VName], [VName])
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (((VName, Classification) -> VName)
-> [(VName, Classification)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Classification) -> VName
forall a b. (a, b) -> a
fst) (((VName, Classification) -> VName)
-> [(VName, Classification)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Classification) -> VName
forall a b. (a, b) -> a
fst)
              (([(VName, Classification)], [(VName, Classification)])
 -> ([VName], [VName]))
-> (Classifications
    -> ([(VName, Classification)], [(VName, Classification)]))
-> Classifications
-> ([VName], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Classification) -> Bool)
-> [(VName, Classification)]
-> ([(VName, Classification)], [(VName, Classification)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition ((Classification -> Classification -> Bool
forall a. Eq a => a -> a -> Bool
== Classification
SOACInput) (Classification -> Bool)
-> ((VName, Classification) -> Classification)
-> (VName, Classification)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Classification) -> Classification
forall a b. (a, b) -> b
snd)
              ([(VName, Classification)]
 -> ([(VName, Classification)], [(VName, Classification)]))
-> (Classifications -> [(VName, Classification)])
-> Classifications
-> ([(VName, Classification)], [(VName, Classification)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Classifications -> [(VName, Classification)]
forall a. Set a -> [a]
S.toList
              (Classifications -> ([VName], [VName]))
-> Classifications -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Stm SOACS -> Classifications) -> Stms SOACS -> Classifications
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs (NodeT -> Stms SOACS
stmFromNode NodeT
stmt)
          mkDep :: VName -> (VName, EdgeT)
mkDep VName
vname = (VName
vname, VName -> EdgeT
Dep VName
vname)
          mkInfDep :: VName -> (VName, EdgeT)
mkInfDep VName
vname = (VName
vname, VName -> EdgeT
InfDep VName
vname)
       in (VName -> (VName, EdgeT)) -> [VName] -> [(VName, EdgeT)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkDep [VName]
fusible [(VName, EdgeT)] -> [(VName, EdgeT)] -> [(VName, EdgeT)]
forall a. Semigroup a => a -> a -> a
<> (VName -> (VName, EdgeT)) -> [VName] -> [(VName, EdgeT)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkInfDep [VName]
infusible

addConsAndAliases :: Monad m => DepGraphAug m
addConsAndAliases :: forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases = EdgeGenerator -> DepGraphAug m
forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
edges
  where
    edges :: EdgeGenerator
edges (StmNode Stm SOACS
s) = Exp (Aliases SOACS) -> [(VName, EdgeT)]
forall {rep}. Aliased rep => Exp rep -> [(VName, EdgeT)]
consEdges Exp (Aliases SOACS)
e [(VName, EdgeT)] -> [(VName, EdgeT)] -> [(VName, EdgeT)]
forall a. Semigroup a => a -> a -> a
<> Exp (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges Exp (Aliases SOACS)
e
      where
        e :: Exp (Aliases SOACS)
e = AliasTable -> Exp SOACS -> Exp (Aliases SOACS)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Exp rep -> Exp (Aliases rep)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty (Exp SOACS -> Exp (Aliases SOACS))
-> Exp SOACS -> Exp (Aliases SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp Stm SOACS
s
    edges NodeT
_ = [(VName, EdgeT)]
forall a. Monoid a => a
mempty
    consEdges :: Exp rep -> [(VName, EdgeT)]
consEdges Exp rep
e = [VName] -> [EdgeT] -> [(VName, EdgeT)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names ((VName -> EdgeT) -> [VName] -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map VName -> EdgeT
Cons [VName]
names)
      where
        names :: [VName]
names = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Exp rep -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp Exp rep
e
    aliasEdges :: Exp (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges =
      (VName -> (VName, EdgeT)) -> [VName] -> [(VName, EdgeT)]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
vname -> (VName
vname, VName -> EdgeT
Alias VName
vname))
        ([VName] -> [(VName, EdgeT)])
-> (Exp (Aliases SOACS) -> [VName])
-> Exp (Aliases SOACS)
-> [(VName, EdgeT)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
        (Names -> [VName])
-> (Exp (Aliases SOACS) -> Names) -> Exp (Aliases SOACS) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
        ([Names] -> Names)
-> (Exp (Aliases SOACS) -> [Names]) -> Exp (Aliases SOACS) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Aliases SOACS) -> [Names]
forall rep. Aliased rep => Exp rep -> [Names]
expAliases

-- extra dependencies mask the fact that consuming nodes "depend" on all other
-- nodes coming before it (now also adds fake edges to aliases - hope this
-- fixes asymptotic complexity guarantees)
addExtraCons :: Monad m => DepGraphAug m
addExtraCons :: forall (m :: * -> *). Monad m => DepGraphAug m
addExtraCons DepGraph
dg =
  [DepEdge] -> DepGraphAug m
forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges ((DepEdge -> [DepEdge]) -> [DepEdge] -> [DepEdge]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DepEdge -> [DepEdge]
makeEdge (Gr NodeT EdgeT -> [DepEdge]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges Gr NodeT EdgeT
g)) DepGraph
dg
  where
    g :: Gr NodeT EdgeT
g = DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
    alias_table :: AliasTable
alias_table = DepGraph -> AliasTable
dgAliasTable DepGraph
dg
    mapping :: ProducerMapping
mapping = DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg
    makeEdge :: DepEdge -> [DepEdge]
makeEdge (Int
from, Int
to, Cons VName
cname) = do
      let aliases :: [VName]
aliases = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
cname AliasTable
alias_table
          to' :: [Int]
to' = (VName -> Int) -> [VName] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ProducerMapping
mapping ProducerMapping -> VName -> Int
forall k a. Ord k => Map k a -> k -> a
M.!) [VName]
aliases
          p :: (Int, EdgeT) -> Bool
p (Int
tonode, EdgeT
toedge) =
            Int
tonode Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
from Bool -> Bool -> Bool
&& EdgeT -> VName
getName EdgeT
toedge VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (VName
cname VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
aliases)
      (Int
to2, EdgeT
_) <- ((Int, EdgeT) -> Bool) -> [(Int, EdgeT)] -> [(Int, EdgeT)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int, EdgeT) -> Bool
p ([(Int, EdgeT)] -> [(Int, EdgeT)])
-> [(Int, EdgeT)] -> [(Int, EdgeT)]
forall a b. (a -> b) -> a -> b
$ (Int -> [(Int, EdgeT)]) -> [Int] -> [(Int, EdgeT)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Gr NodeT EdgeT -> Int -> [(Int, EdgeT)]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g) [Int]
to' [(Int, EdgeT)] -> [(Int, EdgeT)] -> [(Int, EdgeT)]
forall a. Semigroup a => a -> a -> a
<> Gr NodeT EdgeT -> Int -> [(Int, EdgeT)]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g Int
to
      DepEdge -> [DepEdge]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepEdge -> [DepEdge]) -> DepEdge -> [DepEdge]
forall a b. (a -> b) -> a -> b
$ Edge -> EdgeT -> DepEdge
forall b. Edge -> b -> LEdge b
G.toLEdge (Int
from, Int
to2) (VName -> EdgeT
Fake VName
cname)
    makeEdge DepEdge
_ = []

mapAcrossNodeTs :: Monad m => (NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs :: forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs NodeT -> m NodeT
f = (DepContext -> m DepContext) -> DepGraphAug m
forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross DepContext -> m DepContext
forall {a} {b} {d}. (a, b, NodeT, d) -> m (a, b, NodeT, d)
f'
  where
    f' :: (a, b, NodeT, d) -> m (a, b, NodeT, d)
f' (a
ins, b
n, NodeT
nodeT, d
outs) = do
      NodeT
nodeT' <- NodeT -> m NodeT
f NodeT
nodeT
      (a, b, NodeT, d) -> m (a, b, NodeT, d)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
ins, b
n, NodeT
nodeT', d
outs)

nodeToSoacNode :: (HasScope SOACS m, Monad m) => NodeT -> m NodeT
nodeToSoacNode :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode n :: NodeT
n@(StmNode s :: Stm SOACS
s@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
op)) = case Exp SOACS
op of
  Op {} -> do
    Either NotSOAC (SOAC SOACS)
maybeSoac <- Exp SOACS -> m (Either NotSOAC (SOAC SOACS))
forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
H.fromExp Exp SOACS
op
    case Either NotSOAC (SOAC SOACS)
maybeSoac of
      Right SOAC SOACS
hsoac -> NodeT -> m NodeT
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
forall a. Monoid a => a
mempty Pat Type
Pat (LetDec SOACS)
pat SOAC SOACS
hsoac StmAux (ExpDec SOACS)
aux
      Left NotSOAC
H.NotSOAC -> NodeT -> m NodeT
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
  DoLoop {} ->
    NodeT -> m NodeT
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode Stm SOACS
s []
  Match {} ->
    NodeT -> m NodeT
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
s []
  Exp SOACS
_ -> NodeT -> m NodeT
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
nodeToSoacNode NodeT
n = NodeT -> m NodeT
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n

convertGraph :: (HasScope SOACS m, Monad m) => DepGraphAug m
convertGraph :: forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
convertGraph = (NodeT -> m NodeT) -> DepGraphAug m
forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs NodeT -> m NodeT
forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode

initialGraphConstruction :: (HasScope SOACS m, Monad m) => DepGraphAug m
initialGraphConstruction :: forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
initialGraphConstruction =
  [DepGraphAug m] -> DepGraphAug m
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs
    [ DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addDeps,
      DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases,
      DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addExtraCons,
      DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges,
      DepGraphAug m
forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
convertGraph -- Must be done after adding edges
    ]

-- | Construct a graph with only nodes, but no edges.
emptyGraph :: Body SOACS -> DepGraph
emptyGraph :: Body SOACS -> DepGraph
emptyGraph Body SOACS
body =
  DepGraph :: Gr NodeT EdgeT -> ProducerMapping -> AliasTable -> DepGraph
DepGraph
    { dgGraph :: Gr NodeT EdgeT
dgGraph = [DepNode] -> [DepEdge] -> Gr NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
[LNode a] -> [LEdge b] -> gr a b
G.mkGraph ([NodeT] -> [DepNode]
forall {b}. [b] -> [(Int, b)]
labelNodes ([NodeT]
stmnodes [NodeT] -> [NodeT] -> [NodeT]
forall a. Semigroup a => a -> a -> a
<> [NodeT]
resnodes [NodeT] -> [NodeT] -> [NodeT]
forall a. Semigroup a => a -> a -> a
<> [NodeT]
inputnodes)) [],
      dgProducerMapping :: ProducerMapping
dgProducerMapping = ProducerMapping
forall a. Monoid a => a
mempty,
      dgAliasTable :: AliasTable
dgAliasTable = AliasTable
forall a. Monoid a => a
mempty
    }
  where
    labelNodes :: [b] -> [(Int, b)]
labelNodes = [Int] -> [b] -> [(Int, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]
    stmnodes :: [NodeT]
stmnodes = (Stm SOACS -> NodeT) -> [Stm SOACS] -> [NodeT]
forall a b. (a -> b) -> [a] -> [b]
map Stm SOACS -> NodeT
StmNode ([Stm SOACS] -> [NodeT]) -> [Stm SOACS] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
    resnodes :: [NodeT]
resnodes = (VName -> NodeT) -> [VName] -> [NodeT]
forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
ResNode ([VName] -> [NodeT]) -> [VName] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body
    inputnodes :: [NodeT]
inputnodes = (VName -> NodeT) -> [VName] -> [NodeT]
forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
FreeNode ([VName] -> [NodeT]) -> [VName] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
body

-- | Make a dependency graph corresponding to a 'Body'.
mkDepGraph :: (HasScope SOACS m, Monad m) => Body SOACS -> m DepGraph
mkDepGraph :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body = [DepGraphAug m] -> DepGraphAug m
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs DepGraphAug m -> DepGraphAug m
forall a b. (a -> b) -> a -> b
$ Body SOACS -> DepGraph
emptyGraph Body SOACS
body
  where
    augs :: [DepGraphAug m]
augs =
      [ DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping,
        Stms SOACS -> DepGraphAug m
forall (m :: * -> *). Monad m => Stms SOACS -> DepGraphAug m
makeAliasTable (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body),
        DepGraphAug m
forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
initialGraphConstruction
      ]

-- | Merges two contexts.
mergedContext :: Ord b => a -> G.Context a b -> G.Context a b -> G.Context a b
mergedContext :: forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext a
mergedlabel (Adj b
inp1, Int
n1, a
_, Adj b
out1) (Adj b
inp2, Int
n2, a
_, Adj b
out2) =
  let new_inp :: Adj b
new_inp = ((b, Int) -> Bool) -> Adj b -> Adj b
forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Int)
n -> (b, Int) -> Int
forall a b. (a, b) -> b
snd (b, Int)
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n1 Bool -> Bool -> Bool
&& (b, Int) -> Int
forall a b. (a, b) -> b
snd (b, Int)
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n2) (Adj b -> Adj b
forall a. Ord a => [a] -> [a]
nubOrd (Adj b
inp1 Adj b -> Adj b -> Adj b
forall a. Semigroup a => a -> a -> a
<> Adj b
inp2))
      new_out :: Adj b
new_out = ((b, Int) -> Bool) -> Adj b -> Adj b
forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Int)
n -> (b, Int) -> Int
forall a b. (a, b) -> b
snd (b, Int)
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n1 Bool -> Bool -> Bool
&& (b, Int) -> Int
forall a b. (a, b) -> b
snd (b, Int)
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n2) (Adj b -> Adj b
forall a. Ord a => [a] -> [a]
nubOrd (Adj b
out1 Adj b -> Adj b -> Adj b
forall a. Semigroup a => a -> a -> a
<> Adj b
out2))
   in (Adj b
new_inp, Int
n1, a
mergedlabel, Adj b
new_out)

-- | Remove the given node, and insert the 'DepContext' into the
-- graph, replacing any existing information about the node contained
-- in the 'DepContext'.
contractEdge :: Monad m => G.Node -> DepContext -> DepGraphAug m
contractEdge :: forall (m :: * -> *). Monad m => Int -> DepContext -> DepGraphAug m
contractEdge Int
n2 DepContext
ctx DepGraph
dg = do
  let n1 :: Int
n1 = DepContext -> Int
forall a b. Context a b -> Int
G.node' DepContext
ctx -- n1 remains
  DepGraph -> m DepGraph
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraph -> m DepGraph) -> DepGraph -> m DepGraph
forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = DepContext
ctx DepContext -> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall (gr :: * -> * -> *) a b.
DynGraph gr =>
Context a b -> gr a b -> gr a b
G.& [Int] -> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
[Int] -> gr a b -> gr a b
G.delNodes [Int
n1, Int
n2] (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)}

addResEdges :: Monad m => DepGraphAug m
addResEdges :: forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges = EdgeGenerator -> DepGraphAug m
forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
getStmRes

-- Utils for fusibility/infusibility
-- find dependencies - either fusible or infusible. edges are generated based on these

-- | A classification of a free variable.
data Classification
  = -- | Used as array input to a SOAC (meaning fusible).
    SOACInput
  | -- | Used in some other way.
    Other
  deriving (Classification -> Classification -> Bool
(Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool) -> Eq Classification
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Classification -> Classification -> Bool
$c/= :: Classification -> Classification -> Bool
== :: Classification -> Classification -> Bool
$c== :: Classification -> Classification -> Bool
Eq, Eq Classification
Eq Classification
-> (Classification -> Classification -> Ordering)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Classification)
-> (Classification -> Classification -> Classification)
-> Ord Classification
Classification -> Classification -> Bool
Classification -> Classification -> Ordering
Classification -> Classification -> Classification
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Classification -> Classification -> Classification
$cmin :: Classification -> Classification -> Classification
max :: Classification -> Classification -> Classification
$cmax :: Classification -> Classification -> Classification
>= :: Classification -> Classification -> Bool
$c>= :: Classification -> Classification -> Bool
> :: Classification -> Classification -> Bool
$c> :: Classification -> Classification -> Bool
<= :: Classification -> Classification -> Bool
$c<= :: Classification -> Classification -> Bool
< :: Classification -> Classification -> Bool
$c< :: Classification -> Classification -> Bool
compare :: Classification -> Classification -> Ordering
$ccompare :: Classification -> Classification -> Ordering
Ord, Int -> Classification -> ShowS
[Classification] -> ShowS
Classification -> String
(Int -> Classification -> ShowS)
-> (Classification -> String)
-> ([Classification] -> ShowS)
-> Show Classification
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Classification] -> ShowS
$cshowList :: [Classification] -> ShowS
show :: Classification -> String
$cshow :: Classification -> String
showsPrec :: Int -> Classification -> ShowS
$cshowsPrec :: Int -> Classification -> ShowS
Show)

type Classifications = S.Set (VName, Classification)

freeClassifications :: FreeIn a => a -> Classifications
freeClassifications :: forall a. FreeIn a => a -> Classifications
freeClassifications =
  [(VName, Classification)] -> Classifications
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Classification)] -> Classifications)
-> (a -> [(VName, Classification)]) -> a -> Classifications
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> [Classification] -> [(VName, Classification)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Classification -> [Classification]
forall a. a -> [a]
repeat Classification
Other) ([VName] -> [(VName, Classification)])
-> (a -> [VName]) -> a -> [(VName, Classification)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (a -> Names) -> a -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Names
forall a. FreeIn a => a -> Names
freeIn

stmInputs :: Stm SOACS -> Classifications
stmInputs :: Stm SOACS -> Classifications
stmInputs (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
  (Pat Type, StmAux ()) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (Pat Type
Pat (LetDec SOACS)
pat, StmAux ()
StmAux (ExpDec SOACS)
aux) Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> Classifications
expInputs Exp SOACS
e

bodyInputs :: Body SOACS -> Classifications
bodyInputs :: Body SOACS -> Classifications
bodyInputs (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = (Stm SOACS -> Classifications) -> Stms SOACS -> Classifications
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs Stms SOACS
stms Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Result -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Result
res

expInputs :: Exp SOACS -> Classifications
expInputs :: Exp SOACS -> Classifications
expInputs (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
attr) =
  (Case (Body SOACS) -> Classifications)
-> [Case (Body SOACS)] -> Classifications
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body SOACS -> Classifications
bodyInputs (Body SOACS -> Classifications)
-> (Case (Body SOACS) -> Body SOACS)
-> Case (Body SOACS)
-> Classifications
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body SOACS) -> Body SOACS
forall body. Case body -> body
caseBody) [Case (Body SOACS)]
cases
    Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
defbody
    Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> ([SubExp], MatchDec ExtType) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications ([SubExp]
cond, MatchDec ExtType
MatchDec (BranchType SOACS)
attr)
expInputs (DoLoop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
b1) =
  ([(Param DeclType, SubExp)], LoopForm SOACS) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications ([(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
params, LoopForm SOACS
form) Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
b1
expInputs (Op Op SOACS
soac) = case Op SOACS
soac of
  Futhark.Screma SubExp
w [VName]
is ScremaForm SOACS
form -> [VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, ScremaForm SOACS) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, ScremaForm SOACS
form)
  Futhark.Hist SubExp
w [VName]
is [HistOp SOACS]
ops Lambda SOACS
lam -> [VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, [HistOp SOACS], Lambda SOACS) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, [HistOp SOACS]
ops, Lambda SOACS
lam)
  Futhark.Scatter SubExp
w [VName]
is Lambda SOACS
lam [(Shape, Int, VName)]
iws -> [VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, Lambda SOACS, [(Shape, Int, VName)]) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, Lambda SOACS
lam, [(Shape, Int, VName)]
iws)
  Futhark.Stream SubExp
w [VName]
is StreamForm SOACS
form [SubExp]
nes Lambda SOACS
lam ->
    [VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, StreamForm SOACS, [SubExp], Lambda SOACS)
-> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, StreamForm SOACS
form, [SubExp]
nes, Lambda SOACS
lam)
  Futhark.JVP {} -> SOAC SOACS -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
SOAC SOACS
soac
  Futhark.VJP {} -> SOAC SOACS -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
SOAC SOACS
soac
  where
    inputs :: [VName] -> Classifications
inputs = [(VName, Classification)] -> Classifications
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Classification)] -> Classifications)
-> ([VName] -> [(VName, Classification)])
-> [VName]
-> Classifications
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> [Classification] -> [(VName, Classification)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Classification -> [Classification]
forall a. a -> [a]
repeat Classification
SOACInput)
expInputs Exp SOACS
e
  | Just (VName
arr, ArrayTransform
_) <- Certs -> Exp SOACS -> Maybe (VName, ArrayTransform)
forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
H.transformFromExp Certs
forall a. Monoid a => a
mempty Exp SOACS
e =
      (VName, Classification) -> Classifications
forall a. a -> Set a
S.singleton (VName
arr, Classification
SOACInput)
        Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Names -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Exp SOACS
e Names -> Names -> Names
`namesSubtract` VName -> Names
oneName VName
arr)
  | Bool
otherwise = Exp SOACS -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Exp SOACS
e

stmNames :: Stm SOACS -> [VName]
stmNames :: Stm SOACS -> [VName]
stmNames = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm SOACS -> Pat Type) -> Stm SOACS -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Pat Type
forall rep. Stm rep -> Pat (LetDec rep)
stmPat

getStmRes :: EdgeGenerator
getStmRes :: EdgeGenerator
getStmRes (ResNode VName
name) = [(VName
name, VName -> EdgeT
Res VName
name)]
getStmRes NodeT
_ = []

getOutputs :: NodeT -> [VName]
getOutputs :: NodeT -> [VName]
getOutputs NodeT
node = case NodeT
node of
  (StmNode Stm SOACS
stm) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
  (ResNode VName
_) -> []
  (FreeNode VName
name) -> [VName
name]
  (MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
  (DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
  FinalNode {} -> String -> [VName]
forall a. HasCallStack => String -> a
error String
"Final nodes cannot generate edges"
  (SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) -> Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat

-- | Is there a possibility of fusion?
isDep :: EdgeT -> Bool
isDep :: EdgeT -> Bool
isDep (Dep VName
_) = Bool
True
isDep (Res VName
_) = Bool
True
isDep EdgeT
_ = Bool
False

-- | Is this an infusible edge?
isInf :: (G.Node, G.Node, EdgeT) -> Bool
isInf :: DepEdge -> Bool
isInf (Int
_, Int
_, EdgeT
e) = case EdgeT
e of
  InfDep VName
_ -> Bool
True
  Fake VName
_ -> Bool
True -- this is infusible to avoid simultaneous cons/dep edges
  EdgeT
_ -> Bool
False

-- | Is this a 'Cons' edge?
isCons :: EdgeT -> Bool
isCons :: EdgeT -> Bool
isCons (Cons VName
_) = Bool
True
isCons EdgeT
_ = Bool
False