{-# LANGUAGE ScopedTypeVariables #-}
module RegAlloc.Graph.SpillCost (
        SpillCostRecord,
        plusSpillCostRecord,
        pprSpillCostRecord,
        SpillCostInfo,
        zeroSpillCostInfo,
        plusSpillCostInfo,
        slurpSpillCostInfo,
        chooseSpill,
        lifeMapFromSpillCostInfo
) where
import GhcPrelude
import RegAlloc.Liveness
import Instruction
import RegClass
import Reg
import GraphBase
import Hoopl.Collections (mapLookup)
import Cmm
import UniqFM
import UniqSet
import Digraph          (flattenSCCs)
import Outputable
import Platform
import State
import CFG
import Data.List        (nub, minimumBy)
import Data.Maybe
import Control.Monad (join)
type SpillCostRecord
 =      ( VirtualReg    
        , Int           
        , Int           
        , Int)          
type SpillCostInfo
        = UniqFM SpillCostRecord
type LoopMember = Bool
type SpillCostState = State (UniqFM SpillCostRecord) ()
zeroSpillCostInfo :: SpillCostInfo
zeroSpillCostInfo       = emptyUFM
plusSpillCostInfo :: SpillCostInfo -> SpillCostInfo -> SpillCostInfo
plusSpillCostInfo sc1 sc2
        = plusUFM_C plusSpillCostRecord sc1 sc2
plusSpillCostRecord :: SpillCostRecord -> SpillCostRecord -> SpillCostRecord
plusSpillCostRecord (r1, a1, b1, c1) (r2, a2, b2, c2)
        | r1 == r2      = (r1, a1 + a2, b1 + b2, c1 + c2)
        | otherwise     = error "RegSpillCost.plusRegInt: regs don't match"
slurpSpillCostInfo :: forall instr statics. (Outputable instr, Instruction instr)
                   => Platform
                   -> Maybe CFG
                   -> LiveCmmDecl statics instr
                   -> SpillCostInfo
slurpSpillCostInfo platform cfg cmm
        = execState (countCmm cmm) zeroSpillCostInfo
 where
        countCmm CmmData{}              = return ()
        countCmm (CmmProc info _ _ sccs)
                = mapM_ (countBlock info)
                $ flattenSCCs sccs
        
        
        countBlock info (BasicBlock blockId instrs)
                | LiveInfo _ _ (Just blockLive) _ <- info
                , Just rsLiveEntry  <- mapLookup blockId blockLive
                , rsLiveEntry_virt  <- takeVirtuals rsLiveEntry
                = countLIs (loopMember blockId) rsLiveEntry_virt instrs
                | otherwise
                = error "RegAlloc.SpillCost.slurpSpillCostInfo: bad block"
        countLIs :: LoopMember -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState
        countLIs _      _      []
                = return ()
        
        countLIs inLoop rsLive (LiveInstr instr Nothing : lis)
                | isMetaInstr instr
                = countLIs inLoop rsLive lis
                | otherwise
                = pprPanic "RegSpillCost.slurpSpillCostInfo"
                $ text "no liveness information on instruction " <> ppr instr
        countLIs inLoop rsLiveEntry (LiveInstr instr (Just live) : lis)
         = do
                
                mapM_ (incLifetime (loopCount inLoop)) $ nonDetEltsUniqSet rsLiveEntry
                    
                    
                    
                
                let (RU read written)   = regUsageOfInstr platform instr
                mapM_ (incUses (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub read
                mapM_ (incDefs (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub written
                
                let liveDieRead_virt    = takeVirtuals (liveDieRead  live)
                let liveDieWrite_virt   = takeVirtuals (liveDieWrite live)
                let liveBorn_virt       = takeVirtuals (liveBorn     live)
                let rsLiveAcross
                        = rsLiveEntry `minusUniqSet` liveDieRead_virt
                let rsLiveNext
                        = (rsLiveAcross `unionUniqSets` liveBorn_virt)
                                        `minusUniqSet`  liveDieWrite_virt
                countLIs inLoop rsLiveNext lis
        loopCount inLoop
          | inLoop = 10
          | otherwise = 1
        incDefs     count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, count, 0, 0)
        incUses     count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, count, 0)
        incLifetime count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, count)
        loopBlocks = CFG.loopMembers <$> cfg
        loopMember bid
          | Just isMember <- join (mapLookup bid <$> loopBlocks)
          = isMember
          | otherwise
          = False
takeVirtuals :: UniqSet Reg -> UniqSet VirtualReg
takeVirtuals set = mkUniqSet
  [ vr | RegVirtual vr <- nonDetEltsUniqSet set ]
  
chooseSpill
        :: SpillCostInfo
        -> Graph VirtualReg RegClass RealReg
        -> VirtualReg
chooseSpill info graph
 = let  cost    = spillCost_length info graph
        node    = minimumBy (\n1 n2 -> compare (cost $ nodeId n1) (cost $ nodeId n2))
                $ nonDetEltsUFM $ graphMap graph
                
   in   nodeId node
spillCost_length
        :: SpillCostInfo
        -> Graph VirtualReg RegClass RealReg
        -> VirtualReg
        -> Float
spillCost_length info _ reg
        | lifetime <= 1         = 1/0
        | otherwise             = 1 / fromIntegral lifetime
        where (_, _, _, lifetime)
                = fromMaybe (reg, 0, 0, 0)
                $ lookupUFM info reg
lifeMapFromSpillCostInfo :: SpillCostInfo -> UniqFM (VirtualReg, Int)
lifeMapFromSpillCostInfo info
        = listToUFM
        $ map (\(r, _, _, life) -> (r, (r, life)))
        $ nonDetEltsUFM info
        
nodeDegree
        :: (VirtualReg -> RegClass)
        -> Graph VirtualReg RegClass RealReg
        -> VirtualReg
        -> Int
nodeDegree classOfVirtualReg graph reg
        | Just node     <- lookupUFM (graphMap graph) reg
        , virtConflicts
           <- length
           $ filter (\r -> classOfVirtualReg r == classOfVirtualReg reg)
           $ nonDetEltsUniqSet
           
           $ nodeConflicts node
        = virtConflicts + sizeUniqSet (nodeExclusions node)
        | otherwise
        = 0
pprSpillCostRecord
        :: (VirtualReg -> RegClass)
        -> (Reg -> SDoc)
        -> Graph VirtualReg RegClass RealReg
        -> SpillCostRecord
        -> SDoc
pprSpillCostRecord regClass pprReg graph (reg, uses, defs, life)
        =  hsep
        [ pprReg (RegVirtual reg)
        , ppr uses
        , ppr defs
        , ppr life
        , ppr $ nodeDegree regClass graph reg
        , text $ show $ (fromIntegral (uses + defs)
                       / fromIntegral (nodeDegree regClass graph reg) :: Float) ]