module GHC.Core.LateCC.Utils
  ( -- * Inserting cost centres
    doLateCostCenters -- Might be useful for API users

    -- ** Helpers for defining insertion methods
  , getCCFlavour
  , insertCC
  ) where

import GHC.Prelude

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Strict
import qualified Data.Set as S

import GHC.Core
import GHC.Core.LateCC.Types
import GHC.Core.Utils
import GHC.Data.FastString
import GHC.Types.CostCentre
import GHC.Types.CostCentre.State
import GHC.Types.SrcLoc
import GHC.Types.Tickish

-- | Insert cost centres into the 'CoreProgram' using the provided environment,
-- initial state, and insertion method.
doLateCostCenters
  :: LateCCEnv
  -- ^ Environment to run the insertion in
  -> LateCCState s
  -- ^ Initial state to run the insertion with
  -> (CoreBind -> LateCCM s CoreBind)
  -- ^ Insertion method
  -> CoreProgram
  -- ^ Bindings to consider
  -> (CoreProgram, LateCCState s)
doLateCostCenters :: forall s.
LateCCEnv
-> LateCCState s
-> (CoreBind -> LateCCM s CoreBind)
-> CoreProgram
-> (CoreProgram, LateCCState s)
doLateCostCenters LateCCEnv
env LateCCState s
state CoreBind -> LateCCM s CoreBind
method CoreProgram
binds =
    LateCCEnv
-> LateCCState s
-> LateCCM s CoreProgram
-> (CoreProgram, LateCCState s)
forall s a.
LateCCEnv -> LateCCState s -> LateCCM s a -> (a, LateCCState s)
runLateCC LateCCEnv
env LateCCState s
state (LateCCM s CoreProgram -> (CoreProgram, LateCCState s))
-> LateCCM s CoreProgram -> (CoreProgram, LateCCState s)
forall a b. (a -> b) -> a -> b
$ (CoreBind -> LateCCM s CoreBind)
-> CoreProgram -> LateCCM s CoreProgram
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM CoreBind -> LateCCM s CoreBind
method CoreProgram
binds

-- | Evaluate late cost centre insertion
runLateCC :: LateCCEnv -> LateCCState s -> LateCCM s a -> (a, LateCCState s)
runLateCC :: forall s a.
LateCCEnv -> LateCCState s -> LateCCM s a -> (a, LateCCState s)
runLateCC LateCCEnv
env LateCCState s
state = (State (LateCCState s) a -> LateCCState s -> (a, LateCCState s)
forall s a. State s a -> s -> (a, s)
`runState` LateCCState s
state) (State (LateCCState s) a -> (a, LateCCState s))
-> (LateCCM s a -> State (LateCCState s) a)
-> LateCCM s a
-> (a, LateCCState s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LateCCM s a -> LateCCEnv -> State (LateCCState s) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` LateCCEnv
env)

-- | Given the name of a cost centre, get its flavour
getCCFlavour :: FastString -> LateCCM s CCFlavour
getCCFlavour :: forall s. FastString -> LateCCM s CCFlavour
getCCFlavour FastString
name = CostCentreIndex -> CCFlavour
mkLateCCFlavour (CostCentreIndex -> CCFlavour)
-> ReaderT LateCCEnv (State (LateCCState s)) CostCentreIndex
-> ReaderT LateCCEnv (State (LateCCState s)) CCFlavour
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FastString
-> ReaderT LateCCEnv (State (LateCCState s)) CostCentreIndex
forall s. FastString -> LateCCM s CostCentreIndex
getCCIndex' FastString
name
  where
    getCCIndex' :: FastString -> LateCCM s CostCentreIndex
    getCCIndex' :: forall s. FastString -> LateCCM s CostCentreIndex
getCCIndex' FastString
name = do
      cc_state <- State (LateCCState s) CostCentreState
-> ReaderT LateCCEnv (State (LateCCState s)) CostCentreState
forall (m :: * -> *) a. Monad m => m a -> ReaderT LateCCEnv m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State (LateCCState s) CostCentreState
 -> ReaderT LateCCEnv (State (LateCCState s)) CostCentreState)
-> State (LateCCState s) CostCentreState
-> ReaderT LateCCEnv (State (LateCCState s)) CostCentreState
forall a b. (a -> b) -> a -> b
$ (LateCCState s -> CostCentreState)
-> State (LateCCState s) CostCentreState
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets LateCCState s -> CostCentreState
forall s. LateCCState s -> CostCentreState
lateCCState_ccState
      let (index, cc_state') = getCCIndex name cc_state
      lift . modify $ \LateCCState s
s -> LateCCState s
s { lateCCState_ccState = cc_state'}
      return index

-- | Insert a cost centre with the specified name and source span on the given
-- expression. The inserted cost centre will be appropriately tracked in the
-- late cost centre state.
insertCC
  :: FastString
  -- ^ Name of the cost centre to insert
  -> SrcSpan
  -- ^ Source location to associate with the cost centre
  -> CoreExpr
  -- ^ Expression to wrap in the cost centre
  -> LateCCM s CoreExpr
insertCC :: forall s. FastString -> SrcSpan -> CoreExpr -> LateCCM s CoreExpr
insertCC FastString
cc_name SrcSpan
cc_loc CoreExpr
expr = do
    cc_flavour <- FastString -> LateCCM s CCFlavour
forall s. FastString -> LateCCM s CCFlavour
getCCFlavour FastString
cc_name
    env <- ask
    let
      cc_mod = LateCCEnv -> Module
lateCCEnv_module LateCCEnv
env
      cc = CCFlavour -> FastString -> Module -> SrcSpan -> CostCentre
NormalCC CCFlavour
cc_flavour FastString
cc_name Module
cc_mod SrcSpan
cc_loc
      note = CostCentre -> Bool -> Bool -> GenTickish 'TickishPassCore
forall (pass :: TickishPass).
CostCentre -> Bool -> Bool -> GenTickish pass
ProfNote CostCentre
cc (LateCCEnv -> Bool
lateCCEnv_countEntries LateCCEnv
env) Bool
True
    when (lateCCEnv_collectCCs env) $ do
        lift . modify $ \LateCCState s
s ->
          LateCCState s
s { lateCCState_ccs = S.insert cc (lateCCState_ccs s)
            }
    return $ mkTick note expr