module Vectorise.Monad ( module Vectorise.Monad.Base, module Vectorise.Monad.Naming, module Vectorise.Monad.Local, module Vectorise.Monad.Global, module Vectorise.Monad.InstEnv, initV, -- * Builtins liftBuiltinDs, builtin, builtins, -- * Variables lookupVar, lookupVar_maybe, addGlobalParallelVar, addGlobalParallelTyCon, ) where import GhcPrelude import Vectorise.Monad.Base import Vectorise.Monad.Naming import Vectorise.Monad.Local import Vectorise.Monad.Global import Vectorise.Monad.InstEnv import Vectorise.Builtins import Vectorise.Env import CoreSyn import TcRnMonad import DsMonad import HscTypes hiding ( MonadThings(..) ) import DynFlags import InstEnv import Class import TyCon import NameSet import VarSet import VarEnv import Var import Id import Name import ErrUtils import Outputable import Module import Control.Monad (join) -- |Run a vectorisation computation. -- initV :: HscEnv -> ModGuts -> VectInfo -> VM a -> IO (Maybe (VectInfo, a)) initV hsc_env guts info thing_inside = do { dumpIfVtTrace "Incoming VectInfo" (ppr info) ; (_, res) <- initDsWithModGuts hsc_env guts go ; case join res of Nothing -> dumpIfVtTrace "Vectorisation FAILED!" empty Just (info', _) -> dumpIfVtTrace "Outgoing VectInfo" (ppr info') ; return $ join res } where dflags = hsc_dflags hsc_env dumpIfVtTrace = dumpIfSet_dyn dflags Opt_D_dump_vt_trace bindsToIds (NonRec v _) = [v] bindsToIds (Rec binds) = map fst binds ids = concatMap bindsToIds (mg_binds guts) go = do { -- set up tables of builtin entities ; builtins <- initBuiltins ; builtin_vars <- initBuiltinVars builtins -- set up class and type family envrionments ; eps <- liftIO $ hscEPS hsc_env ; let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts) instEnvs = InstEnvs (eps_inst_env eps) (mg_inst_env guts) (mkModuleSet (dep_orphs (mg_deps guts))) builtin_pas = initClassDicts instEnvs (paClass builtins) -- grab all 'PA' and.. builtin_prs = initClassDicts instEnvs (prClass builtins) -- ..'PR' class instances -- construct the initial global environment ; let genv = extendImportedVarsEnv builtin_vars . setPAFunsEnv builtin_pas . setPRFunsEnv builtin_prs $ initGlobalEnv (gopt Opt_VectorisationAvoidance dflags) info (mg_vect_decls guts) instEnvs famInstEnvs -- perform vectorisation ; r <- runVM thing_inside builtins genv emptyLocalEnv ; case r of Yes genv _ x -> return $ Just (new_info genv, x) No reason -> do { unqual <- mkPrintUnqualifiedDs ; liftIO $ printOutputForUser dflags unqual $ mkDumpDoc "Warning: vectorisation failure:" reason ; return Nothing } } new_info genv = modVectInfo genv ids (mg_tcs guts) (mg_vect_decls guts) info -- For a given DPH class, produce a mapping from type constructor (in head position) to the -- instance dfun for that type constructor and class. (DPH class instances cannot overlap in -- head constructors.) -- initClassDicts :: InstEnvs -> Class -> [(Name, Var)] initClassDicts insts cls = map find $ classInstances insts cls where find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i) | otherwise = pprPanic invalidInstance (ppr i) invalidInstance = "Invalid DPH instance (overlapping in head constructor)" -- Builtins ------------------------------------------------------------------- -- |Lift a desugaring computation using the `Builtins` into the vectorisation monad. -- liftBuiltinDs :: (Builtins -> DsM a) -> VM a liftBuiltinDs p = VM $ \bi genv lenv -> do { x <- p bi; return (Yes genv lenv x)} -- |Project something from the set of builtins. -- builtin :: (Builtins -> a) -> VM a builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi)) -- |Lift a function using the `Builtins` into the vectorisation monad. -- builtins :: (a -> Builtins -> b) -> VM (a -> b) builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi)) -- Var ------------------------------------------------------------------------ -- |Lookup the vectorised, and if local, also the lifted version of a variable. -- -- * If it's in the global environment we get the vectorised version. -- * If it's in the local environment we get both the vectorised and lifted version. -- lookupVar :: Var -> VM (Scope Var (Var, Var)) lookupVar v = do { mb_res <- lookupVar_maybe v ; case mb_res of Just x -> return x Nothing -> do dflags <- getDynFlags dumpVar dflags v } lookupVar_maybe :: Var -> VM (Maybe (Scope Var (Var, Var))) lookupVar_maybe v = do { r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v ; case r of Just e -> return $ Just (Local e) Nothing -> fmap Global <$> (readGEnv $ \env -> lookupVarEnv (global_vars env) v) } dumpVar :: DynFlags -> Var -> a dumpVar dflags var | Just _ <- isClassOpId_maybe var = cantVectorise dflags "ClassOpId not vectorised:" (ppr var) | otherwise = cantVectorise dflags "Variable not vectorised:" (ppr var) -- Global parallel entities ---------------------------------------------------- -- |Mark the given variable as parallel — i.e., executing the associated code might involve -- parallel array computations. -- addGlobalParallelVar :: Var -> VM () addGlobalParallelVar var = do { traceVt "addGlobalParallelVar" (ppr var) ; updGEnv $ \env -> env{global_parallel_vars = extendDVarSet (global_parallel_vars env) var} } -- |Mark the given type constructor as parallel — i.e., its values might embed parallel arrays. -- addGlobalParallelTyCon :: TyCon -> VM () addGlobalParallelTyCon tycon = do { traceVt "addGlobalParallelTyCon" (ppr tycon) ; updGEnv $ \env -> env{global_parallel_tycons = extendNameSet (global_parallel_tycons env) (tyConName tycon)} }