{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fprof-auto-top #-}

--
-- Copyright (c) 2010, João Dias, Simon Marlow, Simon Peyton Jones,
-- and Norman Ramsey
--
-- Modifications copyright (c) The University of Glasgow 2012
--
-- This module is a specialised and optimised version of
-- Compiler.Hoopl.Dataflow in the hoopl package.  In particular it is
-- specialised to the UniqSM monad.
--

module Hoopl.Dataflow
  ( C, O, Block
  , lastNode, entryLabel
  , foldNodesBwdOO
  , foldRewriteNodesBwdOO
  , DataflowLattice(..), OldFact(..), NewFact(..), JoinedFact(..)
  , TransferFun, RewriteFun
  , Fact, FactBase
  , getFact, mkFactBase
  , analyzeCmmFwd, analyzeCmmBwd
  , rewriteCmmBwd
  , changedIf
  , joinOutFacts
  , joinFacts
  )
where

import GhcPrelude

import Cmm
import UniqSupply

import Data.Array
import Data.Maybe
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet

import Hoopl.Block
import Hoopl.Graph
import Hoopl.Collections
import Hoopl.Label

type family   Fact x f :: *
type instance Fact C f = FactBase f
type instance Fact O f = f

newtype OldFact a = OldFact a

newtype NewFact a = NewFact a

-- | The result of joining OldFact and NewFact.
data JoinedFact a
    = Changed !a     -- ^ Result is different than OldFact.
    | NotChanged !a  -- ^ Result is the same as OldFact.

getJoined :: JoinedFact a -> a
getJoined :: JoinedFact a -> a
getJoined (Changed a :: a
a) = a
a
getJoined (NotChanged a :: a
a) = a
a

changedIf :: Bool -> a -> JoinedFact a
changedIf :: Bool -> a -> JoinedFact a
changedIf True = a -> JoinedFact a
forall a. a -> JoinedFact a
Changed
changedIf False = a -> JoinedFact a
forall a. a -> JoinedFact a
NotChanged

type JoinFun a = OldFact a -> NewFact a -> JoinedFact a

data DataflowLattice a = DataflowLattice
    { DataflowLattice a -> a
fact_bot :: a
    , DataflowLattice a -> JoinFun a
fact_join :: JoinFun a
    }

data Direction = Fwd | Bwd

type TransferFun f = CmmBlock -> FactBase f -> FactBase f

-- | Function for rewrtiting and analysis combined. To be used with
-- @rewriteCmm@.
--
-- Currently set to work with @UniqSM@ monad, but we could probably abstract
-- that away (if we do that, we might want to specialize the fixpoint algorithms
-- to the particular monads through SPECIALIZE).
type RewriteFun f = CmmBlock -> FactBase f -> UniqSM (CmmBlock, FactBase f)

analyzeCmmBwd, analyzeCmmFwd
    :: DataflowLattice f
    -> TransferFun f
    -> CmmGraph
    -> FactBase f
    -> FactBase f
analyzeCmmBwd :: DataflowLattice f
-> TransferFun f -> CmmGraph -> FactBase f -> FactBase f
analyzeCmmBwd = Direction
-> DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
forall f.
Direction
-> DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
analyzeCmm Direction
Bwd
analyzeCmmFwd :: DataflowLattice f
-> TransferFun f -> CmmGraph -> FactBase f -> FactBase f
analyzeCmmFwd = Direction
-> DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
forall f.
Direction
-> DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
analyzeCmm Direction
Fwd

analyzeCmm
    :: Direction
    -> DataflowLattice f
    -> TransferFun f
    -> CmmGraph
    -> FactBase f
    -> FactBase f
analyzeCmm :: Direction
-> DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
analyzeCmm dir :: Direction
dir lattice :: DataflowLattice f
lattice transfer :: TransferFun f
transfer cmmGraph :: CmmGraph
cmmGraph initFact :: FactBase f
initFact =
    let entry :: BlockId
entry = CmmGraph -> BlockId
forall (n :: * -> * -> *). GenCmmGraph n -> BlockId
g_entry CmmGraph
cmmGraph
        hooplGraph :: Graph CmmNode C C
hooplGraph = CmmGraph -> Graph CmmNode C C
forall (n :: * -> * -> *). GenCmmGraph n -> Graph n C C
g_graph CmmGraph
cmmGraph
        blockMap :: LabelMap CmmBlock
blockMap =
            case Graph CmmNode C C
hooplGraph of
                GMany NothingO bm :: LabelMap CmmBlock
bm NothingO -> LabelMap CmmBlock
bm
    in Direction
-> DataflowLattice f
-> TransferFun f
-> BlockId
-> LabelMap CmmBlock
-> FactBase f
-> FactBase f
forall f.
Direction
-> DataflowLattice f
-> TransferFun f
-> BlockId
-> LabelMap CmmBlock
-> FactBase f
-> FactBase f
fixpointAnalysis Direction
dir DataflowLattice f
lattice TransferFun f
transfer BlockId
entry LabelMap CmmBlock
blockMap FactBase f
initFact

-- Fixpoint algorithm.
fixpointAnalysis
    :: forall f.
       Direction
    -> DataflowLattice f
    -> TransferFun f
    -> Label
    -> LabelMap CmmBlock
    -> FactBase f
    -> FactBase f
fixpointAnalysis :: Direction
-> DataflowLattice f
-> TransferFun f
-> BlockId
-> LabelMap CmmBlock
-> FactBase f
-> FactBase f
fixpointAnalysis direction :: Direction
direction lattice :: DataflowLattice f
lattice do_block :: TransferFun f
do_block entry :: BlockId
entry blockmap :: LabelMap CmmBlock
blockmap = IntHeap -> FactBase f -> FactBase f
loop IntHeap
start
  where
    -- Sorting the blocks helps to minimize the number of times we need to
    -- process blocks. For instance, for forward analysis we want to look at
    -- blocks in reverse postorder. Also, see comments for sortBlocks.
    blocks :: [CmmBlock]
blocks     = Direction -> BlockId -> LabelMap CmmBlock -> [CmmBlock]
forall (n :: * -> * -> *).
NonLocal n =>
Direction -> BlockId -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks Direction
direction BlockId
entry LabelMap CmmBlock
blockmap
    num_blocks :: Int
num_blocks = [CmmBlock] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CmmBlock]
blocks
    block_arr :: Array Int CmmBlock
