module CFG
    ( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..)
    , TransitionSource(..)
    
    , addWeightEdge, addEdge
    , delEdge, delNode
    , addNodesBetween, shortcutWeightMap
    , reverseEdges, filterEdges
    , addImmediateSuccessor
    , mkWeightInfo, adjustEdgeWeight, setEdgeWeight
    
    , infoEdgeList, edgeList
    , getSuccessorEdges, getSuccessors
    , getSuccEdgesSorted
    , getEdgeInfo
    , getCfgNodes, hasNode
    
    , loopMembers, loopLevels, loopInfo
    --Construction/Misc
    , getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg
    
    , optimizeCFG
    , mkGlobalWeights
     )
where
#include "HsVersions.h"
import GhcPrelude
import BlockId
import Cmm
import CmmUtils
import CmmSwitch
import Hoopl.Collections
import Hoopl.Label
import Hoopl.Block
import qualified Hoopl.Graph as G
import Util
import Digraph
import Maybes
import Unique
import qualified Dominators as Dom
import Data.IntMap.Strict (IntMap)
import Data.IntSet (IntSet)
import qualified Data.IntMap.Strict as IM
import qualified Data.Map as M
import qualified Data.IntSet as IS
import qualified Data.Set as S
import Data.Tree
import Data.Bifunctor
import Outputable
import PprCmm () 
import qualified DynFlags as D
import Data.List (sort, nub, partition)
import Data.STRef.Strict
import Control.Monad.ST
import Data.Array.MArray
import Data.Array.ST
import Data.Array.IArray
import Data.Array.Unsafe (unsafeFreeze)
import Data.Array.Base (unsafeRead, unsafeWrite)
import Control.Monad
type Prob = Double
type Edge = (BlockId, BlockId)
type Edges = [Edge]
newtype EdgeWeight
  = EdgeWeight { weightToDouble :: Double }
  deriving (Eq,Ord,Enum,Num,Real,Fractional)
instance Outputable EdgeWeight where
  ppr (EdgeWeight w) = doublePrec 5 w
type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo)
type CFG = EdgeInfoMap EdgeInfo
data CfgEdge
  = CfgEdge
  { edgeFrom :: !BlockId
  , edgeTo :: !BlockId
  , edgeInfo :: !EdgeInfo
  }
instance Eq CfgEdge where
  (==) (CfgEdge from1 to1 _) (CfgEdge from2 to2 _)
    = from1 == from2 && to1 == to2
instance Ord CfgEdge where
  compare (CfgEdge from1 to1 (EdgeInfo {edgeWeight = weight1}))
          (CfgEdge from2 to2 (EdgeInfo {edgeWeight = weight2}))
    | weight1 < weight2 || weight1 == weight2 && from1 < from2 ||
      weight1 == weight2 && from1 == from2 && to1 < to2
    = LT
    | from1 == from2 && to1 == to2 && weight1 == weight2
    = EQ
    | otherwise
    = GT
instance Outputable CfgEdge where
  ppr (CfgEdge from1 to1 edgeInfo)
    = parens (ppr from1 <+> text "-(" <> ppr edgeInfo <> text ")->" <+> ppr to1)
data TransitionSource
  = CmmSource { trans_cmmNode :: (CmmNode O C)
              , trans_info :: BranchInfo }
  | AsmCodeGen
  deriving (Eq)
data BranchInfo = NoInfo         
                | HeapStackCheck 
    deriving Eq
instance Outputable BranchInfo where
    ppr NoInfo = text "regular"
    ppr HeapStackCheck = text "heap/stack"
isHeapOrStackCheck :: TransitionSource -> Bool
isHeapOrStackCheck (CmmSource { trans_info = HeapStackCheck}) = True
isHeapOrStackCheck _ = False
data EdgeInfo
  = EdgeInfo
  { transitionSource :: !TransitionSource
  , edgeWeight :: !EdgeWeight
  } deriving (Eq)
instance Outputable EdgeInfo where
  ppr edgeInfo = text "weight:" <+> ppr (edgeWeight edgeInfo)
mkWeightInfo :: EdgeWeight -> EdgeInfo
mkWeightInfo = EdgeInfo AsmCodeGen
adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight)
                 -> BlockId -> BlockId -> CFG
adjustEdgeWeight cfg f from to
  | Just info <- getEdgeInfo from to cfg
  , !weight <- edgeWeight info
  , !newWeight <- f weight
  = addEdge from to (info { edgeWeight = newWeight}) cfg
  | otherwise = cfg
setEdgeWeight :: CFG -> EdgeWeight
              -> BlockId -> BlockId -> CFG
setEdgeWeight cfg !weight from to
  | Just info <- getEdgeInfo from to cfg
  = addEdge from to (info { edgeWeight = weight}) cfg
  | otherwise = cfg
