module CoreMonad (
    
    CoreToDo(..), runWhen, runMaybe,
    SimplMode(..),
    FloatOutSwitches(..),
    pprPassDetails,
    
    CorePluginPass, bindsOnlyPass,
    
    SimplCount, doSimplTick, doFreeSimplTick, simplCountN,
    pprSimplCount, plusSimplCount, zeroSimplCount,
    isZeroSimplCount, hasDetailedCounts, Tick(..),
    
    CoreM, runCoreM,
    
    getHscEnv, getRuleBase, getModule,
    getDynFlags, getOrigNameCache, getPackageFamInstEnv,
    getVisibleOrphanMods, getUniqMask,
    getPrintUnqualified, getSrcSpanM,
    
    addSimplCount,
    
    liftIO, liftIOWithCount,
    
    getAnnotations, getFirstAnnotations,
    
    putMsg, putMsgS, errorMsg, errorMsgS, warnMsg,
    fatalErrorMsg, fatalErrorMsgS,
    debugTraceMsg, debugTraceMsgS,
    dumpIfSet_dyn
  ) where
import GhcPrelude hiding ( read )
import CoreSyn
import HscTypes
import Module
import DynFlags
import BasicTypes       ( CompilerPhase(..) )
import Annotations
import IOEnv hiding     ( liftIO, failM, failWithM )
import qualified IOEnv  ( liftIO )
import Var
import Outputable
import FastString
import qualified ErrUtils as Err
import ErrUtils( Severity(..) )
import UniqSupply
import UniqFM       ( UniqFM, mapUFM, filterUFM )
import MonadUtils
import NameCache
import SrcLoc
import Data.List (intersperse, groupBy, sortBy)
import Data.Ord
import Data.Dynamic
import Data.IORef
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Map.Strict as MapStrict
import Data.Word
import Control.Monad
import Control.Applicative ( Alternative(..) )
import Panic (throwGhcException, GhcException(..))
data CoreToDo           
                        
                        
  = CoreDoSimplify      
        Int                    
        SimplMode
  | CoreDoPluginPass String CorePluginPass
  | CoreDoFloatInwards
  | CoreDoFloatOutwards FloatOutSwitches
  | CoreLiberateCase
  | CoreDoPrintCore
  | CoreDoStaticArgs
  | CoreDoCallArity
  | CoreDoExitify
  | CoreDoStrictness
  | CoreDoWorkerWrapper
  | CoreDoSpecialising
  | CoreDoSpecConstr
  | CoreCSE
  | CoreDoRuleCheck CompilerPhase String   
                                           
  | CoreDoNothing                
  | CoreDoPasses [CoreToDo]      
  | CoreDesugar    
  | CoreDesugarOpt 
                       
  | CoreTidy
  | CorePrep
  | CoreOccurAnal
instance Outputable CoreToDo where
  ppr (CoreDoSimplify _ _)     = text "Simplifier"
  ppr (CoreDoPluginPass s _)   = text "Core plugin: " <+> text s
  ppr CoreDoFloatInwards       = text "Float inwards"
  ppr (CoreDoFloatOutwards f)  = text "Float out" <> parens (ppr f)
  ppr CoreLiberateCase         = text "Liberate case"
  ppr CoreDoStaticArgs         = text "Static argument"
  ppr CoreDoCallArity          = text "Called arity analysis"
  ppr CoreDoExitify            = text "Exitification transformation"
  ppr CoreDoStrictness         = text "Demand analysis"
  ppr CoreDoWorkerWrapper      = text "Worker Wrapper binds"
  ppr CoreDoSpecialising       = text "Specialise"
  ppr CoreDoSpecConstr         = text "SpecConstr"
  ppr CoreCSE                  = text "Common sub-expression"
  ppr CoreDesugar              = text "Desugar (before optimization)"
  ppr CoreDesugarOpt           = text "Desugar (after optimization)"
  ppr CoreTidy                 = text "Tidy Core"
  ppr CorePrep                 = text "CorePrep"
  ppr CoreOccurAnal            = text "Occurrence analysis"
  ppr CoreDoPrintCore          = text "Print core"
  ppr (CoreDoRuleCheck {})     = text "Rule check"
  ppr CoreDoNothing            = text "CoreDoNothing"
  ppr (CoreDoPasses passes)    = text "CoreDoPasses" <+> ppr passes
