{-# LANGUAGE TypeFamilies, ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
module CFG
( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..)
, TransitionSource(..)
, addWeightEdge, addEdge, delEdge
, addNodesBetween, shortcutWeightMap
, reverseEdges, filterEdges
, addImmediateSuccessor
, mkWeightInfo, adjustEdgeWeight
, infoEdgeList, edgeList
, getSuccessorEdges, getSuccessors
, getSuccEdgesSorted, weightedEdgeList
, getEdgeInfo
, getCfgNodes, hasNode
, loopMembers
, getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg
, optimizeCFG )
where
#include "HsVersions.h"
import GhcPrelude
import BlockId
import Cmm ( RawCmmDecl, GenCmmDecl( .. ), CmmBlock, succ, g_entry
, CmmGraph )
import CmmNode
import CmmUtils
import CmmSwitch
import Hoopl.Collections
import Hoopl.Label
import Hoopl.Block
import qualified Hoopl.Graph as G
import Util
import Digraph
import Outputable
import PprCmm ()
import qualified DynFlags as D
import Data.List
type Edge = (BlockId, BlockId)
type Edges = [Edge]
newtype EdgeWeight
= EdgeWeight Int
deriving (Eq,Ord,Enum,Num,Real,Integral)
instance Outputable EdgeWeight where
ppr (EdgeWeight w) = ppr w
type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo)
type CFG = EdgeInfoMap EdgeInfo
data CfgEdge
= CfgEdge
{ edgeFrom :: !BlockId
, edgeTo :: !BlockId
, edgeInfo :: !EdgeInfo
}
instance Eq CfgEdge where
(==) (CfgEdge from1 to1 _) (CfgEdge from2 to2 _)
= from1 == from2 && to1 == to2
instance Ord CfgEdge where
compare (CfgEdge from1 to1 (EdgeInfo {edgeWeight = weight1}))
(CfgEdge from2 to2 (EdgeInfo {edgeWeight = weight2}))
| weight1 < weight2 || weight1 == weight2 && from1 < from2 ||
weight1 == weight2 && from1 == from2 && to1 < to2
= LT
| from1 == from2 && to1 == to2 && weight1 == weight2
= EQ
| otherwise
= GT
instance Outputable CfgEdge where
ppr (CfgEdge from1 to1 edgeInfo)
= parens (ppr from1 <+> text "-(" <> ppr edgeInfo <> text ")->" <+> ppr to1)
data TransitionSource
= CmmSource (CmmNode O C)
| AsmCodeGen
deriving (Eq)
data EdgeInfo
= EdgeInfo
{ transitionSource :: !TransitionSource
, edgeWeight :: !EdgeWeight
} deriving (Eq)
instance Outputable EdgeInfo where
ppr edgeInfo = text "weight:" <+> ppr (edgeWeight edgeInfo)
{-# INLINEABLE mkWeightInfo #-}
mkWeightInfo :: Integral n => n -> EdgeInfo
mkWeightInfo = EdgeInfo AsmCodeGen . fromIntegral
adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight)
-> BlockId -> BlockId -> CFG
adjustEdgeWeight cfg f from to
| Just info <- getEdgeInfo from to cfg
, weight <- edgeWeight info
= addEdge from to (info { edgeWeight = f weight}) cfg
| otherwise = cfg
getCfgNodes :: CFG -> LabelSet
getCfgNodes m = mapFoldMapWithKey (\k v -> setFromList (k:mapKeys v)) m
hasNode :: CFG -> BlockId -> Bool
hasNode m node = mapMember node m || any (mapMember node) m
sanityCheckCfg :: CFG -> LabelSet -> SDoc -> Bool
sanityCheckCfg m blockSet msg
| blockSet == cfgNodes
= True
| otherwise =
pprPanic "Block list and cfg nodes don't match" (
text "difference:" <+> ppr diff $$
text "blocks:" <+> ppr blockSet $$
text "cfg:" <+> ppr m $$
msg )
False
where
cfgNodes = getCfgNodes m :: LabelSet
diff = (setUnion cfgNodes blockSet) `setDifference` (setIntersection cfgNodes blockSet) :: LabelSet
filterEdges :: (BlockId -> BlockId -> EdgeInfo -> Bool) -> CFG -> CFG
filterEdges f cfg =
mapMapWithKey filterSources cfg
where
filterSources from m =
mapFilterWithKey (\to w -> f from to w) m
shortcutWeightMap :: CFG -> LabelMap (Maybe BlockId) -> CFG
shortcutWeightMap cfg cuts =
foldl' applyMapping cfg $ mapToList cuts
where
applyMapping :: CFG -> (BlockId,Maybe BlockId) -> CFG
applyMapping m (from, Nothing) =
mapDelete from .
fmap (mapDelete from) $ m
applyMapping m (from, Just to) =
let updatedMap :: CFG
updatedMap
= fmap (shortcutEdge (from,to)) $
(mapDelete from m :: CFG )
in case mapLookup to cuts of
Nothing -> updatedMap
Just dest -> applyMapping updatedMap (to, dest)
shortcutEdge :: (BlockId, BlockId) -> LabelMap EdgeInfo -> LabelMap EdgeInfo
shortcutEdge (from, to) m =
case mapLookup from m of
Just info -> mapInsert to info $ mapDelete from m
Nothing -> m
addImmediateSuccessor :: BlockId -> BlockId -> CFG -> CFG
addImmediateSuccessor node follower cfg
= updateEdges . addWeightEdge node follower uncondWeight $ cfg
where
uncondWeight = fromIntegral . D.uncondWeight .
D.cfgWeightInfo $ D.unsafeGlobalDynFlags
targets = getSuccessorEdges cfg node
successors = map fst targets :: [BlockId]
updateEdges = addNewSuccs . remOldSuccs
remOldSuccs m = foldl' (flip (delEdge node)) m successors
addNewSuccs m =
foldl' (\m' (t,info) -> addEdge follower t info m') m targets
addEdge :: BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge from to info cfg =
mapAlter addDest from cfg
where
addDest Nothing = Just $ mapSingleton to info
addDest (Just wm) = Just $ mapInsert to info wm
addWeightEdge :: BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge from to weight cfg =
addEdge from to (mkWeightInfo weight) cfg
delEdge :: BlockId -> BlockId -> CFG -> CFG
delEdge from to m =
mapAlter remDest from m
where
remDest Nothing = Nothing
remDest (Just wm) = Just $ mapDelete to wm
getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccEdgesSorted m bid =
let destMap = mapFindWithDefault mapEmpty bid m
cfgEdges = mapToList destMap
sortedEdges = sortWith (negate . edgeWeight . snd) cfgEdges
in
sortedEdges
getSuccessorEdges :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccessorEdges m bid = maybe [] mapToList $ mapLookup bid m
getEdgeInfo :: BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo from to m
| Just wm <- mapLookup from m
, Just info <- mapLookup to wm
= Just $! info
| otherwise
= Nothing
reverseEdges :: CFG -> CFG
reverseEdges cfg = foldr add mapEmpty flatElems
where
elems = mapToList $ fmap mapToList cfg :: [(BlockId,[(BlockId,EdgeInfo)])]
flatElems =
concatMap (\(from,ws) -> map (\(to,info) -> (to,from,info)) ws ) elems
add (to,from,info) m = addEdge to from info m
infoEdgeList :: CFG -> [CfgEdge]
infoEdgeList m =
mapFoldMapWithKey
(\from toMap ->
map (\(to,info) -> CfgEdge from to info) (mapToList toMap))
m
weightedEdgeList :: CFG -> [(BlockId,BlockId,EdgeWeight)]
weightedEdgeList m =
mapFoldMapWithKey
(\from toMap ->
map (\(to,info) ->
(from,to, edgeWeight info)) (mapToList toMap))
m
edgeList :: CFG -> [Edge]
edgeList m =
mapFoldMapWithKey (\from toMap -> fmap (from,) (mapKeys toMap)) m
getSuccessors :: CFG -> BlockId -> [BlockId]
getSuccessors m bid
| Just wm <- mapLookup bid m
= mapKeys wm
| otherwise = []
pprEdgeWeights :: CFG -> SDoc
pprEdgeWeights m =
let edges = sort $ weightedEdgeList m
printEdge (from, to, weight)
= text "\t" <> ppr from <+> text "->" <+> ppr to <>
text "[label=\"" <> ppr weight <> text "\",weight=\"" <>
ppr weight <> text "\"];\n"
printNode node
= text "\t" <> ppr node <> text ";\n"
getEdgeNodes (from, to, _weight) = [from,to]
edgeNodes = setFromList $ concatMap getEdgeNodes edges :: LabelSet
nodes = filter (\n -> (not . setMember n) edgeNodes) . mapKeys $ mapFilter null m
in
text "digraph {\n" <>
(foldl' (<>) empty (map printEdge edges)) <>
(foldl' (<>) empty (map printNode nodes)) <>
text "}\n"
{-# INLINE updateEdgeWeight #-}
updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG
updateEdgeWeight f (from, to) cfg
| Just oldInfo <- getEdgeInfo from to cfg
= let oldWeight = edgeWeight oldInfo
newWeight = f oldWeight
in addEdge from to (oldInfo {edgeWeight = newWeight}) cfg
| otherwise
= panic "Trying to update invalid edge"
mapWeights :: (BlockId -> BlockId -> EdgeWeight -> EdgeWeight) -> CFG -> CFG
mapWeights f cfg =
foldl' (\cfg (CfgEdge from to info) ->
let oldWeight = edgeWeight info
newWeight = f from to oldWeight
in addEdge from to (info {edgeWeight = newWeight}) cfg)
cfg (infoEdgeList cfg)
addNodesBetween :: CFG -> [(BlockId,BlockId,BlockId)] -> CFG
addNodesBetween m updates =
foldl' updateWeight m .
weightUpdates $ updates
where
weight = fromIntegral . D.uncondWeight .
D.cfgWeightInfo $ D.unsafeGlobalDynFlags
weightUpdates = map getWeight
getWeight :: (BlockId,BlockId,BlockId) -> (BlockId,BlockId,BlockId,EdgeInfo)
getWeight (from,between,old)
| Just edgeInfo <- getEdgeInfo from old m
= (from,between,old,edgeInfo)
| otherwise
= pprPanic "Can't find weight for edge that should have one" (
text "triple" <+> ppr (from,between,old) $$
text "updates" <+> ppr updates )
updateWeight :: CFG -> (BlockId,BlockId,BlockId,EdgeInfo) -> CFG
updateWeight m (from,between,old,edgeInfo)
= addEdge from between edgeInfo .
addWeightEdge between old weight .
delEdge from old $ m
getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG
getCfgProc _ (CmmData {}) = mapEmpty
getCfgProc weights (CmmProc _info _lab _live graph)
| null (toBlockList graph) = mapEmpty
| otherwise = getCfg weights graph
getCfg :: D.CfgWeights -> CmmGraph -> CFG
getCfg weights graph =
foldl' insertEdge edgelessCfg $ concatMap getBlockEdges blocks
where
D.CFGWeights
{ D.uncondWeight = uncondWeight
, D.condBranchWeight = condBranchWeight
, D.switchWeight = switchWeight
, D.callWeight = callWeight
, D.likelyCondWeight = likelyCondWeight
, D.unlikelyCondWeight = unlikelyCondWeight
} = weights
edgelessCfg = mapFromList $ zip (map G.entryLabel blocks) (repeat mapEmpty)
insertEdge :: CFG -> ((BlockId,BlockId),EdgeInfo) -> CFG
insertEdge m ((from,to),weight) =
mapAlter f from m
where
f :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
f Nothing = Just $ mapSingleton to weight
f (Just destMap) = Just $ mapInsert to weight destMap
getBlockEdges :: CmmBlock -> [((BlockId,BlockId),EdgeInfo)]
getBlockEdges block =
case branch of
CmmBranch dest -> [mkEdge dest uncondWeight]
CmmCondBranch _c t f l
| l == Nothing ->
[mkEdge f condBranchWeight, mkEdge t condBranchWeight]
| l == Just True ->
[mkEdge f unlikelyCondWeight, mkEdge t likelyCondWeight]
| l == Just False ->
[mkEdge f likelyCondWeight, mkEdge t unlikelyCondWeight]
(CmmSwitch _e ids) ->
let switchTargets = switchTargetsToList ids
adjustedWeight =
if (length switchTargets > 10) then -1 else switchWeight
in map (\x -> mkEdge x adjustedWeight) switchTargets
(CmmCall { cml_cont = Just cont}) -> [mkEdge cont callWeight]
(CmmForeignCall {Cmm.succ = cont}) -> [mkEdge cont callWeight]
(CmmCall { cml_cont = Nothing }) -> []
other ->
panic "Foo" $
ASSERT2(False, ppr "Unkown successor cause:" <>
(ppr branch <+> text "=>" <> ppr (G.successors other)))
map (\x -> ((bid,x),mkEdgeInfo 0)) $ G.successors other
where
bid = G.entryLabel block
mkEdgeInfo = EdgeInfo (CmmSource branch) . fromIntegral
mkEdge target weight = ((bid,target), mkEdgeInfo weight)
branch = lastNode block :: CmmNode O C
blocks = revPostorder graph :: [CmmBlock]
findBackEdges :: BlockId -> CFG -> Edges
findBackEdges root cfg =
map fst .
filter (\x -> snd x == Backward) $ typedEdges
where
edges = edgeList cfg :: [(BlockId,BlockId)]
getSuccs = getSuccessors cfg :: BlockId -> [BlockId]
typedEdges =
classifyEdges root getSuccs edges :: [((BlockId,BlockId),EdgeType)]
optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG
optimizeCFG _ (CmmData {}) cfg = cfg
optimizeCFG weights (CmmProc info _lab _live graph) cfg =
favourFewerPreds .
penalizeInfoTables info .
increaseBackEdgeWeight (g_entry graph) $ cfg
where
increaseBackEdgeWeight :: BlockId -> CFG -> CFG
increaseBackEdgeWeight root cfg =
let backedges = findBackEdges root cfg
update weight
| weight <= 0 = 0
| otherwise
= weight + fromIntegral (D.backEdgeBonus weights)
in foldl' (\cfg edge -> updateEdgeWeight update edge cfg)
cfg backedges
penalizeInfoTables :: LabelMap a -> CFG -> CFG
penalizeInfoTables info cfg =
mapWeights fupdate cfg
where
fupdate :: BlockId -> BlockId -> EdgeWeight -> EdgeWeight
fupdate _ to weight
| mapMember to info
= weight - (fromIntegral $ D.infoTablePenalty weights)
| otherwise = weight
favourFewerPreds :: CFG -> CFG
favourFewerPreds cfg =
let
revCfg =
reverseEdges $ filterEdges
(\_from -> fallthroughTarget) cfg
predCount n = length $ getSuccessorEdges revCfg n
nodes = getCfgNodes cfg
modifiers :: Int -> Int -> (EdgeWeight, EdgeWeight)
modifiers preds1 preds2
| preds1 < preds2 = ( 1,-1)
| preds1 == preds2 = ( 0, 0)
| otherwise = (-1, 1)
update cfg node
| [(s1,e1),(s2,e2)] <- getSuccessorEdges cfg node
, w1 <- edgeWeight e1
, w2 <- edgeWeight e2
, w1 == w2
, (mod1,mod2) <- modifiers (predCount s1) (predCount s2)
= (\cfg' ->
(adjustEdgeWeight cfg' (+mod2) node s2))
(adjustEdgeWeight cfg (+mod1) node s1)
| otherwise
= cfg
in setFoldl update cfg nodes
where
fallthroughTarget :: BlockId -> EdgeInfo -> Bool
fallthroughTarget to (EdgeInfo source _weight)
| mapMember to info = False
| AsmCodeGen <- source = True
| CmmSource (CmmBranch {}) <- source = True
| CmmSource (CmmCondBranch {}) <- source = True
| otherwise = False
loopMembers :: CFG -> LabelMap Bool
loopMembers cfg =
foldl' (flip setLevel) mapEmpty sccs
where
mkNode :: BlockId -> Node BlockId BlockId
mkNode bid = DigraphNode bid bid (getSuccessors cfg bid)
nodes = map mkNode (setElems $ getCfgNodes cfg)
sccs = stronglyConnCompFromEdgedVerticesOrd nodes
setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
setLevel (AcyclicSCC bid) m = mapInsert bid False m
setLevel (CyclicSCC bids) m = foldl' (\m k -> mapInsert k True m) m bids