--
-- Copyright (c) 2018 Andreas Klebinger
--

{-# LANGUAGE TypeFamilies, ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}

module CFG
    ( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..)
    , TransitionSource(..)

    --Modify the CFG
    , addWeightEdge, addEdge
    , delEdge, delNode
    , addNodesBetween, shortcutWeightMap
    , reverseEdges, filterEdges
    , addImmediateSuccessor
    , mkWeightInfo, adjustEdgeWeight, setEdgeWeight

    --Query the CFG
    , infoEdgeList, edgeList
    , getSuccessorEdges, getSuccessors
    , getSuccEdgesSorted
    , getEdgeInfo
    , getCfgNodes, hasNode

    -- Loop Information
    , loopMembers, loopLevels, loopInfo

    --Construction/Misc
    , getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg

    --Find backedges and update their weight
    , optimizeCFG
    , mkGlobalWeights

     )
where

#include "HsVersions.h"

import GhcPrelude

import BlockId
import Cmm

import CmmUtils
import CmmSwitch
import Hoopl.Collections
import Hoopl.Label
import Hoopl.Block
import qualified Hoopl.Graph as G

import Util
import Digraph
import Maybes

import Unique
import qualified Dominators as Dom
import Data.IntMap.Strict (IntMap)
import Data.IntSet (IntSet)

import qualified Data.IntMap.Strict as IM
import qualified Data.Map as M
import qualified Data.IntSet as IS
import qualified Data.Set as S
import Data.Tree
import Data.Bifunctor

import Outputable
-- DEBUGGING ONLY
--import Debug
-- import Debug.Trace
--import OrdList
--import Debug.Trace
import PprCmm () -- For Outputable instances
import qualified DynFlags as D

import Data.List (sort, nub, partition)
import Data.STRef.Strict
import Control.Monad.ST

import Data.Array.MArray
import Data.Array.ST
import Data.Array.IArray
import Data.Array.Unsafe (unsafeFreeze)
import Data.Array.Base (unsafeRead, unsafeWrite)

import Control.Monad

type Prob = Double

type Edge = (BlockId, BlockId)
type Edges = [Edge]

newtype EdgeWeight
  = EdgeWeight { EdgeWeight -> Double
weightToDouble :: Double }
  deriving (EdgeWeight -> EdgeWeight -> Bool
(EdgeWeight -> EdgeWeight -> Bool)
-> (EdgeWeight -> EdgeWeight -> Bool) -> Eq EdgeWeight
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EdgeWeight -> EdgeWeight -> Bool
$c/= :: EdgeWeight -> EdgeWeight -> Bool
== :: EdgeWeight -> EdgeWeight -> Bool
$c== :: EdgeWeight -> EdgeWeight -> Bool
Eq,Eq EdgeWeight
Eq EdgeWeight
-> (EdgeWeight -> EdgeWeight -> Ordering)
-> (EdgeWeight -> EdgeWeight -> Bool)
-> (EdgeWeight -> EdgeWeight -> Bool)
-> (EdgeWeight -> EdgeWeight -> Bool)
-> (EdgeWeight -> EdgeWeight -> Bool)
-> (EdgeWeight -> EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight -> EdgeWeight)
-> Ord EdgeWeight
EdgeWeight -> EdgeWeight -> Bool
EdgeWeight -> EdgeWeight -> Ordering
EdgeWeight -> EdgeWeight -> EdgeWeight
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EdgeWeight -> EdgeWeight -> EdgeWeight
$cmin :: EdgeWeight -> EdgeWeight -> EdgeWeight
max :: EdgeWeight -> EdgeWeight -> EdgeWeight
$cmax :: EdgeWeight -> EdgeWeight -> EdgeWeight
>= :: EdgeWeight -> EdgeWeight -> Bool
$c>= :: EdgeWeight -> EdgeWeight -> Bool
> :: EdgeWeight -> EdgeWeight -> Bool
$c> :: EdgeWeight -> EdgeWeight -> Bool
<= :: EdgeWeight -> EdgeWeight -> Bool
$c<= :: EdgeWeight -> EdgeWeight -> Bool
< :: EdgeWeight -> EdgeWeight -> Bool
$c< :: EdgeWeight -> EdgeWeight -> Bool
compare :: EdgeWeight -> EdgeWeight -> Ordering
$ccompare :: EdgeWeight -> EdgeWeight -> Ordering
$cp1Ord :: Eq EdgeWeight
Ord,Int -> EdgeWeight
EdgeWeight -> Int
EdgeWeight -> [EdgeWeight]
EdgeWeight -> EdgeWeight
EdgeWeight -> EdgeWeight -> [EdgeWeight]
EdgeWeight -> EdgeWeight -> EdgeWeight -> [EdgeWeight]
(EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight)
-> (Int -> EdgeWeight)
-> (EdgeWeight -> Int)
-> (EdgeWeight -> [EdgeWeight])
-> (EdgeWeight -> EdgeWeight -> [EdgeWeight])
-> (EdgeWeight -> EdgeWeight -> [EdgeWeight])
-> (EdgeWeight -> EdgeWeight -> EdgeWeight -> [EdgeWeight])
-> Enum EdgeWeight
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: EdgeWeight -> EdgeWeight -> EdgeWeight -> [EdgeWeight]
$cenumFromThenTo :: EdgeWeight -> EdgeWeight -> EdgeWeight -> [EdgeWeight]
enumFromTo :: EdgeWeight -> EdgeWeight -> [EdgeWeight]
$cenumFromTo :: EdgeWeight -> EdgeWeight -> [EdgeWeight]
enumFromThen :: EdgeWeight -> EdgeWeight -> [EdgeWeight]
$cenumFromThen :: EdgeWeight -> EdgeWeight -> [EdgeWeight]
enumFrom :: EdgeWeight -> [EdgeWeight]
$cenumFrom :: EdgeWeight -> [EdgeWeight]
fromEnum :: EdgeWeight -> Int
$cfromEnum :: EdgeWeight -> Int
toEnum :: Int -> EdgeWeight
$ctoEnum :: Int -> EdgeWeight
pred :: EdgeWeight -> EdgeWeight
$cpred :: EdgeWeight -> EdgeWeight
succ :: EdgeWeight -> EdgeWeight
$csucc :: EdgeWeight -> EdgeWeight
Enum,Integer -> EdgeWeight
EdgeWeight -> EdgeWeight
EdgeWeight -> EdgeWeight -> EdgeWeight
(EdgeWeight -> EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight)
-> (Integer -> EdgeWeight)
-> Num EdgeWeight
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> EdgeWeight
$cfromInteger :: Integer -> EdgeWeight
signum :: EdgeWeight -> EdgeWeight
$csignum :: EdgeWeight -> EdgeWeight
abs :: EdgeWeight -> EdgeWeight
$cabs :: EdgeWeight -> EdgeWeight
negate :: EdgeWeight -> EdgeWeight
$cnegate :: EdgeWeight -> EdgeWeight
* :: EdgeWeight -> EdgeWeight -> EdgeWeight
$c* :: EdgeWeight -> EdgeWeight -> EdgeWeight
- :: EdgeWeight -> EdgeWeight -> EdgeWeight
$c- :: EdgeWeight -> EdgeWeight -> EdgeWeight
+ :: EdgeWeight -> EdgeWeight -> EdgeWeight
$c+ :: EdgeWeight -> EdgeWeight -> EdgeWeight
Num,Num EdgeWeight
Ord EdgeWeight
Num EdgeWeight
-> Ord EdgeWeight -> (EdgeWeight -> Rational) -> Real EdgeWeight
EdgeWeight -> Rational
forall a. Num a -> Ord a -> (a -> Rational) -> Real a
toRational :: EdgeWeight -> Rational
$ctoRational :: EdgeWeight -> Rational
$cp2Real :: Ord EdgeWeight
$cp1Real :: Num EdgeWeight
Real,Num EdgeWeight
Num EdgeWeight
-> (EdgeWeight -> EdgeWeight -> EdgeWeight)
-> (EdgeWeight -> EdgeWeight)
-> (Rational -> EdgeWeight)
-> Fractional EdgeWeight
Rational -> EdgeWeight
EdgeWeight -> EdgeWeight
EdgeWeight -> EdgeWeight -> EdgeWeight
forall a.
Num a
-> (a -> a -> a) -> (a -> a) -> (Rational -> a) -> Fractional a
fromRational :: Rational -> EdgeWeight
$cfromRational :: Rational -> EdgeWeight
recip :: EdgeWeight -> EdgeWeight
$crecip :: EdgeWeight -> EdgeWeight
/ :: EdgeWeight -> EdgeWeight -> EdgeWeight
$c/ :: EdgeWeight -> EdgeWeight -> EdgeWeight
$cp1Fractional :: Num EdgeWeight
Fractional)

instance Outputable EdgeWeight where
  ppr :: EdgeWeight -> SDoc
ppr (EdgeWeight Double
w) = Int -> Double -> SDoc
doublePrec Int
5 Double
w

type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo)

-- | A control flow graph where edges have been annotated with a weight.
-- Implemented as IntMap (IntMap <edgeData>)
-- We must uphold the invariant that for each edge A -> B we must have:
-- A entry B in the outer map.
-- A entry B in the map we get when looking up A.
-- Maintaining this invariant is useful as any failed lookup now indicates
-- an actual error in code which might go unnoticed for a while
-- otherwise.
type CFG = EdgeInfoMap EdgeInfo

data CfgEdge
  = CfgEdge
  { CfgEdge -> BlockId
edgeFrom :: !BlockId
  , CfgEdge -> BlockId
edgeTo :: !BlockId
  , CfgEdge -> EdgeInfo
edgeInfo :: !EdgeInfo
  }

-- | Careful! Since we assume there is at most one edge from A to B
--   the Eq instance does not consider weight.
instance Eq CfgEdge where
  == :: CfgEdge -> CfgEdge -> Bool
(==) (CfgEdge BlockId
from1 BlockId
to1 EdgeInfo
_) (CfgEdge BlockId
from2 BlockId
to2 EdgeInfo
_)
    = BlockId
from1 BlockId -> BlockId -> Bool
forall a. Eq a => a -> a -> Bool
== BlockId
from2 Bool -> Bool -> Bool
&& BlockId
to1 BlockId -> BlockId -> Bool
forall a. Eq a => a -> a -> Bool
== BlockId
to2

-- | Edges are sorted ascending pointwise by weight, source and destination
instance Ord CfgEdge where
  compare :: CfgEdge -> CfgEdge -> Ordering
compare (CfgEdge BlockId
from1 BlockId
to1 (EdgeInfo {edgeWeight :: EdgeInfo -> EdgeWeight
edgeWeight = EdgeWeight
weight1}))
          (CfgEdge BlockId
from2 BlockId
to2 (EdgeInfo {edgeWeight :: EdgeInfo -> EdgeWeight
edgeWeight = EdgeWeight
weight2}))
    | EdgeWeight
weight1 EdgeWeight -> EdgeWeight -> Bool
forall a. Ord a => a -> a -> Bool
< EdgeWeight
weight2 Bool -> Bool -> Bool
|| EdgeWeight
weight1 EdgeWeight -> EdgeWeight -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeWeight
weight2 Bool -> Bool -> Bool
&& BlockId
from1 BlockId -> BlockId -> Bool
forall a. Ord a => a -> a -> Bool
< BlockId
from2 Bool -> Bool -> Bool
||
      EdgeWeight
weight1 EdgeWeight -> EdgeWeight -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeWeight
weight2 Bool -> Bool -> Bool
&& BlockId
from1 BlockId -> BlockId -> Bool
forall a. Eq a => a -> a -> Bool
== BlockId
from2 Bool -> Bool -> Bool
&& BlockId
to1 BlockId -> BlockId -> Bool
forall a. Ord a => a -> a -> Bool
< BlockId
to2
    = Ordering
LT
    | BlockId
from1 BlockId -> BlockId -> Bool
forall a. Eq a => a -> a -> Bool
== BlockId
from2 Bool -> Bool -> Bool
&& BlockId
to1 BlockId -> BlockId -> Bool
forall a. Eq a => a -> a -> Bool
== BlockId
to2 Bool -> Bool -> Bool
&& EdgeWeight
weight1 EdgeWeight -> EdgeWeight -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeWeight
weight2
    = Ordering
EQ
    | Bool
otherwise
    = Ordering
GT

instance Outputable CfgEdge where
  ppr :: CfgEdge -> SDoc
ppr (CfgEdge BlockId
from1 BlockId
to1 EdgeInfo
edgeInfo)
    = SDoc -> SDoc
parens (BlockId -> SDoc
forall a. Outputable a => a -> SDoc
ppr BlockId
from1 SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"-(" SDoc -> SDoc -> SDoc
<> EdgeInfo -> SDoc
forall a. Outputable a => a -> SDoc
ppr EdgeInfo
edgeInfo SDoc -> SDoc -> SDoc
<> String -> SDoc
text String
")->" SDoc -> SDoc -> SDoc
<+> BlockId -> SDoc
forall a. Outputable a => a -> SDoc
ppr BlockId
to1)

-- | Can we trace back a edge to a specific Cmm Node
-- or has it been introduced during assembly codegen. We use this to maintain
-- some information which would otherwise be lost during the
-- Cmm <-> asm transition.
-- See also Note [Inverting Conditional Branches]
data TransitionSource
  = CmmSource { TransitionSource -> CmmNode O C
trans_cmmNode :: (CmmNode O C)
              , TransitionSource -> BranchInfo
trans_info :: BranchInfo }
  | AsmCodeGen
  deriving (TransitionSource -> TransitionSource -> Bool
(TransitionSource -> TransitionSource -> Bool)
-> (TransitionSource -> TransitionSource -> Bool)
-> Eq TransitionSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransitionSource -> TransitionSource -> Bool
$c/= :: TransitionSource -> TransitionSource -> Bool
== :: TransitionSource -> TransitionSource -> Bool
$c== :: TransitionSource -> TransitionSource -> Bool
Eq)

data BranchInfo = NoInfo         -- ^ Unknown, but not heap or stack check.
                | HeapStackCheck -- ^ Heap or stack check
    deriving BranchInfo -> BranchInfo -> Bool
(BranchInfo -> BranchInfo -> Bool)
-> (BranchInfo -> BranchInfo -> Bool) -> Eq BranchInfo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BranchInfo -> BranchInfo -> Bool
$c/= :: BranchInfo -> BranchInfo -> Bool
== :: BranchInfo -> BranchInfo -> Bool
$c== :: BranchInfo -> BranchInfo -> Bool
Eq

instance Outputable BranchInfo where
    ppr :: BranchInfo -> SDoc
ppr BranchInfo
NoInfo = String -> SDoc
text String
"regular"
    ppr BranchInfo
HeapStackCheck = String -> SDoc
text String
"heap/stack"

isHeapOrStackCheck :: TransitionSource -> Bool
isHeapOrStackCheck :: TransitionSource -> Bool
isHeapOrStackCheck (CmmSource { trans_info :: TransitionSource -> BranchInfo
trans_info = BranchInfo
HeapStackCheck}) = Bool
True
isHeapOrStackCheck TransitionSource
_ = Bool
False

-- | Information about edges
data EdgeInfo
  = EdgeInfo
  { EdgeInfo -> TransitionSource
transitionSource :: !TransitionSource
  , EdgeInfo -> EdgeWeight
edgeWeight :: !EdgeWeight
  } deriving (EdgeInfo -> EdgeInfo -> Bool
(EdgeInfo -> EdgeInfo -> Bool)
-> (EdgeInfo -> EdgeInfo -> Bool) -> Eq EdgeInfo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EdgeInfo -> EdgeInfo -> Bool
$c/= :: EdgeInfo -> EdgeInfo -> Bool
== :: EdgeInfo -> EdgeInfo -> Bool
$c== :: EdgeInfo -> EdgeInfo -> Bool
Eq)

instance Outputable EdgeInfo where
  ppr :: EdgeInfo -> SDoc
ppr EdgeInfo
edgeInfo = String -> SDoc
text String
"weight:" SDoc -> SDoc -> SDoc
<+> EdgeWeight -> SDoc
forall a. Outputable a => a -> SDoc
ppr (EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
edgeInfo)

-- | Convenience function, generate edge info based
--   on weight not originating from cmm.
mkWeightInfo :: EdgeWeight -> EdgeInfo
mkWeightInfo :: EdgeWeight -> EdgeInfo
mkWeightInfo = TransitionSource -> EdgeWeight -> EdgeInfo
EdgeInfo TransitionSource
AsmCodeGen

-- | Adjust the weight between the blocks using the given function.
--   If there is no such edge returns the original map.
adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight)
                 -> BlockId -> BlockId -> CFG
adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight) -> BlockId -> BlockId -> CFG
adjustEdgeWeight CFG
cfg EdgeWeight -> EdgeWeight
f BlockId
from BlockId
to
  | Just EdgeInfo
info <- BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo BlockId
from BlockId
to CFG
cfg
  , !EdgeWeight
weight <- EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
info
  , !EdgeWeight
newWeight <- EdgeWeight -> EdgeWeight
f EdgeWeight
weight
  = BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
from BlockId
to (EdgeInfo
info { edgeWeight :: EdgeWeight
edgeWeight = EdgeWeight
newWeight}) CFG
cfg
  | Bool
otherwise = CFG
cfg

-- | Set the weight between the blocks to the given weight.
--   If there is no such edge returns the original map.
setEdgeWeight :: CFG -> EdgeWeight
              -> BlockId -> BlockId -> CFG
setEdgeWeight :: CFG -> EdgeWeight -> BlockId -> BlockId -> CFG
setEdgeWeight CFG
cfg !EdgeWeight
weight BlockId
from BlockId
to
  | Just EdgeInfo
info <- BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo BlockId
from BlockId
to CFG
cfg
  = BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
from BlockId
to (EdgeInfo
info { edgeWeight :: EdgeWeight
edgeWeight = EdgeWeight
weight}) CFG
cfg
  | Bool
otherwise = CFG
cfg


getCfgNodes :: CFG -> [BlockId]
getCfgNodes :: CFG -> [BlockId]
getCfgNodes CFG
m =
    CFG -> [KeyOf LabelMap]
forall (map :: * -> *) a. IsMap map => map a -> [KeyOf map]
mapKeys CFG
m

-- | Is this block part of this graph?
hasNode :: CFG -> BlockId -> Bool
hasNode :: CFG -> BlockId -> Bool
hasNode CFG
m BlockId
node =
  -- Check the invariant that each node must exist in the first map or not at all.
  ASSERT( found || not (any (mapMember node) m))
  Bool
found
    where
      found :: Bool
found = KeyOf LabelMap -> CFG -> Bool
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> Bool
mapMember KeyOf LabelMap
BlockId
node CFG
m



-- | Check if the nodes in the cfg and the set of blocks are the same.
--   In a case of a missmatch we panic and show the difference.
sanityCheckCfg :: CFG -> LabelSet -> SDoc -> Bool
sanityCheckCfg :: CFG -> LabelSet -> SDoc -> Bool
sanityCheckCfg CFG
m LabelSet
blockSet SDoc
msg
    | LabelSet
blockSet LabelSet -> LabelSet -> Bool
forall a. Eq a => a -> a -> Bool
== LabelSet
cfgNodes
    = Bool
True
    | Bool
otherwise =
        String -> SDoc -> Bool -> Bool
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Block list and cfg nodes don't match" (
            String -> SDoc
text String
"difference:" SDoc -> SDoc -> SDoc
<+> LabelSet -> SDoc
forall a. Outputable a => a -> SDoc
ppr LabelSet
diff SDoc -> SDoc -> SDoc
$$
            String -> SDoc
text String
"blocks:" SDoc -> SDoc -> SDoc
<+> LabelSet -> SDoc
forall a. Outputable a => a -> SDoc
ppr LabelSet
blockSet SDoc -> SDoc -> SDoc
$$
            String -> SDoc
text String
"cfg:" SDoc -> SDoc -> SDoc
<+> CFG -> SDoc
pprEdgeWeights CFG
m SDoc -> SDoc -> SDoc
$$
            SDoc
msg )
            Bool
False
    where
      cfgNodes :: LabelSet
