{-# LANGUAGE TemplateHaskell #-}
module Loopbreaker.InlineRecCalls (action) where
import Control.Arrow hiding ((<+>))
import Data.Generics
import Data.Kind
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as S
import Bag
import ErrUtils
import GhcPlugins hiding ((<>), debugTraceMsg)
import HsSyn
import MonadUtils
type MonadInline m = ((MonadUnique m, MonadIO m, HasDynFlags m) :: Constraint)
action :: MonadInline m
=> [CommandLineOption] -> HsGroup GhcRn -> m (HsGroup GhcRn)
action opts group@HsGroup{ hs_valds } = do
let shouldDisable = "disable" `elem` opts
dyn_flags <- getDynFlags
if not shouldDisable && optLevel dyn_flags > 0
then do
liftIO $ showPass dyn_flags "Break loops"
valds' <- inlineRecCalls hs_valds
pure group{ hs_valds = valds' }
else
pure group
action _ _ = error "Loopbreaker.InlineRecCalls.action: expected renamed group"
inlineRecCalls :: MonadInline m => HsValBinds GhcRn -> m (HsValBinds GhcRn)
inlineRecCalls (XValBindsLR (NValBinds binds sigs)) = do
let (types, inlined) = typesFromSigs &&& inlinedFromSigs $ unLoc <$> sigs
(binds', extra_sigs) <- second concat . unzip
<$> traverse (inlineRecCall types inlined) binds
pure $ XValBindsLR $ NValBinds binds' $ sigs ++ extra_sigs
inlineRecCalls val_binds = pure val_binds
typesFromSigs :: Ord (IdP p) => [Sig p] -> Map (IdP p) (LHsSigWcType p)
typesFromSigs = M.fromList . concatMap sigToTups where
sigToTups (TypeSig _ names type_) = (,type_) . unLoc <$> names
sigToTups _ = []
inlinedFromSigs :: Ord (IdP p) => [Sig p] -> Set (IdP p)
inlinedFromSigs = S.fromList . catMaybes . map nameIfInlineSig where
nameIfInlineSig (InlineSig _ (L _ name) pragma)
| isInlinePragma pragma = Just name
nameIfInlineSig _ = Nothing
inlineRecCall
:: MonadInline m
=> Map Name (LHsSigWcType GhcRn)
-> Set Name
-> (RecFlag, LHsBinds GhcRn)
-> m ((RecFlag, LHsBinds GhcRn), [LSig GhcRn])
inlineRecCall types inlined (Recursive, binds)
| (bagToList -> [L fun_loc fun_bind]) <- binds
, FunBind{ fun_id = L _ fun_name, fun_matches } <- fun_bind
, S.member fun_name inlined
= do
dyn_flags <- getDynFlags
liftIO $ debugTraceMsg dyn_flags 2 $ text "Loopbreaker:" <+> ppr fun_name
(loopb_name, loopb_decl) <- loopbreaker fun_name
fun_matches' <- everywhereM ( fmap (replaceVarNamesT fun_name loopb_name)
. inlineLocalRecCallsM
) fun_matches
let m_loopb_sig = loopbreakerSig loopb_name <$> M.lookup fun_name types
pure
( ( Recursive
, listToBag
[ L fun_loc fun_bind{ fun_matches = fun_matches' }
, loopb_decl
]
)
, inlineSig noInlinePragma loopb_name : maybeToList m_loopb_sig
)
inlineRecCall _ _ binds = pure (binds, [])
loopbreaker :: MonadUnique m => Name -> m (Name, LHsBind GhcRn)
loopbreaker fun_name =
(id &&& loopbreakerDecl fun_name) <$> loopbreakerName fun_name
loopbreakerName :: MonadUnique m => Name -> m Name
loopbreakerName (occName -> occNameFS -> orig_fs) =
flip mkSystemVarName (orig_fs <> "__Loopbreaker") <$> getUniqueM
loopbreakerDecl :: Name -> Name -> LHsBind GhcRn
loopbreakerDecl fun_name loopb_name =
noLoc $ mkTopFunBind Generated (noLoc loopb_name)
[ mkSimpleMatch (mkPrefixFunRhs $ noLoc loopb_name) [] $
nlHsVar fun_name
]
loopbreakerSig :: Name -> LHsSigWcType GhcRn -> LSig GhcRn
loopbreakerSig loopb_name fun_type =
noLoc $ TypeSig NoExt [noLoc loopb_name] fun_type
inlineSig :: (XInlineSig p ~ NoExt) => InlinePragma -> IdP p -> LSig p
inlineSig how name = noLoc $ InlineSig NoExt (noLoc name) how
noInlinePragma :: InlinePragma
noInlinePragma = defaultInlinePragma
{ inl_inline = NoInline
, inl_act = NeverActive
}
inlineLocalRecCallsM :: (MonadInline m, Typeable a) => a -> m a
inlineLocalRecCallsM = mkM $ \case
(HsValBinds NoExt binds :: HsLocalBinds GhcRn)
-> HsValBinds NoExt <$> inlineRecCalls binds
e -> pure e
replaceVarNamesT :: Typeable a => Name -> Name -> a -> a
replaceVarNamesT from to = mkT $ \case
HsVar NoExt (L loc name) :: HsExpr GhcRn
| name == from -> HsVar NoExt $ L loc to
e -> e