{-# LANGUAGE CPP, ScopedTypeVariables, TypeFamilies #-}
{-# LANGUAGE DataKinds #-}

module GHC.Stg.Utils
    ( mkStgAltTypeFromStgAlts
    , bindersOf, bindersOfX, bindersOfTop, bindersOfTopBinds

    , stripStgTicksTop, stripStgTicksTopE
    , idArgs

    , mkUnarisedId, mkUnarisedIds
    ) where

import GHC.Prelude

import GHC.Types.Id
import GHC.Core.Type
import GHC.Core.TyCon
import GHC.Core.DataCon
import GHC.Core ( AltCon(..) )
import GHC.Types.Tickish
import GHC.Types.Unique.Supply

import GHC.Types.RepType
import GHC.Stg.Syntax

import GHC.Utils.Outputable

import GHC.Utils.Panic

import GHC.Data.FastString

mkUnarisedIds :: MonadUnique m => FastString -> [UnaryType] -> m [Id]
mkUnarisedIds :: forall (m :: * -> *).
MonadUnique m =>
FastString -> [UnaryType] -> m [Id]
mkUnarisedIds FastString
fs [UnaryType]
tys = (UnaryType -> m Id) -> [UnaryType] -> m [Id]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (FastString -> UnaryType -> m Id
forall (m :: * -> *).
MonadUnique m =>
FastString -> UnaryType -> m Id
mkUnarisedId FastString
fs) [UnaryType]
tys

mkUnarisedId :: MonadUnique m => FastString -> UnaryType -> m Id
mkUnarisedId :: forall (m :: * -> *).
MonadUnique m =>
FastString -> UnaryType -> m Id
mkUnarisedId FastString
s UnaryType
t = FastString -> UnaryType -> UnaryType -> m Id
forall (m :: * -> *).
MonadUnique m =>
FastString -> UnaryType -> UnaryType -> m Id
mkSysLocalM FastString
s UnaryType
ManyTy UnaryType
t

-- Checks if id is a top level error application.
-- isErrorAp_maybe :: Id ->

-- | Extract the default case alternative
-- findDefaultStg :: [Alt b] -> ([Alt b], Maybe (Expr b))
findDefaultStg
  :: [GenStgAlt p]
  -> ([GenStgAlt p], Maybe (GenStgExpr p))
findDefaultStg :: forall (p :: StgPass).
[GenStgAlt p] -> ([GenStgAlt p], Maybe (GenStgExpr p))
findDefaultStg (GenStgAlt{ alt_con :: forall (pass :: StgPass). GenStgAlt pass -> AltCon
alt_con    = AltCon
DEFAULT
                         , alt_bndrs :: forall (pass :: StgPass). GenStgAlt pass -> [BinderP pass]
alt_bndrs  = [BinderP p]
args
                         , alt_rhs :: forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs    = GenStgExpr p
rhs} : [GenStgAlt p]
alts) = Bool
-> ([GenStgAlt p], Maybe (GenStgExpr p))
-> ([GenStgAlt p], Maybe (GenStgExpr p))
forall a. HasCallStack => Bool -> a -> a
assert( [BinderP p] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [BinderP p]
args ) ([GenStgAlt p]
alts, GenStgExpr p -> Maybe (GenStgExpr p)
forall a. a -> Maybe a
Just GenStgExpr p
rhs)
findDefaultStg [GenStgAlt p]
alts                                  = ([GenStgAlt p]
alts, Maybe (GenStgExpr p)
forall a. Maybe a
Nothing)

mkStgAltTypeFromStgAlts :: forall p. Id -> [GenStgAlt p] -> AltType
mkStgAltTypeFromStgAlts :: forall (p :: StgPass). Id -> [GenStgAlt p] -> AltType
mkStgAltTypeFromStgAlts Id
bndr [GenStgAlt p]
alts
  | UnaryType -> Bool
isUnboxedTupleType UnaryType
bndr_ty Bool -> Bool -> Bool
|| UnaryType -> Bool
isUnboxedSumType UnaryType
bndr_ty
  = Int -> AltType
MultiValAlt ([PrimRep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimRep]
prim_reps)  -- always use MultiValAlt for unboxed tuples

  | Bool
otherwise
  = case [PrimRep]
prim_reps of
      [PrimRep
rep] | PrimRep -> Bool
isGcPtrRep PrimRep
rep ->
        case UnaryType -> Maybe TyCon
tyConAppTyCon_maybe (UnaryType -> UnaryType
unwrapType UnaryType
bndr_ty) of
          Just TyCon
tc
            | TyCon -> Bool
isAbstractTyCon TyCon
tc -> AltType
look_for_better_tycon
            | TyCon -> Bool
isAlgTyCon TyCon
tc      -> TyCon -> AltType
AlgAlt TyCon
tc
            | Bool
otherwise          -> Bool -> SDoc -> AltType -> AltType
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr ( TyCon -> Bool
_is_poly_alt_tycon TyCon
tc) (TyCon -> SDoc
forall a. Outputable a => a -> SDoc
ppr TyCon
tc)
                                    AltType
PolyAlt
          Maybe TyCon
Nothing                -> AltType
PolyAlt
      [PrimRep
non_gcd] -> PrimRep -> AltType
PrimAlt PrimRep
non_gcd
      [PrimRep]
not_unary -> Int -> AltType
MultiValAlt ([PrimRep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimRep]
not_unary)
  where
   bndr_ty :: UnaryType
bndr_ty   = Id -> UnaryType
idType Id
bndr
   prim_reps :: [PrimRep]
prim_reps = HasDebugCallStack => UnaryType -> [PrimRep]
UnaryType -> [PrimRep]
typePrimRep UnaryType
bndr_ty

   _is_poly_alt_tycon :: TyCon -> Bool
_is_poly_alt_tycon TyCon
tc
        =  TyCon -> Bool
isPrimTyCon TyCon
tc   -- "Any" is lifted but primitive
        Bool -> Bool -> Bool
|| TyCon -> Bool
isFamilyTyCon TyCon
tc -- Type family; e.g. Any, or arising from strict
                            -- function application where argument has a
                            -- type-family type

   -- Sometimes, the TyCon is a AbstractTyCon which may not have any
   -- constructors inside it.  Then we may get a better TyCon by
   -- grabbing the one from a constructor alternative
   -- if one exists.
   look_for_better_tycon :: AltType
look_for_better_tycon
        | (DataAlt DataCon
con : [AltCon]
_) <- GenStgAlt p -> AltCon
forall (pass :: StgPass). GenStgAlt pass -> AltCon
alt_con (GenStgAlt p -> AltCon) -> [GenStgAlt p] -> [AltCon]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [GenStgAlt p]
data_alts =
                TyCon -> AltType
AlgAlt (DataCon -> TyCon
dataConTyCon DataCon
con)
        | Bool
otherwise =
                Bool -> AltType -> AltType
forall a. HasCallStack => Bool -> a -> a
assert([GenStgAlt p] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [GenStgAlt p]
data_alts)
                AltType
PolyAlt
        where
                ([GenStgAlt p]
data_alts, Maybe (GenStgExpr p)
_deflt) = [GenStgAlt p] -> ([GenStgAlt p], Maybe (GenStgExpr p))
forall (p :: StgPass).
[GenStgAlt p] -> ([GenStgAlt p], Maybe (GenStgExpr p))
findDefaultStg [GenStgAlt p]
alts

bindersOf :: BinderP a ~ Id => GenStgBinding a -> [Id]
bindersOf :: forall (a :: StgPass). (BinderP a ~ Id) => GenStgBinding a -> [Id]
bindersOf (StgNonRec BinderP a
binder GenStgRhs a
_) = [Id
BinderP a
binder]
bindersOf (StgRec [(BinderP a, GenStgRhs a)]
pairs)       = [Id
binder | (Id
binder, GenStgRhs a
_) <- [(Id, GenStgRhs a)]
[(BinderP a, GenStgRhs a)]
pairs]

bindersOfX :: GenStgBinding a -> [BinderP a]
bindersOfX :: forall (a :: StgPass). GenStgBinding a -> [BinderP a]
bindersOfX (StgNonRec BinderP a
binder GenStgRhs a
_) = [BinderP a
binder]
bindersOfX (StgRec [(BinderP a, GenStgRhs a)]
pairs)       = [BinderP a
binder | (BinderP a
binder, GenStgRhs a
_) <- [(BinderP a, GenStgRhs a)]
pairs]

bindersOfTop :: BinderP a ~ Id => GenStgTopBinding a -> [Id]
bindersOfTop :: forall (a :: StgPass).
(BinderP a ~ Id) =>
GenStgTopBinding a -> [Id]
bindersOfTop (StgTopLifted GenStgBinding a
bind) = GenStgBinding a -> [Id]
forall (a :: StgPass). (BinderP a ~ Id) => GenStgBinding a -> [Id]
bindersOf GenStgBinding a
bind
bindersOfTop (StgTopStringLit Id
binder ByteString
_) = [Id
binder]

-- All ids we bind something to on the top level.
bindersOfTopBinds :: BinderP a ~ Id => [GenStgTopBinding a] -> [Id]
-- bindersOfTopBinds binds = mapUnionVarSet (mkVarSet . bindersOfTop) binds
bindersOfTopBinds :: forall (a :: StgPass).
(BinderP a ~ Id) =>
[GenStgTopBinding a] -> [Id]
bindersOfTopBinds [GenStgTopBinding a]
binds = (GenStgTopBinding a -> [Id] -> [Id])
-> [Id] -> [GenStgTopBinding a] -> [Id]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ([Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
(++) ([Id] -> [Id] -> [Id])
-> (GenStgTopBinding a -> [Id])
-> GenStgTopBinding a
-> [Id]
-> [Id]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgTopBinding a -> [Id]
forall (a :: StgPass).
(BinderP a ~ Id) =>
GenStgTopBinding a -> [Id]
bindersOfTop) [] [GenStgTopBinding a]
binds

idArgs :: [StgArg] -> [Id]
idArgs :: [StgArg] -> [Id]
idArgs [StgArg]
args = [Id
v | StgVarArg Id
v <- [StgArg]
args]

-- | Strip ticks of a given type from an STG expression.
stripStgTicksTop :: (StgTickish -> Bool) -> GenStgExpr p -> ([StgTickish], GenStgExpr p)
stripStgTicksTop :: forall (p :: StgPass).
(StgTickish -> Bool)
-> GenStgExpr p -> ([StgTickish], GenStgExpr p)
stripStgTicksTop StgTickish -> Bool
p = [StgTickish] -> GenStgExpr p -> ([StgTickish], GenStgExpr p)
go []
   where go :: [StgTickish] -> GenStgExpr p -> ([StgTickish], GenStgExpr p)
go [StgTickish]
ts (StgTick StgTickish
t GenStgExpr p
e) | StgTickish -> Bool
p StgTickish
t = [StgTickish] -> GenStgExpr p -> ([StgTickish], GenStgExpr p)
go (StgTickish
tStgTickish -> [StgTickish] -> [StgTickish]
forall a. a -> [a] -> [a]
:[StgTickish]
ts) GenStgExpr p
e
         -- This special case avoid building a thunk for "reverse ts" when there are no ticks
         go [] GenStgExpr p
other               = ([], GenStgExpr p
other)
         go [StgTickish]
ts GenStgExpr p
other               = ([StgTickish] -> [StgTickish]
forall a. [a] -> [a]
reverse [StgTickish]
ts, GenStgExpr p
other)

-- | Strip ticks of a given type from an STG expression returning only the expression.
stripStgTicksTopE :: (StgTickish -> Bool) -> GenStgExpr p -> GenStgExpr p
stripStgTicksTopE :: forall (p :: StgPass).
(StgTickish -> Bool) -> GenStgExpr p -> GenStgExpr p
stripStgTicksTopE StgTickish -> Bool
p = GenStgExpr p -> GenStgExpr p
go
   where go :: GenStgExpr p -> GenStgExpr p
go (StgTick StgTickish
t GenStgExpr p
e) | StgTickish -> Bool
p StgTickish
t = GenStgExpr p -> GenStgExpr p
go GenStgExpr p
e
         go GenStgExpr p
other               = GenStgExpr p
other