cfgNodes = [ElemOf LabelSet] -> LabelSet
forall set. IsSet set => [ElemOf set] -> set
setFromList ([ElemOf LabelSet] -> LabelSet) -> [ElemOf LabelSet] -> LabelSet
forall a b. (a -> b) -> a -> b
$ CFG -> [BlockId]
getCfgNodes CFG
m :: LabelSet
      diff :: LabelSet
diff = (LabelSet -> LabelSet -> LabelSet
forall set. IsSet set => set -> set -> set
setUnion LabelSet
cfgNodes LabelSet
blockSet) LabelSet -> LabelSet -> LabelSet
forall set. IsSet set => set -> set -> set
`setDifference` (LabelSet -> LabelSet -> LabelSet
forall set. IsSet set => set -> set -> set
setIntersection LabelSet
cfgNodes LabelSet
blockSet) :: LabelSet

-- | Filter the CFG with a custom function f.
--   Paramaeters are `f from to edgeInfo`
filterEdges :: (BlockId -> BlockId -> EdgeInfo -> Bool) -> CFG -> CFG
filterEdges :: (BlockId -> BlockId -> EdgeInfo -> Bool) -> CFG -> CFG
filterEdges BlockId -> BlockId -> EdgeInfo -> Bool
f CFG
cfg =
    (KeyOf LabelMap -> LabelMap EdgeInfo -> LabelMap EdgeInfo)
-> CFG -> CFG
forall (map :: * -> *) a b.
IsMap map =>
(KeyOf map -> a -> b) -> map a -> map b
mapMapWithKey KeyOf LabelMap -> LabelMap EdgeInfo -> LabelMap EdgeInfo
BlockId -> LabelMap EdgeInfo -> LabelMap EdgeInfo
filterSources CFG
cfg
    where
      filterSources :: BlockId -> LabelMap EdgeInfo -> LabelMap EdgeInfo
filterSources BlockId
from LabelMap EdgeInfo
m =
        (KeyOf LabelMap -> EdgeInfo -> Bool)
-> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a.
IsMap map =>
(KeyOf map -> a -> Bool) -> map a -> map a
mapFilterWithKey (\KeyOf LabelMap
to EdgeInfo
w -> BlockId -> BlockId -> EdgeInfo -> Bool
f BlockId
from KeyOf LabelMap
BlockId
to EdgeInfo
w) LabelMap EdgeInfo
m


{- Note [Updating the CFG during shortcutting]

See Note [What is shortcutting] in the control flow optimization
code (CmmContFlowOpt.hs) for a slightly more in depth explanation on shortcutting.

In the native backend we shortcut jumps at the assembly level. (AsmCodeGen.hs)
This means we remove blocks containing only one jump from the code
and instead redirecting all jumps targeting this block to the deleted
blocks jump target.

However we want to have an accurate representation of control
flow in the CFG. So we add/remove edges accordingly to account
for the eliminated blocks and new edges.

If we shortcut A -> B -> C to A -> C:
* We delete edges A -> B and B -> C
* Replacing them with the edge A -> C

We also try to preserve jump weights while doing so.

Note that:
* The edge B -> C can't have interesting weights since
  the block B consists of a single unconditional jump without branching.
* We delete the edge A -> B and add the edge A -> C.
* The edge A -> B can be one of many edges originating from A so likely
  has edge weights we want to preserve.

For this reason we simply store the edge info from the original A -> B
edge and apply this information to the new edge A -> C.

Sometimes we have a scenario where jump target C is not represented by an
BlockId but an immediate value. I'm only aware of this happening without
tables next to code currently.

Then we go from A ---> B - -> IMM   to   A - -> IMM where the dashed arrows
are not stored in the CFG.

In that case we simply delete the edge A -> B.

In terms of implementation the native backend first builds a mapping
from blocks suitable for shortcutting to their jump targets.
Then it redirects all jump instructions to these blocks using the
built up mapping.
This function (shortcutWeightMap) takes the same mapping and
applies the mapping to the CFG in the way layed out above.

-}
shortcutWeightMap :: LabelMap (Maybe BlockId) -> CFG -> CFG
shortcutWeightMap :: LabelMap (Maybe BlockId) -> CFG -> CFG
shortcutWeightMap LabelMap (Maybe BlockId)
cuts CFG
cfg =
  (CFG -> (BlockId, Maybe BlockId) -> CFG)
-> CFG -> [(BlockId, Maybe BlockId)] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' CFG -> (BlockId, Maybe BlockId) -> CFG
applyMapping CFG
cfg ([(BlockId, Maybe BlockId)] -> CFG)
-> [(BlockId, Maybe BlockId)] -> CFG
forall a b. (a -> b) -> a -> b
$ LabelMap (Maybe BlockId) -> [(KeyOf LabelMap, Maybe BlockId)]
forall (map :: * -> *) a. IsMap map => map a -> [(KeyOf map, a)]
mapToList LabelMap (Maybe BlockId)
cuts
    where
-- takes the tuple (B,C) from the notation in [Updating the CFG during shortcutting]
      applyMapping :: CFG -> (BlockId,Maybe BlockId) -> CFG
      --Shortcut immediate
      applyMapping :: CFG -> (BlockId, Maybe BlockId) -> CFG
applyMapping CFG
m (BlockId
from, Maybe BlockId
Nothing) =
        KeyOf LabelMap -> CFG -> CFG
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> map a
mapDelete KeyOf LabelMap
BlockId
from (CFG -> CFG) -> (CFG -> CFG) -> CFG -> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
        (LabelMap EdgeInfo -> LabelMap EdgeInfo) -> CFG -> CFG
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (KeyOf LabelMap -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> map a
mapDelete KeyOf LabelMap
BlockId
from) (CFG -> CFG) -> CFG -> CFG
forall a b. (a -> b) -> a -> b
$ CFG
m
      --Regular shortcut
      applyMapping CFG
m (BlockId
from, Just BlockId
to) =
        let updatedMap :: CFG
            updatedMap :: CFG
updatedMap
              = (LabelMap EdgeInfo -> LabelMap EdgeInfo) -> CFG -> CFG
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((BlockId, BlockId) -> LabelMap EdgeInfo -> LabelMap EdgeInfo
shortcutEdge (BlockId
from,BlockId
to)) (CFG -> CFG) -> CFG -> CFG
forall a b. (a -> b) -> a -> b
$
                (KeyOf LabelMap -> CFG -> CFG
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> map a
mapDelete KeyOf LabelMap
BlockId
from CFG
m :: CFG )
        --Sometimes we can shortcut multiple blocks like so:
        -- A -> B -> C -> D -> E => A -> E
        -- so we check for such chains.
        in case KeyOf LabelMap -> LabelMap (Maybe BlockId) -> Maybe (Maybe BlockId)
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
to LabelMap (Maybe BlockId)
cuts of
            Maybe (Maybe BlockId)
Nothing -> CFG
updatedMap
            Just Maybe BlockId
dest -> CFG -> (BlockId, Maybe BlockId) -> CFG
applyMapping CFG
updatedMap (BlockId
to, Maybe BlockId
dest)
      --Redirect edge from B to C
      shortcutEdge :: (BlockId, BlockId) -> LabelMap EdgeInfo -> LabelMap EdgeInfo
      shortcutEdge :: (BlockId, BlockId) -> LabelMap EdgeInfo -> LabelMap EdgeInfo
shortcutEdge (BlockId
from, BlockId
to) LabelMap EdgeInfo
m =
        case KeyOf LabelMap -> LabelMap EdgeInfo -> Maybe EdgeInfo
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
from LabelMap EdgeInfo
m of
          Just EdgeInfo
info -> KeyOf LabelMap
-> EdgeInfo -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
to EdgeInfo
info (LabelMap EdgeInfo -> LabelMap EdgeInfo)
-> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall a b. (a -> b) -> a -> b
$ KeyOf LabelMap -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> map a
mapDelete KeyOf LabelMap
BlockId
from LabelMap EdgeInfo
m
          Maybe EdgeInfo
Nothing   -> LabelMap EdgeInfo
m

-- | Sometimes we insert a block which should unconditionally be executed
--   after a given block. This function updates the CFG for these cases.
--  So we get A -> B    => A -> A' -> B
--             \                  \
--              -> C    =>         -> C
--
addImmediateSuccessor :: BlockId -> BlockId -> CFG -> CFG
addImmediateSuccessor :: BlockId -> BlockId -> CFG -> CFG
addImmediateSuccessor BlockId
node BlockId
follower CFG
cfg
    = CFG -> CFG
updateEdges (CFG -> CFG) -> (CFG -> CFG) -> CFG -> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge BlockId
node BlockId
follower EdgeWeight
uncondWeight (CFG -> CFG) -> CFG -> CFG
forall a b. (a -> b) -> a -> b
$ CFG
cfg
    where
        uncondWeight :: EdgeWeight
uncondWeight = Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> EdgeWeight) -> (DynFlags -> Int) -> DynFlags -> EdgeWeight
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CfgWeights -> Int
D.uncondWeight (CfgWeights -> Int) -> (DynFlags -> CfgWeights) -> DynFlags -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                       DynFlags -> CfgWeights
D.cfgWeightInfo (DynFlags -> EdgeWeight) -> DynFlags -> EdgeWeight
forall a b. (a -> b) -> a -> b
$ DynFlags
D.unsafeGlobalDynFlags
        targets :: [(BlockId, EdgeInfo)]
targets = HasDebugCallStack => CFG -> BlockId -> [(BlockId, EdgeInfo)]
CFG -> BlockId -> [(BlockId, EdgeInfo)]
getSuccessorEdges CFG
cfg BlockId
node
        successors :: [BlockId]
successors = ((BlockId, EdgeInfo) -> BlockId)
-> [(BlockId, EdgeInfo)] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map (BlockId, EdgeInfo) -> BlockId
forall a b. (a, b) -> a
fst [(BlockId, EdgeInfo)]
targets :: [BlockId]
        updateEdges :: CFG -> CFG
updateEdges = CFG -> CFG
addNewSuccs (CFG -> CFG) -> (CFG -> CFG) -> CFG -> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CFG -> CFG
remOldSuccs
        remOldSuccs :: CFG -> CFG
remOldSuccs CFG
m = (CFG -> BlockId -> CFG) -> CFG -> [BlockId] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((BlockId -> CFG -> CFG) -> CFG -> BlockId -> CFG
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BlockId -> BlockId -> CFG -> CFG
delEdge BlockId
node)) CFG
m [BlockId]
successors
        addNewSuccs :: CFG -> CFG
addNewSuccs CFG
m =
          (CFG -> (BlockId, EdgeInfo) -> CFG)
-> CFG -> [(BlockId, EdgeInfo)] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\CFG
m' (BlockId
t,EdgeInfo
info) -> BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
follower BlockId
t EdgeInfo
info CFG
m') CFG
m [(BlockId, EdgeInfo)]
targets

-- | Adds a new edge, overwrites existing edges if present
addEdge :: BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge :: BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
from BlockId
to EdgeInfo
info CFG
cfg =
    (Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo))
-> KeyOf LabelMap -> CFG -> CFG
forall (map :: * -> *) a.
IsMap map =>
(Maybe a -> Maybe a) -> KeyOf map -> map a -> map a
mapAlter Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
addFromToEdge KeyOf LabelMap
BlockId
from (CFG -> CFG) -> CFG -> CFG
forall a b. (a -> b) -> a -> b
$
    (Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo))
-> KeyOf LabelMap -> CFG -> CFG
forall (map :: * -> *) a.
IsMap map =>
(Maybe a -> Maybe a) -> KeyOf map -> map a -> map a
mapAlter Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
forall (map :: * -> *) a.
IsMap map =>
Maybe (map a) -> Maybe (map a)
addDestNode KeyOf LabelMap
BlockId
to CFG
cfg
    where
        -- Simply insert the edge into the edge list.
        addFromToEdge :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
addFromToEdge Maybe (LabelMap EdgeInfo)
Nothing = LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a. a -> Maybe a
Just (LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo))
-> LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a b. (a -> b) -> a -> b
$ KeyOf LabelMap -> EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => KeyOf map -> a -> map a
mapSingleton KeyOf LabelMap
BlockId
to EdgeInfo
info
        addFromToEdge (Just LabelMap EdgeInfo
wm) = LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a. a -> Maybe a
Just (LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo))
-> LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a b. (a -> b) -> a -> b
$ KeyOf LabelMap
-> EdgeInfo -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
to EdgeInfo
info LabelMap EdgeInfo
wm
        -- We must add the destination node explicitly
        addDestNode :: Maybe (map a) -> Maybe (map a)
addDestNode Maybe (map a)
Nothing = map a -> Maybe (map a)
forall a. a -> Maybe a
Just (map a -> Maybe (map a)) -> map a -> Maybe (map a)
forall a b. (a -> b) -> a -> b
$ map a
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
        addDestNode n :: Maybe (map a)
n@(Just map a
_) = Maybe (map a)
n


-- | Adds a edge with the given weight to the cfg
--   If there already existed an edge it is overwritten.
--   `addWeightEdge from to weight cfg`
addWeightEdge :: BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge :: BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge BlockId
from BlockId
to EdgeWeight
weight CFG
cfg =
    BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
from BlockId
to (EdgeWeight -> EdgeInfo
mkWeightInfo EdgeWeight
weight) CFG
cfg

delEdge :: BlockId -> BlockId -> CFG -> CFG
delEdge :: BlockId -> BlockId -> CFG -> CFG
delEdge BlockId
from BlockId
to CFG
m =
    (Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo))
-> KeyOf LabelMap -> CFG -> CFG
forall (map :: * -> *) a.
IsMap map =>
(Maybe a -> Maybe a) -> KeyOf map -> map a -> map a
mapAlter Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
remDest KeyOf LabelMap
BlockId
from CFG
m
    where
        remDest :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
remDest Maybe (LabelMap EdgeInfo)
Nothing = Maybe (LabelMap EdgeInfo)
forall a. Maybe a
Nothing
        remDest (Just LabelMap EdgeInfo
wm) = LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a. a -> Maybe a
Just (LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo))
-> LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a b. (a -> b) -> a -> b
$ KeyOf LabelMap -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> map a
mapDelete KeyOf LabelMap
BlockId
to LabelMap EdgeInfo
wm

delNode :: BlockId -> CFG -> CFG
delNode :: BlockId -> CFG -> CFG
delNode BlockId
node CFG
cfg =
  (LabelMap EdgeInfo -> LabelMap EdgeInfo) -> CFG -> CFG
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (KeyOf LabelMap -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> map a
mapDelete KeyOf LabelMap
BlockId
node)  -- < Edges to the node
    (KeyOf LabelMap -> CFG -> CFG
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> map a
mapDelete KeyOf LabelMap
BlockId
node CFG
cfg) -- < Edges from the node

-- | Destinations from bid ordered by weight (descending)
getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId, EdgeInfo)]
getSuccEdgesSorted CFG
m BlockId
bid =
    let destMap :: LabelMap EdgeInfo
destMap = LabelMap EdgeInfo -> KeyOf LabelMap -> CFG -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => a -> KeyOf map -> map a -> a
mapFindWithDefault LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => map a
mapEmpty KeyOf LabelMap
BlockId
bid CFG
m
        cfgEdges :: [(KeyOf LabelMap, EdgeInfo)]
cfgEdges = LabelMap EdgeInfo -> [(KeyOf LabelMap, EdgeInfo)]
forall (map :: * -> *) a. IsMap map => map a -> [(KeyOf map, a)]
mapToList LabelMap EdgeInfo
destMap
        sortedEdges :: [(BlockId, EdgeInfo)]
sortedEdges = ((BlockId, EdgeInfo) -> EdgeWeight)
-> [(BlockId, EdgeInfo)] -> [(BlockId, EdgeInfo)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortWith (EdgeWeight -> EdgeWeight
forall a. Num a => a -> a
negate (EdgeWeight -> EdgeWeight)
-> ((BlockId, EdgeInfo) -> EdgeWeight)
-> (BlockId, EdgeInfo)
-> EdgeWeight
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EdgeInfo -> EdgeWeight
edgeWeight (EdgeInfo -> EdgeWeight)
-> ((BlockId, EdgeInfo) -> EdgeInfo)
-> (BlockId, EdgeInfo)
-> EdgeWeight
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BlockId, EdgeInfo) -> EdgeInfo
forall a b. (a, b) -> b
snd) [(KeyOf LabelMap, EdgeInfo)]
[(BlockId, EdgeInfo)]
cfgEdges
    in  --pprTrace "getSuccEdgesSorted" (ppr bid <+> text "map:" <+> ppr m)
        [(BlockId, EdgeInfo)]
sortedEdges

-- | Get successors of a given node with edge weights.
getSuccessorEdges :: HasDebugCallStack => CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccessorEdges :: CFG -> BlockId -> [(BlockId, EdgeInfo)]
getSuccessorEdges CFG
m BlockId
bid = [(BlockId, EdgeInfo)]
-> (LabelMap EdgeInfo -> [(BlockId, EdgeInfo)])
-> Maybe (LabelMap EdgeInfo)
-> [(BlockId, EdgeInfo)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(BlockId, EdgeInfo)]
lookupError LabelMap EdgeInfo -> [(BlockId, EdgeInfo)]
forall (map :: * -> *) a. IsMap map => map a -> [(KeyOf map, a)]
mapToList (KeyOf LabelMap -> CFG -> Maybe (LabelMap EdgeInfo)
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
bid CFG
m)
  where
    lookupError :: [(BlockId, EdgeInfo)]
lookupError = String -> SDoc -> [(BlockId, EdgeInfo)]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"getSuccessorEdges: Block does not exist" (SDoc -> [(BlockId, EdgeInfo)]) -> SDoc -> [(BlockId, EdgeInfo)]
forall a b. (a -> b) -> a -> b
$
                    BlockId -> SDoc
forall a. Outputable a => a -> SDoc
ppr BlockId
bid SDoc -> SDoc -> SDoc
<+> CFG -> SDoc
pprEdgeWeights CFG
m

getEdgeInfo :: BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo :: BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo BlockId
from BlockId
to CFG
m
    | Just LabelMap EdgeInfo
wm <- KeyOf LabelMap -> CFG -> Maybe (LabelMap EdgeInfo)
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
from CFG
m
    , Just EdgeInfo
info <- KeyOf LabelMap -> LabelMap EdgeInfo -> Maybe EdgeInfo
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
to LabelMap EdgeInfo
wm
    = EdgeInfo -> Maybe EdgeInfo
forall a. a -> Maybe a
Just (EdgeInfo -> Maybe EdgeInfo) -> EdgeInfo -> Maybe EdgeInfo
forall a b. (a -> b) -> a -> b
$! EdgeInfo
info
    | Bool
otherwise
    = Maybe EdgeInfo
forall a. Maybe a
Nothing

getEdgeWeight :: CFG -> BlockId -> BlockId -> EdgeWeight
getEdgeWeight :: CFG -> BlockId -> BlockId -> EdgeWeight
getEdgeWeight CFG
cfg BlockId
from BlockId
to =
    EdgeInfo -> EdgeWeight
edgeWeight (EdgeInfo -> EdgeWeight) -> EdgeInfo -> EdgeWeight
forall a b. (a -> b) -> a -> b
$ String -> Maybe EdgeInfo -> EdgeInfo
forall a. HasCallStack => String -> Maybe a -> a
expectJust String
"Edgeweight for noexisting block" (Maybe EdgeInfo -> EdgeInfo) -> Maybe EdgeInfo -> EdgeInfo
forall a b. (a -> b) -> a -> b
$
                 BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo BlockId
from BlockId
to CFG
cfg

getTransitionSource :: BlockId -> BlockId -> CFG -> TransitionSource
getTransitionSource :: BlockId -> BlockId -> CFG -> TransitionSource
getTransitionSource BlockId
from BlockId
to CFG
cfg = EdgeInfo -> TransitionSource
transitionSource (EdgeInfo -> TransitionSource) -> EdgeInfo -> TransitionSource
forall a b. (a -> b) -> a -> b
$ String -> Maybe EdgeInfo -> EdgeInfo
forall a. HasCallStack => String -> Maybe a -> a
expectJust String
"Source info for noexisting block" (Maybe EdgeInfo -> EdgeInfo) -> Maybe EdgeInfo -> EdgeInfo
forall a b. (a -> b) -> a -> b
$
                        BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo BlockId
