{-# LANGUAGE CPP #-}
module GHC.StgToCmm.TagCheck
( emitTagAssertion, emitArgTagCheck, checkArg, whenCheckTags,
checkArgStatic, checkFunctionArgTags,checkConArgsStatic,checkConArgsDyn) where
#include "ClosureTypes.h"
import GHC.Prelude
import GHC.StgToCmm.Env
import GHC.StgToCmm.Monad
import GHC.StgToCmm.Utils
import GHC.Cmm
import GHC.Cmm.BlockId
import GHC.Cmm.Graph as CmmGraph
import GHC.Core.Type
import GHC.Types.Id
import GHC.Utils.Misc
import GHC.Utils.Outputable
import GHC.Core.DataCon
import Control.Monad
import GHC.StgToCmm.Types
import GHC.Utils.Panic (pprPanic)
import GHC.Utils.Panic.Plain (panic)
import GHC.Stg.Syntax
import GHC.StgToCmm.Closure
import GHC.Cmm.Switch (mkSwitchTargets)
import GHC.Cmm.Info (cmmGetClosureType)
import GHC.Types.RepType (dataConRuntimeRepStrictness)
import GHC.Types.Basic
import GHC.Data.FastString (mkFastString)
import qualified Data.Map as M
checkFunctionArgTags :: SDoc -> Id -> [Id] -> FCode ()
checkFunctionArgTags :: SDoc -> Id -> [Id] -> FCode ()
checkFunctionArgTags SDoc
msg Id
f [Id]
args = FCode () -> FCode ()
whenCheckTags forall a b. (a -> b) -> a -> b
$ do
forall b a. b -> Maybe a -> (a -> b) -> b
onJust (forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Id -> Maybe [CbvMark]
idCbvMarks_maybe Id
f) forall a b. (a -> b) -> a -> b
$ \[CbvMark]
marks -> do
let cbv_args :: [Id]
cbv_args = forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
isLiftedRuntimeRep forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
idType) forall a b. (a -> b) -> a -> b
$ forall a. [Bool] -> [a] -> [a]
filterByList (forall a b. (a -> b) -> [a] -> [b]
map CbvMark -> Bool
isMarkedCbv [CbvMark]
marks) [Id]
args
[CgIdInfo]
arg_infos <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Id -> FCode CgIdInfo
getCgIdInfo [Id]
cbv_args
let arg_cmms :: [CmmExpr]
arg_cmms = forall a b. (a -> b) -> [a] -> [b]
map CgIdInfo -> CmmExpr
idInfoToAmode [CgIdInfo]
arg_infos
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (String -> CmmExpr -> FCode ()
emitTagAssertion (forall a. Outputable a => a -> String
showPprUnsafe SDoc
msg)) ([CmmExpr]
arg_cmms)
checkConArgsStatic :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsStatic :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsStatic SDoc
msg DataCon
con [StgArg]
args = FCode () -> FCode ()
whenCheckTags forall a b. (a -> b) -> a -> b
$ do
let marks :: [StrictnessMark]
marks = HasDebugCallStack => DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic SDoc
msg) [StrictnessMark]
marks [StgArg]
args
checkConArgsDyn :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsDyn :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsDyn SDoc
msg DataCon
con [StgArg]
args = FCode () -> FCode ()
whenCheckTags forall a b. (a -> b) -> a -> b
$ do
let marks :: [StrictnessMark]
marks = HasDebugCallStack => DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SDoc -> CbvMark -> StgArg -> FCode ()
checkArg SDoc
msg) (forall a b. (a -> b) -> [a] -> [b]
map StrictnessMark -> CbvMark
cbvFromStrictMark [StrictnessMark]
marks) [StgArg]
args
whenCheckTags :: FCode () -> FCode ()
whenCheckTags :: FCode () -> FCode ()
whenCheckTags FCode ()
act = do
Bool
check_tags <- StgToCmmConfig -> Bool
stgToCmmDoTagCheck forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FCode StgToCmmConfig
getStgToCmmConfig
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
check_tags FCode ()
act
emitTagAssertion :: String -> CmmExpr -> FCode ()
emitTagAssertion :: String -> CmmExpr -> FCode ()
emitTagAssertion String
onWhat CmmExpr
fun = do
{ Platform
platform <- FCode Platform
getPlatform
; BlockId
lret <- forall (m :: * -> *). MonadUnique m => m BlockId
newBlockId
; BlockId
lno_tag <- forall (m :: * -> *). MonadUnique m => m BlockId
newBlockId
; BlockId
lbarf <- forall (m :: * -> *). MonadUnique m => m BlockId
newBlockId
; CmmAGraph -> FCode ()
emit forall a b. (a -> b) -> a -> b
$ CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmAGraph
mkCbranch (Platform -> CmmExpr -> CmmExpr
cmmIsTagged Platform
platform CmmExpr
fun)
BlockId
lret BlockId
lno_tag (forall a. a -> Maybe a
Just Bool
True)
; BlockId -> FCode ()
emitLabel BlockId
lno_tag
; FastString -> FCode ()
emitComment (String -> FastString
mkFastString String
"closereTypeCheck")
; CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag CmmExpr
fun BlockId
lbarf BlockId
lret
; BlockId -> FCode ()
emitLabel BlockId
lbarf
; String -> FCode ()
emitBarf (String
"Tag inference failed on:" forall a. [a] -> [a] -> [a]
++ String
onWhat)
; BlockId -> FCode ()
emitLabel BlockId
lret
}
needsArgTag :: CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag :: CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag CmmExpr
closure BlockId
fail BlockId
lpass = do
Profile
profile <- FCode Profile
getProfile
Bool
align_check <- StgToCmmConfig -> Bool
stgToCmmAlignCheck forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FCode StgToCmmConfig
getStgToCmmConfig
let clo_ty_e :: CmmExpr
clo_ty_e = Profile -> Bool -> CmmExpr -> CmmExpr
cmmGetClosureType Profile
profile Bool
align_check CmmExpr
closure
let targets :: SwitchTargets
targets = Bool
-> (Integer, Integer)
-> Maybe BlockId
-> Map Integer BlockId
-> SwitchTargets
mkSwitchTargets
Bool
False
(INVALID_OBJECT, N_CLOSURE_TYPES)
(forall a. a -> Maybe a
Just BlockId
fail)
(forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(PAP,lpass)
,(BCO,lpass)
,(FUN,lpass)
,(FUN_1_0,lpass)
,(FUN_0_1,lpass)
,(FUN_2_0,lpass)
,(FUN_1_1,lpass)
,(FUN_0_2,lpass)
,(FUN_STATIC,lpass)
])
CmmAGraph -> FCode ()
emit forall a b. (a -> b) -> a -> b
$ CmmExpr -> SwitchTargets -> CmmAGraph
mkSwitch CmmExpr
clo_ty_e SwitchTargets
targets
CmmAGraph -> FCode ()
emit forall a b. (a -> b) -> a -> b
$ BlockId -> CmmAGraph
mkBranch BlockId
lpass
emitArgTagCheck :: SDoc -> [CbvMark] -> [Id] -> FCode ()
emitArgTagCheck :: SDoc -> [CbvMark] -> [Id] -> FCode ()
emitArgTagCheck SDoc
info [CbvMark]
marks [Id]
args = FCode () -> FCode ()
whenCheckTags forall a b. (a -> b) -> a -> b
$ do
Module
mod <- FCode Module
getModuleName
let cbv_args :: [Id]
cbv_args = forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
isLiftedRuntimeRep forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
idType) forall a b. (a -> b) -> a -> b
$ forall a. [Bool] -> [a] -> [a]
filterByList (forall a b. (a -> b) -> [a] -> [b]
map CbvMark -> Bool
isMarkedCbv [CbvMark]
marks) [Id]
args
[CgIdInfo]
arg_infos <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Id -> FCode CgIdInfo
getCgIdInfo [Id]
cbv_args
let arg_cmms :: [CmmExpr]
arg_cmms = forall a b. (a -> b) -> [a] -> [b]
map CgIdInfo -> CmmExpr
idInfoToAmode [CgIdInfo]
arg_infos
mk_msg :: Id -> String
mk_msg Id
arg = forall a. Outputable a => a -> String
showPprUnsafe (String -> SDoc
text String
"Untagged arg:" SDoc -> SDoc -> SDoc
<> (forall a. Outputable a => a -> SDoc
ppr Module
mod) SDoc -> SDoc -> SDoc
<> Char -> SDoc
char Char
':' SDoc -> SDoc -> SDoc
<> SDoc
info SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Id
arg)
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ String -> CmmExpr -> FCode ()
emitTagAssertion (forall a b. (a -> b) -> [a] -> [b]
map Id -> String
mk_msg [Id]
args) ([CmmExpr]
arg_cmms)
taggedCgInfo :: CgIdInfo -> Bool
taggedCgInfo :: CgIdInfo -> Bool
taggedCgInfo CgIdInfo
cg_info
= case LambdaFormInfo
lf of
LFCon {} -> Bool
True
LFReEntrant {} -> Bool
True
LFUnlifted {} -> Bool
True
LFThunk {} -> Bool
False
LFUnknown {} -> Bool
False
LambdaFormInfo
LFLetNoEscape -> forall a. String -> a
panic String
"Let no escape binding passed to top level con"
where
lf :: LambdaFormInfo
lf = CgIdInfo -> LambdaFormInfo
cg_lf CgIdInfo
cg_info
checkArg :: SDoc -> CbvMark -> StgArg -> FCode ()
checkArg :: SDoc -> CbvMark -> StgArg -> FCode ()
checkArg SDoc
_ CbvMark
NotMarkedCbv StgArg
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkArg SDoc
msg CbvMark
MarkedCbv StgArg
arg = FCode () -> FCode ()
whenCheckTags forall a b. (a -> b) -> a -> b
$
case StgArg
arg of
StgLitArg Literal
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
StgVarArg Id
v -> do
CgIdInfo
info <- Id -> FCode CgIdInfo
getCgIdInfo Id
v
if CgIdInfo -> Bool
taggedCgInfo CgIdInfo
info
then forall (m :: * -> *) a. Monad m => a -> m a
return ()
else case (CgIdInfo -> CgLoc
cg_loc CgIdInfo
info) of
CmmLoc CmmExpr
loc -> String -> CmmExpr -> FCode ()
emitTagAssertion (forall a. Outputable a => a -> String
showPprUnsafe forall a b. (a -> b) -> a -> b
$ SDoc
msg SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"arg:" SDoc -> SDoc -> SDoc
<> forall a. Outputable a => a -> SDoc
ppr StgArg
arg) CmmExpr
loc
LneLoc {} -> forall a. String -> a
panic String
"LNE-arg"
checkArgStatic :: SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic :: SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic SDoc
_ StrictnessMark
NotMarkedStrict StgArg
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkArgStatic SDoc
msg StrictnessMark
MarkedStrict StgArg
arg = FCode () -> FCode ()
whenCheckTags forall a b. (a -> b) -> a -> b
$
case StgArg
arg of
StgLitArg Literal
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
StgVarArg Id
v -> do
CgIdInfo
info <- Id -> FCode CgIdInfo
getCgIdInfo Id
v
if CgIdInfo -> Bool
taggedCgInfo CgIdInfo
info
then forall (m :: * -> *) a. Monad m => a -> m a
return ()
else forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Arg not tagged as expectd" (forall a. Outputable a => a -> SDoc
ppr SDoc
msg SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr StgArg
arg)