{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.GHC.LoadModules
( loadModules
, ghcLibDir
)
where
#ifndef TOOL_VERSION_ghc
#error TOOL_VERSION_ghc undefined
#endif
import Control.Arrow (second)
import Data.Generics.Uniplate.DataOnly (transform)
import Data.List (foldl', lookup, nub)
import Data.Maybe (listToMaybe)
import Data.Word (Word8)
import Clash.Annotations.Primitive (HDL)
import Clash.Annotations.TopEntity (TopEntity (..))
import System.Exit (ExitCode (..))
import System.IO (hGetLine)
import System.IO.Error (tryIOError)
import System.Process (runInteractiveCommand,
waitForProcess)
import qualified Annotations
import qualified CoreSyn
import qualified CoreFVs
import qualified Digraph
import DynFlags (GeneralFlag (..))
import qualified DynFlags
import qualified GHC
import qualified HscMain
import qualified HscTypes
import qualified MonadUtils
import qualified Panic
#if MIN_VERSION_ghc(8,4,1)
import qualified GhcPlugins
#else
import qualified Serialized
#endif
import qualified TidyPgm
import qualified TcRnMonad
import qualified TcRnTypes
import qualified Unique
#if MIN_VERSION_ghc(8,2,0)
import qualified UniqDFM
#else
import qualified UniqFM
#endif
import qualified UniqSet
import qualified Var
import qualified FamInst
import qualified FamInstEnv
import qualified Name
import Outputable (ppr)
import qualified Outputable
import qualified OccName
import qualified GHC.LanguageExtensions as LangExt
import Clash.GHC.GHC2Core (modNameM)
import Clash.GHC.LoadInterfaceFiles
import Clash.Util (curLoc)
ghcLibDir :: IO FilePath
ghcLibDir = do
(libDirM,exitCode) <- getProcessOutput $ "ghc-" ++ TOOL_VERSION_ghc ++ " --print-libdir"
case exitCode of
ExitSuccess -> case libDirM of
Just libDir -> return libDir
Nothing -> Panic.pgmError noGHC
ExitFailure i -> case i of
127 -> Panic.pgmError noGHC
i' -> Panic.pgmError $ "Calling GHC failed with error code: " ++ show i'
where
noGHC = "Clash needs the GHC compiler it was built with, ghc-" ++ TOOL_VERSION_ghc ++
", but it was not found. Make sure its location is in your PATH variable."
getProcessOutput :: String -> IO (Maybe String, ExitCode)
getProcessOutput command =
do (_, pOut, _, handle) <- runInteractiveCommand command
exitCode <- waitForProcess handle
output <- either (const Nothing) Just <$> tryIOError (hGetLine pOut)
return (output, exitCode)
loadModules
:: HDL
-> String
-> Maybe (DynFlags.DynFlags)
-> IO ( [CoreSyn.CoreBind]
, [(CoreSyn.CoreBndr,Int)]
, [CoreSyn.CoreBndr]
, FamInstEnv.FamInstEnvs
, [( CoreSyn.CoreBndr
, Maybe TopEntity
, Maybe CoreSyn.CoreBndr)]
, [FilePath]
)
loadModules hdl modName dflagsM = do
libDir <- MonadUtils.liftIO ghcLibDir
GHC.runGhc (Just libDir) $ do
dflags <- case dflagsM of
Just df -> return df
Nothing -> do
df <- GHC.getSessionDynFlags
let dfEn = foldl DynFlags.xopt_set df
[ LangExt.TemplateHaskell
, LangExt.TemplateHaskellQuotes
, LangExt.DataKinds
, LangExt.MonoLocalBinds
, LangExt.TypeOperators
, LangExt.FlexibleContexts
, LangExt.ConstraintKinds
, LangExt.TypeFamilies
, LangExt.BinaryLiterals
, LangExt.ExplicitNamespaces
, LangExt.KindSignatures
, LangExt.DeriveLift
, LangExt.TypeApplications
, LangExt.ScopedTypeVariables
, LangExt.MagicHash
, LangExt.ExplicitForAll
]
let dfDis = foldl DynFlags.xopt_unset dfEn
[ LangExt.ImplicitPrelude
, LangExt.MonomorphismRestriction
, LangExt.Strict
, LangExt.StrictData
]
let ghcTyLitNormPlugin = GHC.mkModuleName "GHC.TypeLits.Normalise"
ghcTyLitExtrPlugin = GHC.mkModuleName "GHC.TypeLits.Extra.Solver"
ghcTyLitKNPlugin = GHC.mkModuleName "GHC.TypeLits.KnownNat.Solver"
let dfPlug = dfDis { DynFlags.pluginModNames = nub $
ghcTyLitNormPlugin : ghcTyLitExtrPlugin :
ghcTyLitKNPlugin : DynFlags.pluginModNames dfDis
}
return dfPlug
let dflags1 = dflags
#if __GLASGOW_HASKELL__ >= 711
{ DynFlags.reductionDepth = 1000
#else
{ DynFlags.ctxtStkDepth = 1000
#endif
, DynFlags.optLevel = 2
, DynFlags.ghcMode = GHC.CompManager
, DynFlags.ghcLink = GHC.LinkInMemory
, DynFlags.hscTarget = DynFlags.defaultObjectTarget
(DynFlags.targetPlatform dflags)
}
let dflags2 = wantedOptimizationFlags dflags1
let ghcDynamic = case lookup "GHC Dynamic" (DynFlags.compilerInfo dflags) of
Just "YES" -> True
_ -> False
let dflags3 = if ghcDynamic then DynFlags.gopt_set dflags2 DynFlags.Opt_BuildDynamicToo
else dflags2
_ <- GHC.setSessionDynFlags dflags3
target <- GHC.guessTarget modName Nothing
GHC.setTargets [target]
modGraph <- GHC.depanal [] False
#if MIN_VERSION_ghc(8,4,1)
let modGraph' = GHC.mapMG disableOptimizationsFlags modGraph
#else
let modGraph' = map disableOptimizationsFlags modGraph
#endif
modGraph2 = Digraph.flattenSCCs (GHC.topSortModuleGraph True modGraph' Nothing)
tidiedMods <- mapM (\m -> do { pMod <- parseModule m
; tcMod <- GHC.typecheckModule (removeStrictnessAnnotations pMod)
; tcMod' <- GHC.loadModule tcMod
; dsMod <- fmap GHC.coreModule $ GHC.desugarModule tcMod'
; hsc_env <- GHC.getSession
#if MIN_VERSION_ghc(8,4,1)
; simpl_guts <- MonadUtils.liftIO $ HscMain.hscSimplify hsc_env [] dsMod
#else
; simpl_guts <- MonadUtils.liftIO $ HscMain.hscSimplify hsc_env dsMod
#endif
; (tidy_guts,_) <- MonadUtils.liftIO $ TidyPgm.tidyProgram hsc_env simpl_guts
; let pgm = HscTypes.cg_binds tidy_guts
; let modFamInstEnv = TcRnTypes.tcg_fam_inst_env $ fst $ GHC.tm_internals_ tcMod
; return (pgm,modFamInstEnv)
}
) modGraph2
let (binders,modFamInstEnvs) = unzip tidiedMods
bindersC = concat binders
binderIds = map fst (CoreSyn.flattenBinds bindersC)
#if MIN_VERSION_ghc(8,2,0)
modFamInstEnvs' = foldr UniqDFM.plusUDFM UniqDFM.emptyUDFM modFamInstEnvs
#else
modFamInstEnvs' = foldr UniqFM.plusUFM UniqFM.emptyUFM modFamInstEnvs
#endif
(externalBndrs,clsOps,unlocatable,pFP) <-
loadExternalExprs hdl (UniqSet.mkUniqSet binderIds) bindersC
hscEnv <- GHC.getSession
famInstEnvs <- TcRnMonad.liftIO $ TcRnMonad.initTcForLookup hscEnv FamInst.tcGetFamInstEnvs
let rootModule = GHC.ms_mod_name . last $ modGraph2
rootIds = map fst . CoreSyn.flattenBinds $ last binders
allSyn <- findSynthesizeAnnotations binderIds
benchAnn <- findTestBenchAnnotations binderIds
topSyn <- findSynthesizeAnnotations rootIds
let varNameString = OccName.occNameString . Name.nameOccName . Var.varName
topEntities = filter ((== "topEntity") . varNameString) rootIds
benches = filter ((== "testBench") . varNameString) rootIds
mergeBench (x,y) = (x,y,lookup x benchAnn)
allSyn' = map mergeBench allSyn
topEntities' <- case topEntities of
[] -> case topSyn of
[] -> Panic.pgmError $ "No 'topEntity', nor function with a 'Synthesize' annotation found in root module: " ++
(Outputable.showSDocUnsafe (ppr rootModule))
_ -> return allSyn'
[x] -> case lookup x topSyn of
Nothing -> case lookup x benchAnn of
Nothing -> return ((x,Nothing,listToMaybe benches):allSyn')
Just y -> return ((x,Nothing,Just y):allSyn')
Just _ -> return allSyn'
_ -> Panic.pgmError $ $(curLoc) ++ "Multiple 'topEntities' found."
return (bindersC ++ makeRecursiveGroups externalBndrs,clsOps,unlocatable,(fst famInstEnvs,modFamInstEnvs'),topEntities',nub pFP)
makeRecursiveGroups
:: [(CoreSyn.CoreBndr,CoreSyn.CoreExpr)]
-> [CoreSyn.CoreBind]
makeRecursiveGroups
= map makeBind
. Digraph.stronglyConnCompFromEdgedVerticesUniq
. map makeNode
where
makeNode
:: (CoreSyn.CoreBndr,CoreSyn.CoreExpr)
-> Digraph.Node Unique.Unique (CoreSyn.CoreBndr,CoreSyn.CoreExpr)
makeNode (b,e) =
#if MIN_VERSION_ghc(8,4,1)
Digraph.DigraphNode
(b,e)
(Var.varUnique b)
(UniqSet.nonDetKeysUniqSet (CoreFVs.exprFreeIds e))
#else
((b,e)
,Var.varUnique b
,UniqSet.nonDetKeysUniqSet (CoreFVs.exprFreeIds e))
#endif
makeBind
:: Digraph.SCC (CoreSyn.CoreBndr,CoreSyn.CoreExpr)
-> CoreSyn.CoreBind
makeBind (Digraph.AcyclicSCC (b,e)) = CoreSyn.NonRec b e
makeBind (Digraph.CyclicSCC bs) = CoreSyn.Rec bs
findSynthesizeAnnotations
:: GHC.GhcMonad m
=> [CoreSyn.CoreBndr]
-> m [(CoreSyn.CoreBndr,Maybe TopEntity)]
findSynthesizeAnnotations bndrs = do
#if MIN_VERSION_ghc(8,4,1)
let deserializer = GhcPlugins.deserializeWithData :: ([Word8] -> TopEntity)
#else
let deserializer = Serialized.deserializeWithData :: ([Word8] -> TopEntity)
#endif
targets = map (Annotations.NamedTarget . Var.varName) bndrs
anns <- mapM (GHC.findGlobalAnns deserializer) targets
let isSyn (Synthesize {}) = True
isSyn _ = False
anns' = map (filter isSyn) anns
annBndrs = filter (not . null . snd) (zip bndrs anns')
case filter ((> 1) . length . snd) annBndrs of
[] -> return (map (second listToMaybe) annBndrs)
as -> Panic.pgmError $
"The following functions have multiple 'Synthesize' annotations: " ++
Outputable.showSDocUnsafe (ppr (map fst as))
findTestBenchAnnotations
:: GHC.GhcMonad m
=> [CoreSyn.CoreBndr]
-> m [(CoreSyn.CoreBndr,CoreSyn.CoreBndr)]
findTestBenchAnnotations bndrs = do
#if MIN_VERSION_ghc(8,4,1)
let deserializer = GhcPlugins.deserializeWithData :: ([Word8] -> TopEntity)
#else
let deserializer = Serialized.deserializeWithData :: ([Word8] -> TopEntity)
#endif
targets = map (Annotations.NamedTarget . Var.varName) bndrs
anns <- mapM (GHC.findGlobalAnns deserializer) targets
let isTB (TestBench {}) = True
isTB _ = False
anns' = map (filter isTB) anns
annBndrs = filter (not . null . snd) (zip bndrs anns')
annBndrs' = case filter ((> 1) . length . snd) annBndrs of
[] -> map (second head) annBndrs
as -> Panic.pgmError $
"The following functions have multiple 'TestBench' annotations: " ++
Outputable.showSDocUnsafe (ppr (map fst as))
return (map (second findTB) annBndrs')
where
findTB :: TopEntity -> CoreSyn.CoreBndr
findTB (TestBench tb) = case listToMaybe (filter (eqNm tb) bndrs) of
Just tb' -> tb'
Nothing -> Panic.pgmError $
"TestBench named: " ++ show tb ++ " not found"
findTB _ = Panic.pgmError "Unexpected Synthesize"
eqNm thNm bndr = show thNm == qualNm
where
bndrNm = Var.varName bndr
qualNm = maybe occName (\modName -> modName ++ ('.':occName)) (modNameM bndrNm)
occName = OccName.occNameString (Name.nameOccName bndrNm)
parseModule :: GHC.GhcMonad m => GHC.ModSummary -> m GHC.ParsedModule
parseModule modSum = do
(GHC.ParsedModule pmModSum pmParsedSource extraSrc anns) <-
GHC.parseModule modSum
return (GHC.ParsedModule
(disableOptimizationsFlags pmModSum)
pmParsedSource extraSrc anns)
disableOptimizationsFlags :: GHC.ModSummary -> GHC.ModSummary
disableOptimizationsFlags ms@(GHC.ModSummary {..})
= ms {GHC.ms_hspp_opts = dflags}
where
dflags = wantedOptimizationFlags (ms_hspp_opts
{ DynFlags.optLevel = 2
#if __GLASGOW_HASKELL__ >= 711
, DynFlags.reductionDepth = 1000
#else
, DynFlags.ctxtStkDepth = 1000
#endif
})
wantedOptimizationFlags :: GHC.DynFlags -> GHC.DynFlags
wantedOptimizationFlags df =
foldl' DynFlags.xopt_unset
(foldl' DynFlags.gopt_unset
(foldl' DynFlags.gopt_set df wanted) unwanted) unwantedLang
where
wanted = [ Opt_CSE
, Opt_Specialise
, Opt_DoLambdaEtaExpansion
, Opt_CaseMerge
, Opt_DictsCheap
, Opt_ExposeAllUnfoldings
, Opt_ForceRecomp
, Opt_EnableRewriteRules
, Opt_SimplPreInlining
, Opt_StaticArgumentTransformation
, Opt_FloatIn
, Opt_DictsStrict
, Opt_DmdTxDictSel
, Opt_Strictness
, Opt_SpecialiseAggressively
, Opt_CrossModuleSpecialise
]
unwanted = [ Opt_LiberateCase
, Opt_SpecConstr
, Opt_IgnoreAsserts
, Opt_DoEtaReduction
, Opt_UnboxStrictFields
, Opt_UnboxSmallStrictFields
, Opt_Vectorise
, Opt_VectorisationAvoidance
, Opt_RegsGraph
, Opt_RegsGraph
, Opt_PedanticBottoms
, Opt_LlvmTBAA
, Opt_CmmSink
, Opt_CmmElimCommonBlocks
, Opt_OmitYields
, Opt_IgnoreInterfacePragmas
, Opt_OmitInterfacePragmas
, Opt_IrrefutableTuples
, Opt_Loopification
, Opt_CprAnal
, Opt_FullLaziness
]
unwantedLang = [ LangExt.Strict
, LangExt.StrictData
]
removeStrictnessAnnotations ::
GHC.ParsedModule
-> GHC.ParsedModule
removeStrictnessAnnotations pm =
pm {GHC.pm_parsed_source = fmap rmPS (GHC.pm_parsed_source pm)}
where
rmPS :: GHC.DataId name => GHC.HsModule name -> GHC.HsModule name
rmPS hsm = hsm {GHC.hsmodDecls = (fmap . fmap) rmHSD (GHC.hsmodDecls hsm)}
rmHSD :: GHC.DataId name => GHC.HsDecl name -> GHC.HsDecl name
rmHSD (GHC.TyClD tyClDecl) = GHC.TyClD (rmTyClD tyClDecl)
rmHSD hsd = hsd
rmTyClD :: GHC.DataId name => GHC.TyClDecl name -> GHC.TyClDecl name
rmTyClD dc@(GHC.DataDecl {}) = dc {GHC.tcdDataDefn = rmDataDefn (GHC.tcdDataDefn dc)}
rmTyClD tyClD = tyClD
rmDataDefn :: GHC.DataId name => GHC.HsDataDefn name -> GHC.HsDataDefn name
rmDataDefn hdf = hdf {GHC.dd_cons = (fmap . fmap) rmCD (GHC.dd_cons hdf)}
rmCD :: GHC.DataId name => GHC.ConDecl name -> GHC.ConDecl name
rmCD gadt@(GHC.ConDeclGADT {}) = gadt {GHC.con_type = rmSigType (GHC.con_type gadt)}
rmCD h98@(GHC.ConDeclH98 {}) = h98 {GHC.con_details = rmConDetails (GHC.con_details h98)}
rmSigType :: GHC.DataId name => GHC.LHsSigType name -> GHC.LHsSigType name
rmSigType hsIB = hsIB {GHC.hsib_body = rmHsType (GHC.hsib_body hsIB)}
rmConDetails :: GHC.DataId name => GHC.HsConDeclDetails name -> GHC.HsConDeclDetails name
rmConDetails (GHC.PrefixCon args) = GHC.PrefixCon (fmap rmHsType args)
rmConDetails (GHC.RecCon rec) = GHC.RecCon ((fmap . fmap . fmap) rmConDeclF rec)
rmConDetails (GHC.InfixCon l r) = GHC.InfixCon (rmHsType l) (rmHsType r)
rmHsType :: GHC.DataId name => GHC.Located (GHC.HsType name) -> GHC.Located (GHC.HsType name)
rmHsType = transform go
where
go (GHC.unLoc -> GHC.HsBangTy _ ty) = ty
go ty = ty
rmConDeclF :: GHC.DataId name => GHC.ConDeclField name -> GHC.ConDeclField name
rmConDeclF cdf = cdf {GHC.cd_fld_type = rmHsType (GHC.cd_fld_type cdf)}