getCfgNodes :: CFG -> [BlockId]
getCfgNodes m =
    mapKeys m
hasNode :: CFG -> BlockId -> Bool
hasNode m node =
  
  ASSERT( found || not (any (mapMember node) m))
  found
    where
      found = mapMember node m
sanityCheckCfg :: CFG -> LabelSet -> SDoc -> Bool
sanityCheckCfg m blockSet msg
    | blockSet == cfgNodes
    = True
    | otherwise =
        pprPanic "Block list and cfg nodes don't match" (
            text "difference:" <+> ppr diff $$
            text "blocks:" <+> ppr blockSet $$
            text "cfg:" <+> pprEdgeWeights m $$
            msg )
            False
    where
      cfgNodes = setFromList $ getCfgNodes m :: LabelSet
      diff = (setUnion cfgNodes blockSet) `setDifference` (setIntersection cfgNodes blockSet) :: LabelSet
filterEdges :: (BlockId -> BlockId -> EdgeInfo -> Bool) -> CFG -> CFG
filterEdges f cfg =
    mapMapWithKey filterSources cfg
    where
      filterSources from m =
        mapFilterWithKey (\to w -> f from to w) m
shortcutWeightMap :: LabelMap (Maybe BlockId) -> CFG -> CFG
shortcutWeightMap cuts cfg =
  foldl' applyMapping cfg $ mapToList cuts
    where
      applyMapping :: CFG -> (BlockId,Maybe BlockId) -> CFG
      
      applyMapping m (from, Nothing) =
        mapDelete from .
        fmap (mapDelete from) $ m
      
      applyMapping m (from, Just to) =
        let updatedMap :: CFG
            updatedMap
              = fmap (shortcutEdge (from,to)) $
                (mapDelete from m :: CFG )
        
        
        
        in case mapLookup to cuts of
            Nothing -> updatedMap
            Just dest -> applyMapping updatedMap (to, dest)
      
      shortcutEdge :: (BlockId, BlockId) -> LabelMap EdgeInfo -> LabelMap EdgeInfo
      shortcutEdge (from, to) m =
        case mapLookup from m of
          Just info -> mapInsert to info $ mapDelete from m
          Nothing   -> m
addImmediateSuccessor :: BlockId -> BlockId -> CFG -> CFG
addImmediateSuccessor node follower cfg
    = updateEdges . addWeightEdge node follower uncondWeight $ cfg
    where
        uncondWeight = fromIntegral . D.uncondWeight .
                       D.cfgWeightInfo $ D.unsafeGlobalDynFlags
        targets = getSuccessorEdges cfg node
        successors = map fst targets :: [BlockId]
        updateEdges = addNewSuccs . remOldSuccs
        remOldSuccs m = foldl' (flip (delEdge node)) m successors
        addNewSuccs m =
          foldl' (\m' (t,info) -> addEdge follower t info m') m targets
addEdge :: BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge from to info cfg =
    mapAlter addFromToEdge from $
    mapAlter addDestNode to cfg
    where
        
        addFromToEdge Nothing = Just $ mapSingleton to info
        addFromToEdge (Just wm) = Just $ mapInsert to info wm
        
        addDestNode Nothing = Just $ mapEmpty
        addDestNode n@(Just _) = n
addWeightEdge :: BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge from to weight cfg =
    addEdge from to (mkWeightInfo weight) cfg
delEdge :: BlockId -> BlockId -> CFG -> CFG
delEdge from to m =
    mapAlter remDest from m
    where
        remDest Nothing = Nothing
        remDest (Just wm) = Just $ mapDelete to wm
delNode :: BlockId -> CFG -> CFG
delNode node cfg =
  fmap (mapDelete node)  
    (mapDelete node cfg) 
getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccEdgesSorted m bid =
    let destMap = mapFindWithDefault mapEmpty bid m
        cfgEdges = mapToList destMap
        sortedEdges = sortWith (negate . edgeWeight . snd) cfgEdges
    in  
        sortedEdges
getSuccessorEdges :: HasDebugCallStack => CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccessorEdges m bid = maybe lookupError mapToList (mapLookup bid m)
  where
    lookupError = pprPanic "getSuccessorEdges: Block does not exist" $
                    ppr bid <+> pprEdgeWeights m
getEdgeInfo :: BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo from to m
    | Just wm <- mapLookup from m
    , Just info <- mapLookup to wm
    = Just $! info
    | otherwise
    = Nothing
getEdgeWeight :: CFG -> BlockId -> BlockId -> EdgeWeight
getEdgeWeight cfg from to =
    edgeWeight $ expectJust "Edgeweight for noexisting block" $
                 getEdgeInfo from to cfg
getTransitionSource :: BlockId -> BlockId -> CFG -> TransitionSource
getTransitionSource from to cfg = transitionSource $ expectJust "Source info for noexisting block" $
                        getEdgeInfo from to cfg