from BlockId
to CFG
cfg

reverseEdges :: CFG -> CFG
reverseEdges :: CFG -> CFG
reverseEdges CFG
cfg = (CFG -> KeyOf LabelMap -> LabelMap EdgeInfo -> CFG)
-> CFG -> CFG -> CFG
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey (\CFG
cfg KeyOf LabelMap
from LabelMap EdgeInfo
toMap -> CFG -> BlockId -> LabelMap EdgeInfo -> CFG
go (CFG -> BlockId -> CFG
addNode CFG
cfg KeyOf LabelMap
BlockId
from) KeyOf LabelMap
BlockId
from LabelMap EdgeInfo
toMap) CFG
forall (map :: * -> *) a. IsMap map => map a
mapEmpty CFG
cfg
  where
    -- We must preserve nodes without outgoing edges!
    addNode :: CFG -> BlockId -> CFG
    addNode :: CFG -> BlockId -> CFG
addNode CFG
cfg BlockId
b = (LabelMap EdgeInfo -> LabelMap EdgeInfo -> LabelMap EdgeInfo)
-> KeyOf LabelMap -> LabelMap EdgeInfo -> CFG -> CFG
forall (map :: * -> *) a.
IsMap map =>
(a -> a -> a) -> KeyOf map -> a -> map a -> map a
mapInsertWith LabelMap EdgeInfo -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => map a -> map a -> map a
mapUnion KeyOf LabelMap
BlockId
b LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => map a
mapEmpty CFG
cfg
    go :: CFG -> BlockId -> (LabelMap EdgeInfo) -> CFG
    go :: CFG -> BlockId -> LabelMap EdgeInfo -> CFG
go CFG
cfg BlockId
from LabelMap EdgeInfo
toMap = (CFG -> KeyOf LabelMap -> EdgeInfo -> CFG)
-> CFG -> LabelMap EdgeInfo -> CFG
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey (\CFG
cfg KeyOf LabelMap
to EdgeInfo
info -> BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge KeyOf LabelMap
BlockId
to BlockId
from EdgeInfo
info CFG
cfg) CFG
cfg LabelMap EdgeInfo
toMap  :: CFG