pprPassDetails :: CoreToDo -> SDoc
pprPassDetails (CoreDoSimplify n md) = vcat [ text "Max iterations =" <+> int n
                                            , ppr md ]
pprPassDetails _ = Outputable.empty
data SimplMode             
  = SimplMode
        { sm_names      :: [String] 
        , sm_phase      :: CompilerPhase
        , sm_dflags     :: DynFlags 
                                    
        , sm_rules      :: Bool     
        , sm_inline     :: Bool     
        , sm_case_case  :: Bool     
        , sm_eta_expand :: Bool     
        }
instance Outputable SimplMode where
    ppr (SimplMode { sm_phase = p, sm_names = ss
                   , sm_rules = r, sm_inline = i
                   , sm_eta_expand = eta, sm_case_case = cc })
       = text "SimplMode" <+> braces (
         sep [ text "Phase =" <+> ppr p <+>
               brackets (text (concat $ intersperse "," ss)) <> comma
             , pp_flag i   (sLit "inline") <> comma
             , pp_flag r   (sLit "rules") <> comma
             , pp_flag eta (sLit "eta-expand") <> comma
             , pp_flag cc  (sLit "case-of-case") ])
         where
           pp_flag f s = ppUnless f (text "no") <+> ptext s
data FloatOutSwitches = FloatOutSwitches {
  floatOutLambdas   :: Maybe Int,  
                                   
                                   
                                   
                                   
                                   
                                   
  floatOutConstants :: Bool,       
                                   
  floatOutOverSatApps :: Bool,
                             
                             
                             
                             
  floatToTopLevelOnly :: Bool      
  }
instance Outputable FloatOutSwitches where
    ppr = pprFloatOutSwitches
pprFloatOutSwitches :: FloatOutSwitches -> SDoc
pprFloatOutSwitches sw
  = text "FOS" <+> (braces $
     sep $ punctuate comma $
     [ text "Lam ="    <+> ppr (floatOutLambdas sw)
     , text "Consts =" <+> ppr (floatOutConstants sw)
     , text "OverSatApps ="   <+> ppr (floatOutOverSatApps sw) ])