block_arr  = {-# SCC "block_arr" #-} (Int, Int) -> [CmmBlock] -> Array Int CmmBlock
forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (0, Int
num_blocks Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1) [CmmBlock]
blocks
    start :: IntHeap
start      = {-# SCC "start" #-} [Int] -> IntHeap
IntSet.fromDistinctAscList
      [0 .. Int
num_blocks Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1]
    dep_blocks :: LabelMap IntHeap
dep_blocks = {-# SCC "dep_blocks" #-} Direction -> [CmmBlock] -> LabelMap IntHeap
mkDepBlocks Direction
direction [CmmBlock]
blocks
    join :: JoinFun f
join       = DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice

    loop
        :: IntHeap     -- ^ Worklist, i.e., blocks to process
        -> FactBase f  -- ^ Current result (increases monotonically)
        -> FactBase f
    loop :: IntHeap -> FactBase f -> FactBase f
loop todo :: IntHeap
todo !FactBase f
fbase1 | Just (index :: Int
index, todo1 :: IntHeap
todo1) <- IntHeap -> Maybe (Int, IntHeap)
IntSet.minView IntHeap
todo =
        let block :: CmmBlock
block = Array Int CmmBlock
block_arr Array Int CmmBlock -> Int -> CmmBlock
forall i e. Ix i => Array i e -> i -> e
! Int
index
            out_facts :: FactBase f
out_facts = {-# SCC "do_block" #-} TransferFun f
do_block CmmBlock
block FactBase f
fbase1
            -- For each of the outgoing edges, we join it with the current
            -- information in fbase1 and (if something changed) we update it
            -- and add the affected blocks to the worklist.
            (todo2 :: IntHeap
todo2, fbase2 :: FactBase f
fbase2) = {-# SCC "mapFoldWithKey" #-}
                ((IntHeap, FactBase f)
 -> KeyOf LabelMap -> f -> (IntHeap, FactBase f))
-> (IntHeap, FactBase f) -> FactBase f -> (IntHeap, FactBase f)
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey
                    (JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
forall f.
JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
updateFact JoinFun f
join LabelMap IntHeap
dep_blocks) (IntHeap
todo1, FactBase f
fbase1) FactBase f
out_facts
        in IntHeap -> FactBase f -> FactBase f
loop IntHeap
todo2 FactBase f
fbase2
    loop _ !FactBase f
fbase1 = FactBase f
fbase1

rewriteCmmBwd
    :: DataflowLattice f
    -> RewriteFun f
    -> CmmGraph
    -> FactBase f
    -> UniqSM (CmmGraph, FactBase f)
rewriteCmmBwd :: DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
rewriteCmmBwd = Direction
-> DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
forall f.
Direction
-> DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
rewriteCmm Direction
Bwd

rewriteCmm
    :: Direction
    -> DataflowLattice f
    -> RewriteFun f
    -> CmmGraph
    -> FactBase f
    -> UniqSM (CmmGraph, FactBase f)
rewriteCmm :: Direction
-> DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
rewriteCmm dir :: Direction
dir lattice :: DataflowLattice f
lattice rwFun :: RewriteFun f
rwFun cmmGraph :: CmmGraph
cmmGraph initFact :: FactBase f
initFact = do
    let entry :: BlockId
entry = CmmGraph -> BlockId
forall (n :: * -> * -> *). GenCmmGraph n -> BlockId
g_entry CmmGraph
cmmGraph
        hooplGraph :: Graph CmmNode C C
hooplGraph = CmmGraph -> Graph CmmNode C C
forall (n :: * -> * -> *). GenCmmGraph n -> Graph n C C
g_graph CmmGraph
cmmGraph
        blockMap1 :: LabelMap CmmBlock
blockMap1 =
            case Graph CmmNode C C
hooplGraph of
                GMany NothingO bm :: LabelMap CmmBlock
bm NothingO -> LabelMap CmmBlock
bm
    (blockMap2 :: LabelMap CmmBlock
blockMap2, facts :: FactBase f
facts) <-
        Direction
-> DataflowLattice f
-> RewriteFun f
-> BlockId
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
forall f.
Direction
-> DataflowLattice f
-> RewriteFun f
-> BlockId
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
fixpointRewrite Direction
dir DataflowLattice f
lattice RewriteFun f
rwFun BlockId
entry LabelMap CmmBlock
blockMap1 FactBase f
initFact
    (CmmGraph, FactBase f) -> UniqSM (CmmGraph, FactBase f)
forall (m :: * -> *) a. Monad m => a -> m a
return (CmmGraph
cmmGraph {g_graph :: Graph CmmNode C C
g_graph = MaybeO C (Block CmmNode O C)
-> LabelMap CmmBlock
-> MaybeO C (Block CmmNode C O)
-> Graph CmmNode C C
forall e (block :: (* -> * -> *) -> * -> * -> *) (n :: * -> * -> *)
       x.
MaybeO e (block n O C)
-> Body' block n -> MaybeO x (block n C O) -> Graph' block n e x
GMany MaybeO C (Block CmmNode O C)
forall t. MaybeO C t
NothingO LabelMap CmmBlock
blockMap2 MaybeO C (Block CmmNode C O)
forall t. MaybeO C t
NothingO}, FactBase f
facts)

fixpointRewrite
    :: forall f.
       Direction
    -> DataflowLattice f
    -> RewriteFun f
    -> Label
    -> LabelMap CmmBlock
    -> FactBase f
    -> UniqSM (LabelMap CmmBlock, FactBase f)
fixpointRewrite :: Direction
-> DataflowLattice f
-> RewriteFun f
-> BlockId
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
fixpointRewrite dir :: Direction
dir lattice :: DataflowLattice f
lattice do_block :: RewriteFun f
do_block entry :: BlockId
entry blockmap :: LabelMap CmmBlock
blockmap = IntHeap
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
loop IntHeap
start LabelMap CmmBlock
blockmap
  where
    -- Sorting the blocks helps to minimize the number of times we need to
    -- process blocks. For instance, for forward analysis we want to look at
    -- blocks in reverse postorder. Also, see comments for sortBlocks.
    blocks :: [CmmBlock]
blocks     = Direction -> BlockId -> LabelMap CmmBlock -> [CmmBlock]
forall (n :: * -> * -> *).
NonLocal n =>
Direction -> BlockId -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks Direction
dir BlockId
entry LabelMap CmmBlock
blockmap
    num_blocks :: Int
num_blocks = [CmmBlock] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CmmBlock]
blocks
    block_arr :: Array Int CmmBlock
block_arr  = {-# SCC "block_arr_rewrite" #-}
                 (Int, Int) -> [CmmBlock] -> Array Int CmmBlock
forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (0, Int
num_blocks Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1) [CmmBlock]
blocks
    start :: IntHeap
start      = {-# SCC "start_rewrite" #-}
                 [Int] -> IntHeap
IntSet.fromDistinctAscList [0 .. Int
num_blocks Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1]
    dep_blocks :: LabelMap IntHeap
dep_blocks = {-# SCC "dep_blocks_rewrite" #-} Direction -> [CmmBlock] -> LabelMap IntHeap
mkDepBlocks Direction
dir [CmmBlock]
blocks
    join :: JoinFun f
join       = DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice

    loop
        :: IntHeap            -- ^ Worklist, i.e., blocks to process
        -> LabelMap CmmBlock  -- ^ Rewritten blocks.
        -> FactBase f         -- ^ Current facts.
        -> UniqSM (LabelMap CmmBlock, FactBase f)
    loop :: IntHeap
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
loop todo :: IntHeap
todo !LabelMap CmmBlock
blocks1 !FactBase f
fbase1
      | Just (index :: Int
index, todo1 :: IntHeap
todo1) <- IntHeap -> Maybe (Int, IntHeap)
IntSet.minView IntHeap
todo = do
        -- Note that we use the *original* block here. This is important.
        -- We're optimistically rewriting blocks even before reaching the fixed
        -- point, which means that the rewrite might be incorrect. So if the
        -- facts change, we need to rewrite the original block again (taking
        -- into account the new facts).
        let block :: CmmBlock
block = Array Int CmmBlock
block_arr Array Int CmmBlock -> Int -> CmmBlock
forall i e. Ix i => Array i e -> i -> e
! Int
index
        (new_block :: CmmBlock
new_block, out_facts :: FactBase f
out_facts) <- {-# SCC "do_block_rewrite" #-}
            RewriteFun f
do_block CmmBlock
block FactBase f
fbase1
        let blocks2 :: LabelMap CmmBlock
blocks2 = KeyOf LabelMap
-> CmmBlock -> LabelMap CmmBlock -> LabelMap CmmBlock
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert (CmmBlock -> BlockId
forall (thing :: * -> * -> *) x.
NonLocal thing =>
thing C x -> BlockId
entryLabel CmmBlock
new_block) CmmBlock
new_block LabelMap CmmBlock
blocks1
            (todo2 :: IntHeap
todo2, fbase2 :: FactBase f
fbase2) = {-# SCC "mapFoldWithKey_rewrite" #-}
                ((IntHeap, FactBase f)
 -> KeyOf LabelMap -> f -> (IntHeap, FactBase f))
-> (IntHeap, FactBase f) -> FactBase f -> (IntHeap, FactBase f)
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey
                    (JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
forall f.
JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
updateFact JoinFun f
join LabelMap IntHeap
dep_blocks) (IntHeap
todo1, FactBase f
fbase1) FactBase f
out_facts
        IntHeap
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
loop IntHeap
todo2 LabelMap CmmBlock
blocks2 FactBase f
fbase2
    loop _ !LabelMap CmmBlock
blocks1 !FactBase f
fbase1 = (LabelMap CmmBlock, FactBase f)
-> UniqSM (LabelMap CmmBlock, FactBase f)
forall (m :: * -> *) a. Monad m => a -> m a
return (LabelMap CmmBlock
blocks1, FactBase f
fbase1)


{-
Note [Unreachable blocks]
~~~~~~~~~~~~~~~~~~~~~~~~~
A block that is not in the domain of tfb_fbase is "currently unreachable".
A currently-unreachable block is not even analyzed.  Reason: consider
constant prop and this graph, with entry point L1:
  L1: x:=3; goto L4
  L2: x:=4; goto L4
  L4: if x>3 goto L2 else goto L5
Here L2 is actually unreachable, but if we process it with bottom input fact,
we'll propagate (x=4) to L4, and nuke the otherwise-good rewriting of L4.

* If a currently-unreachable block is not analyzed, then its rewritten
  graph will not be accumulated in tfb_rg.  And that is good:
  unreachable blocks simply do not appear in the output.

* Note that clients must be careful to provide a fact (even if bottom)
  for each entry point. Otherwise useful blocks may be garbage collected.

* Note that updateFact must set the change-flag if a label goes from
  not-in-fbase to in-fbase, even if its fact is bottom.  In effect the
  real fact lattice is
       UNR
       bottom
       the points above bottom

* Even if the fact is going from UNR to bottom, we still call the
  client's fact_join function because it might give the client
  some useful debugging information.

* All of this only applies for *forward* ixpoints.  For the backward
  case we must treat every block as reachable; it might finish with a
  'return', and therefore have no successors, for example.
-}


-----------------------------------------------------------------------------
--  Pieces that are shared by fixpoint and fixpoint_anal
-----------------------------------------------------------------------------

-- | Sort the blocks into the right order for analysis. This means reverse
-- postorder for a forward analysis. For the backward one, we simply reverse
-- that (see Note [Backward vs forward analysis]).
sortBlocks
    :: NonLocal n
    => Direction -> Label -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks :: Direction -> BlockId -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks direction :: Direction
direction entry :: BlockId
entry blockmap :: LabelMap (Block n C C)
blockmap =
    case Direction
direction of
        Fwd -> [Block n C C]
fwd
        Bwd -> [Block n C C] -> [Block n C C]
forall a. [a] -> [a]
reverse [Block n C C]
fwd
  where
    fwd :: [Block n C C]
fwd = LabelMap (Block n C C) -> BlockId -> [Block n C C]
forall (block :: * -> * -> *).
NonLocal block =>
LabelMap (block C C) -> BlockId -> [block C C]
revPostorderFrom LabelMap (Block n C C)
blockmap BlockId
entry

-- Note [Backward vs forward analysis]
--
-- The forward and backward cases are not dual.  In the forward case, the entry
-- points are known, and one simply traverses the body blocks from those points.
-- In the backward case, something is known about the exit points, but a
-- backward analysis must also include reachable blocks that don't reach the
-- exit, as in a procedure that loops forever and has side effects.)
-- For instance, let E be the entry and X the exit blocks (arrows indicate
-- control flow)
--   E -> X
--   E -> B
--   B -> C
--   C -> B
-- We do need to include B and C even though they're unreachable in the
-- *reverse* graph (that we could use for backward analysis):
--   E <- X
--   E <- B
--   B <- C
--   C <- B
-- So when sorting the blocks for the backward analysis, we simply take the
-- reverse of what is used for the forward one.


-- | Construct a mapping from a @Label@ to the block indexes that should be
-- re-analyzed if the facts at that @Label@ change.
--
-- Note that we're considering here the entry point of the block, so if the
-- facts change at the entry:
-- * for a backward analysis we need to re-analyze all the predecessors, but
-- * for a forward analysis, we only need to re-analyze the current block
--   (and that will in turn propagate facts into its successors).
mkDepBlocks :: Direction -> [CmmBlock] -> LabelMap IntSet
mkDepBlocks :: Direction -> [CmmBlock] -> LabelMap IntHeap
mkDepBlocks Fwd blocks :: [CmmBlock]
blocks = [CmmBlock] -> Int -> LabelMap IntHeap -> LabelMap IntHeap
forall (map :: * -> *) (thing :: * -> * -> *) x.
(IsMap map, NonLocal thing, KeyOf map ~ BlockId) =>
[thing C x] -> Int -> map IntHeap -> map IntHeap
go [CmmBlock]
blocks 0 LabelMap IntHeap
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
  where
    go :: [thing C x] -> Int -> map IntHeap -> map IntHeap
go []     !Int
_ !map IntHeap
dep_map = map IntHeap
dep_map
    go (b :: thing C x
b:bs :: [thing C x]
bs) !Int
n !map IntHeap
dep_map =
        [thing C x] -> Int -> map IntHeap -> map IntHeap
go [thing C x]
bs (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (map IntHeap -> map IntHeap) -> map IntHeap -> map IntHeap
forall a b. (a -> b) -> a -> b
$ KeyOf map -> IntHeap -> map IntHeap -> map IntHeap
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert (thing C x -> BlockId
forall (thing :: * -> * -> *) x.
NonLocal thing =>
thing C x -> BlockId
entryLabel thing C x
b) (Int -> IntHeap
IntSet.singleton Int
n) map IntHeap
dep_map
mkDepBlocks Bwd blocks :: [CmmBlock]
blocks = [CmmBlock] -> Int -> LabelMap IntHeap -> LabelMap IntHeap
forall (map :: * -> *) (thing :: * -> * -> *) e.
(IsMap map, NonLocal thing, KeyOf map ~ BlockId) =>
[thing e C] -> Int -> map IntHeap -> map IntHeap
go [CmmBlock]
blocks 0 LabelMap IntHeap
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
  where
    go :: [thing e C] -> Int -> map IntHeap -> map IntHeap
go []     !Int
_ !map IntHeap
dep_map = map IntHeap
dep_map
    go (b :: thing e C
b:bs :: [thing e C]
bs) !Int
n !map IntHeap
dep_map =
        let insert :: map IntHeap -> BlockId -> map IntHeap
insert m :: map IntHeap
m l :: BlockId
l = (IntHeap -> IntHeap -> IntHeap)
-> KeyOf map -> IntHeap -> map IntHeap -> map IntHeap
forall (map :: * -> *) a.
IsMap map =>
(a -> a -> a) -> KeyOf map -> a -> map a -> map a
mapInsertWith IntHeap -> IntHeap -> IntHeap
IntSet.union KeyOf map
BlockId
l (Int -> IntHeap
IntSet.singleton Int
n) map IntHeap
m
        in [thing e C] -> Int -> map IntHeap -> map IntHeap
go [thing e C]
bs (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (map IntHeap -> map IntHeap) -> map IntHeap -> map IntHeap
forall a b. (a -> b) -> a -> b
$ (map IntHeap -> BlockId -> map IntHeap)
-> map IntHeap -> [BlockId] -> map IntHeap
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' map IntHeap -> BlockId -> map IntHeap
insert map IntHeap
dep_map (thing e C -> [BlockId]
forall (thing :: * -> * -> *) e.
NonLocal thing =>
thing e C -> [BlockId]
successors thing e C
b)

-- | After some new facts have been generated by analysing a block, we
-- fold this function over them to generate (a) a list of block
-- indices to (re-)analyse, and (b) the new FactBase.
updateFact
    :: JoinFun f
    -> LabelMap IntSet
    -> (IntHeap, FactBase f)
    -> Label
    -> f -- out fact
    -> (IntHeap, FactBase f)
updateFact :: JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
updateFact fact_join :: JoinFun f
fact_join dep_blocks :: LabelMap IntHeap
dep_blocks (todo :: IntHeap
todo, fbase :: FactBase f
fbase) lbl :: BlockId
lbl new_fact :: f
new_fact
  = case BlockId -> FactBase f -> Maybe f
forall f. BlockId -> FactBase f -> Maybe f
lookupFact BlockId
lbl FactBase f
fbase of
      Nothing ->
          -- Note [No old fact]
          let !z :: FactBase f
z = KeyOf LabelMap -> f -> FactBase f -> FactBase f
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
lbl f
new_fact FactBase f
fbase in (IntHeap
changed, FactBase f
z)
      Just old_fact :: f
old_fact ->
          case JoinFun f
fact_join (f -> OldFact f
forall a. a -> OldFact a
OldFact f
old_fact) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
new_fact) of
              (NotChanged _) -> (IntHeap
todo, FactBase f
fbase)
              (Changed f :: f
f) -> let !z :: FactBase f
z = KeyOf LabelMap -> f -> FactBase f -> FactBase f
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
lbl f
f FactBase f
fbase in (IntHeap
changed, FactBase f
z)
  where
    changed :: IntHeap
changed = IntHeap
todo IntHeap -> IntHeap -> IntHeap
`IntSet.union`
              IntHeap -> KeyOf LabelMap -> LabelMap IntHeap -> IntHeap
forall (map :: * -> *) a. IsMap map => a -> KeyOf map -> map a -> a
mapFindWithDefault IntHeap
IntSet.empty KeyOf LabelMap
BlockId
lbl LabelMap IntHeap
dep_blocks

{-
Note [No old fact]

We know that the new_fact is >= _|_, so we don't need to join.  However,
if the new fact is also _|_, and we have already analysed its block,
we don't need to record a change.  So there's a tradeoff here.  It turns
out that always recording a change is faster.
-}

----------------------------------------------------------------
--       Utilities
----------------------------------------------------------------

-- Fact lookup: the fact `orelse` bottom
getFact  :: DataflowLattice f -> Label -> FactBase f -> f
getFact :: DataflowLattice f -> BlockId -> FactBase f -> f
getFact lat :: DataflowLattice f
lat l :: BlockId
l fb :: FactBase f
fb = case BlockId -> FactBase f -> Maybe f
forall f. BlockId -> FactBase f -> Maybe f
lookupFact BlockId
l FactBase f
fb of Just  f :: f
f -> f
f
                                           Nothing -> DataflowLattice f -> f
forall a. DataflowLattice a -> a
fact_bot DataflowLattice f
lat

-- | Returns the result of joining the facts from all the successors of the
-- provided node or block.
joinOutFacts :: (NonLocal n) => DataflowLattice f -> n e C -> FactBase f -> f
joinOutFacts :: DataflowLattice f -> n e C -> FactBase f -> f
joinOutFacts lattice :: DataflowLattice f
lattice nonLocal :: n e C
nonLocal fact_base :: FactBase f
fact_base = (f -> f -> f) -> f -> [f] -> f
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' f -> f -> f
join (DataflowLattice f -> f
forall a. DataflowLattice a -> a
fact_bot DataflowLattice f
lattice) [f]
facts
  where
    join :: f -> f -> f
join new :: f
new old :: f
old = JoinedFact f -> f
forall a. JoinedFact a -> a
getJoined (JoinedFact f -> f) -> JoinedFact f -> f
forall a b. (a -> b) -> a -> b
$ DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice (f -> OldFact f
forall a. a -> OldFact a
OldFact f
old) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
new)
    facts :: [f]
facts =
        [ Maybe f -> f
forall a. HasCallStack => Maybe a -> a
fromJust Maybe f
fact
        | BlockId
s <- n e C -> [BlockId]
forall (thing :: * -> * -> *) e.
NonLocal thing =>
thing e C -> [BlockId]
successors n e C
nonLocal
        , let fact :: Maybe f
fact = BlockId -> FactBase f -> Maybe f
forall f. BlockId -> FactBase f -> Maybe f
lookupFact BlockId
s FactBase f
fact_base
        , Maybe f -> Bool
forall a. Maybe a -> Bool
isJust Maybe f
fact
        ]

joinFacts :: DataflowLattice f -> [f] -> f
joinFacts :: DataflowLattice f -> [f] -> f
joinFacts lattice :: DataflowLattice f
lattice facts :: [f]
facts  = (f -> f -> f) -> f -> [f] -> f
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' f -> f -> f
join (DataflowLattice f -> f
forall a. DataflowLattice a -> a
fact_bot DataflowLattice f
lattice) [f]
facts
  where
    join :: f -> f -> f
join new :: f
new old :: f
old = JoinedFact f -> f
forall a. JoinedFact a -> a
getJoined (JoinedFact f -> f) -> JoinedFact f -> f
forall a b. (a -> b) -> a -> b
$ DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice (f -> OldFact f
forall a. a -> OldFact a
OldFact f
old) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
new)

-- | Returns the joined facts for each label.
mkFactBase :: DataflowLattice f -> [(Label, f)] -> FactBase f
mkFactBase :: DataflowLattice f -> [(BlockId, f)] -> FactBase f
mkFactBase lattice :: DataflowLattice f
lattice = (FactBase f -> (BlockId, f) -> FactBase f)
-> FactBase f -> [(BlockId, f)] -> FactBase f
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' FactBase f -> (BlockId, f) -> FactBase f
add FactBase f
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
  where
    join :: JoinFun f
join = DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice

    add :: FactBase f -> (BlockId, f) -> FactBase f
add result :: FactBase f
result (l :: BlockId
l, f1 :: f
f1) =
        let !newFact :: f
newFact =
                case KeyOf LabelMap -> FactBase f -> Maybe f
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
l FactBase f
result of
                    Nothing -> f
f1
                    Just f2 :: f
f2 -> JoinedFact f -> f
forall a. JoinedFact a -> a
getJoined (JoinedFact f -> f) -> JoinedFact f -> f
forall a b. (a -> b) -> a -> b
$ JoinFun f
join (f -> OldFact f
forall a. a -> OldFact a
OldFact f
f1) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
f2)
        in KeyOf LabelMap -> f -> FactBase f -> FactBase f
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
l f
newFact FactBase f
result

-- | Folds backward over all nodes of an open-open block.
-- Strict in the accumulator.
foldNodesBwdOO :: (CmmNode O O -> f -> f) -> Block CmmNode O O -> f -> f
foldNodesBwdOO :: (CmmNode O O -> f -> f) -> Block CmmNode O O -> f -> f
foldNodesBwdOO funOO :: CmmNode O O -> f -> f
funOO = Block CmmNode O O -> f -> f
go
  where
    go :: Block CmmNode O O -> f -> f
go (BCat b1 :: Block CmmNode O O
b1 b2 :: Block CmmNode O O
b2) f :: f
f = Block CmmNode O O -> f -> f
go Block CmmNode O O
b1 (f -> f) -> f -> f
forall a b. (a -> b) -> a -> b
$! Block CmmNode O O -> f -> f
go Block CmmNode O O
b2 f
f
    go (BSnoc h :: Block CmmNode O O
h n :: CmmNode O O
n) f :: f
f = Block CmmNode O O -> f -> f
go Block CmmNode O O
h (f -> f) -> f -> f
forall a b. (a -> b) -> a -> b
$! CmmNode O O -> f -> f
funOO CmmNode O O
n f
f
    go (BCons n :: CmmNode O O
n t :: Block CmmNode O O
t) f :: f
f = CmmNode O O -> f -> f
funOO CmmNode O O
n (f -> f) -> f -> f
forall a b. (a -> b) -> a -> b
$! Block CmmNode O O -> f -> f
go Block CmmNode O O
t f
f
    go (BMiddle n :: CmmNode O O
n) f :: f
f = CmmNode O O -> f -> f
funOO CmmNode O O
n f
f
    go BNil f :: f
f = f
f
{-# INLINABLE foldNodesBwdOO #-}

-- | Folds backward over all the nodes of an open-open block and allows
-- rewriting them. The accumulator is both the block of nodes and @f@ (usually
-- dataflow facts).
-- Strict in both accumulated parts.
foldRewriteNodesBwdOO
    :: forall f.
       (CmmNode O O -> f -> UniqSM (Block CmmNode O O, f))
    -> Block CmmNode O O
    -> f
    -> UniqSM (Block CmmNode O O, f)
foldRewriteNodesBwdOO :: (CmmNode O O -> f -> UniqSM (Block CmmNode O O, f))
-> Block CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
foldRewriteNodesBwdOO rewriteOO :: CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
rewriteOO initBlock :: Block CmmNode O O
initBlock initFacts :: f
initFacts = Block CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
go Block CmmNode O O
initBlock f
initFacts
  where
    go :: Block CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
go (BCons node1 :: CmmNode O O
node1 block1 :: Block CmmNode O O
block1) !f
fact1 = (CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
rewriteOO CmmNode O O
node1 (f -> UniqSM (Block CmmNode O O, f))
-> (f -> UniqSM (Block CmmNode O O, f))
-> f
-> UniqSM (Block CmmNode O O, f)
forall (m :: * -> *) t (n :: * -> * -> *) b t.
Monad m =>
(t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
`comp` Block CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
go Block CmmNode O O
block1) f
fact1
    go (BSnoc block1 :: Block CmmNode O O
block1 node1 :: CmmNode O O
node1) !f
fact1 = (Block CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
go Block CmmNode O O
block1 (f -> UniqSM (Block CmmNode O O, f))
-> (f -> UniqSM (Block CmmNode O O, f))
-> f
-> UniqSM (Block CmmNode O O, f)
forall (m :: * -> *) t (n :: * -> * -> *) b t.
Monad m =>
(t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
`comp` CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
rewriteOO CmmNode O O
node1) f
fact1
    go (BCat blockA1 :: Block CmmNode O O
blockA1 blockB1 :: Block CmmNode O O
blockB1) !f
fact1 = (Block CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
go Block CmmNode O O
blockA1 (f -> UniqSM (Block CmmNode O O, f))
-> (f -> UniqSM (Block CmmNode O O, f))
-> f
-> UniqSM (Block CmmNode O O, f)
forall (m :: * -> *) t (n :: * -> * -> *) b t.
Monad m =>
(t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
`comp` Block CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
go Block CmmNode O O
blockB1) f
fact1
    go (BMiddle node :: CmmNode O O
node) !f
fact1 = CmmNode O O -> f -> UniqSM (Block CmmNode O O, f)
rewriteOO CmmNode O O
node f
fact1
    go BNil !f
fact = (Block CmmNode O O, f) -> UniqSM (Block CmmNode O O, f)
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O O
forall (n :: * -> * -> *). Block n O O
BNil, f
fact)

    comp :: (t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
comp rew1 :: t -> m (Block n O O, b)
rew1 rew2 :: t -> m (Block n O O, t)
rew2 = \f1 :: t
f1 -> do
        (b :: Block n O O
b, f2 :: t
f2) <- t -> m (Block n O O, t)
rew2 t
f1
        (a :: Block n O O
a, !b
f3) <- t -> m (Block n O O, b)
rew1 t
f2
        let !c :: Block n O O
c = Block n O O -> Block n O O -> Block n O O
forall (n :: * -> * -> *).
Block n O O -> Block n O O -> Block n O O
joinBlocksOO Block n O O
a Block n O O
b
        (Block n O O, b) -> m (Block n O O, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Block n O O
c, b
f3)
    {-# INLINE comp #-}
{-# INLINABLE foldRewriteNodesBwdOO #-}

joinBlocksOO :: Block n O O -> Block n O O -> Block n O O
joinBlocksOO :: Block n O O -> Block n O O -> Block n O O
joinBlocksOO BNil b :: Block n O O
b = Block n O O
b
joinBlocksOO b :: Block n O O
b BNil = Block n O O
b
joinBlocksOO (BMiddle n :: n O O
n) b :: Block n O O
b = n O O -> Block n O O -> Block n O O
forall (n :: * -> * -> *) x. n O O -> Block n O x -> Block n O x
blockCons n O O
n Block n O O
b
joinBlocksOO b :: Block n O O
b (BMiddle n :: n O O
n) = Block n O O -> n O O -> Block n O O
forall (n :: * -> * -> *) e. Block n e O -> n O O -> Block n e O
blockSnoc Block n O O
b n O O
n
joinBlocksOO b1 :: Block n O O
b1 b2 :: Block n O O
b2 = Block n O O -> Block n O O -> Block n O O
forall (n :: * -> * -> *).
Block n O O -> Block n O O -> Block n O O
BCat Block n O O
b1 Block n O O
b2

type IntHeap = IntSet