-- | Returns a unordered list of all edges with info
infoEdgeList :: CFG -> [CfgEdge]
infoEdgeList :: CFG -> [CfgEdge]
infoEdgeList CFG
m =
    [(BlockId, LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go (CFG -> [(KeyOf LabelMap, LabelMap EdgeInfo)]
forall (map :: * -> *) a. IsMap map => map a -> [(KeyOf map, a)]
mapToList CFG
m) []
  where
    -- We avoid foldMap to avoid thunk buildup
    go :: [(BlockId,LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
    go :: [(BlockId, LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go [] [CfgEdge]
acc = [CfgEdge]
acc
    go ((BlockId
from,LabelMap EdgeInfo
toMap):[(BlockId, LabelMap EdgeInfo)]
xs) [CfgEdge]
acc
      = [(BlockId, LabelMap EdgeInfo)]
-> BlockId -> [(BlockId, EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go' [(BlockId, LabelMap EdgeInfo)]
xs BlockId
from (LabelMap EdgeInfo -> [(KeyOf LabelMap, EdgeInfo)]
forall (map :: * -> *) a. IsMap map => map a -> [(KeyOf map, a)]
mapToList LabelMap EdgeInfo
toMap) [CfgEdge]
acc
    go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [(BlockId,EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
    go' :: [(BlockId, LabelMap EdgeInfo)]
-> BlockId -> [(BlockId, EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go' [(BlockId, LabelMap EdgeInfo)]
froms BlockId
_    []              [CfgEdge]
acc = [(BlockId, LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go [(BlockId, LabelMap EdgeInfo)]
froms [CfgEdge]
acc
    go' [(BlockId, LabelMap EdgeInfo)]
froms BlockId
from ((BlockId
to,EdgeInfo
info):[(BlockId, EdgeInfo)]
tos) [CfgEdge]
acc
      = [(BlockId, LabelMap EdgeInfo)]
-> BlockId -> [(BlockId, EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go' [(BlockId, LabelMap EdgeInfo)]
froms BlockId
from [(BlockId, EdgeInfo)]
tos (BlockId -> BlockId -> EdgeInfo -> CfgEdge
CfgEdge BlockId
from BlockId
to EdgeInfo
info CfgEdge -> [CfgEdge] -> [CfgEdge]
forall a. a -> [a] -> [a]
: [CfgEdge]
acc)

-- | Returns a unordered list of all edges without weights
edgeList :: CFG -> [Edge]
edgeList :: CFG -> [(BlockId, BlockId)]
edgeList CFG
m =
    [(BlockId, LabelMap EdgeInfo)]
-> [(BlockId, BlockId)] -> [(BlockId, BlockId)]
go (CFG -> [(KeyOf LabelMap, LabelMap EdgeInfo)]
forall (map :: * -> *) a. IsMap map => map a -> [(KeyOf map, a)]
mapToList CFG
m) []
  where
    -- We avoid foldMap to avoid thunk buildup
    go :: [(BlockId,LabelMap EdgeInfo)] -> [Edge] -> [Edge]
    go :: [(BlockId, LabelMap EdgeInfo)]
-> [(BlockId, BlockId)] -> [(BlockId, BlockId)]
go [] [(BlockId, BlockId)]
acc = [(BlockId, BlockId)]
acc
    go ((BlockId
from,LabelMap EdgeInfo
toMap):[(BlockId, LabelMap EdgeInfo)]
xs) [(BlockId, BlockId)]
acc
      = [(BlockId, LabelMap EdgeInfo)]
-> BlockId
-> [BlockId]
-> [(BlockId, BlockId)]
-> [(BlockId, BlockId)]
go' [(BlockId, LabelMap EdgeInfo)]
xs BlockId
from (LabelMap EdgeInfo -> [KeyOf LabelMap]
forall (map :: * -> *) a. IsMap map => map a -> [KeyOf map]
mapKeys LabelMap EdgeInfo
toMap) [(BlockId, BlockId)]
acc
    go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [BlockId] -> [Edge] -> [Edge]
    go' :: [(BlockId, LabelMap EdgeInfo)]
-> BlockId
-> [BlockId]
-> [(BlockId, BlockId)]
-> [(BlockId, BlockId)]
go' [(BlockId, LabelMap EdgeInfo)]
froms BlockId
_    []              [(BlockId, BlockId)]
acc = [(BlockId, LabelMap EdgeInfo)]
-> [(BlockId, BlockId)] -> [(BlockId, BlockId)]
go [(BlockId, LabelMap EdgeInfo)]
froms [(BlockId, BlockId)]
acc
    go' [(BlockId, LabelMap EdgeInfo)]
froms BlockId
from (BlockId
to:[BlockId]
tos) [(BlockId, BlockId)]
acc
      = [(BlockId, LabelMap EdgeInfo)]
-> BlockId
-> [BlockId]
-> [(BlockId, BlockId)]
-> [(BlockId, BlockId)]
go' [(BlockId, LabelMap EdgeInfo)]
froms BlockId
from [BlockId]
tos ((BlockId
from,BlockId
to) (BlockId, BlockId) -> [(BlockId, BlockId)] -> [(BlockId, BlockId)]
forall a. a -> [a] -> [a]
: [(BlockId, BlockId)]
acc)

-- | Get successors of a given node without edge weights.
getSuccessors :: HasDebugCallStack => CFG -> BlockId -> [BlockId]
getSuccessors :: CFG -> BlockId -> [BlockId]
getSuccessors CFG
m BlockId
bid
    | Just LabelMap EdgeInfo
wm <- KeyOf LabelMap -> CFG -> Maybe (LabelMap EdgeInfo)
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
bid CFG
m
    = LabelMap EdgeInfo -> [KeyOf LabelMap]
forall (map :: * -> *) a. IsMap map => map a -> [KeyOf map]
mapKeys LabelMap EdgeInfo
wm
    | Bool
otherwise = [BlockId]
lookupError
    where
      lookupError :: [BlockId]
lookupError = String -> SDoc -> [BlockId]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"getSuccessors: Block does not exist" (SDoc -> [BlockId]) -> SDoc -> [BlockId]
forall a b. (a -> b) -> a -> b
$
                    BlockId -> SDoc
forall a. Outputable a => a -> SDoc
ppr BlockId
bid SDoc -> SDoc -> SDoc
<+> CFG -> SDoc
pprEdgeWeights CFG
m

pprEdgeWeights :: CFG -> SDoc
pprEdgeWeights :: CFG -> SDoc
pprEdgeWeights CFG
m =
    let edges :: [CfgEdge]
edges = [CfgEdge] -> [CfgEdge]
forall a. Ord a => [a] -> [a]
sort ([CfgEdge] -> [CfgEdge]) -> [CfgEdge] -> [CfgEdge]
forall a b. (a -> b) -> a -> b
$ CFG -> [CfgEdge]
infoEdgeList CFG
m :: [CfgEdge]
        printEdge :: CfgEdge -> SDoc
printEdge (CfgEdge BlockId
from BlockId
to (EdgeInfo { edgeWeight :: EdgeInfo -> EdgeWeight
edgeWeight = EdgeWeight
weight }))
            = String -> SDoc
text String
"\t" SDoc -> SDoc -> SDoc
<> BlockId -> SDoc
forall a. Outputable a => a -> SDoc
ppr BlockId
from SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"->" SDoc -> SDoc -> SDoc
<+> BlockId -> SDoc
forall a. Outputable a => a -> SDoc
ppr BlockId
to SDoc -> SDoc -> SDoc
<>
              String -> SDoc
text String
"[label=\"" SDoc -> SDoc -> SDoc
<> EdgeWeight -> SDoc
forall a. Outputable a => a -> SDoc
ppr EdgeWeight
weight SDoc -> SDoc -> SDoc
<> String -> SDoc
text String
"\",weight=\"" SDoc -> SDoc -> SDoc
<>
              EdgeWeight -> SDoc
forall a. Outputable a => a -> SDoc
ppr EdgeWeight
weight SDoc -> SDoc -> SDoc
<> String -> SDoc
text String
"\"];\n"
        --for the case that there are no edges from/to this node.
        --This should rarely happen but it can save a lot of time
        --to immediately see it when it does.
        printNode :: a -> SDoc
printNode a
node
            = String -> SDoc
text String
"\t" SDoc -> SDoc -> SDoc
<> a -> SDoc
forall a. Outputable a => a -> SDoc
ppr a
node SDoc -> SDoc -> SDoc
<> String -> SDoc
text String
";\n"
        getEdgeNodes :: CfgEdge -> [BlockId]
getEdgeNodes (CfgEdge BlockId
from BlockId
to EdgeInfo
_) = [BlockId
from,BlockId
to]
        edgeNodes :: LabelSet
edgeNodes = [ElemOf LabelSet] -> LabelSet
forall set. IsSet set => [ElemOf set] -> set
setFromList ([ElemOf LabelSet] -> LabelSet) -> [ElemOf LabelSet] -> LabelSet
forall a b. (a -> b) -> a -> b
$ (CfgEdge -> [BlockId]) -> [CfgEdge] -> [BlockId]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CfgEdge -> [BlockId]
getEdgeNodes [CfgEdge]
edges :: LabelSet
        nodes :: [BlockId]
nodes = (BlockId -> Bool) -> [BlockId] -> [BlockId]
forall a. (a -> Bool) -> [a] -> [a]
filter (\BlockId
n -> (Bool -> Bool
not (Bool -> Bool) -> (LabelSet -> Bool) -> LabelSet -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ElemOf LabelSet -> LabelSet -> Bool
forall set. IsSet set => ElemOf set -> set -> Bool
setMember ElemOf LabelSet
BlockId
n) LabelSet
edgeNodes) ([BlockId] -> [BlockId]) -> (CFG -> [BlockId]) -> CFG -> [BlockId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CFG -> [BlockId]
forall (map :: * -> *) a. IsMap map => map a -> [KeyOf map]
mapKeys (CFG -> [BlockId]) -> CFG -> [BlockId]
forall a b. (a -> b) -> a -> b
$ (LabelMap EdgeInfo -> Bool) -> CFG -> CFG
forall (map :: * -> *) a.
IsMap map =>
(a -> Bool) -> map a -> map a
mapFilter LabelMap EdgeInfo -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null CFG
m
    in
    String -> SDoc
text String
"digraph {\n" SDoc -> SDoc -> SDoc
<>
        ((SDoc -> SDoc -> SDoc) -> SDoc -> [SDoc] -> SDoc
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SDoc -> SDoc -> SDoc
(<>) SDoc
empty ((CfgEdge -> SDoc) -> [CfgEdge] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map CfgEdge -> SDoc
printEdge [CfgEdge]
edges)) SDoc -> SDoc -> SDoc
<>
        ((SDoc -> SDoc -> SDoc) -> SDoc -> [SDoc] -> SDoc
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SDoc -> SDoc -> SDoc
(<>) SDoc
empty ((BlockId -> SDoc) -> [BlockId] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map BlockId -> SDoc
forall a. Outputable a => a -> SDoc
printNode [BlockId]
nodes)) SDoc -> SDoc -> SDoc
<>
    String -> SDoc
text String
"}\n"

{-# INLINE updateEdgeWeight #-} --Allows eliminating the tuple when possible
-- | Invariant: The edge **must** exist already in the graph.
updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG
updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> (BlockId, BlockId) -> CFG -> CFG
updateEdgeWeight EdgeWeight -> EdgeWeight
f (BlockId
from, BlockId
to) CFG
cfg
    | Just EdgeInfo
oldInfo <- BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo BlockId
from BlockId
to CFG
cfg
    = let !oldWeight :: EdgeWeight
oldWeight = EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
oldInfo
          !newWeight :: EdgeWeight
newWeight = EdgeWeight -> EdgeWeight
f EdgeWeight
oldWeight
      in BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
from BlockId
to (EdgeInfo
oldInfo {edgeWeight :: EdgeWeight
edgeWeight = EdgeWeight
newWeight}) CFG
cfg
    | Bool
otherwise
    = String -> CFG
forall a. String -> a
panic String
"Trying to update invalid edge"

-- from to oldWeight => newWeight
mapWeights :: (BlockId -> BlockId -> EdgeWeight -> EdgeWeight) -> CFG -> CFG
mapWeights :: (BlockId -> BlockId -> EdgeWeight -> EdgeWeight) -> CFG -> CFG
mapWeights BlockId -> BlockId -> EdgeWeight -> EdgeWeight
f CFG
cfg =
  (CFG -> CfgEdge -> CFG) -> CFG -> [CfgEdge] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\CFG
cfg (CfgEdge BlockId
from BlockId
to EdgeInfo
info) ->
            let oldWeight :: EdgeWeight
oldWeight = EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
info
                newWeight :: EdgeWeight
newWeight = BlockId -> BlockId -> EdgeWeight -> EdgeWeight
f BlockId
from BlockId
to EdgeWeight
oldWeight
            in BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
from BlockId
to (EdgeInfo
info {edgeWeight :: EdgeWeight
edgeWeight = EdgeWeight
newWeight}) CFG
cfg)
          CFG
cfg (CFG -> [CfgEdge]
infoEdgeList CFG
cfg)


-- | Insert a block in the control flow between two other blocks.
-- We pass a list of tuples (A,B,C) where
-- * A -> C: Old edge
-- * A -> B -> C : New Arc, where B is the new block.
-- It's possible that a block has two jumps to the same block
-- in the assembly code. However we still only store a single edge for
-- these cases.
-- We assign the old edge info to the edge A -> B and assign B -> C the
-- weight of an unconditional jump.
addNodesBetween :: CFG -> [(BlockId,BlockId,BlockId)] -> CFG
addNodesBetween :: CFG -> [(BlockId, BlockId, BlockId)] -> CFG
addNodesBetween CFG
m [(BlockId, BlockId, BlockId)]
updates =
  (CFG -> (BlockId, BlockId, BlockId, EdgeInfo) -> CFG)
-> CFG -> [(BlockId, BlockId, BlockId, EdgeInfo)] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'  CFG -> (BlockId, BlockId, BlockId, EdgeInfo) -> CFG
updateWeight CFG
m ([(BlockId, BlockId, BlockId, EdgeInfo)] -> CFG)
-> ([(BlockId, BlockId, BlockId)]
    -> [(BlockId, BlockId, BlockId, EdgeInfo)])
-> [(BlockId, BlockId, BlockId)]
-> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
          [(BlockId, BlockId, BlockId)]
-> [(BlockId, BlockId, BlockId, EdgeInfo)]
weightUpdates ([(BlockId, BlockId, BlockId)] -> CFG)
-> [(BlockId, BlockId, BlockId)] -> CFG
forall a b. (a -> b) -> a -> b
$ [(BlockId, BlockId, BlockId)]
updates
    where
      weight :: EdgeWeight
weight = Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> EdgeWeight) -> (DynFlags -> Int) -> DynFlags -> EdgeWeight
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CfgWeights -> Int
D.uncondWeight (CfgWeights -> Int) -> (DynFlags -> CfgWeights) -> DynFlags -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                DynFlags -> CfgWeights
D.cfgWeightInfo (DynFlags -> EdgeWeight) -> DynFlags -> EdgeWeight
forall a b. (a -> b) -> a -> b
$ DynFlags
D.unsafeGlobalDynFlags
      -- We might add two blocks for different jumps along a single
      -- edge. So we end up with edges:   A -> B -> C   ,   A -> D -> C
      -- in this case after applying the first update the weight for A -> C
      -- is no longer available. So we calculate future weights before updates.
      weightUpdates :: [(BlockId, BlockId, BlockId)]
-> [(BlockId, BlockId, BlockId, EdgeInfo)]
weightUpdates = ((BlockId, BlockId, BlockId)
 -> (BlockId, BlockId, BlockId, EdgeInfo))
-> [(BlockId, BlockId, BlockId)]
-> [(BlockId, BlockId, BlockId, EdgeInfo)]
forall a b. (a -> b) -> [a] -> [b]
map (BlockId, BlockId, BlockId)
-> (BlockId, BlockId, BlockId, EdgeInfo)
getWeight
      getWeight :: (BlockId,BlockId,BlockId) -> (BlockId,BlockId,BlockId,EdgeInfo)
      getWeight :: (BlockId, BlockId, BlockId)
-> (BlockId, BlockId, BlockId, EdgeInfo)
getWeight (BlockId
from,BlockId
between,BlockId
old)
        | Just EdgeInfo
edgeInfo <- BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo BlockId
from BlockId
old CFG
m
        = (BlockId
from,BlockId
between,BlockId
old,EdgeInfo
edgeInfo)
        | Bool
otherwise
        = String -> SDoc -> (BlockId, BlockId, BlockId, EdgeInfo)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Can't find weight for edge that should have one" (
            String -> SDoc
text String
"triple" SDoc -> SDoc -> SDoc
<+> (BlockId, BlockId, BlockId) -> SDoc
forall a. Outputable a => a -> SDoc
ppr (BlockId
from,BlockId
between,BlockId
old) SDoc -> SDoc -> SDoc
$$
            String -> SDoc
text String
"updates" SDoc -> SDoc -> SDoc
<+> [(BlockId, BlockId, BlockId)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(BlockId, BlockId, BlockId)]
updates SDoc -> SDoc -> SDoc
$$
            String -> SDoc
text String
"cfg:" SDoc -> SDoc -> SDoc
<+> CFG -> SDoc
pprEdgeWeights CFG
m )
      updateWeight :: CFG -> (BlockId,BlockId,BlockId,EdgeInfo) -> CFG
      updateWeight :: CFG -> (BlockId, BlockId, BlockId, EdgeInfo) -> CFG
updateWeight CFG
m (BlockId
from,BlockId
between,BlockId
old,EdgeInfo
edgeInfo)
        = BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge BlockId
from BlockId
between EdgeInfo
edgeInfo (CFG -> CFG) -> (CFG -> CFG) -> CFG -> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
          BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge BlockId
between BlockId
old EdgeWeight
weight (CFG -> CFG) -> (CFG -> CFG) -> CFG -> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
          BlockId -> BlockId -> CFG -> CFG
delEdge BlockId
from BlockId
old (CFG -> CFG) -> CFG -> CFG
forall a b. (a -> b) -> a -> b
$ CFG
m

{-
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  ~~~       Note [CFG Edge Weights]    ~~~
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

  Edge weights assigned do not currently represent a specific
  cost model and rather just a ranking of which blocks should
  be placed next to each other given their connection type in
  the CFG.
  This is especially relevant if we whenever two blocks will
  jump to the same target.

                     A   B
                      \ /
                       C

  Should A or B be placed in front of C? The block layout algorithm
  decides this based on which edge (A,C)/(B,C) is heavier. So we
  make a educated guess on which branch should be preferred.

  We rank edges in this order:
  * Unconditional Control Transfer - They will always
    transfer control to their target. Unless there is a info table
    we can turn the jump into a fallthrough as well.
    We use 20k as default, so it's easy to spot if values have been
    modified but unlikely that we run into issues with overflow.
  * If branches (likely) - We assume branches marked as likely
    are taken more than 80% of the time.
    By ranking them below unconditional jumps we make sure we
    prefer the unconditional if there is a conditional and
    unconditional edge towards a block.
  * If branches (regular) - The false branch can potentially be turned
    into a fallthrough so we prefer it slightly over the true branch.
  * Unlikely branches - These can be assumed to be taken less than 20%
    of the time. So we given them one of the lowest priorities.
  * Switches - Switches at this level are implemented as jump tables
    so have a larger number of successors. So without more information
    we can only say that each individual successor is unlikely to be
    jumped to and we rank them accordingly.
  * Calls - We currently ignore calls completly:
        * By the time we return from a call there is a good chance
          that the address we return to has already been evicted from
          cache eliminating a main advantage sequential placement brings.
        * Calls always require a info table in front of their return
          address. This reduces the chance that we return to the same
          cache line further.

-}
-- | Generate weights for a Cmm proc based on some simple heuristics.
getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG
getCfgProc :: CfgWeights -> RawCmmDecl -> CFG
getCfgProc CfgWeights
_       (CmmData {}) = CFG
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
getCfgProc CfgWeights
weights (CmmProc LabelMap CmmStatics
_info CLabel
_lab [GlobalReg]
_live CmmGraph
graph) = CfgWeights -> CmmGraph -> CFG
getCfg CfgWeights
weights CmmGraph
graph

getCfg :: D.CfgWeights -> CmmGraph -> CFG
getCfg :: CfgWeights -> CmmGraph -> CFG
getCfg CfgWeights
weights CmmGraph
graph =
  (CFG -> ((BlockId, BlockId), EdgeInfo) -> CFG)
-> CFG -> [((BlockId, BlockId), EdgeInfo)] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' CFG -> ((BlockId, BlockId), EdgeInfo) -> CFG
insertEdge CFG
edgelessCfg ([((BlockId, BlockId), EdgeInfo)] -> CFG)
-> [((BlockId, BlockId), EdgeInfo)] -> CFG
forall a b. (a -> b) -> a -> b
$ (CmmBlock -> [((BlockId, BlockId), EdgeInfo)])
-> [CmmBlock] -> [((BlockId, BlockId), EdgeInfo)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CmmBlock -> [((BlockId, BlockId), EdgeInfo)]
getBlockEdges [CmmBlock]
blocks
  where
    D.CFGWeights
            { uncondWeight :: CfgWeights -> Int
D.uncondWeight = Int
uncondWeight
            , condBranchWeight :: CfgWeights -> Int
D.condBranchWeight = Int
condBranchWeight
            , switchWeight :: CfgWeights -> Int
D.switchWeight = Int
switchWeight
            , callWeight :: CfgWeights -> Int
D.callWeight = Int
callWeight
            , likelyCondWeight :: CfgWeights -> Int
D.likelyCondWeight = Int
likelyCondWeight
            , unlikelyCondWeight :: CfgWeights -> Int
D.unlikelyCondWeight = Int
unlikelyCondWeight
            --  Last two are used in other places
            --, D.infoTablePenalty = infoTablePenalty
            --, D.backEdgeBonus = backEdgeBonus
            } = CfgWeights
weights
    -- Explicitly add all nodes to the cfg to ensure they are part of the
    -- CFG.
    edgelessCfg :: CFG
edgelessCfg = [(KeyOf LabelMap, LabelMap EdgeInfo)] -> CFG
forall (map :: * -> *) a. IsMap map => [(KeyOf map, a)] -> map a
mapFromList ([(KeyOf LabelMap, LabelMap EdgeInfo)] -> CFG)
-> [(KeyOf LabelMap, LabelMap EdgeInfo)] -> CFG
forall a b. (a -> b) -> a -> b
$ [BlockId] -> [LabelMap EdgeInfo] -> [(BlockId, LabelMap EdgeInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((CmmBlock -> BlockId) -> [CmmBlock] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map CmmBlock -> BlockId
forall (thing :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
NonLocal thing =>
thing C x -> BlockId
G.entryLabel [CmmBlock]
blocks) (LabelMap EdgeInfo -> [LabelMap EdgeInfo]
forall a. a -> [a]
repeat LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => map a
mapEmpty)
    insertEdge :: CFG -> ((BlockId,BlockId),EdgeInfo) -> CFG
    insertEdge :: CFG -> ((BlockId, BlockId), EdgeInfo) -> CFG
insertEdge CFG
m ((BlockId
from,BlockId
to),EdgeInfo
weight) =
      (Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo))
-> KeyOf LabelMap -> CFG -> CFG
forall (map :: * -> *) a.
IsMap map =>
(Maybe a -> Maybe a) -> KeyOf map -> map a -> map a
mapAlter Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
f KeyOf LabelMap
BlockId
from CFG
m
        where
          f :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
          f :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
f Maybe (LabelMap EdgeInfo)
Nothing = LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a. a -> Maybe a
Just (LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo))
-> LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a b. (a -> b) -> a -> b
$ KeyOf LabelMap -> EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a. IsMap map => KeyOf map -> a -> map a
mapSingleton KeyOf LabelMap
BlockId
to EdgeInfo
weight
          f (Just LabelMap EdgeInfo
destMap) = LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a. a -> Maybe a
Just (LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo))
-> LabelMap EdgeInfo -> Maybe (LabelMap EdgeInfo)
forall a b. (a -> b) -> a -> b
$ KeyOf LabelMap
-> EdgeInfo -> LabelMap EdgeInfo -> LabelMap EdgeInfo
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
to EdgeInfo
weight LabelMap EdgeInfo
destMap
    getBlockEdges :: CmmBlock -> [((BlockId,BlockId),EdgeInfo)]
    getBlockEdges :: CmmBlock -> [((BlockId, BlockId), EdgeInfo)]
getBlockEdges CmmBlock
block =
      case CmmNode O C
branch of
        CmmBranch BlockId
dest -> [BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
dest Int
uncondWeight]
        CmmCondBranch CmmExpr
cond BlockId
t BlockId
f Maybe Bool
l
          | Maybe Bool
l Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Bool
forall a. Maybe a
Nothing ->
              [BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
f Int
condBranchWeight,   BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
t Int
condBranchWeight]
          | Maybe Bool
l Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True ->
              [BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
f Int
unlikelyCondWeight, BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
t Int
likelyCondWeight]
          | Maybe Bool
l Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False ->
              [BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
f Int
likelyCondWeight,   BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
t Int
unlikelyCondWeight]
          where
            mkEdgeInfo :: Int -> EdgeInfo
mkEdgeInfo = -- pprTrace "Info" (ppr branchInfo <+> ppr cond)
                         TransitionSource -> EdgeWeight -> EdgeInfo
EdgeInfo (CmmNode O C -> BranchInfo -> TransitionSource
CmmSource CmmNode O C
branch BranchInfo
branchInfo) (EdgeWeight -> EdgeInfo) -> (Int -> EdgeWeight) -> Int -> EdgeInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral
            mkEdge :: BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
target Int
weight = ((BlockId
bid,BlockId
target), Int -> EdgeInfo
mkEdgeInfo Int
weight)
            branchInfo :: BranchInfo
branchInfo =
              DynFlags
-> (BranchInfo -> GlobalReg -> BranchInfo)
-> BranchInfo
-> CmmExpr
-> BranchInfo
forall r a b.
UserOfRegs r a =>
DynFlags -> (b -> r -> b) -> b -> a -> b
foldRegsUsed
                (String -> DynFlags
forall a. String -> a
panic String
"foldRegsDynFlags")
                (\BranchInfo
info GlobalReg
r -> if GlobalReg
r GlobalReg -> GlobalReg -> Bool
forall a. Eq a => a -> a -> Bool
== GlobalReg
SpLim Bool -> Bool -> Bool
|| GlobalReg
r GlobalReg -> GlobalReg -> Bool
forall a. Eq a => a -> a -> Bool
== GlobalReg
HpLim Bool -> Bool -> Bool
|| GlobalReg
r GlobalReg -> GlobalReg -> Bool
forall a. Eq a => a -> a -> Bool
== GlobalReg
BaseReg
                    then BranchInfo
HeapStackCheck else BranchInfo
info)
                BranchInfo
NoInfo CmmExpr
cond

        (CmmSwitch CmmExpr
_e SwitchTargets
ids) ->
          let switchTargets :: [BlockId]
switchTargets = SwitchTargets -> [BlockId]
switchTargetsToList SwitchTargets
ids
              --Compiler performance hack - for very wide switches don't
              --consider targets for layout.
              adjustedWeight :: Int
adjustedWeight =
                if ([BlockId] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BlockId]
switchTargets Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) then -Int
1 else Int
switchWeight
          in (BlockId -> ((BlockId, BlockId), EdgeInfo))
-> [BlockId] -> [((BlockId, BlockId), EdgeInfo)]
forall a b. (a -> b) -> [a] -> [b]
map (\BlockId
x -> BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
x Int
adjustedWeight) [BlockId]
switchTargets
        (CmmCall { cml_cont :: CmmNode O C -> Maybe BlockId
cml_cont = Just BlockId
cont})  -> [BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
cont Int
callWeight]
        (CmmForeignCall {succ :: CmmNode O C -> BlockId
Cmm.succ = BlockId
cont}) -> [BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
cont Int
callWeight]
        (CmmCall { cml_cont :: CmmNode O C -> Maybe BlockId
cml_cont = Maybe BlockId
Nothing })   -> []
        CmmNode O C
other ->
            String
-> [((BlockId, BlockId), EdgeInfo)]
-> [((BlockId, BlockId), EdgeInfo)]
forall a. String -> a
panic String
"Foo" ([((BlockId, BlockId), EdgeInfo)]
 -> [((BlockId, BlockId), EdgeInfo)])
-> [((BlockId, BlockId), EdgeInfo)]
-> [((BlockId, BlockId), EdgeInfo)]
forall a b. (a -> b) -> a -> b
$
            ASSERT2(False, ppr "Unkown successor cause:" <>
              (ppr branch <+> text "=>" <> ppr (G.successors other)))
            (BlockId -> ((BlockId, BlockId), EdgeInfo))
-> [BlockId] -> [((BlockId, BlockId), EdgeInfo)]
forall a b. (a -> b) -> [a] -> [b]
map (\BlockId
x -> ((BlockId
bid,BlockId
x),Int -> EdgeInfo
mkEdgeInfo Int
0)) ([BlockId] -> [((BlockId, BlockId), EdgeInfo)])
-> [BlockId] -> [((BlockId, BlockId), EdgeInfo)]
forall a b. (a -> b) -> a -> b
$ CmmNode O C -> [BlockId]
forall (thing :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
NonLocal thing =>
thing e C -> [BlockId]
G.successors CmmNode O C
other
      where
        bid :: BlockId
bid = CmmBlock -> BlockId
forall (thing :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
NonLocal thing =>
thing C x -> BlockId
G.entryLabel CmmBlock
block
        mkEdgeInfo :: Int -> EdgeInfo
mkEdgeInfo = TransitionSource -> EdgeWeight -> EdgeInfo
EdgeInfo (CmmNode O C -> BranchInfo -> TransitionSource
CmmSource CmmNode O C
branch BranchInfo
NoInfo) (EdgeWeight -> EdgeInfo) -> (Int -> EdgeWeight) -> Int -> EdgeInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral
        mkEdge :: BlockId -> Int -> ((BlockId, BlockId), EdgeInfo)
mkEdge BlockId
target Int
weight = ((BlockId
bid,BlockId
target), Int -> EdgeInfo
mkEdgeInfo Int
weight)
        branch :: CmmNode O C
branch = CmmBlock -> CmmNode O C
forall (n :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
Block n x C -> n O C
lastNode CmmBlock
block :: CmmNode O C

    blocks :: [CmmBlock]
blocks = CmmGraph -> [CmmBlock]
revPostorder CmmGraph
graph :: [CmmBlock]

--Find back edges by BFS
findBackEdges :: HasDebugCallStack => BlockId -> CFG -> Edges
findBackEdges :: BlockId -> CFG -> [(BlockId, BlockId)]
findBackEdges BlockId
root CFG
cfg =
    --pprTraceIt "Backedges:" $
    (((BlockId, BlockId), EdgeType) -> (BlockId, BlockId))
-> [((BlockId, BlockId), EdgeType)] -> [(BlockId, BlockId)]
forall a b. (a -> b) -> [a] -> [b]
map ((BlockId, BlockId), EdgeType) -> (BlockId, BlockId)
forall a b. (a, b) -> a
fst ([((BlockId, BlockId), EdgeType)] -> [(BlockId, BlockId)])
-> ([((BlockId, BlockId), EdgeType)]
    -> [((BlockId, BlockId), EdgeType)])
-> [((BlockId, BlockId), EdgeType)]
-> [(BlockId, BlockId)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    (((BlockId, BlockId), EdgeType) -> Bool)
-> [((BlockId, BlockId), EdgeType)]
-> [((BlockId, BlockId), EdgeType)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\((BlockId, BlockId), EdgeType)
x -> ((BlockId, BlockId), EdgeType) -> EdgeType
forall a b. (a, b) -> b
snd ((BlockId, BlockId), EdgeType)
x EdgeType -> EdgeType -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeType
Backward) ([((BlockId, BlockId), EdgeType)] -> [(BlockId, BlockId)])
-> [((BlockId, BlockId), EdgeType)] -> [(BlockId, BlockId)]
forall a b. (a -> b) -> a -> b
$ [((BlockId, BlockId), EdgeType)]
typedEdges
  where
    edges :: [(BlockId, BlockId)]
edges = CFG -> [(BlockId, BlockId)]
edgeList CFG
cfg :: [(BlockId,BlockId)]
    getSuccs :: BlockId -> [BlockId]
getSuccs = HasDebugCallStack => CFG -> BlockId -> [BlockId]
CFG -> BlockId -> [BlockId]
getSuccessors CFG
cfg :: BlockId -> [BlockId]
    typedEdges :: [((BlockId, BlockId), EdgeType)]
typedEdges =
      BlockId
-> (BlockId -> [BlockId])
-> [(BlockId, BlockId)]
-> [((BlockId, BlockId), EdgeType)]
forall key.
Uniquable key =>
key -> (key -> [key]) -> [(key, key)] -> [((key, key), EdgeType)]
classifyEdges BlockId
root BlockId -> [BlockId]
getSuccs [(BlockId, BlockId)]
edges :: [((BlockId,BlockId),EdgeType)]


optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG
optimizeCFG :: CfgWeights -> RawCmmDecl -> CFG -> CFG
optimizeCFG CfgWeights
_ (CmmData {}) CFG
cfg = CFG
cfg
optimizeCFG CfgWeights
weights (CmmProc LabelMap CmmStatics
info CLabel
_lab [GlobalReg]
_live CmmGraph
graph) CFG
cfg =
    {-# SCC optimizeCFG #-}
    -- pprTrace "Initial:" (pprEdgeWeights cfg) $
    -- pprTrace "Initial:" (ppr $ mkGlobalWeights (g_entry graph) cfg) $

    -- pprTrace "LoopInfo:" (ppr $ loopInfo cfg (g_entry graph)) $
    CFG -> CFG
favourFewerPreds  (CFG -> CFG) -> (CFG -> CFG) -> CFG -> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    LabelMap CmmStatics -> CFG -> CFG
forall a. LabelMap a -> CFG -> CFG
penalizeInfoTables LabelMap CmmStatics
info (CFG -> CFG) -> (CFG -> CFG) -> CFG -> CFG
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    BlockId -> CFG -> CFG
increaseBackEdgeWeight (CmmGraph -> BlockId
forall (n :: Extensibility -> Extensibility -> *).
GenCmmGraph n -> BlockId
g_entry CmmGraph
graph) (CFG -> CFG) -> CFG -> CFG
forall a b. (a -> b) -> a -> b
$ CFG
cfg
  where

    -- | Increase the weight of all backedges in the CFG
    -- this helps to make loop jumpbacks the heaviest edges
    increaseBackEdgeWeight :: BlockId -> CFG -> CFG
    increaseBackEdgeWeight :: BlockId -> CFG -> CFG
increaseBackEdgeWeight BlockId
root CFG
cfg =
        let backedges :: [(BlockId, BlockId)]
backedges = HasDebugCallStack => BlockId -> CFG -> [(BlockId, BlockId)]
BlockId -> CFG -> [(BlockId, BlockId)]
findBackEdges BlockId
root CFG
cfg
            update :: EdgeWeight -> EdgeWeight
update EdgeWeight
weight
              --Keep irrelevant edges irrelevant
              | EdgeWeight
weight EdgeWeight -> EdgeWeight -> Bool
forall a. Ord a => a -> a -> Bool
<= EdgeWeight
0 = EdgeWeight
0
              | Bool
otherwise
              = EdgeWeight
weight EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
+ Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CfgWeights -> Int
D.backEdgeBonus CfgWeights
weights)
        in  (CFG -> (BlockId, BlockId) -> CFG)
-> CFG -> [(BlockId, BlockId)] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'  (\CFG
cfg (BlockId, BlockId)
edge -> (EdgeWeight -> EdgeWeight) -> (BlockId, BlockId) -> CFG -> CFG
updateEdgeWeight EdgeWeight -> EdgeWeight
update (BlockId, BlockId)
edge CFG
cfg)
                    CFG
cfg [(BlockId, BlockId)]
backedges

    -- | Since we cant fall through info tables we penalize these.
    penalizeInfoTables :: LabelMap a -> CFG -> CFG
    penalizeInfoTables :: LabelMap a -> CFG -> CFG
penalizeInfoTables LabelMap a
info CFG
cfg =
        (BlockId -> BlockId -> EdgeWeight -> EdgeWeight) -> CFG -> CFG
mapWeights BlockId -> BlockId -> EdgeWeight -> EdgeWeight
fupdate CFG
cfg
      where
        fupdate :: BlockId -> BlockId -> EdgeWeight -> EdgeWeight
        fupdate :: BlockId -> BlockId -> EdgeWeight -> EdgeWeight
fupdate BlockId
_ BlockId
to EdgeWeight
weight
          | KeyOf LabelMap -> LabelMap a -> Bool
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> Bool
mapMember KeyOf LabelMap
BlockId
to LabelMap a
info
          = EdgeWeight
weight EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
- (Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> EdgeWeight) -> Int -> EdgeWeight
forall a b. (a -> b) -> a -> b
$ CfgWeights -> Int
D.infoTablePenalty CfgWeights
weights)
          | Bool
otherwise = EdgeWeight
weight

    -- | If a block has two successors, favour the one with fewer
    -- predecessors and/or the one allowing fall through.
    favourFewerPreds :: CFG -> CFG
    favourFewerPreds :: CFG -> CFG
favourFewerPreds CFG
cfg =
        let
            revCfg :: CFG
revCfg =
              CFG -> CFG
reverseEdges (CFG -> CFG) -> CFG -> CFG
forall a b. (a -> b) -> a -> b
$ (BlockId -> BlockId -> EdgeInfo -> Bool) -> CFG -> CFG
filterEdges
                              (\BlockId
_from -> BlockId -> EdgeInfo -> Bool
fallthroughTarget)  CFG
cfg

            predCount :: BlockId -> Int
predCount BlockId
n = [(BlockId, EdgeInfo)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([(BlockId, EdgeInfo)] -> Int) -> [(BlockId, EdgeInfo)] -> Int
forall a b. (a -> b) -> a -> b
$ HasDebugCallStack => CFG -> BlockId -> [(BlockId, EdgeInfo)]
CFG -> BlockId -> [(BlockId, EdgeInfo)]
getSuccessorEdges CFG
revCfg BlockId
n
            nodes :: [BlockId]
nodes = CFG -> [BlockId]
getCfgNodes CFG
cfg

            modifiers :: Int -> Int -> (EdgeWeight, EdgeWeight)
            modifiers :: Int -> Int -> (EdgeWeight, EdgeWeight)
modifiers Int
preds1 Int
preds2
              | Int
preds1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<  Int
preds2 = ( EdgeWeight
1,-EdgeWeight
1)
              | Int
preds1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
preds2 = ( EdgeWeight
0, EdgeWeight
0)
              | Bool
otherwise        = (-EdgeWeight
1, EdgeWeight
1)

            update :: CFG -> BlockId -> CFG
            update :: CFG -> BlockId -> CFG
update CFG
cfg BlockId
node
              | [(BlockId
s1,EdgeInfo
e1),(BlockId
s2,EdgeInfo
e2)] <- HasDebugCallStack => CFG -> BlockId -> [(BlockId, EdgeInfo)]
CFG -> BlockId -> [(BlockId, EdgeInfo)]
getSuccessorEdges CFG
cfg BlockId
node
              , !EdgeWeight
w1 <- EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
e1
              , !EdgeWeight
w2 <- EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
e2
              --Only change the weights if there isn't already a ordering.
              , EdgeWeight
w1 EdgeWeight -> EdgeWeight -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeWeight
w2
              , (EdgeWeight
mod1,EdgeWeight
mod2) <- Int -> Int -> (EdgeWeight, EdgeWeight)
modifiers (BlockId -> Int
predCount BlockId
s1) (BlockId -> Int
predCount BlockId
s2)
              = (\CFG
cfg' ->
                  (CFG -> (EdgeWeight -> EdgeWeight) -> BlockId -> BlockId -> CFG
adjustEdgeWeight CFG
cfg' (EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
+EdgeWeight
mod2) BlockId
node BlockId
s2))
                    (CFG -> (EdgeWeight -> EdgeWeight) -> BlockId -> BlockId -> CFG
adjustEdgeWeight CFG
cfg  (EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
+EdgeWeight
mod1) BlockId
node BlockId
s1)
              | Bool
otherwise
              = CFG
cfg
        in (CFG -> BlockId -> CFG) -> CFG -> [BlockId] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' CFG -> BlockId -> CFG
update CFG
cfg [BlockId]
nodes
      where
        fallthroughTarget :: BlockId -> EdgeInfo -> Bool
        fallthroughTarget :: BlockId -> EdgeInfo -> Bool
fallthroughTarget BlockId
to (EdgeInfo TransitionSource
source EdgeWeight
_weight)
          | KeyOf LabelMap -> LabelMap CmmStatics -> Bool
forall (map :: * -> *) a. IsMap map => KeyOf map -> map a -> Bool
mapMember KeyOf LabelMap
BlockId
to LabelMap CmmStatics
info = Bool
False
          | TransitionSource
AsmCodeGen <- TransitionSource
source = Bool
True
          | CmmSource { trans_cmmNode :: TransitionSource -> CmmNode O C
trans_cmmNode = CmmBranch {} } <- TransitionSource
source = Bool
True
          | CmmSource { trans_cmmNode :: TransitionSource -> CmmNode O C
trans_cmmNode = CmmCondBranch {} } <- TransitionSource
source = Bool
True
          | Bool
otherwise = Bool
False

-- | Determine loop membership of blocks based on SCC analysis
--   This is faster but only gives yes/no answers.
loopMembers :: HasDebugCallStack => CFG -> LabelMap Bool
loopMembers :: CFG -> LabelMap Bool
loopMembers CFG
cfg =
    (LabelMap Bool -> SCC BlockId -> LabelMap Bool)
-> LabelMap Bool -> [SCC BlockId] -> LabelMap Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((SCC BlockId -> LabelMap Bool -> LabelMap Bool)
-> LabelMap Bool -> SCC BlockId -> LabelMap Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip SCC BlockId -> LabelMap Bool -> LabelMap Bool
setLevel) LabelMap Bool
forall (map :: * -> *) a. IsMap map => map a
mapEmpty [SCC BlockId]
sccs
  where
    mkNode :: BlockId -> Node BlockId BlockId
    mkNode :: BlockId -> Node BlockId BlockId
mkNode BlockId
bid = BlockId -> BlockId -> [BlockId] -> Node BlockId BlockId
forall key payload. payload -> key -> [key] -> Node key payload
DigraphNode BlockId
bid BlockId
bid (HasDebugCallStack => CFG -> BlockId -> [BlockId]
CFG -> BlockId -> [BlockId]
getSuccessors CFG
cfg BlockId
bid)
    nodes :: [Node BlockId BlockId]
nodes = (BlockId -> Node BlockId BlockId)
-> [BlockId] -> [Node BlockId BlockId]
forall a b. (a -> b) -> [a] -> [b]
map BlockId -> Node BlockId BlockId
mkNode (CFG -> [BlockId]
getCfgNodes CFG
cfg)

    sccs :: [SCC BlockId]
sccs = [Node BlockId BlockId] -> [SCC BlockId]
forall key payload. Ord key => [Node key payload] -> [SCC payload]
stronglyConnCompFromEdgedVerticesOrd [Node BlockId BlockId]
nodes

    setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
    setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
setLevel (AcyclicSCC BlockId
bid) LabelMap Bool
m = KeyOf LabelMap -> Bool -> LabelMap Bool -> LabelMap Bool
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
bid Bool
False LabelMap Bool
m
    setLevel (CyclicSCC [BlockId]
bids) LabelMap Bool
m = (LabelMap Bool -> BlockId -> LabelMap Bool)
-> LabelMap Bool -> [BlockId] -> LabelMap Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\LabelMap Bool
m BlockId
k -> KeyOf LabelMap -> Bool -> LabelMap Bool -> LabelMap Bool
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
k Bool
True LabelMap Bool
m) LabelMap Bool
m [BlockId]
bids

loopLevels :: CFG -> BlockId -> LabelMap Int
loopLevels :: CFG -> BlockId -> LabelMap Int
loopLevels CFG
cfg BlockId
root = LoopInfo -> LabelMap Int
liLevels LoopInfo
loopInfos
    where
      loopInfos :: LoopInfo
loopInfos = HasDebugCallStack => CFG -> BlockId -> LoopInfo
CFG -> BlockId -> LoopInfo
loopInfo CFG
cfg BlockId
root

data LoopInfo = LoopInfo
  { LoopInfo -> [(BlockId, BlockId)]
liBackEdges :: [(Edge)] -- ^ List of back edges
  , LoopInfo -> LabelMap Int
liLevels :: LabelMap Int -- ^ BlockId -> LoopLevel mapping
  , LoopInfo -> [((BlockId, BlockId), LabelSet)]
liLoops :: [(Edge, LabelSet)] -- ^ (backEdge, loopBody), body includes header
  }

instance Outputable LoopInfo where
    ppr :: LoopInfo -> SDoc
ppr (LoopInfo [(BlockId, BlockId)]
_ LabelMap Int
_lvls [((BlockId, BlockId), LabelSet)]
loops) =
        String -> SDoc
text String
"Loops:(backEdge, bodyNodes)" SDoc -> SDoc -> SDoc
$$
            ([SDoc] -> SDoc
vcat ([SDoc] -> SDoc) -> [SDoc] -> SDoc
forall a b. (a -> b) -> a -> b
$ (((BlockId, BlockId), LabelSet) -> SDoc)
-> [((BlockId, BlockId), LabelSet)] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map ((BlockId, BlockId), LabelSet) -> SDoc
forall a. Outputable a => a -> SDoc
ppr [((BlockId, BlockId), LabelSet)]
loops)

{-  Note [Determining the loop body]
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    Starting with the knowledge that:
    * head dominates the loop
    * `tail` -> `head` is a backedge

    We can determine all nodes by:
    * Deleting the loop head from the graph.
    * Collect all blocks which are reachable from the `tail`.

    We do so by performing bfs from the tail node towards the head.
 -}

-- | Determine loop membership of blocks based on Dominator analysis.
--   This is slower but gives loop levels instead of just loop membership.
--   However it only detects natural loops. Irreducible control flow is not
--   recognized even if it loops. But that is rare enough that we don't have
--   to care about that special case.
loopInfo :: HasDebugCallStack => CFG -> BlockId -> LoopInfo
loopInfo :: CFG -> BlockId -> LoopInfo
loopInfo CFG
cfg BlockId
root = LoopInfo :: [(BlockId, BlockId)]
-> LabelMap Int -> [((BlockId, BlockId), LabelSet)] -> LoopInfo
LoopInfo  { liBackEdges :: [(BlockId, BlockId)]
liBackEdges = [(BlockId, BlockId)]
backEdges
                              , liLevels :: LabelMap Int
liLevels = [(KeyOf LabelMap, Int)] -> LabelMap Int
forall (map :: * -> *) a. IsMap map => [(KeyOf map, a)] -> map a
mapFromList [(KeyOf LabelMap, Int)]
[(BlockId, Int)]
loopCounts
                              , liLoops :: [((BlockId, BlockId), LabelSet)]
liLoops = [((BlockId, BlockId), LabelSet)]
loopBodies }
  where
    revCfg :: CFG
revCfg = CFG -> CFG
reverseEdges CFG
cfg

    graph :: LabelMap LabelSet
graph = -- pprTrace "CFG - loopInfo" (pprEdgeWeights cfg) $
            (LabelMap EdgeInfo -> LabelSet) -> CFG -> LabelMap LabelSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([BlockId] -> LabelSet
forall set. IsSet set => [ElemOf set] -> set
setFromList ([BlockId] -> LabelSet)
-> (LabelMap EdgeInfo -> [BlockId])
-> LabelMap EdgeInfo
-> LabelSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LabelMap EdgeInfo -> [BlockId]
forall (map :: * -> *) a. IsMap map => map a -> [KeyOf map]
mapKeys ) CFG
cfg :: LabelMap LabelSet


    --TODO - This should be a no op: Export constructors? Use unsafeCoerce? ...
    rooted :: (Int, IntMap IntSet)
rooted = ( BlockId -> Int
fromBlockId BlockId
root
              , LabelMap IntSet -> IntMap IntSet
forall a. LabelMap a -> IntMap a
toIntMap (LabelMap IntSet -> IntMap IntSet)
-> LabelMap IntSet -> IntMap IntSet
forall a b. (a -> b) -> a -> b
$ (LabelSet -> IntSet) -> LabelMap LabelSet -> LabelMap IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap LabelSet -> IntSet
toIntSet LabelMap LabelSet
graph) :: (Int, IntMap IntSet)
    tree :: Tree BlockId
tree = (Int -> BlockId) -> Tree Int -> Tree BlockId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> BlockId
toBlockId (Tree Int -> Tree BlockId) -> Tree Int -> Tree BlockId
forall a b. (a -> b) -> a -> b
$ (Int, IntMap IntSet) -> Tree Int
Dom.domTree (Int, IntMap IntSet)
rooted :: Tree BlockId

    -- Map from Nodes to their dominators
    domMap :: LabelMap LabelSet
    domMap :: LabelMap LabelSet
domMap = Tree BlockId -> LabelMap LabelSet
mkDomMap Tree BlockId
tree

    edges :: [(BlockId, BlockId)]
edges = CFG -> [(BlockId, BlockId)]
edgeList CFG
cfg :: [(BlockId, BlockId)]
    -- We can't recompute nodes from edges, there might be blocks not connected via edges.
    nodes :: [BlockId]
nodes = CFG -> [BlockId]
getCfgNodes CFG
cfg :: [BlockId]

    -- identify back edges
    isBackEdge :: (BlockId, BlockId) -> Bool
isBackEdge (BlockId
from,BlockId
to)
      | Just LabelSet
doms <- KeyOf LabelMap -> LabelMap LabelSet -> Maybe LabelSet
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
from LabelMap LabelSet
domMap
      , ElemOf LabelSet -> LabelSet -> Bool
forall set. IsSet set => ElemOf set -> set -> Bool
setMember ElemOf LabelSet
BlockId
to LabelSet
doms
      = Bool
True
      | Bool
otherwise = Bool
False

    -- See Note [Determining the loop body]
    -- Get the loop body associated with a back edge.
    findBody :: (BlockId, BlockId) -> ((BlockId, BlockId), LabelSet)
findBody edge :: (BlockId, BlockId)
edge@(BlockId
tail, BlockId
head)
      = ( (BlockId, BlockId)
edge, ElemOf LabelSet -> LabelSet -> LabelSet
forall set. IsSet set => ElemOf set -> set -> set
setInsert ElemOf LabelSet
BlockId
head (LabelSet -> LabelSet) -> LabelSet -> LabelSet
forall a b. (a -> b) -> a -> b
$ LabelSet -> LabelSet -> LabelSet
go (ElemOf LabelSet -> LabelSet
forall set. IsSet set => ElemOf set -> set
setSingleton ElemOf LabelSet
BlockId
tail) (ElemOf LabelSet -> LabelSet
forall set. IsSet set => ElemOf set -> set
setSingleton ElemOf LabelSet
BlockId
tail) )
      where
        -- See Note [Determining the loop body]
        cfg' :: CFG
cfg' = BlockId -> CFG -> CFG
delNode BlockId
head CFG
revCfg

        go :: LabelSet -> LabelSet -> LabelSet
        go :: LabelSet -> LabelSet -> LabelSet
go LabelSet
found LabelSet
current
          | LabelSet -> Bool
forall set. IsSet set => set -> Bool
setNull LabelSet
current = LabelSet
found
          | Bool
otherwise = LabelSet -> LabelSet -> LabelSet
go  (LabelSet -> LabelSet -> LabelSet
forall set. IsSet set => set -> set -> set
setUnion LabelSet
newSuccessors LabelSet
found)
                            LabelSet
newSuccessors
          where
            -- Really predecessors, since we use the reversed cfg.
            newSuccessors :: LabelSet
newSuccessors = (ElemOf LabelSet -> Bool) -> LabelSet -> LabelSet
forall set. IsSet set => (ElemOf set -> Bool) -> set -> set
setFilter (\ElemOf LabelSet
n -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ElemOf LabelSet -> LabelSet -> Bool
forall set. IsSet set => ElemOf set -> set -> Bool
setMember ElemOf LabelSet
n LabelSet
found) LabelSet
successors :: LabelSet
            successors :: LabelSet
successors = [ElemOf LabelSet] -> LabelSet
forall set. IsSet set => [ElemOf set] -> set
setFromList ([ElemOf LabelSet] -> LabelSet) -> [ElemOf LabelSet] -> LabelSet
forall a b. (a -> b) -> a -> b
$ (BlockId -> [BlockId]) -> [BlockId] -> [BlockId]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
                                      (HasDebugCallStack => CFG -> BlockId -> [BlockId]
CFG -> BlockId -> [BlockId]
getSuccessors CFG
cfg')
                                      -- we filter head as it's no longer part of the cfg.
                                      ((BlockId -> Bool) -> [BlockId] -> [BlockId]
forall a. (a -> Bool) -> [a] -> [a]
filter (BlockId -> BlockId -> Bool
forall a. Eq a => a -> a -> Bool
/= BlockId
head) ([BlockId] -> [BlockId]) -> [BlockId] -> [BlockId]
forall a b. (a -> b) -> a -> b
$ LabelSet -> [ElemOf LabelSet]
forall set. IsSet set => set -> [ElemOf set]
setElems LabelSet
current) :: LabelSet

    backEdges :: [(BlockId, BlockId)]
backEdges = ((BlockId, BlockId) -> Bool)
-> [(BlockId, BlockId)] -> [(BlockId, BlockId)]
forall a. (a -> Bool) -> [a] -> [a]
filter (BlockId, BlockId) -> Bool
isBackEdge [(BlockId, BlockId)]
edges
    loopBodies :: [((BlockId, BlockId), LabelSet)]
loopBodies = ((BlockId, BlockId) -> ((BlockId, BlockId), LabelSet))
-> [(BlockId, BlockId)] -> [((BlockId, BlockId), LabelSet)]
forall a b. (a -> b) -> [a] -> [b]
map (BlockId, BlockId) -> ((BlockId, BlockId), LabelSet)
findBody [(BlockId, BlockId)]
backEdges :: [(Edge, LabelSet)]

    -- Block b is part of n loop bodies => loop nest level of n
    loopCounts :: [(BlockId, Int)]
loopCounts =
      let bodies :: [(BlockId, LabelSet)]
bodies = (((BlockId, BlockId), LabelSet) -> (BlockId, LabelSet))
-> [((BlockId, BlockId), LabelSet)] -> [(BlockId, LabelSet)]
forall a b. (a -> b) -> [a] -> [b]
map (((BlockId, BlockId) -> BlockId)
-> ((BlockId, BlockId), LabelSet) -> (BlockId, LabelSet)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (BlockId, BlockId) -> BlockId
forall a b. (a, b) -> b
snd) [((BlockId, BlockId), LabelSet)]
loopBodies -- [(Header, Body)]
          loopCount :: BlockId -> Int
loopCount BlockId
n = [BlockId] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([BlockId] -> Int) -> [BlockId] -> Int
forall a b. (a -> b) -> a -> b
$ [BlockId] -> [BlockId]
forall a. Eq a => [a] -> [a]
nub ([BlockId] -> [BlockId])
-> ([(BlockId, LabelSet)] -> [BlockId])
-> [(BlockId, LabelSet)]
-> [BlockId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((BlockId, LabelSet) -> BlockId)
-> [(BlockId, LabelSet)] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map (BlockId, LabelSet) -> BlockId
forall a b. (a, b) -> a
fst ([(BlockId, LabelSet)] -> [BlockId])
-> ([(BlockId, LabelSet)] -> [(BlockId, LabelSet)])
-> [(BlockId, LabelSet)]
-> [BlockId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((BlockId, LabelSet) -> Bool)
-> [(BlockId, LabelSet)] -> [(BlockId, LabelSet)]
forall a. (a -> Bool) -> [a] -> [a]
filter (ElemOf LabelSet -> LabelSet -> Bool
forall set. IsSet set => ElemOf set -> set -> Bool
setMember ElemOf LabelSet
BlockId
n (LabelSet -> Bool)
-> ((BlockId, LabelSet) -> LabelSet) -> (BlockId, LabelSet) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BlockId, LabelSet) -> LabelSet
forall a b. (a, b) -> b
snd) ([(BlockId, LabelSet)] -> [BlockId])
-> [(BlockId, LabelSet)] -> [BlockId]
forall a b. (a -> b) -> a -> b
$ [(BlockId, LabelSet)]
bodies
      in  (BlockId -> (BlockId, Int)) -> [BlockId] -> [(BlockId, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (\BlockId
n -> (BlockId
n, BlockId -> Int
loopCount BlockId
n)) ([BlockId] -> [(BlockId, Int)]) -> [BlockId] -> [(BlockId, Int)]
forall a b. (a -> b) -> a -> b
$ [BlockId]
nodes :: [(BlockId, Int)]

    toIntSet :: LabelSet -> IntSet
    toIntSet :: LabelSet -> IntSet
toIntSet LabelSet
s = [Int] -> IntSet
IS.fromList ([Int] -> IntSet) -> (LabelSet -> [Int]) -> LabelSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BlockId -> Int) -> [BlockId] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map BlockId -> Int
fromBlockId ([BlockId] -> [Int])
-> (LabelSet -> [BlockId]) -> LabelSet -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LabelSet -> [BlockId]
forall set. IsSet set => set -> [ElemOf set]
setElems (LabelSet -> IntSet) -> LabelSet -> IntSet
forall a b. (a -> b) -> a -> b
$ LabelSet
s
    toIntMap :: LabelMap a -> IntMap a
    toIntMap :: LabelMap a -> IntMap a
toIntMap LabelMap a
m = [(Int, a)] -> IntMap a
forall a. [(Int, a)] -> IntMap a
IM.fromList ([(Int, a)] -> IntMap a) -> [(Int, a)] -> IntMap a
forall a b. (a -> b) -> a -> b
$ ((BlockId, a) -> (Int, a)) -> [(BlockId, a)] -> [(Int, a)]
forall a b. (a -> b) -> [a] -> [b]
map (\(BlockId
x,a
y) -> (BlockId -> Int
fromBlockId BlockId
x,a
y)) ([(BlockId, a)] -> [(Int, a)]) -> [(BlockId, a)] -> [(Int, a)]
forall a b. (a -> b) -> a -> b
$ LabelMap a -> [(KeyOf LabelMap, a)]
forall (map :: * -> *) a. IsMap map => map a -> [(KeyOf map, a)]
mapToList LabelMap a
m

    mkDomMap :: Tree BlockId -> LabelMap LabelSet
    mkDomMap :: Tree BlockId -> LabelMap LabelSet
mkDomMap Tree BlockId
root = [(KeyOf LabelMap, LabelSet)] -> LabelMap LabelSet
forall (map :: * -> *) a. IsMap map => [(KeyOf map, a)] -> map a
mapFromList ([(KeyOf LabelMap, LabelSet)] -> LabelMap LabelSet)
-> [(KeyOf LabelMap, LabelSet)] -> LabelMap LabelSet
forall a b. (a -> b) -> a -> b
$ LabelSet -> Tree BlockId -> [(BlockId, LabelSet)]
go LabelSet
forall set. IsSet set => set
setEmpty Tree BlockId
root
      where
        go :: LabelSet -> Tree BlockId -> [(Label,LabelSet)]
        go :: LabelSet -> Tree BlockId -> [(BlockId, LabelSet)]
go LabelSet
parents (Node BlockId
lbl [])
          =  [(BlockId
lbl, LabelSet
parents)]
        go LabelSet
parents (Node BlockId
_ [Tree BlockId]
leaves)
          = let nodes :: [BlockId]
nodes = (Tree BlockId -> BlockId) -> [Tree BlockId] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map Tree BlockId -> BlockId
forall a. Tree a -> a
rootLabel [Tree BlockId]
leaves
                entries :: [(BlockId, LabelSet)]
entries = (BlockId -> (BlockId, LabelSet))
-> [BlockId] -> [(BlockId, LabelSet)]
forall a b. (a -> b) -> [a] -> [b]
map (\BlockId
x -> (BlockId
x,LabelSet
parents)) [BlockId]
nodes
            in  [(BlockId, LabelSet)]
entries [(BlockId, LabelSet)]
-> [(BlockId, LabelSet)] -> [(BlockId, LabelSet)]
forall a. [a] -> [a] -> [a]
++ (Tree BlockId -> [(BlockId, LabelSet)])
-> [Tree BlockId] -> [(BlockId, LabelSet)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
                            (\Tree BlockId
n -> LabelSet -> Tree BlockId -> [(BlockId, LabelSet)]
go (ElemOf LabelSet -> LabelSet -> LabelSet
forall set. IsSet set => ElemOf set -> set -> set
setInsert (Tree BlockId -> BlockId
forall a. Tree a -> a
rootLabel Tree BlockId
n) LabelSet
parents) Tree BlockId
n)
                            [Tree BlockId]
leaves

    fromBlockId :: BlockId -> Int
    fromBlockId :: BlockId -> Int
fromBlockId = Unique -> Int
getKey (Unique -> Int) -> (BlockId -> Unique) -> BlockId -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockId -> Unique
forall a. Uniquable a => a -> Unique
getUnique

    toBlockId :: Int -> BlockId
    toBlockId :: Int -> BlockId
toBlockId = Unique -> BlockId
mkBlockId (Unique -> BlockId) -> (Int -> Unique) -> Int -> BlockId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Unique
mkUniqueGrimily

-- We make the CFG a Hoopl Graph, so we can reuse revPostOrder.
newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId])

instance G.NonLocal (BlockNode) where
  entryLabel :: BlockNode C x -> BlockId
entryLabel (BN (BlockId
lbl,[BlockId]
_))   = BlockId
lbl
  successors :: BlockNode e C -> [BlockId]
successors (BN (BlockId
_,[BlockId]
succs)) = [BlockId]
succs

revPostorderFrom :: HasDebugCallStack => CFG -> BlockId -> [BlockId]
revPostorderFrom :: CFG -> BlockId -> [BlockId]
revPostorderFrom CFG
cfg BlockId
root =
    (BlockNode C C -> BlockId) -> [BlockNode C C] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map BlockNode C C -> BlockId
fromNode ([BlockNode C C] -> [BlockId]) -> [BlockNode C C] -> [BlockId]
forall a b. (a -> b) -> a -> b
$ LabelMap (BlockNode C C) -> BlockId -> [BlockNode C C]
forall (block :: Extensibility -> Extensibility -> *).
NonLocal block =>
LabelMap (block C C) -> BlockId -> [block C C]
G.revPostorderFrom LabelMap (BlockNode C C)
hooplGraph BlockId
root
  where
    nodes :: [BlockId]
nodes = CFG -> [BlockId]
getCfgNodes CFG
cfg
    hooplGraph :: LabelMap (BlockNode C C)
hooplGraph = (LabelMap (BlockNode C C) -> BlockId -> LabelMap (BlockNode C C))
-> LabelMap (BlockNode C C)
-> [BlockId]
-> LabelMap (BlockNode C C)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\LabelMap (BlockNode C C)
m BlockId
n -> KeyOf LabelMap
-> BlockNode C C
-> LabelMap (BlockNode C C)
-> LabelMap (BlockNode C C)
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
n (BlockId -> BlockNode C C
toNode BlockId
n) LabelMap (BlockNode C C)
m) LabelMap (BlockNode C C)
forall (map :: * -> *) a. IsMap map => map a
mapEmpty [BlockId]
nodes

    fromNode :: BlockNode C C -> BlockId
    fromNode :: BlockNode C C -> BlockId
fromNode (BN (BlockId, [BlockId])
x) = (BlockId, [BlockId]) -> BlockId
forall a b. (a, b) -> a
fst (BlockId, [BlockId])
x

    toNode :: BlockId -> BlockNode C C
    toNode :: BlockId -> BlockNode C C
toNode BlockId
bid =
        (BlockId, [BlockId]) -> BlockNode C C
forall (e :: Extensibility) (x :: Extensibility).
(BlockId, [BlockId]) -> BlockNode e x
BN (BlockId
bid,HasDebugCallStack => CFG -> BlockId -> [BlockId]
CFG -> BlockId -> [BlockId]
getSuccessors CFG
cfg (BlockId -> [BlockId]) -> BlockId -> [BlockId]
forall a b. (a -> b) -> a -> b
$ BlockId
bid)


-- | We take in a CFG which has on its edges weights which are
--   relative only to other edges originating from the same node.
--
--   We return a CFG for which each edge represents a GLOBAL weight.
--   This means edge weights are comparable across the whole graph.
--
--   For irreducible control flow results might be imprecise, otherwise they
--   are reliable.
--
--   The algorithm is based on the Paper
--   "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus
--   The only big change is that we go over the nodes in the body of loops in
--   reverse post order. Which is required for diamond control flow to work probably.
--
--   We also apply a few prediction heuristics (based on the same paper)

{-# NOINLINE mkGlobalWeights #-}
{-# SCC mkGlobalWeights #-}
mkGlobalWeights :: HasDebugCallStack => BlockId -> CFG -> (LabelMap Double, LabelMap (LabelMap Double))
mkGlobalWeights :: BlockId -> CFG -> (LabelMap Double, LabelMap (LabelMap Double))
mkGlobalWeights BlockId
root CFG
localCfg
  | CFG -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null CFG
localCfg = String -> (LabelMap Double, LabelMap (LabelMap Double))
forall a. String -> a
panic String
"Error - Empty CFG"
  | Bool
otherwise
  = (LabelMap Double
blockFreqs', LabelMap (LabelMap Double)
edgeFreqs')
  where
    -- Calculate fixpoints
    (Array Int Double
blockFreqs, IntMap (IntMap Double)
edgeFreqs) = IntMap (IntMap Double)
-> [(Int, Int)]
-> [(Int, [Int])]
-> [Int]
-> (Array Int Double, IntMap (IntMap Double))
calcFreqs IntMap (IntMap Double)
nodeProbs [(Int, Int)]
backEdges' [(Int, [Int])]
bodies' [Int]
revOrder'
    blockFreqs' :: LabelMap Double
blockFreqs' = [(KeyOf LabelMap, Double)] -> LabelMap Double
forall (map :: * -> *) a. IsMap map => [(KeyOf map, a)] -> map a
mapFromList ([(KeyOf LabelMap, Double)] -> LabelMap Double)
-> [(KeyOf LabelMap, Double)] -> LabelMap Double
forall a b. (a -> b) -> a -> b
$ ((Int, Double) -> (BlockId, Double))
-> [(Int, Double)] -> [(BlockId, Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((Int -> BlockId) -> (Int, Double) -> (BlockId, Double)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Int -> BlockId
fromVertex) (Array Int Double -> [(Int, Double)]
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> [(i, e)]
assocs Array Int Double
blockFreqs) :: LabelMap Double
    edgeFreqs' :: LabelMap (LabelMap Double)
edgeFreqs' = (IntMap Double -> LabelMap Double)
-> LabelMap (IntMap Double) -> LabelMap (LabelMap Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap IntMap Double -> LabelMap Double
forall x. IntMap x -> LabelMap x
fromVertexMap (LabelMap (IntMap Double) -> LabelMap (LabelMap Double))
-> LabelMap (IntMap Double) -> LabelMap (LabelMap Double)
forall a b. (a -> b) -> a -> b
$ IntMap (IntMap Double) -> LabelMap (IntMap Double)
forall x. IntMap x -> LabelMap x
fromVertexMap IntMap (IntMap Double)
edgeFreqs

    fromVertexMap :: IM.IntMap x -> LabelMap x
    fromVertexMap :: IntMap x -> LabelMap x
fromVertexMap IntMap x
m = [(BlockId, x)] -> LabelMap x
forall (map :: * -> *) a. IsMap map => [(KeyOf map, a)] -> map a
mapFromList ([(BlockId, x)] -> LabelMap x)
-> ([(Int, x)] -> [(BlockId, x)]) -> [(Int, x)] -> LabelMap x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, x) -> (BlockId, x)) -> [(Int, x)] -> [(BlockId, x)]
forall a b. (a -> b) -> [a] -> [b]
map ((Int -> BlockId) -> (Int, x) -> (BlockId, x)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Int -> BlockId
fromVertex) ([(Int, x)] -> LabelMap x) -> [(Int, x)] -> LabelMap x
forall a b. (a -> b) -> a -> b
$ IntMap x -> [(Int, x)]
forall a. IntMap a -> [(Int, a)]
IM.toList IntMap x
m

    revOrder :: [BlockId]
revOrder = HasDebugCallStack => CFG -> BlockId -> [BlockId]
CFG -> BlockId -> [BlockId]
revPostorderFrom CFG
localCfg BlockId
root :: [BlockId]
    loopResults :: LoopInfo
loopResults@(LoopInfo [(BlockId, BlockId)]
backedges LabelMap Int
_levels [((BlockId, BlockId), LabelSet)]
bodies) = HasDebugCallStack => CFG -> BlockId -> LoopInfo
CFG -> BlockId -> LoopInfo
loopInfo CFG
localCfg BlockId
root

    revOrder' :: [Int]
revOrder' = (BlockId -> Int) -> [BlockId] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map BlockId -> Int
toVertex [BlockId]
revOrder
    backEdges' :: [(Int, Int)]
backEdges' = ((BlockId, BlockId) -> (Int, Int))
-> [(BlockId, BlockId)] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((BlockId -> Int)
-> (BlockId -> Int) -> (BlockId, BlockId) -> (Int, Int)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap BlockId -> Int
toVertex BlockId -> Int
toVertex) [(BlockId, BlockId)]
backedges
    bodies' :: [(Int, [Int])]
bodies' = (((BlockId, BlockId), LabelSet) -> (Int, [Int]))
-> [((BlockId, BlockId), LabelSet)] -> [(Int, [Int])]
forall a b. (a -> b) -> [a] -> [b]
map ((BlockId, BlockId), LabelSet) -> (Int, [Int])
forall set a.
(IsSet set, ElemOf set ~ BlockId) =>
((a, BlockId), set) -> (Int, [Int])
calcBody [((BlockId, BlockId), LabelSet)]
bodies

    estimatedCfg :: CFG
estimatedCfg = BlockId -> LoopInfo -> CFG -> CFG
staticBranchPrediction BlockId
root LoopInfo
loopResults CFG
localCfg
    -- Normalize the weights to probabilities and apply heuristics
    nodeProbs :: IntMap (IntMap Double)
nodeProbs = CFG -> (BlockId -> Int) -> IntMap (IntMap Double)
cfgEdgeProbabilities CFG
estimatedCfg BlockId -> Int
toVertex

    -- By mapping vertices to numbers in reverse post order we can bring any subset into reverse post
    -- order simply by sorting.
    -- TODO: The sort is redundant if we can guarantee that setElems returns elements ascending
    calcBody :: ((a, BlockId), set) -> (Int, [Int])
calcBody ((a, BlockId)
backedge, set
blocks) =
        (BlockId -> Int
toVertex (BlockId -> Int) -> BlockId -> Int
forall a b. (a -> b) -> a -> b
$ (a, BlockId) -> BlockId
forall a b. (a, b) -> b
snd (a, BlockId)
backedge, [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort ([Int] -> [Int]) -> ([BlockId] -> [Int]) -> [BlockId] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BlockId -> Int) -> [BlockId] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map BlockId -> Int
toVertex ([BlockId] -> [Int]) -> [BlockId] -> [Int]
forall a b. (a -> b) -> a -> b
$ (set -> [ElemOf set]
forall set. IsSet set => set -> [ElemOf set]
setElems set
blocks))

    vertexMapping :: LabelMap Int
vertexMapping = [(KeyOf LabelMap, Int)] -> LabelMap Int
forall (map :: * -> *) a. IsMap map => [(KeyOf map, a)] -> map a
mapFromList ([(KeyOf LabelMap, Int)] -> LabelMap Int)
-> [(KeyOf LabelMap, Int)] -> LabelMap Int
forall a b. (a -> b) -> a -> b
$ [BlockId] -> [Int] -> [(BlockId, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [BlockId]
revOrder [Int
0..] :: LabelMap Int
    blockMapping :: Array Int BlockId
blockMapping = (Int, Int) -> [BlockId] -> Array Int BlockId
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Int
0,LabelMap Int -> Int
forall (map :: * -> *) a. IsMap map => map a -> Int
mapSize LabelMap Int
vertexMapping Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [BlockId]
revOrder :: Array Int BlockId
    -- Map from blockId to indicies starting at zero
    toVertex :: BlockId -> Int
    toVertex :: BlockId -> Int
toVertex   BlockId
blockId  = String -> Maybe Int -> Int
forall a. HasCallStack => String -> Maybe a -> a
expectJust String
"mkGlobalWeights" (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ KeyOf LabelMap -> LabelMap Int -> Maybe Int
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
blockId LabelMap Int
vertexMapping
    -- Map from indicies starting at zero to blockIds
    fromVertex :: Int -> BlockId
    fromVertex :: Int -> BlockId
fromVertex Int
vertex   = Array Int BlockId
blockMapping Array Int BlockId -> Int -> BlockId
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! Int
vertex

{- Note [Static Branch Prediction]
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The work here has been based on the paper
"Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus.

The primary differences are that if we branch on the result of a heap
check we do not apply any of the heuristics.
The reason is simple: They look like loops in the control flow graph
but are usually never entered, and if at most once.

Currently implemented is a heuristic to predict that we do not exit
loops (lehPredicts) and one to predict that backedges are more likely
than any other edge.

The back edge case is special as it superceeds any other heuristic if it
applies.

Do NOT rely solely on nofib results for benchmarking this. I recommend at least
comparing megaparsec and container benchmarks. Nofib does not seeem to have
many instances of "loopy" Cmm where these make a difference.

TODO:
* The paper containers more benchmarks which should be implemented.
* If we turn the likelyhood on if/else branches into a probability
  instead of true/false we could implement this as a Cmm pass.
  + The complete Cmm code still exists and can be accessed by the heuristics
  + There is no chance of register allocation/codegen inserting branches/blocks
  + making the TransitionSource info wrong.
  + potential to use this information in CmmPasses.
  - Requires refactoring of all the code relying on the binary nature of likelyhood.
  - Requires refactoring `loopInfo` to work on both, Cmm Graphs and the backend CFG.
-}

-- | Combination of target node id and information about the branch
--   we are looking at.
type TargetNodeInfo = (BlockId, EdgeInfo)


-- | Update branch weights based on certain heuristics.
-- See Note [Static Branch Prediction]
-- TODO: This should be combined with optimizeCFG
{-# SCC staticBranchPrediction #-}
staticBranchPrediction :: BlockId -> LoopInfo -> CFG -> CFG
staticBranchPrediction :: BlockId -> LoopInfo -> CFG -> CFG
staticBranchPrediction BlockId
_root (LoopInfo [(BlockId, BlockId)]
l_backEdges LabelMap Int
loopLevels [((BlockId, BlockId), LabelSet)]
l_loops) CFG
cfg =
    -- pprTrace "staticEstimatesOn" (ppr (cfg)) $
    (CFG -> BlockId -> CFG) -> CFG -> [BlockId] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' CFG -> BlockId -> CFG
update CFG
cfg [BlockId]
nodes
  where
    nodes :: [BlockId]
nodes = CFG -> [BlockId]
getCfgNodes CFG
cfg
    backedges :: Set (BlockId, BlockId)
backedges = [(BlockId, BlockId)] -> Set (BlockId, BlockId)
forall a. Ord a => [a] -> Set a
S.fromList ([(BlockId, BlockId)] -> Set (BlockId, BlockId))
-> [(BlockId, BlockId)] -> Set (BlockId, BlockId)
forall a b. (a -> b) -> a -> b
$ [(BlockId, BlockId)]
l_backEdges
    -- Loops keyed by their back edge
    loops :: Map (BlockId, BlockId) LabelSet
loops = [((BlockId, BlockId), LabelSet)] -> Map (BlockId, BlockId) LabelSet
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([((BlockId, BlockId), LabelSet)]
 -> Map (BlockId, BlockId) LabelSet)
-> [((BlockId, BlockId), LabelSet)]
-> Map (BlockId, BlockId) LabelSet
forall a b. (a -> b) -> a -> b
$ [((BlockId, BlockId), LabelSet)]
l_loops :: M.Map Edge LabelSet
    loopHeads :: Set BlockId
loopHeads = [BlockId] -> Set BlockId
forall a. Ord a => [a] -> Set a
S.fromList ([BlockId] -> Set BlockId) -> [BlockId] -> Set BlockId
forall a b. (a -> b) -> a -> b
$ ((BlockId, BlockId) -> BlockId)
-> [(BlockId, BlockId)] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map (BlockId, BlockId) -> BlockId
forall a b. (a, b) -> b
snd ([(BlockId, BlockId)] -> [BlockId])
-> [(BlockId, BlockId)] -> [BlockId]
forall a b. (a -> b) -> a -> b
$ Map (BlockId, BlockId) LabelSet -> [(BlockId, BlockId)]
forall k a. Map k a -> [k]
M.keys Map (BlockId, BlockId) LabelSet
loops

    update :: CFG -> BlockId -> CFG
    update :: CFG -> BlockId -> CFG
update CFG
cfg BlockId
node
        -- No successors, nothing to do.
        | [(BlockId, EdgeInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(BlockId, EdgeInfo)]
successors = CFG
cfg

        -- Mix of backedges and others:
        -- Always predict the backedges.
        | Bool -> Bool
not ([(BlockId, EdgeInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(BlockId, EdgeInfo)]
m) Bool -> Bool -> Bool
&& [(BlockId, EdgeInfo)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(BlockId, EdgeInfo)]
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [(BlockId, EdgeInfo)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(BlockId, EdgeInfo)]
successors
        -- Heap/Stack checks "loop", but only once.
        -- So we simply exclude any case involving them.
        , Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((BlockId, EdgeInfo) -> Bool) -> [(BlockId, EdgeInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (TransitionSource -> Bool
isHeapOrStackCheck  (TransitionSource -> Bool)
-> ((BlockId, EdgeInfo) -> TransitionSource)
-> (BlockId, EdgeInfo)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EdgeInfo -> TransitionSource
transitionSource (EdgeInfo -> TransitionSource)
-> ((BlockId, EdgeInfo) -> EdgeInfo)
-> (BlockId, EdgeInfo)
-> TransitionSource
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BlockId, EdgeInfo) -> EdgeInfo
forall a b. (a, b) -> b
snd) [(BlockId, EdgeInfo)]
successors
        = let   loopChance :: [EdgeWeight]
loopChance = EdgeWeight -> [EdgeWeight]
forall a. a -> [a]
repeat (EdgeWeight -> [EdgeWeight]) -> EdgeWeight -> [EdgeWeight]
forall a b. (a -> b) -> a -> b
$! EdgeWeight
pred_LBH EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Fractional a => a -> a -> a
/ (Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> EdgeWeight) -> Int -> EdgeWeight
forall a b. (a -> b) -> a -> b
$ [(BlockId, EdgeInfo)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(BlockId, EdgeInfo)]
m)
                exitChance :: [EdgeWeight]
exitChance = EdgeWeight -> [EdgeWeight]
forall a. a -> [a]
repeat (EdgeWeight -> [EdgeWeight]) -> EdgeWeight -> [EdgeWeight]
forall a b. (a -> b) -> a -> b
$! (EdgeWeight
1 EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
- EdgeWeight
pred_LBH) EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Fractional a => a -> a -> a
/ Int -> EdgeWeight
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([(BlockId, EdgeInfo)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(BlockId, EdgeInfo)]
not_m)
                updates :: [(BlockId, EdgeWeight)]
updates = [BlockId] -> [EdgeWeight] -> [(BlockId, EdgeWeight)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((BlockId, EdgeInfo) -> BlockId)
-> [(BlockId, EdgeInfo)] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map (BlockId, EdgeInfo) -> BlockId
forall a b. (a, b) -> a
fst [(BlockId, EdgeInfo)]
m) [EdgeWeight]
loopChance [(BlockId, EdgeWeight)]
-> [(BlockId, EdgeWeight)] -> [(BlockId, EdgeWeight)]
forall a. [a] -> [a] -> [a]
++ [BlockId] -> [EdgeWeight] -> [(BlockId, EdgeWeight)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((BlockId, EdgeInfo) -> BlockId)
-> [(BlockId, EdgeInfo)] -> [BlockId]
forall a b. (a -> b) -> [a] -> [b]
map (BlockId, EdgeInfo) -> BlockId
forall a b. (a, b) -> a
fst [(BlockId, EdgeInfo)]
not_m) [EdgeWeight]
exitChance
        in  -- pprTrace "mix" (ppr (node,successors)) $
            (CFG -> (BlockId, EdgeWeight) -> CFG)
-> CFG -> [(BlockId, EdgeWeight)] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\CFG
cfg (BlockId
to,EdgeWeight
weight) -> CFG -> EdgeWeight -> BlockId -> BlockId -> CFG
setEdgeWeight CFG
cfg EdgeWeight
weight BlockId
node BlockId
to) CFG
cfg [(BlockId, EdgeWeight)]
updates

        -- For (regular) non-binary branches we keep the weights from the STG -> Cmm translation.
        | [(BlockId, EdgeInfo)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(BlockId, EdgeInfo)]
successors Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
2
        = CFG
cfg

        -- Only backedges - no need to adjust
        | [(BlockId, EdgeInfo)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(BlockId, EdgeInfo)]
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
        = CFG
cfg

        -- A regular binary branch, we can plug addition predictors in here.
        | [(BlockId
s1,EdgeInfo
s1_info),(BlockId
s2,EdgeInfo
s2_info)] <- [(BlockId, EdgeInfo)]
successors
        , Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((BlockId, EdgeInfo) -> Bool) -> [(BlockId, EdgeInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (TransitionSource -> Bool
isHeapOrStackCheck  (TransitionSource -> Bool)
-> ((BlockId, EdgeInfo) -> TransitionSource)
-> (BlockId, EdgeInfo)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EdgeInfo -> TransitionSource
transitionSource (EdgeInfo -> TransitionSource)
-> ((BlockId, EdgeInfo) -> EdgeInfo)
-> (BlockId, EdgeInfo)
-> TransitionSource
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BlockId, EdgeInfo) -> EdgeInfo
forall a b. (a, b) -> b
snd) [(BlockId, EdgeInfo)]
successors
        = -- Normalize weights to total of 1
            let !w1 :: EdgeWeight
w1 = EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Ord a => a -> a -> a
max (EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
s1_info) (EdgeWeight
0)
                !w2 :: EdgeWeight
w2 = EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Ord a => a -> a -> a
max (EdgeInfo -> EdgeWeight
edgeWeight EdgeInfo
s2_info) (EdgeWeight
0)
                -- Of both weights are <= 0 we set both to 0.5
                normalizeWeight :: EdgeWeight -> EdgeWeight
normalizeWeight EdgeWeight
w = if EdgeWeight
w1 EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
+ EdgeWeight
w2 EdgeWeight -> EdgeWeight -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeWeight
0 then EdgeWeight
0.5 else EdgeWeight
wEdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Fractional a => a -> a -> a
/(EdgeWeight
w1EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
+EdgeWeight
w2)
                !cfg' :: CFG
cfg'  = CFG -> EdgeWeight -> BlockId -> BlockId -> CFG
setEdgeWeight CFG
cfg  (EdgeWeight -> EdgeWeight
normalizeWeight EdgeWeight
w1) BlockId
node BlockId
s1
                !cfg'' :: CFG
cfg'' = CFG -> EdgeWeight -> BlockId -> BlockId -> CFG
setEdgeWeight CFG
cfg' (EdgeWeight -> EdgeWeight
normalizeWeight EdgeWeight
w2) BlockId
node BlockId
s2

                -- Figure out which heuristics apply to these successors
                heuristics :: [Maybe Double]
heuristics = ((((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double)
 -> Maybe Double)
-> [((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double]
-> [Maybe Double]
forall a b. (a -> b) -> [a] -> [b]
map ((((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double)
-> ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
forall a b. (a -> b) -> a -> b
$ ((BlockId
s1,EdgeInfo
s1_info),(BlockId
s2,EdgeInfo
s2_info)))
                            [((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
lehPredicts, ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
forall b a. b -> Maybe a
phPredicts, ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
ohPredicts, ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
forall b a. b -> Maybe a
ghPredicts, ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
forall b a. b -> Maybe a
lhhPredicts, ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
forall b a. b -> Maybe a
chPredicts
                            , ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
forall b a. b -> Maybe a
shPredicts, ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
forall b a. b -> Maybe a
rhPredicts]
                -- Apply result of a heuristic. Argument is the likelyhood
                -- predicted for s1.
                applyHeuristic :: CFG -> Maybe Prob -> CFG
                applyHeuristic :: CFG -> Maybe Double -> CFG
applyHeuristic CFG
cfg Maybe Double
Nothing = CFG
cfg
                applyHeuristic CFG
cfg (Just (Double
s1_pred :: Double))
                  | EdgeWeight
s1_old EdgeWeight -> EdgeWeight -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeWeight
0 Bool -> Bool -> Bool
|| EdgeWeight
s2_old EdgeWeight -> EdgeWeight -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeWeight
0 Bool -> Bool -> Bool
||
                    TransitionSource -> Bool
isHeapOrStackCheck (EdgeInfo -> TransitionSource
transitionSource EdgeInfo
s1_info) Bool -> Bool -> Bool
||
                    TransitionSource -> Bool
isHeapOrStackCheck (EdgeInfo -> TransitionSource
transitionSource EdgeInfo
s2_info)
                  = CFG
cfg
                  | Bool
otherwise =
                    let -- Predictions from heuristic
                        s1_prob :: EdgeWeight
s1_prob = Double -> EdgeWeight
EdgeWeight Double
s1_pred :: EdgeWeight
                        s2_prob :: EdgeWeight
s2_prob = EdgeWeight
1.0 EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
- EdgeWeight
s1_prob
                        -- Update
                        d :: EdgeWeight
d = (EdgeWeight
s1_old EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
* EdgeWeight
s1_prob) EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
+ (EdgeWeight
s2_old EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
* EdgeWeight
s2_prob) :: EdgeWeight
                        s1_prob' :: EdgeWeight
s1_prob' = EdgeWeight
s1_old EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
* EdgeWeight
s1_prob EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Fractional a => a -> a -> a
/ EdgeWeight
d
                        !s2_prob' :: EdgeWeight
s2_prob' = EdgeWeight
s2_old EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Num a => a -> a -> a
* EdgeWeight
s2_prob EdgeWeight -> EdgeWeight -> EdgeWeight
forall a. Fractional a => a -> a -> a
/ EdgeWeight
d
                        !cfg_s1 :: CFG
cfg_s1 = CFG -> EdgeWeight -> BlockId -> BlockId -> CFG
setEdgeWeight CFG
cfg    EdgeWeight
s1_prob' BlockId
node BlockId
s1
                    in  -- pprTrace "Applying heuristic!" (ppr (node,s1,s2) $$ ppr (s1_prob', s2_prob')) $
                        CFG -> EdgeWeight -> BlockId -> BlockId -> CFG
setEdgeWeight CFG
cfg_s1 EdgeWeight
s2_prob' BlockId
node BlockId
s2
                  where
                    -- Old weights
                    s1_old :: EdgeWeight
s1_old = CFG -> BlockId -> BlockId -> EdgeWeight
getEdgeWeight CFG
cfg BlockId
node BlockId
s1
                    s2_old :: EdgeWeight
s2_old = CFG -> BlockId -> BlockId -> EdgeWeight
getEdgeWeight CFG
cfg BlockId
node BlockId
s2

            in
            -- pprTraceIt "RegularCfgResult" $
            (CFG -> Maybe Double -> CFG) -> CFG -> [Maybe Double] -> CFG
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' CFG -> Maybe Double -> CFG
applyHeuristic CFG
cfg'' [Maybe Double]
heuristics

        -- Branch on heap/stack check
        | Bool
otherwise = CFG
cfg

      where
        -- Chance that loops are taken.
        pred_LBH :: EdgeWeight
pred_LBH = EdgeWeight
0.875
        -- successors
        successors :: [(BlockId, EdgeInfo)]
successors = HasDebugCallStack => CFG -> BlockId -> [(BlockId, EdgeInfo)]
CFG -> BlockId -> [(BlockId, EdgeInfo)]
getSuccessorEdges CFG
cfg BlockId
node
        -- backedges
        ([(BlockId, EdgeInfo)]
m,[(BlockId, EdgeInfo)]
not_m) = ((BlockId, EdgeInfo) -> Bool)
-> [(BlockId, EdgeInfo)]
-> ([(BlockId, EdgeInfo)], [(BlockId, EdgeInfo)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(BlockId, EdgeInfo)
succ -> (BlockId, BlockId) -> Set (BlockId, BlockId) -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member (BlockId
node, (BlockId, EdgeInfo) -> BlockId
forall a b. (a, b) -> a
fst (BlockId, EdgeInfo)
succ) Set (BlockId, BlockId)
backedges) [(BlockId, EdgeInfo)]
successors

        -- Heuristics return nothing if they don't say anything about this branch
        -- or Just (prob_s1) where prob_s1 is the likelyhood for s1 to be the
        -- taken branch. s1 is the branch in the true case.

        -- Loop exit heuristic.
        -- We are unlikely to leave a loop unless it's to enter another one.
        pred_LEH :: Double
pred_LEH = Double
0.75
        -- If and only if no successor is a loopheader,
        -- then we will likely not exit the current loop body.
        lehPredicts :: (TargetNodeInfo,TargetNodeInfo) -> Maybe Prob
        lehPredicts :: ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
lehPredicts ((BlockId
s1,EdgeInfo
_s1_info),(BlockId
s2,EdgeInfo
_s2_info))
          | BlockId -> Set BlockId -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member BlockId
s1 Set BlockId
loopHeads Bool -> Bool -> Bool
|| BlockId -> Set BlockId -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member BlockId
s2 Set BlockId
loopHeads
          = Maybe Double
forall a. Maybe a
Nothing

          | Bool
otherwise
          = --pprTrace "lehPredict:" (ppr $ compare s1Level s2Level) $
            case Maybe Int -> Maybe Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Maybe Int
s1Level Maybe Int
s2Level of
                Ordering
EQ -> Maybe Double
forall a. Maybe a
Nothing
                Ordering
LT -> Double -> Maybe Double
forall a. a -> Maybe a
Just (Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
-Double
pred_LEH) --s1 exits to a shallower loop level (exits loop)
                Ordering
GT -> Double -> Maybe Double
forall a. a -> Maybe a
Just (Double
pred_LEH)   --s1 exits to a deeper loop level
            where
                s1Level :: Maybe Int
s1Level = KeyOf LabelMap -> LabelMap Int -> Maybe Int
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
s1 LabelMap Int
loopLevels
                s2Level :: Maybe Int
s2Level = KeyOf LabelMap -> LabelMap Int -> Maybe Int
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
s2 LabelMap Int
loopLevels

        -- Comparing to a constant is unlikely to be equal.
        ohPredicts :: ((BlockId, EdgeInfo), (BlockId, EdgeInfo)) -> Maybe Double
ohPredicts ((BlockId, EdgeInfo)
s1,(BlockId, EdgeInfo)
_s2)
            | CmmSource { trans_cmmNode :: TransitionSource -> CmmNode O C
trans_cmmNode = CmmNode O C
src1 } <- BlockId -> BlockId -> CFG -> TransitionSource
getTransitionSource BlockId
node ((BlockId, EdgeInfo) -> BlockId
forall a b. (a, b) -> a
fst (BlockId, EdgeInfo)
s1) CFG
cfg
            , CmmCondBranch CmmExpr
cond BlockId
ltrue BlockId
_lfalse Maybe Bool
likely <- CmmNode O C
src1
            , Maybe Bool
likely Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Bool
forall a. Maybe a
Nothing
            , CmmMachOp MachOp
mop [CmmExpr]
args <- CmmExpr
cond
            , MO_Eq {} <- MachOp
mop
            , Bool -> Bool
not ([CmmExpr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CmmExpr
x | x :: CmmExpr
x@CmmLit{} <- [CmmExpr]
args])
            = if (BlockId, EdgeInfo) -> BlockId
forall a b. (a, b) -> a
fst (BlockId, EdgeInfo)
s1 BlockId -> BlockId -> Bool
forall a. Eq a => a -> a -> Bool
== BlockId
ltrue then Double -> Maybe Double
forall a. a -> Maybe a
Just Double
0.3 else Double -> Maybe Double
forall a. a -> Maybe a
Just Double
0.7

            | Bool
otherwise
            = Maybe Double
forall a. Maybe a
Nothing

        -- TODO: These are all the other heuristics from the paper.
        -- Not all will apply, for now we just stub them out as Nothing.
        phPredicts :: b -> Maybe a
phPredicts = Maybe a -> b -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing
        ghPredicts :: b -> Maybe a
ghPredicts = Maybe a -> b -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing
        lhhPredicts :: b -> Maybe a
lhhPredicts = Maybe a -> b -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing
        chPredicts :: b -> Maybe a
chPredicts = Maybe a -> b -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing
        shPredicts :: b -> Maybe a
shPredicts = Maybe a -> b -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing
        rhPredicts :: b -> Maybe a
rhPredicts = Maybe a -> b -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing

-- We normalize all edge weights as probabilities between 0 and 1.
-- Ignoring rounding errors all outgoing edges sum up to 1.
cfgEdgeProbabilities :: CFG -> (BlockId -> Int) -> IM.IntMap (IM.IntMap Prob)
cfgEdgeProbabilities :: CFG -> (BlockId -> Int) -> IntMap (IntMap Double)
cfgEdgeProbabilities CFG
cfg BlockId -> Int
toVertex
    = (IntMap (IntMap Double)
 -> KeyOf LabelMap -> LabelMap EdgeInfo -> IntMap (IntMap Double))
-> IntMap (IntMap Double) -> CFG -> IntMap (IntMap Double)
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey IntMap (IntMap Double)
-> KeyOf LabelMap -> LabelMap EdgeInfo -> IntMap (IntMap Double)
IntMap (IntMap Double)
-> BlockId -> LabelMap EdgeInfo -> IntMap (IntMap Double)
foldEdges IntMap (IntMap Double)
forall a. IntMap a
IM.empty CFG
cfg
  where
    foldEdges :: IntMap (IntMap Double)
-> BlockId -> LabelMap EdgeInfo -> IntMap (IntMap Double)
foldEdges = (\IntMap (IntMap Double)
m BlockId
from LabelMap EdgeInfo
toMap -> Int
-> IntMap Double
-> IntMap (IntMap Double)
-> IntMap (IntMap Double)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (BlockId -> Int
toVertex BlockId
from) (LabelMap EdgeInfo -> IntMap Double
normalize LabelMap EdgeInfo
toMap) IntMap (IntMap Double)
m)

    normalize :: (LabelMap EdgeInfo) -> (IM.IntMap Prob)
    normalize :: LabelMap EdgeInfo -> IntMap Double
normalize LabelMap EdgeInfo
weightMap
        | Int
edgeCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = (IntMap Double -> KeyOf LabelMap -> EdgeInfo -> IntMap Double)
-> IntMap Double -> LabelMap EdgeInfo -> IntMap Double
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey (\IntMap Double
m KeyOf LabelMap
k EdgeInfo
_ -> Int -> Double -> IntMap Double -> IntMap Double
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (BlockId -> Int
toVertex KeyOf LabelMap
BlockId
k) Double
1.0 IntMap Double
m) IntMap Double
forall a. IntMap a
IM.empty LabelMap EdgeInfo
weightMap
        | Bool
otherwise = (IntMap Double -> KeyOf LabelMap -> EdgeInfo -> IntMap Double)
-> IntMap Double -> LabelMap EdgeInfo -> IntMap Double
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey (\IntMap Double
m KeyOf LabelMap
k EdgeInfo
_ -> Int -> Double -> IntMap Double -> IntMap Double
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (BlockId -> Int
toVertex KeyOf LabelMap
BlockId
k) (BlockId -> Double
normalWeight KeyOf LabelMap
BlockId
k) IntMap Double
m) IntMap Double
forall a. IntMap a
IM.empty LabelMap EdgeInfo
weightMap
      where
        edgeCount :: Int
edgeCount = LabelMap EdgeInfo -> Int
forall (map :: * -> *) a. IsMap map => map a -> Int
mapSize LabelMap EdgeInfo
weightMap
        -- Negative weights are generally allowed but are mapped to zero.
        -- We then check if there is at least one non-zero edge and if not
        -- assign uniform weights to all branches.
        minWeight :: Double
minWeight = Double
0 :: Prob
        weightMap' :: LabelMap Double
weightMap' = (EdgeInfo -> Double) -> LabelMap EdgeInfo -> LabelMap Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\EdgeInfo
w -> Double -> Double -> Double
forall a. Ord a => a -> a -> a
max (EdgeWeight -> Double
weightToDouble (EdgeWeight -> Double)
-> (EdgeInfo -> EdgeWeight) -> EdgeInfo -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EdgeInfo -> EdgeWeight
edgeWeight (EdgeInfo -> Double) -> EdgeInfo -> Double
forall a b. (a -> b) -> a -> b
$ EdgeInfo
w) Double
minWeight) LabelMap EdgeInfo
weightMap
        totalWeight :: Double
totalWeight = LabelMap Double -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum LabelMap Double
weightMap'

        normalWeight :: BlockId -> Prob
        normalWeight :: BlockId -> Double
normalWeight BlockId
bid
         | Double
totalWeight Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0
         = Double
1.0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
edgeCount
         | Just Double
w <- KeyOf LabelMap -> LabelMap Double -> Maybe Double
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
bid LabelMap Double
weightMap'
         = Double
wDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
totalWeight
         | Bool
otherwise = String -> Double
forall a. String -> a
panic String
"impossible"

-- This is the fixpoint algorithm from
--   "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus
-- The adaption to Haskell is my own.
calcFreqs :: IM.IntMap (IM.IntMap Prob) -> [(Int,Int)] -> [(Int, [Int])] -> [Int]
          -> (Array Int Double, IM.IntMap (IM.IntMap Prob))
calcFreqs :: IntMap (IntMap Double)
-> [(Int, Int)]
-> [(Int, [Int])]
-> [Int]
-> (Array Int Double, IntMap (IntMap Double))
calcFreqs IntMap (IntMap Double)
graph [(Int, Int)]
backEdges [(Int, [Int])]
loops [Int]
revPostOrder = (forall s. ST s (Array Int Double, IntMap (IntMap Double)))
-> (Array Int Double, IntMap (IntMap Double))
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Array Int Double, IntMap (IntMap Double)))
 -> (Array Int Double, IntMap (IntMap Double)))
-> (forall s. ST s (Array Int Double, IntMap (IntMap Double)))
-> (Array Int Double, IntMap (IntMap Double))
forall a b. (a -> b) -> a -> b
$ do
    STUArray s Int Bool
visitedNodes <- (Int, Int) -> Bool -> ST s (STUArray s Int Bool)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0,Int
nodeCountInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Bool
False :: ST s (STUArray s Int Bool)
    STUArray s Int Double
blockFreqs <- (Int, Int) -> Double -> ST s (STUArray s Int Double)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0,Int
nodeCountInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Double
0.0 :: ST s (STUArray s Int Double)
    STRef s (IntMap (IntMap Double))
edgeProbs <- IntMap (IntMap Double) -> ST s (STRef s (IntMap (IntMap Double)))
forall a s. a -> ST s (STRef s a)
newSTRef IntMap (IntMap Double)
graph
    STRef s (IntMap (IntMap Double))
edgeBackProbs <- IntMap (IntMap Double) -> ST s (STRef s (IntMap (IntMap Double)))
forall a s. a -> ST s (STRef s a)
newSTRef IntMap (IntMap Double)
graph

    -- let traceArray a = do
    --       vs <- forM [0..nodeCount-1] $ \i -> readArray a i >>= (\v -> return (i,v))
          -- trace ("array: " ++ show vs) $ return ()

    let  -- See #1600, we need to inline or unboxing makes perf worse.
        -- {-# INLINE getFreq #-}
        {-# INLINE visited #-}
        visited :: Int -> ST s Bool
visited Int
b = STUArray s Int Bool -> Int -> ST s Bool
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead STUArray s Int Bool
visitedNodes Int
b
        getFreq :: Int -> ST s Double
getFreq Int
b = STUArray s Int Double -> Int -> ST s Double
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead STUArray s Int Double
blockFreqs Int
b
        -- setFreq :: forall s. Int -> Double -> ST s ()
        setFreq :: Int -> Double -> ST s ()
setFreq Int
b Double
f = STUArray s Int Double -> Int -> Double -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite STUArray s Int Double
blockFreqs Int
b Double
f
        -- setVisited :: forall s. Node -> ST s ()
        setVisited :: Int -> ST s ()
setVisited Int
b = STUArray s Int Bool -> Int -> Bool -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite STUArray s Int Bool
visitedNodes Int
b Bool
True
        -- Frequency/probability that edge is taken.
        getProb' :: STRef s (IntMap (IntMap b)) -> Int -> Int -> ST s b
getProb' STRef s (IntMap (IntMap b))
arr Int
b1 Int
b2 = STRef s (IntMap (IntMap b)) -> ST s (IntMap (IntMap b))
forall s a. STRef s a -> ST s a
readSTRef STRef s (IntMap (IntMap b))
arr ST s (IntMap (IntMap b)) -> (IntMap (IntMap b) -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
            (\IntMap (IntMap b)
graph ->
                b -> ST s b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> ST s b)
-> (Maybe (IntMap b) -> b) -> Maybe (IntMap b) -> ST s b
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                        b -> Maybe b -> b
forall a. a -> Maybe a -> a
fromMaybe (String -> b
forall a. HasCallStack => String -> a
error String
"getFreq 1") (Maybe b -> b)
-> (Maybe (IntMap b) -> Maybe b) -> Maybe (IntMap b) -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                        Int -> IntMap b -> Maybe b
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
b2 (IntMap b -> Maybe b)
-> (Maybe (IntMap b) -> IntMap b) -> Maybe (IntMap b) -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                        IntMap b -> Maybe (IntMap b) -> IntMap b
forall a. a -> Maybe a -> a
fromMaybe (String -> IntMap b
forall a. HasCallStack => String -> a
error String
"getFreq 2") (Maybe (IntMap b) -> ST s b) -> Maybe (IntMap b) -> ST s b
forall a b. (a -> b) -> a -> b
$
                        (Int -> IntMap (IntMap b) -> Maybe (IntMap b)
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
b1 IntMap (IntMap b)
graph)
            )
        setProb' :: STRef s (IntMap (IntMap a)) -> Int -> Int -> a -> ST s ()
setProb' STRef s (IntMap (IntMap a))
arr Int
b1 Int
b2 a
prob = do
          IntMap (IntMap a)
g <- STRef s (IntMap (IntMap a)) -> ST s (IntMap (IntMap a))
forall s a. STRef s a -> ST s a
readSTRef STRef s (IntMap (IntMap a))
arr
          let !m :: IntMap a
m = IntMap a -> Maybe (IntMap a) -> IntMap a
forall a. a -> Maybe a -> a
fromMaybe (String -> IntMap a
forall a. HasCallStack => String -> a
error String
"Foo") (Maybe (IntMap a) -> IntMap a) -> Maybe (IntMap a) -> IntMap a
forall a b. (a -> b) -> a -> b
$ Int -> IntMap (IntMap a) -> Maybe (IntMap a)
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
b1 IntMap (IntMap a)
g
              !m' :: IntMap a
m' = Int -> a -> IntMap a -> IntMap a
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
b2 a
prob IntMap a
m
          STRef s (IntMap (IntMap a)) -> IntMap (IntMap a) -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s (IntMap (IntMap a))
arr (IntMap (IntMap a) -> ST s ()) -> IntMap (IntMap a) -> ST s ()
forall a b. (a -> b) -> a -> b
$! (Int -> IntMap a -> IntMap (IntMap a) -> IntMap (IntMap a)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
b1 IntMap a
m' IntMap (IntMap a)
g)

        getEdgeFreq :: Int -> Int -> ST s Double
getEdgeFreq Int
b1 Int
b2 = STRef s (IntMap (IntMap Double)) -> Int -> Int -> ST s Double
forall s b. STRef s (IntMap (IntMap b)) -> Int -> Int -> ST s b
getProb' STRef s (IntMap (IntMap Double))
edgeProbs Int
b1 Int
b2
        setEdgeFreq :: Int -> Int -> Double -> ST s ()
setEdgeFreq Int
b1 Int
b2 = STRef s (IntMap (IntMap Double)) -> Int -> Int -> Double -> ST s ()
forall s a.
STRef s (IntMap (IntMap a)) -> Int -> Int -> a -> ST s ()
setProb' STRef s (IntMap (IntMap Double))
edgeProbs Int
b1 Int
b2
        getProb :: Int -> Int -> Double
getProb Int
b1 Int
b2 = Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe (String -> Double
forall a. HasCallStack => String -> a
error String
"getProb") (Maybe Double -> Double) -> Maybe Double -> Double
forall a b. (a -> b) -> a -> b
$ do
            IntMap Double
m' <- Int -> IntMap (IntMap Double) -> Maybe (IntMap Double)
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
b1 IntMap (IntMap Double)
graph
            Int -> IntMap Double -> Maybe Double
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
b2 IntMap Double
m'

        getBackProb :: Int -> Int -> ST s Double
getBackProb Int
b1 Int
b2 = STRef s (IntMap (IntMap Double)) -> Int -> Int -> ST s Double
forall s b. STRef s (IntMap (IntMap b)) -> Int -> Int -> ST s b
getProb' STRef s (IntMap (IntMap Double))
edgeBackProbs Int
b1 Int
b2
        setBackProb :: Int -> Int -> Double -> ST s ()
setBackProb Int
b1 Int
b2 = STRef s (IntMap (IntMap Double)) -> Int -> Int -> Double -> ST s ()
forall s a.
STRef s (IntMap (IntMap a)) -> Int -> Int -> a -> ST s ()
setProb' STRef s (IntMap (IntMap Double))
edgeBackProbs Int
b1 Int
b2


    let -- calcOutFreqs :: Node -> ST s ()
        calcOutFreqs :: Int -> Int -> ST s [()]
calcOutFreqs Int
bhead Int
block = do
          !Double
f <- Int -> ST s Double
getFreq Int
block
          [Int] -> (Int -> ST s ()) -> ST s [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Int -> [Int]
successors Int
block) ((Int -> ST s ()) -> ST s [()]) -> (Int -> ST s ()) -> ST s [()]
forall a b. (a -> b) -> a -> b
$ \Int
bi -> do
            let !prob :: Double
prob = Int -> Int -> Double
getProb Int
block Int
bi
            let !succFreq :: Double
succFreq = Double
f Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
prob
            Int -> Int -> Double -> ST s ()
setEdgeFreq Int
block Int
bi Double
succFreq
            -- traceM $ "SetOut: " ++ show (block, bi, f, prob, succFreq)
            Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bi Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
bhead) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Double -> ST s ()
setBackProb Int
block Int
bi Double
succFreq


    let propFreq :: Int -> Int -> ST s [()]
propFreq Int
block Int
head = do
            -- traceM ("prop:" ++ show (block,head))
            -- traceShowM block

            !Bool
v <- Int -> ST s Bool
visited Int
block
            if Bool
v then
                () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return () --Dont look at nodes twice
            else if Int
block Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
head then
                Int -> Double -> ST s ()
setFreq Int
block Double
1.0 -- Loop header frequency is always 1
            else do
                let preds :: [Int]
preds = IntSet -> [Int]
IS.elems (IntSet -> [Int]) -> IntSet -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> IntSet
predecessors Int
block
                Bool
irreducible <- (([Bool] -> Bool) -> ST s [Bool] -> ST s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or) (ST s [Bool] -> ST s Bool) -> ST s [Bool] -> ST s Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> (Int -> ST s Bool) -> ST s [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
preds ((Int -> ST s Bool) -> ST s [Bool])
-> (Int -> ST s Bool) -> ST s [Bool]
forall a b. (a -> b) -> a -> b
$ \Int
bp -> do
                    !Bool
bp_visited <- Int -> ST s Bool
visited Int
bp
                    let bp_backedge :: Bool
bp_backedge = Int -> Int -> Bool
isBackEdge Int
bp Int
block
                    Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool
not Bool
bp_visited Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bp_backedge)

                if Bool
irreducible
                then () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Rare we don't care
                else do
                    Int -> Double -> ST s ()
setFreq Int
block Double
0
                    !Double
cycleProb <- [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Double] -> Double) -> ST s [Double] -> ST s Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Int] -> (Int -> ST s Double) -> ST s [Double]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
preds ((Int -> ST s Double) -> ST s [Double])
-> (Int -> ST s Double) -> ST s [Double]
forall a b. (a -> b) -> a -> b
$ \Int
pred -> do
                        if Int -> Int -> Bool
isBackEdge Int
pred Int
block
                            then
                                Int -> Int -> ST s Double
getBackProb Int
pred Int
block
                            else do
                                !Double
f <- Int -> ST s Double
getFreq Int
block
                                !Double
prob <- Int -> Int -> ST s Double
getEdgeFreq Int
pred Int
block
                                Int -> Double -> ST s ()
setFreq Int
block (Double -> ST s ()) -> Double -> ST s ()
forall a b. (a -> b) -> a -> b
$! Double
f Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
prob
                                Double -> ST s Double
forall (m :: * -> *) a. Monad m => a -> m a
return Double
0)
                    -- traceM $ "cycleProb:" ++ show cycleProb
                    let limit :: Double
limit = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
512 -- Paper uses 1 - epsilon, but this works.
                                          -- determines how large likelyhoods in loops can grow.
                    !Double
cycleProb <- Double -> ST s Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> ST s Double) -> Double -> ST s Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
cycleProb Double
limit -- <- return $ if cycleProb > limit then limit else cycleProb
                    -- traceM $ "cycleProb:" ++ show cycleProb

                    !Double
f <- Int -> ST s Double
getFreq Int
block
                    Int -> Double -> ST s ()
setFreq Int
block (Double
f Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
1.0 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
cycleProb))

            Int -> ST s ()
setVisited Int
block
            Int -> Int -> ST s [()]
calcOutFreqs Int
head Int
block

    -- Loops, by nesting, inner to outer
    [(Int, [Int])] -> ((Int, [Int]) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Int, [Int])]
loops (((Int, [Int]) -> ST s ()) -> ST s ())
-> ((Int, [Int]) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
head, [Int]
body) -> do
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nodeCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] (\Int
i -> STUArray s Int Bool -> Int -> Bool -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite STUArray s Int Bool
visitedNodes Int
i Bool
True) -- Mark all nodes as visited.
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int]
body (\Int
i -> STUArray s Int Bool -> Int -> Bool -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite STUArray s Int Bool
visitedNodes Int
i Bool
False) -- Mark all blocks reachable from head as not visited
        [Int] -> (Int -> ST s [()]) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int]
body ((Int -> ST s [()]) -> ST s ()) -> (Int -> ST s [()]) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
block -> Int -> Int -> ST s [()]
propFreq Int
block Int
head

    -- After dealing with all loops, deal with non-looping parts of the CFG
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nodeCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] (\Int
i -> STUArray s Int Bool -> Int -> Bool -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite STUArray s Int Bool
visitedNodes Int
i Bool
False) -- Everything in revPostOrder is reachable
    [Int] -> (Int -> ST s [()]) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int]
revPostOrder ((Int -> ST s [()]) -> ST s ()) -> (Int -> ST s [()]) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
block -> Int -> Int -> ST s [()]
propFreq Int
block ([Int] -> Int
forall a. [a] -> a
head [Int]
revPostOrder)

    -- trace ("Final freqs:") $ return ()
    -- let freqString = pprFreqs freqs
    -- trace (unlines freqString) $ return ()
    -- trace (pprFre) $ return ()
    IntMap (IntMap Double)
graph' <- STRef s (IntMap (IntMap Double)) -> ST s (IntMap (IntMap Double))
forall s a. STRef s a -> ST s a
readSTRef STRef s (IntMap (IntMap Double))
edgeProbs
    Array Int Double
freqs' <- STUArray s Int Double -> ST s (Array Int Double)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
unsafeFreeze  STUArray s Int Double
blockFreqs

    (Array Int Double, IntMap (IntMap Double))
-> ST s (Array Int Double, IntMap (IntMap Double))
forall (m :: * -> *) a. Monad m => a -> m a
return (Array Int Double
freqs', IntMap (IntMap Double)
graph')
  where
    -- How can these lookups fail? Consider the CFG [A -> B]
    predecessors :: Int -> IS.IntSet
    predecessors :: Int -> IntSet
predecessors Int
b = IntSet -> Maybe IntSet -> IntSet
forall a. a -> Maybe a -> a
fromMaybe IntSet
IS.empty (Maybe IntSet -> IntSet) -> Maybe IntSet -> IntSet
forall a b. (a -> b) -> a -> b
$ Int -> IntMap IntSet -> Maybe IntSet
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
b IntMap IntSet
revGraph
    successors :: Int -> [Int]
    successors :: Int -> [Int]
successors Int
b = [Int] -> Maybe [Int] -> [Int]
forall a. a -> Maybe a -> a
fromMaybe (String -> Int -> IntMap (IntMap Double) -> [Int]
forall a a.
Outputable a =>
String -> a -> IntMap (IntMap Double) -> a
lookupError String
"succ" Int
b IntMap (IntMap Double)
graph)(Maybe [Int] -> [Int]) -> Maybe [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ IntMap Double -> [Int]
forall a. IntMap a -> [Int]
IM.keys (IntMap Double -> [Int]) -> Maybe (IntMap Double) -> Maybe [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IntMap (IntMap Double) -> Maybe (IntMap Double)
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
b IntMap (IntMap Double)
graph
    lookupError :: String -> a -> IntMap (IntMap Double) -> a
lookupError String
s a
b IntMap (IntMap Double)
g = String -> SDoc -> a
forall a. HasCallStack => String -> SDoc -> a
pprPanic (String
"Lookup error " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) (SDoc -> a) -> SDoc -> a
forall a b. (a -> b) -> a -> b
$
                            ( String -> SDoc
text String
"node" SDoc -> SDoc -> SDoc
<+> a -> SDoc
forall a. Outputable a => a -> SDoc
ppr a
b SDoc -> SDoc -> SDoc
$$
                                String -> SDoc
text String
"graph" SDoc -> SDoc -> SDoc
<+>
                                [SDoc] -> SDoc
vcat (((Int, IntMap Double) -> SDoc) -> [(Int, IntMap Double)] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map (\(Int
k,IntMap Double
m) -> (Int, IntMap Double) -> SDoc
forall a. Outputable a => a -> SDoc
ppr (Int
k,IntMap Double
m :: IM.IntMap Double)) ([(Int, IntMap Double)] -> [SDoc])
-> [(Int, IntMap Double)] -> [SDoc]
forall a b. (a -> b) -> a -> b
$ IntMap (IntMap Double) -> [(Int, IntMap Double)]
forall a. IntMap a -> [(Int, a)]
IM.toList IntMap (IntMap Double)
g)
                            )

    nodeCount :: Int
nodeCount = (Int -> IntMap Double -> Int)
-> Int -> IntMap (IntMap Double) -> Int
forall a b. (a -> b -> a) -> a -> IntMap b -> a
IM.foldl' (\Int
count IntMap Double
toMap -> (Int -> Int -> Double -> Int) -> Int -> IntMap Double -> Int
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IM.foldlWithKey' Int -> Int -> Double -> Int
countTargets Int
count IntMap Double
toMap) (IntMap (IntMap Double) -> Int
forall a. IntMap a -> Int
IM.size IntMap (IntMap Double)
graph) IntMap (IntMap Double)
graph
      where
        countTargets :: Int -> Int -> Double -> Int
countTargets = (\Int
count Int
k Double
_ -> Int -> Int
countNode Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
count )
        countNode :: Int -> Int
countNode Int
n = if Int -> IntMap (IntMap Double) -> Bool
forall a. Int -> IntMap a -> Bool
IM.member Int
n IntMap (IntMap Double)
graph then Int
0 else Int
1

    isBackEdge :: Int -> Int -> Bool
isBackEdge Int
from Int
to = (Int, Int) -> Set (Int, Int) -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member (Int
from,Int
to) Set (Int, Int)
backEdgeSet
    backEdgeSet :: Set (Int, Int)
backEdgeSet = [(Int, Int)] -> Set (Int, Int)
forall a. Ord a => [a] -> Set a
S.fromList [(Int, Int)]
backEdges

    revGraph :: IntMap IntSet
    revGraph :: IntMap IntSet
revGraph = (IntMap IntSet -> Int -> IntMap Double -> IntMap IntSet)
-> IntMap IntSet -> IntMap (IntMap Double) -> IntMap IntSet
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IM.foldlWithKey' (\IntMap IntSet
m Int
from IntMap Double
toMap -> IntMap IntSet -> Int -> IntMap Double -> IntMap IntSet
forall b. IntMap IntSet -> Int -> IntMap b -> IntMap IntSet
addEdges IntMap IntSet
m Int
from IntMap Double
toMap) IntMap IntSet
forall a. IntMap a
IM.empty IntMap (IntMap Double)
graph
        where
            addEdges :: IntMap IntSet -> Int -> IntMap b -> IntMap IntSet
addEdges IntMap IntSet
m0 Int
from IntMap b
toMap = (IntMap IntSet -> Int -> b -> IntMap IntSet)
-> IntMap IntSet -> IntMap b -> IntMap IntSet
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IM.foldlWithKey' (\IntMap IntSet
m Int
k b
_ -> IntMap IntSet -> Int -> Int -> IntMap IntSet
addEdge IntMap IntSet
m Int
from Int
k) IntMap IntSet
m0 IntMap b
toMap
            addEdge :: IntMap IntSet -> Int -> Int -> IntMap IntSet
addEdge IntMap IntSet
m0 Int
from Int
to = (IntSet -> IntSet -> IntSet)
-> Int -> IntSet -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IM.insertWith IntSet -> IntSet -> IntSet
IS.union Int
to (Int -> IntSet
IS.singleton Int
from) IntMap IntSet
m0