reverseEdges :: CFG -> CFG
reverseEdges cfg = mapFoldlWithKey (\cfg from toMap -> go (addNode cfg from) from toMap) mapEmpty cfg
  where
    
    addNode :: CFG -> BlockId -> CFG
    addNode cfg b = mapInsertWith mapUnion b mapEmpty cfg
    go :: CFG -> BlockId -> (LabelMap EdgeInfo) -> CFG
    go cfg from toMap = mapFoldlWithKey (\cfg to info -> addEdge to from info cfg) cfg toMap  :: CFG
infoEdgeList :: CFG -> [CfgEdge]
infoEdgeList m =
    go (mapToList m) []
  where
    
    go :: [(BlockId,LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
    go [] acc = acc
    go ((from,toMap):xs) acc
      = go' xs from (mapToList toMap) acc
    go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [(BlockId,EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
    go' froms _    []              acc = go froms acc
    go' froms from ((to,info):tos) acc
      = go' froms from tos (CfgEdge from to info : acc)
edgeList :: CFG -> [Edge]
edgeList m =
    go (mapToList m) []
  where
    
    go :: [(BlockId,LabelMap EdgeInfo)] -> [Edge] -> [Edge]
    go [] acc = acc
    go ((from,toMap):xs) acc
      = go' xs from (mapKeys toMap) acc
    go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [BlockId] -> [Edge] -> [Edge]
    go' froms _    []              acc = go froms acc
    go' froms from (to:tos) acc
      = go' froms from tos ((from,to) : acc)
getSuccessors :: HasDebugCallStack => CFG -> BlockId -> [BlockId]
getSuccessors m bid
    | Just wm <- mapLookup bid m
    = mapKeys wm
    | otherwise = lookupError
    where
      lookupError = pprPanic "getSuccessors: Block does not exist" $
                    ppr bid <+> pprEdgeWeights m
pprEdgeWeights :: CFG -> SDoc
pprEdgeWeights m =
    let edges = sort $ infoEdgeList m :: [CfgEdge]
        printEdge (CfgEdge from to (EdgeInfo { edgeWeight = weight }))
            = text "\t" <> ppr from <+> text "->" <+> ppr to <>
              text "[label=\"" <> ppr weight <> text "\",weight=\"" <>
              ppr weight <> text "\"];\n"
        
        
        
        printNode node
            = text "\t" <> ppr node <> text ";\n"
        getEdgeNodes (CfgEdge from to _) = [from,to]
        edgeNodes = setFromList $ concatMap getEdgeNodes edges :: LabelSet
        nodes = filter (\n -> (not . setMember n) edgeNodes) . mapKeys $ mapFilter null m
    in
    text "digraph {\n" <>
        (foldl' (<>) empty (map printEdge edges)) <>
        (foldl' (<>) empty (map printNode nodes)) <>
    text "}\n"
 
updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG
updateEdgeWeight f (from, to) cfg
    | Just oldInfo <- getEdgeInfo from to cfg
    = let !oldWeight = edgeWeight oldInfo
          !newWeight = f oldWeight
      in addEdge from to (oldInfo {edgeWeight = newWeight}) cfg
    | otherwise
    = panic "Trying to update invalid edge"
mapWeights :: (BlockId -> BlockId -> EdgeWeight -> EdgeWeight) -> CFG -> CFG
mapWeights f cfg =
  foldl' (\cfg (CfgEdge from to info) ->
            let oldWeight = edgeWeight info
                newWeight = f from to oldWeight
            in addEdge from to (info {edgeWeight = newWeight}) cfg)
          cfg (infoEdgeList cfg)
addNodesBetween :: CFG -> [(BlockId,BlockId,BlockId)] -> CFG
addNodesBetween m updates =
  foldl'  updateWeight m .
          weightUpdates $ updates
    where
      weight = fromIntegral . D.uncondWeight .
                D.cfgWeightInfo $ D.unsafeGlobalDynFlags
      
      
      
      
      weightUpdates = map getWeight
      getWeight :: (BlockId,BlockId,BlockId) -> (BlockId,BlockId,BlockId,EdgeInfo)
      getWeight (from,between,old)
        | Just edgeInfo <- getEdgeInfo from old m
        = (from,between,old,edgeInfo)
        | otherwise
        = pprPanic "Can't find weight for edge that should have one" (
            text "triple" <+> ppr (from,between,old) $$
            text "updates" <+> ppr updates $$
            text "cfg:" <+> pprEdgeWeights m )
      updateWeight :: CFG -> (BlockId,BlockId,BlockId,EdgeInfo) -> CFG
      updateWeight m (from,between,old,edgeInfo)
        = addEdge from between edgeInfo .
          addWeightEdge between old weight .
          delEdge from old $ m
getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG
getCfgProc _       (CmmData {}) = mapEmpty
getCfgProc weights (CmmProc _info _lab _live graph) = getCfg weights graph
getCfg :: D.CfgWeights -> CmmGraph -> CFG
getCfg weights graph =
  foldl' insertEdge edgelessCfg $ concatMap getBlockEdges blocks
  where
    D.CFGWeights
            { D.uncondWeight = uncondWeight
            , D.condBranchWeight = condBranchWeight
            , D.switchWeight = switchWeight
            , D.callWeight = callWeight
            , D.likelyCondWeight = likelyCondWeight
            , D.unlikelyCondWeight = unlikelyCondWeight
            
            
            
            } = weights
    
    
    edgelessCfg = mapFromList $ zip (map G.entryLabel blocks) (repeat mapEmpty)
    insertEdge :: CFG -> ((BlockId,BlockId),EdgeInfo) -> CFG
    insertEdge m ((from,to),weight) =
      mapAlter f from m
        where
          f :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
          f Nothing = Just $ mapSingleton to weight
          f (Just destMap) = Just $ mapInsert to weight destMap
    getBlockEdges :: CmmBlock -> [((BlockId,BlockId),EdgeInfo)]
    getBlockEdges block =
      case branch of
        CmmBranch dest -> [mkEdge dest uncondWeight]
        CmmCondBranch cond t f l
          | l == Nothing ->
              [mkEdge f condBranchWeight,   mkEdge t condBranchWeight]
          | l == Just True ->
              [mkEdge f unlikelyCondWeight, mkEdge t likelyCondWeight]
          | l == Just False ->
              [mkEdge f likelyCondWeight,   mkEdge t unlikelyCondWeight]
          where
            mkEdgeInfo = 
                         EdgeInfo (CmmSource branch branchInfo) . fromIntegral
            mkEdge target weight = ((bid,target), mkEdgeInfo weight)
            branchInfo =
              foldRegsUsed
                (panic "foldRegsDynFlags")
                (\info r -> if r == SpLim || r == HpLim || r == BaseReg
                    then HeapStackCheck else info)
                NoInfo cond
        (CmmSwitch _e ids) ->
          let switchTargets = switchTargetsToList ids
              
              
              adjustedWeight =
                if (length switchTargets > 10) then 1 else switchWeight
          in map (\x -> mkEdge x adjustedWeight) switchTargets
        (CmmCall { cml_cont = Just cont})  -> [mkEdge cont callWeight]
        (CmmForeignCall {Cmm.succ = cont}) -> [mkEdge cont callWeight]
        (CmmCall { cml_cont = Nothing })   -> []
        other ->
            panic "Foo" $
            ASSERT2(False, ppr "Unkown successor cause:" <>
              (ppr branch <+> text "=>" <> ppr (G.successors other)))
            map (\x -> ((bid,x),mkEdgeInfo 0)) $ G.successors other
      where
        bid = G.entryLabel block
        mkEdgeInfo = EdgeInfo (CmmSource branch NoInfo) . fromIntegral
        mkEdge target weight = ((bid,target), mkEdgeInfo weight)
        branch = lastNode block :: CmmNode O C
    blocks = revPostorder graph :: [CmmBlock]
findBackEdges :: HasDebugCallStack => BlockId -> CFG -> Edges
findBackEdges root cfg =
    
    map fst .
    filter (\x -> snd x == Backward) $ typedEdges
  where
    edges = edgeList cfg :: [(BlockId,BlockId)]
    getSuccs = getSuccessors cfg :: BlockId -> [BlockId]
    typedEdges =
      classifyEdges root getSuccs edges :: [((BlockId,BlockId),EdgeType)]
optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG
optimizeCFG _ (CmmData {}) cfg = cfg
optimizeCFG weights (CmmProc info _lab _live graph) cfg =
    
    
    
    
    favourFewerPreds  .
    penalizeInfoTables info .
    increaseBackEdgeWeight (g_entry graph) $ cfg
  where
    
    
    increaseBackEdgeWeight :: BlockId -> CFG -> CFG
    increaseBackEdgeWeight root cfg =
        let backedges = findBackEdges root cfg
            update weight
              
              | weight <= 0 = 0
              | otherwise
              = weight + fromIntegral (D.backEdgeBonus weights)
        in  foldl'  (\cfg edge -> updateEdgeWeight update edge cfg)
                    cfg backedges
    
    penalizeInfoTables :: LabelMap a -> CFG -> CFG
    penalizeInfoTables info cfg =
        mapWeights fupdate cfg
      where
        fupdate :: BlockId -> BlockId -> EdgeWeight -> EdgeWeight
        fupdate _ to weight
          | mapMember to info
          = weight  (fromIntegral $ D.infoTablePenalty weights)
          | otherwise = weight
    
    
    favourFewerPreds :: CFG -> CFG
    favourFewerPreds cfg =
        let
            revCfg =
              reverseEdges $ filterEdges
                              (\_from -> fallthroughTarget)  cfg
            predCount n = length $ getSuccessorEdges revCfg n
            nodes = getCfgNodes cfg
            modifiers :: Int -> Int -> (EdgeWeight, EdgeWeight)
            modifiers preds1 preds2
              | preds1 <  preds2 = ( 1,1)
              | preds1 == preds2 = ( 0, 0)
              | otherwise        = (1, 1)
            update :: CFG -> BlockId -> CFG
            update cfg node
              | [(s1,e1),(s2,e2)] <- getSuccessorEdges cfg node
              , !w1 <- edgeWeight e1
              , !w2 <- edgeWeight e2
              
              , w1 == w2
              , (mod1,mod2) <- modifiers (predCount s1) (predCount s2)
              = (\cfg' ->
                  (adjustEdgeWeight cfg' (+mod2) node s2))
                    (adjustEdgeWeight cfg  (+mod1) node s1)
              | otherwise
              = cfg
        in foldl' update cfg nodes
      where
        fallthroughTarget :: BlockId -> EdgeInfo -> Bool
        fallthroughTarget to (EdgeInfo source _weight)
          | mapMember to info = False
          | AsmCodeGen <- source = True
          | CmmSource { trans_cmmNode = CmmBranch {} } <- source = True
          | CmmSource { trans_cmmNode = CmmCondBranch {} } <- source = True
          | otherwise = False
loopMembers :: HasDebugCallStack => CFG -> LabelMap Bool
loopMembers cfg =
    foldl' (flip setLevel) mapEmpty sccs
  where
    mkNode :: BlockId -> Node BlockId BlockId
    mkNode bid = DigraphNode bid bid (getSuccessors cfg bid)
    nodes = map mkNode (getCfgNodes cfg)
    sccs = stronglyConnCompFromEdgedVerticesOrd nodes
    setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
    setLevel (AcyclicSCC bid) m = mapInsert bid False m
    setLevel (CyclicSCC bids) m = foldl' (\m k -> mapInsert k True m) m bids
loopLevels :: CFG -> BlockId -> LabelMap Int
loopLevels cfg root = liLevels loopInfos
    where
      loopInfos = loopInfo cfg root
data LoopInfo = LoopInfo
  { liBackEdges :: [(Edge)] 
  , liLevels :: LabelMap Int 
  , liLoops :: [(Edge, LabelSet)] 
  }
instance Outputable LoopInfo where
    ppr (LoopInfo _ _lvls loops) =
        text "Loops:(backEdge, bodyNodes)" $$
            (vcat $ map ppr loops)
loopInfo :: HasDebugCallStack => CFG -> BlockId -> LoopInfo
loopInfo cfg root = LoopInfo  { liBackEdges = backEdges
                              , liLevels = mapFromList loopCounts
                              , liLoops = loopBodies }
  where
    revCfg = reverseEdges cfg
    graph = 
            fmap (setFromList . mapKeys ) cfg :: LabelMap LabelSet
    
    rooted = ( fromBlockId root
              , toIntMap $ fmap toIntSet graph) :: (Int, IntMap IntSet)
    tree = fmap toBlockId $ Dom.domTree rooted :: Tree BlockId
    
    domMap :: LabelMap LabelSet
    domMap = mkDomMap tree
    edges = edgeList cfg :: [(BlockId, BlockId)]
    
    nodes = getCfgNodes cfg :: [BlockId]
    
    isBackEdge (from,to)
      | Just doms <- mapLookup from domMap
      , setMember to doms
      = True
      | otherwise = False
    
    
    findBody edge@(tail, head)
      = ( edge, setInsert head $ go (setSingleton tail) (setSingleton tail) )
      where
        
        cfg' = delNode head revCfg
        go :: LabelSet -> LabelSet -> LabelSet
        go found current
          | setNull current = found
          | otherwise = go  (setUnion newSuccessors found)
                            newSuccessors
          where
            
            newSuccessors = setFilter (\n -> not $ setMember n found) successors :: LabelSet
            successors = setFromList $ concatMap
                                      (getSuccessors cfg')
                                      
                                      (filter (/= head) $ setElems current) :: LabelSet
    backEdges = filter isBackEdge edges
    loopBodies = map findBody backEdges :: [(Edge, LabelSet)]
    
    loopCounts =
      let bodies = map (first snd) loopBodies 
          loopCount n = length $ nub . map fst . filter (setMember n . snd) $ bodies
      in  map (\n -> (n, loopCount n)) $ nodes :: [(BlockId, Int)]
    toIntSet :: LabelSet -> IntSet
    toIntSet s = IS.fromList . map fromBlockId . setElems $ s
    toIntMap :: LabelMap a -> IntMap a
    toIntMap m = IM.fromList $ map (\(x,y) -> (fromBlockId x,y)) $ mapToList m
    mkDomMap :: Tree BlockId -> LabelMap LabelSet
    mkDomMap root = mapFromList $ go setEmpty root
      where
        go :: LabelSet -> Tree BlockId -> [(Label,LabelSet)]
        go parents (Node lbl [])
          =  [(lbl, parents)]
        go parents (Node _ leaves)
          = let nodes = map rootLabel leaves
                entries = map (\x -> (x,parents)) nodes
            in  entries ++ concatMap
                            (\n -> go (setInsert (rootLabel n) parents) n)
                            leaves
    fromBlockId :: BlockId -> Int
    fromBlockId = getKey . getUnique
    toBlockId :: Int -> BlockId
    toBlockId = mkBlockId . mkUniqueGrimily
newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId])
instance G.NonLocal (BlockNode) where
  entryLabel (BN (lbl,_))   = lbl
  successors (BN (_,succs)) = succs
revPostorderFrom :: HasDebugCallStack => CFG -> BlockId -> [BlockId]
revPostorderFrom cfg root =
    map fromNode $ G.revPostorderFrom hooplGraph root
  where
    nodes = getCfgNodes cfg
    hooplGraph = foldl' (\m n -> mapInsert n (toNode n) m) mapEmpty nodes
    fromNode :: BlockNode C C -> BlockId
    fromNode (BN x) = fst x
    toNode :: BlockId -> BlockNode C C
    toNode bid =
        BN (bid,getSuccessors cfg $ bid)
mkGlobalWeights :: HasDebugCallStack => BlockId -> CFG -> (LabelMap Double, LabelMap (LabelMap Double))
mkGlobalWeights root localCfg
  | null localCfg = panic "Error - Empty CFG"
  | otherwise
  = (blockFreqs', edgeFreqs')
  where
    
    (blockFreqs, edgeFreqs) = calcFreqs nodeProbs backEdges' bodies' revOrder'
    blockFreqs' = mapFromList $ map (first fromVertex) (assocs blockFreqs) :: LabelMap Double
    edgeFreqs' = fmap fromVertexMap $ fromVertexMap edgeFreqs
    fromVertexMap :: IM.IntMap x -> LabelMap x
    fromVertexMap m = mapFromList . map (first fromVertex) $ IM.toList m
    revOrder = revPostorderFrom localCfg root :: [BlockId]
    loopResults@(LoopInfo backedges _levels bodies) = loopInfo localCfg root
    revOrder' = map toVertex revOrder
    backEdges' = map (bimap toVertex toVertex) backedges
    bodies' = map calcBody bodies
    estimatedCfg = staticBranchPrediction root loopResults localCfg
    
    nodeProbs = cfgEdgeProbabilities estimatedCfg toVertex
    
    
    
    calcBody (backedge, blocks) =
        (toVertex $ snd backedge, sort . map toVertex $ (setElems blocks))
    vertexMapping = mapFromList $ zip revOrder [0..] :: LabelMap Int
    blockMapping = listArray (0,mapSize vertexMapping  1) revOrder :: Array Int BlockId
    
    toVertex :: BlockId -> Int
    toVertex   blockId  = expectJust "mkGlobalWeights" $ mapLookup blockId vertexMapping
    
    fromVertex :: Int -> BlockId
    fromVertex vertex   = blockMapping ! vertex
type TargetNodeInfo = (BlockId, EdgeInfo)
staticBranchPrediction :: BlockId -> LoopInfo -> CFG -> CFG
staticBranchPrediction _root (LoopInfo l_backEdges loopLevels l_loops) cfg =
    
    foldl' update cfg nodes
  where
    nodes = getCfgNodes cfg
    backedges = S.fromList $ l_backEdges
    
    loops = M.fromList $ l_loops :: M.Map Edge LabelSet
    loopHeads = S.fromList $ map snd $ M.keys loops
    update :: CFG -> BlockId -> CFG
    update cfg node
        
        | null successors = cfg
        
        
        | not (null m) && length m < length successors
        
        
        , not $ any (isHeapOrStackCheck  . transitionSource . snd) successors
        = let   loopChance = repeat $! pred_LBH / (fromIntegral $ length m)
                exitChance = repeat $! (1  pred_LBH) / fromIntegral (length not_m)
                updates = zip (map fst m) loopChance ++ zip (map fst not_m) exitChance
        in  
            foldl' (\cfg (to,weight) -> setEdgeWeight cfg weight node to) cfg updates
        
        | length successors /= 2
        = cfg
        
        | length m > 0
        = cfg
        
        | [(s1,s1_info),(s2,s2_info)] <- successors
        , not $ any (isHeapOrStackCheck  . transitionSource . snd) successors
        = 
            let !w1 = max (edgeWeight s1_info) (0)
                !w2 = max (edgeWeight s2_info) (0)
                
                normalizeWeight w = if w1 + w2 == 0 then 0.5 else w/(w1+w2)
                !cfg'  = setEdgeWeight cfg  (normalizeWeight w1) node s1
                !cfg'' = setEdgeWeight cfg' (normalizeWeight w2) node s2
                
                heuristics = map ($ ((s1,s1_info),(s2,s2_info)))
                            [lehPredicts, phPredicts, ohPredicts, ghPredicts, lhhPredicts, chPredicts
                            , shPredicts, rhPredicts]
                
                
                applyHeuristic :: CFG -> Maybe Prob -> CFG
                applyHeuristic cfg Nothing = cfg
                applyHeuristic cfg (Just (s1_pred :: Double))
                  | s1_old == 0 || s2_old == 0 ||
                    isHeapOrStackCheck (transitionSource s1_info) ||
                    isHeapOrStackCheck (transitionSource s2_info)
                  = cfg
                  | otherwise =
                    let 
                        s1_prob = EdgeWeight s1_pred :: EdgeWeight
                        s2_prob = 1.0  s1_prob
                        
                        d = (s1_old * s1_prob) + (s2_old * s2_prob) :: EdgeWeight
                        s1_prob' = s1_old * s1_prob / d
                        !s2_prob' = s2_old * s2_prob / d
                        !cfg_s1 = setEdgeWeight cfg    s1_prob' node s1
                    in  
                        setEdgeWeight cfg_s1 s2_prob' node s2
                  where
                    
                    s1_old = getEdgeWeight cfg node s1
                    s2_old = getEdgeWeight cfg node s2
            in
            
            foldl' applyHeuristic cfg'' heuristics
        
        | otherwise = cfg
      where
        
        pred_LBH = 0.875
        
        successors = getSuccessorEdges cfg node
        
        (m,not_m) = partition (\succ -> S.member (node, fst succ) backedges) successors
        
        
        
        
        
        pred_LEH = 0.75
        
        
        lehPredicts :: (TargetNodeInfo,TargetNodeInfo) -> Maybe Prob
        lehPredicts ((s1,_s1_info),(s2,_s2_info))
          | S.member s1 loopHeads || S.member s2 loopHeads
          = Nothing
          | otherwise
          = 
            case compare s1Level s2Level of
                EQ -> Nothing
                LT -> Just (1pred_LEH) 
                GT -> Just (pred_LEH)   
            where
                s1Level = mapLookup s1 loopLevels
                s2Level = mapLookup s2 loopLevels
        
        ohPredicts (s1,_s2)
            | CmmSource { trans_cmmNode = src1 } <- getTransitionSource node (fst s1) cfg
            , CmmCondBranch cond ltrue _lfalse likely <- src1
            , likely == Nothing
            , CmmMachOp mop args <- cond
            , MO_Eq {} <- mop
            , not (null [x | x@CmmLit{} <- args])
            = if fst s1 == ltrue then Just 0.3 else Just 0.7
            | otherwise
            = Nothing
        
        
        phPredicts = const Nothing
        ghPredicts = const Nothing
        lhhPredicts = const Nothing
        chPredicts = const Nothing
        shPredicts = const Nothing
        rhPredicts = const Nothing
cfgEdgeProbabilities :: CFG -> (BlockId -> Int) -> IM.IntMap (IM.IntMap Prob)
cfgEdgeProbabilities cfg toVertex
    = mapFoldlWithKey foldEdges IM.empty cfg
  where
    foldEdges = (\m from toMap -> IM.insert (toVertex from) (normalize toMap) m)
    normalize :: (LabelMap EdgeInfo) -> (IM.IntMap Prob)
    normalize weightMap
        | edgeCount <= 1 = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) 1.0 m) IM.empty weightMap
        | otherwise = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) (normalWeight k) m) IM.empty weightMap
      where
        edgeCount = mapSize weightMap
        
        
        
        minWeight = 0 :: Prob
        weightMap' = fmap (\w -> max (weightToDouble . edgeWeight $ w) minWeight) weightMap
        totalWeight = sum weightMap'
        normalWeight :: BlockId -> Prob
        normalWeight bid
         | totalWeight == 0
         = 1.0 / fromIntegral edgeCount
         | Just w <- mapLookup bid weightMap'
         = w/totalWeight
         | otherwise = panic "impossible"
calcFreqs :: IM.IntMap (IM.IntMap Prob) -> [(Int,Int)] -> [(Int, [Int])] -> [Int]
          -> (Array Int Double, IM.IntMap (IM.IntMap Prob))
calcFreqs graph backEdges loops revPostOrder = runST $ do
    visitedNodes <- newArray (0,nodeCount1) False :: ST s (STUArray s Int Bool)
    blockFreqs <- newArray (0,nodeCount1) 0.0 :: ST s (STUArray s Int Double)
    edgeProbs <- newSTRef graph
    edgeBackProbs <- newSTRef graph
    
    
          
    let  
        
        
        visited b = unsafeRead visitedNodes b
        getFreq b = unsafeRead blockFreqs b
        
        setFreq b f = unsafeWrite blockFreqs b f
        
        setVisited b = unsafeWrite visitedNodes b True
        
        getProb' arr b1 b2 = readSTRef arr >>=
            (\graph ->
                return .
                        fromMaybe (error "getFreq 1") .
                        IM.lookup b2 .
                        fromMaybe (error "getFreq 2") $
                        (IM.lookup b1 graph)
            )
        setProb' arr b1 b2 prob = do
          g <- readSTRef arr
          let !m = fromMaybe (error "Foo") $ IM.lookup b1 g
              !m' = IM.insert b2 prob m
          writeSTRef arr $! (IM.insert b1 m' g)
        getEdgeFreq b1 b2 = getProb' edgeProbs b1 b2
        setEdgeFreq b1 b2 = setProb' edgeProbs b1 b2
        getProb b1 b2 = fromMaybe (error "getProb") $ do
            m' <- IM.lookup b1 graph
            IM.lookup b2 m'
        getBackProb b1 b2 = getProb' edgeBackProbs b1 b2
        setBackProb b1 b2 = setProb' edgeBackProbs b1 b2
    let 
        calcOutFreqs bhead block = do
          !f <- getFreq block
          forM (successors block) $ \bi -> do
            let !prob = getProb block bi
            let !succFreq = f * prob
            setEdgeFreq block bi succFreq
            
            when (bi == bhead) $ setBackProb block bi succFreq
    let propFreq block head = do
            
            
            !v <- visited block
            if v then
                return () 
            else if block == head then
                setFreq block 1.0 
            else do
                let preds = IS.elems $ predecessors block
                irreducible <- (fmap or) $ forM preds $ \bp -> do
                    !bp_visited <- visited bp
                    let bp_backedge = isBackEdge bp block
                    return (not bp_visited && not bp_backedge)
                if irreducible
                then return () 
                else do
                    setFreq block 0
                    !cycleProb <- sum <$> (forM preds $ \pred -> do
                        if isBackEdge pred block
                            then
                                getBackProb pred block
                            else do
                                !f <- getFreq block
                                !prob <- getEdgeFreq pred block
                                setFreq block $! f + prob
                                return 0)
                    
                    let limit = 1  1/512 
                                          
                    !cycleProb <- return $ min cycleProb limit 
                    
                    !f <- getFreq block
                    setFreq block (f / (1.0  cycleProb))
            setVisited block
            calcOutFreqs head block
    
    forM_ loops $ \(head, body) -> do
        forM_ [0 .. nodeCount  1] (\i -> unsafeWrite visitedNodes i True) 
        forM_ body (\i -> unsafeWrite visitedNodes i False) 
        forM_ body $ \block -> propFreq block head
    
    forM_ [0 .. nodeCount  1] (\i -> unsafeWrite visitedNodes i False) 
    forM_ revPostOrder $ \block -> propFreq block (head revPostOrder)
    
    
    
    
    graph' <- readSTRef edgeProbs
    freqs' <- unsafeFreeze  blockFreqs
    return (freqs', graph')
  where
    
    predecessors :: Int -> IS.IntSet
    predecessors b = fromMaybe IS.empty $ IM.lookup b revGraph
    successors :: Int -> [Int]
    successors b = fromMaybe (lookupError "succ" b graph)$ IM.keys <$> IM.lookup b graph
    lookupError s b g = pprPanic ("Lookup error " ++ s) $
                            ( text "node" <+> ppr b $$
                                text "graph" <+>
                                vcat (map (\(k,m) -> ppr (k,m :: IM.IntMap Double)) $ IM.toList g)
                            )
    nodeCount = IM.foldl' (\count toMap -> IM.foldlWithKey' countTargets count toMap) (IM.size graph) graph
      where
        countTargets = (\count k _ -> countNode k + count )
        countNode n = if IM.member n graph then 0 else 1
    isBackEdge from to = S.member (from,to) backEdgeSet
    backEdgeSet = S.fromList backEdges
    revGraph :: IntMap IntSet
    revGraph = IM.foldlWithKey' (\m from toMap -> addEdges m from toMap) IM.empty graph
        where
            addEdges m0 from toMap = IM.foldlWithKey' (\m k _ -> addEdge m from k) m0 toMap
            addEdge m0 from to = IM.insertWith IS.union to (IS.singleton from) m0