runWhen :: Bool -> CoreToDo -> CoreToDo
runWhen True  do_this = do_this
runWhen False _       = CoreDoNothing
runMaybe :: Maybe a -> (a -> CoreToDo) -> CoreToDo
runMaybe (Just x) f = f x
runMaybe Nothing  _ = CoreDoNothing
type CorePluginPass = ModGuts -> CoreM ModGuts
bindsOnlyPass :: (CoreProgram -> CoreM CoreProgram) -> ModGuts -> CoreM ModGuts
bindsOnlyPass pass guts
  = do { binds' <- pass (mg_binds guts)
       ; return (guts { mg_binds = binds' }) }
getVerboseSimplStats :: (Bool -> SDoc) -> SDoc
getVerboseSimplStats = getPprDebug          
zeroSimplCount     :: DynFlags -> SimplCount
isZeroSimplCount   :: SimplCount -> Bool
hasDetailedCounts  :: SimplCount -> Bool
pprSimplCount      :: SimplCount -> SDoc
doSimplTick        :: DynFlags -> Tick -> SimplCount -> SimplCount
doFreeSimplTick    ::             Tick -> SimplCount -> SimplCount
plusSimplCount     :: SimplCount -> SimplCount -> SimplCount
data SimplCount
   = VerySimplCount !Int        
   | SimplCount {
        ticks   :: !Int,        
        details :: !TickCounts, 
        n_log   :: !Int,        
        log1    :: [Tick],      
                                
        log2    :: [Tick]       
                                
                                
     }
type TickCounts = Map Tick Int
simplCountN :: SimplCount -> Int
simplCountN (VerySimplCount n)         = n
simplCountN (SimplCount { ticks = n }) = n
zeroSimplCount dflags
                
                
  | dopt Opt_D_dump_simpl_stats dflags
  = SimplCount {ticks = 0, details = Map.empty,
                n_log = 0, log1 = [], log2 = []}
  | otherwise
  = VerySimplCount 0
isZeroSimplCount (VerySimplCount n)         = n==0
isZeroSimplCount (SimplCount { ticks = n }) = n==0
hasDetailedCounts (VerySimplCount {}) = False
hasDetailedCounts (SimplCount {})     = True
doFreeSimplTick tick sc@SimplCount { details = dts }
  = sc { details = dts `addTick` tick }
doFreeSimplTick _ sc = sc
doSimplTick dflags tick
    sc@(SimplCount { ticks = tks, details = dts, n_log = nl, log1 = l1 })
  | nl >= historySize dflags = sc1 { n_log = 1, log1 = [tick], log2 = l1 }
  | otherwise                = sc1 { n_log = nl+1, log1 = tick : l1 }
  where
    sc1 = sc { ticks = tks+1, details = dts `addTick` tick }
doSimplTick _ _ (VerySimplCount n) = VerySimplCount (n+1)
addTick :: TickCounts -> Tick -> TickCounts
addTick fm tick = MapStrict.insertWith (+) tick 1 fm
plusSimplCount sc1@(SimplCount { ticks = tks1, details = dts1 })
               sc2@(SimplCount { ticks = tks2, details = dts2 })
  = log_base { ticks = tks1 + tks2
             , details = MapStrict.unionWith (+) dts1 dts2 }
  where
        
    log_base | null (log1 sc2) = sc1    
             | null (log2 sc2) = sc2 { log2 = log1 sc1 }
             | otherwise       = sc2
plusSimplCount (VerySimplCount n) (VerySimplCount m) = VerySimplCount (n+m)
plusSimplCount lhs                rhs                =
  throwGhcException . PprProgramError "plusSimplCount" $ vcat
    [ text "lhs"
    , pprSimplCount lhs
    , text "rhs"
    , pprSimplCount rhs
    ]
       
pprSimplCount (VerySimplCount n) = text "Total ticks:" <+> int n
pprSimplCount (SimplCount { ticks = tks, details = dts, log1 = l1, log2 = l2 })
  = vcat [text "Total ticks:    " <+> int tks,
          blankLine,
          pprTickCounts dts,
          getVerboseSimplStats $ \dbg -> if dbg
          then
                vcat [blankLine,
                      text "Log (most recent first)",
                      nest 4 (vcat (map ppr l1) $$ vcat (map ppr l2))]
          else Outputable.empty
    ]
pprTickCounts :: Map Tick Int -> SDoc
pprTickCounts counts
  = vcat (map pprTickGroup groups)
  where
    groups :: [[(Tick,Int)]]    
                                
    groups = groupBy same_tag (Map.toList counts)
    same_tag (tick1,_) (tick2,_) = tickToTag tick1 == tickToTag tick2
pprTickGroup :: [(Tick, Int)] -> SDoc
pprTickGroup group@((tick1,_):_)
  = hang (int (sum [n | (_,n) <- group]) <+> text (tickString tick1))
       2 (vcat [ int n <+> pprTickCts tick
                                    
               | (tick,n) <- sortBy (flip (comparing snd)) group])
pprTickGroup [] = panic "pprTickGroup"
data Tick  
  = PreInlineUnconditionally    Id
  | PostInlineUnconditionally   Id
  | UnfoldingDone               Id
  | RuleFired                   FastString      
  | LetFloatFromLet
  | EtaExpansion                Id      
  | EtaReduction                Id      
  | BetaReduction               Id      
  | CaseOfCase                  Id      
  | KnownBranch                 Id      
  | CaseMerge                   Id      
  | AltMerge                    Id      
  | CaseElim                    Id      
  | CaseIdentity                Id      
  | FillInCaseDefault           Id      
  | SimplifierDone              
instance Outputable Tick where
  ppr tick = text (tickString tick) <+> pprTickCts tick
instance Eq Tick where
  a == b = case a `cmpTick` b of
           EQ -> True
           _ -> False
instance Ord Tick where
  compare = cmpTick
tickToTag :: Tick -> Int
tickToTag (PreInlineUnconditionally _)  = 0
tickToTag (PostInlineUnconditionally _) = 1
tickToTag (UnfoldingDone _)             = 2
tickToTag (RuleFired _)                 = 3
tickToTag LetFloatFromLet               = 4
tickToTag (EtaExpansion _)              = 5
tickToTag (EtaReduction _)              = 6
tickToTag (BetaReduction _)             = 7
tickToTag (CaseOfCase _)                = 8
tickToTag (KnownBranch _)               = 9
tickToTag (CaseMerge _)                 = 10
tickToTag (CaseElim _)                  = 11
tickToTag (CaseIdentity _)              = 12
tickToTag (FillInCaseDefault _)         = 13
tickToTag SimplifierDone                = 16
tickToTag (AltMerge _)                  = 17
tickString :: Tick -> String
tickString (PreInlineUnconditionally _) = "PreInlineUnconditionally"
tickString (PostInlineUnconditionally _)= "PostInlineUnconditionally"
tickString (UnfoldingDone _)            = "UnfoldingDone"
tickString (RuleFired _)                = "RuleFired"
tickString LetFloatFromLet              = "LetFloatFromLet"
tickString (EtaExpansion _)             = "EtaExpansion"
tickString (EtaReduction _)             = "EtaReduction"
tickString (BetaReduction _)            = "BetaReduction"
tickString (CaseOfCase _)               = "CaseOfCase"
tickString (KnownBranch _)              = "KnownBranch"
tickString (CaseMerge _)                = "CaseMerge"
tickString (AltMerge _)                 = "AltMerge"
tickString (CaseElim _)                 = "CaseElim"
tickString (CaseIdentity _)             = "CaseIdentity"
tickString (FillInCaseDefault _)        = "FillInCaseDefault"
tickString SimplifierDone               = "SimplifierDone"
pprTickCts :: Tick -> SDoc
pprTickCts (PreInlineUnconditionally v) = ppr v
pprTickCts (PostInlineUnconditionally v)= ppr v
pprTickCts (UnfoldingDone v)            = ppr v
pprTickCts (RuleFired v)                = ppr v
pprTickCts LetFloatFromLet              = Outputable.empty
pprTickCts (EtaExpansion v)             = ppr v
pprTickCts (EtaReduction v)             = ppr v
pprTickCts (BetaReduction v)            = ppr v
pprTickCts (CaseOfCase v)               = ppr v
pprTickCts (KnownBranch v)              = ppr v
pprTickCts (CaseMerge v)                = ppr v
pprTickCts (AltMerge v)                 = ppr v
pprTickCts (CaseElim v)                 = ppr v
pprTickCts (CaseIdentity v)             = ppr v
pprTickCts (FillInCaseDefault v)        = ppr v
pprTickCts _                            = Outputable.empty
cmpTick :: Tick -> Tick -> Ordering
cmpTick a b = case (tickToTag a `compare` tickToTag b) of
                GT -> GT
                EQ -> cmpEqTick a b
                LT -> LT
cmpEqTick :: Tick -> Tick -> Ordering
cmpEqTick (PreInlineUnconditionally a)  (PreInlineUnconditionally b)    = a `compare` b
cmpEqTick (PostInlineUnconditionally a) (PostInlineUnconditionally b)   = a `compare` b
cmpEqTick (UnfoldingDone a)             (UnfoldingDone b)               = a `compare` b
cmpEqTick (RuleFired a)                 (RuleFired b)                   = a `compare` b
cmpEqTick (EtaExpansion a)              (EtaExpansion b)                = a `compare` b
cmpEqTick (EtaReduction a)              (EtaReduction b)                = a `compare` b
cmpEqTick (BetaReduction a)             (BetaReduction b)               = a `compare` b
cmpEqTick (CaseOfCase a)                (CaseOfCase b)                  = a `compare` b
cmpEqTick (KnownBranch a)               (KnownBranch b)                 = a `compare` b
cmpEqTick (CaseMerge a)                 (CaseMerge b)                   = a `compare` b
cmpEqTick (AltMerge a)                  (AltMerge b)                    = a `compare` b
cmpEqTick (CaseElim a)                  (CaseElim b)                    = a `compare` b
cmpEqTick (CaseIdentity a)              (CaseIdentity b)                = a `compare` b
cmpEqTick (FillInCaseDefault a)         (FillInCaseDefault b)           = a `compare` b
cmpEqTick _                             _                               = EQ
data CoreReader = CoreReader {
        cr_hsc_env             :: HscEnv,
        cr_rule_base           :: RuleBase,
        cr_module              :: Module,
        cr_print_unqual        :: PrintUnqualified,
        cr_loc                 :: SrcSpan,   
                                             
        cr_visible_orphan_mods :: !ModuleSet,
        cr_uniq_mask           :: !Char      
}
newtype CoreWriter = CoreWriter {
        cw_simpl_count :: SimplCount
}
emptyWriter :: DynFlags -> CoreWriter
emptyWriter dflags = CoreWriter {
        cw_simpl_count = zeroSimplCount dflags
    }
plusWriter :: CoreWriter -> CoreWriter -> CoreWriter
plusWriter w1 w2 = CoreWriter {
        cw_simpl_count = (cw_simpl_count w1) `plusSimplCount` (cw_simpl_count w2)
    }
type CoreIOEnv = IOEnv CoreReader
newtype CoreM a = CoreM { unCoreM :: CoreIOEnv (a, CoreWriter) }
    deriving (Functor)
instance Monad CoreM where
    mx >>= f = CoreM $ do
            (x, w1) <- unCoreM mx
            (y, w2) <- unCoreM (f x)
            let w = w1 `plusWriter` w2
            return $ seq w (y, w)
            
            
instance Applicative CoreM where
    pure x = CoreM $ nop x
    (<*>) = ap
    m *> k = m >>= \_ -> k
instance Alternative CoreM where
    empty   = CoreM Control.Applicative.empty
    m <|> n = CoreM (unCoreM m <|> unCoreM n)
instance MonadPlus CoreM
instance MonadUnique CoreM where
    getUniqueSupplyM = do
        mask <- read cr_uniq_mask
        liftIO $! mkSplitUniqSupply mask
    getUniqueM = do
        mask <- read cr_uniq_mask
        liftIO $! uniqFromMask mask
runCoreM :: HscEnv
         -> RuleBase
         -> Char 
         -> Module
         -> ModuleSet
         -> PrintUnqualified
         -> SrcSpan
         -> CoreM a
         -> IO (a, SimplCount)
runCoreM hsc_env rule_base mask mod orph_imps print_unqual loc m
  = liftM extract $ runIOEnv reader $ unCoreM m
  where
    reader = CoreReader {
            cr_hsc_env = hsc_env,
            cr_rule_base = rule_base,
            cr_module = mod,
            cr_visible_orphan_mods = orph_imps,
            cr_print_unqual = print_unqual,
            cr_loc = loc,
            cr_uniq_mask = mask
        }
    extract :: (a, CoreWriter) -> (a, SimplCount)
    extract (value, writer) = (value, cw_simpl_count writer)
nop :: a -> CoreIOEnv (a, CoreWriter)
nop x = do
    r <- getEnv
    return (x, emptyWriter $ (hsc_dflags . cr_hsc_env) r)
read :: (CoreReader -> a) -> CoreM a
read f = CoreM $ getEnv >>= (\r -> nop (f r))
write :: CoreWriter -> CoreM ()
write w = CoreM $ return ((), w)
liftIOEnv :: CoreIOEnv a -> CoreM a
liftIOEnv mx = CoreM (mx >>= (\x -> nop x))
instance MonadIO CoreM where
    liftIO = liftIOEnv . IOEnv.liftIO
liftIOWithCount :: IO (SimplCount, a) -> CoreM a
liftIOWithCount what = liftIO what >>= (\(count, x) -> addSimplCount count >> return x)
getHscEnv :: CoreM HscEnv
getHscEnv = read cr_hsc_env
getRuleBase :: CoreM RuleBase
getRuleBase = read cr_rule_base
getVisibleOrphanMods :: CoreM ModuleSet
getVisibleOrphanMods = read cr_visible_orphan_mods
getPrintUnqualified :: CoreM PrintUnqualified
getPrintUnqualified = read cr_print_unqual
getSrcSpanM :: CoreM SrcSpan
getSrcSpanM = read cr_loc
addSimplCount :: SimplCount -> CoreM ()
addSimplCount count = write (CoreWriter { cw_simpl_count = count })
getUniqMask :: CoreM Char
getUniqMask = read cr_uniq_mask
instance HasDynFlags CoreM where
    getDynFlags = fmap hsc_dflags getHscEnv
instance HasModule CoreM where
    getModule = read cr_module
getOrigNameCache :: CoreM OrigNameCache
getOrigNameCache = do
    nameCacheRef <- fmap hsc_NC getHscEnv
    liftIO $ fmap nsNames $ readIORef nameCacheRef
getPackageFamInstEnv :: CoreM PackageFamInstEnv
getPackageFamInstEnv = do
    hsc_env <- getHscEnv
    eps <- liftIO $ hscEPS hsc_env
    return $ eps_fam_inst_env eps
getAnnotations :: Typeable a => ([Word8] -> a) -> ModGuts -> CoreM (UniqFM [a])
getAnnotations deserialize guts = do
     hsc_env <- getHscEnv
     ann_env <- liftIO $ prepareAnnotations hsc_env (Just guts)
     return (deserializeAnns deserialize ann_env)
getFirstAnnotations :: Typeable a => ([Word8] -> a) -> ModGuts -> CoreM (UniqFM a)
getFirstAnnotations deserialize guts
  = liftM (mapUFM head . filterUFM (not . null))
  $ getAnnotations deserialize guts
msg :: Severity -> WarnReason -> SDoc -> CoreM ()
msg sev reason doc
  = do { dflags <- getDynFlags
       ; loc    <- getSrcSpanM
       ; unqual <- getPrintUnqualified
       ; let sty = case sev of
                     SevError   -> err_sty
                     SevWarning -> err_sty
                     SevDump    -> dump_sty
                     _          -> user_sty
             err_sty  = mkErrStyle dflags unqual
             user_sty = mkUserStyle dflags unqual AllTheWay
             dump_sty = mkDumpStyle dflags unqual
       ; liftIO $ putLogMsg dflags reason sev loc sty doc }
putMsgS :: String -> CoreM ()
putMsgS = putMsg . text
putMsg :: SDoc -> CoreM ()
putMsg = msg SevInfo NoReason
errorMsgS :: String -> CoreM ()
errorMsgS = errorMsg . text
errorMsg :: SDoc -> CoreM ()
errorMsg = msg SevError NoReason
warnMsg :: WarnReason -> SDoc -> CoreM ()
warnMsg = msg SevWarning
fatalErrorMsgS :: String -> CoreM ()
fatalErrorMsgS = fatalErrorMsg . text
fatalErrorMsg :: SDoc -> CoreM ()
fatalErrorMsg = msg SevFatal NoReason
debugTraceMsgS :: String -> CoreM ()
debugTraceMsgS = debugTraceMsg . text
debugTraceMsg :: SDoc -> CoreM ()
debugTraceMsg = msg SevDump NoReason
dumpIfSet_dyn :: DumpFlag -> String -> SDoc -> CoreM ()
dumpIfSet_dyn flag str doc
  = do { dflags <- getDynFlags
       ; unqual <- getPrintUnqualified
       ; when (dopt flag dflags) $ liftIO $
         Err.dumpSDoc dflags unqual flag str doc }