{-# LANGUAGE FlexibleContexts, ScopedTypeVariables #-} -- Placeholder for new prims module Language.HERMIT.Primitive.New where import GhcPlugins as GHC hiding (varName) import Control.Arrow import Control.Monad (liftM) import Data.List(transpose) import Data.Set (intersection, unions, fromList, toList) import qualified Data.Set as S import Language.HERMIT.Context import Language.HERMIT.Core import Language.HERMIT.Monad import Language.HERMIT.Kure import Language.HERMIT.External import Language.HERMIT.GHC import Language.HERMIT.ParserCore import Language.HERMIT.Primitive.Common import Language.HERMIT.Primitive.GHC import Language.HERMIT.Primitive.Local import Language.HERMIT.Primitive.Inline import Language.HERMIT.Primitive.Unfold -- import Language.HERMIT.Primitive.Debug import qualified Language.Haskell.TH as TH externals :: [External] externals = map ((.+ Experiment) . (.+ TODO)) [ external "test" (testQuery :: RewriteH Core -> TranslateH Core String) [ "determines if a rewrite could be successfully applied" ] , external "push" (promoteExprR . push :: TH.Name -> RewriteH Core) [ "push a function into argument." , "Unsafe if f is not strict." ] .+ PreCondition , external "var" (promoteExprT . isVar :: TH.Name -> TranslateH Core ()) [ "var ' returns successfully for variable v, and fails otherwise.", "Useful in combination with \"when\", as in: when (var v) r" ] .+ Predicate , external "simplify" (simplifyR :: RewriteH Core) [ "innermost (unfold 'id <+ unfold '$ <+ unfold '. <+ beta-reduce-plus <+ safe-let-subst <+ case-reduce <+ dead-let-elimination)" ] .+ Bash , external "let-tuple" (promoteExprR . letTupleR :: TH.Name -> RewriteH Core) [ "let x = e1 in (let y = e2 in e) ==> let t = (e1,e2) in (let x = fst t in (let y = snd t in e))" ] , external "static-arg" (promoteDefR staticArg :: RewriteH Core) [ "perform the static argument transformation on a recursive function" ] , external "unsafe-replace" (promoteExprR . unsafeReplace :: CoreString -> RewriteH Core) [ "replace the currently focused expression with a new expression" ] .+ Unsafe , external "unsafe-replace" (promoteExprR . unsafeReplaceStash :: String -> RewriteH Core) [ "replace the currently focused expression with an expression from the stash" , "DOES NOT ensure expressions have the same type, or that free variables in the replacement expression are in scope" ] .+ Unsafe , external "inline-all" (inlineAll :: [TH.Name] -> RewriteH Core) [ "inline all named functions in a bottom-up manner" ] ] ------------------------------------------------------------------------------------------------------ isVar :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => TH.Name -> Translate c m CoreExpr () isVar nm = let matchName = arr (cmpTHName2Var nm) in (varT matchName <+ typeT (tyVarT matchName)) >>= guardM ------------------------------------------------------------------------------------------------------ simplifyR :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM Core simplifyR = setFailMsg "Simplify failed: nothing to simplify." $ innermostR (promoteExprR (unfoldNameR (TH.mkName "$") <+ unfoldNameR (TH.mkName ".") <+ unfoldNameR (TH.mkName "id") <+ betaReducePlus <+ safeLetSubstR <+ caseReduce <+ letElim)) collectLets :: CoreExpr -> ([(Var, CoreExpr)],CoreExpr) collectLets (Let (NonRec x e1) e2) = let (bs,expr) = collectLets e2 in ((x,e1):bs, expr) collectLets expr = ([],expr) -- | Combine nested non-recursive lets into case of a tuple. letTupleR :: TH.Name -> Rewrite c HermitM CoreExpr letTupleR nm = prefixFailMsg "Let-tuple failed: " $ do (bnds, body) <- arr collectLets let numBnds = length bnds guardMsg (numBnds > 1) "at least two non-recursive let bindings required." let (vs, rhss) = unzip bnds guardMsg (all isId vs) "cannot tuple type variables." -- TODO: it'd be better if collectLets stopped on reaching a TyVar -- check if tupling the bindings would cause unbound variables let frees = map coreExprFreeVars (drop 1 rhss) used = unions $ zipWith intersection (map (fromList . (`take` vs)) [1..]) frees if S.null used then let rhs = mkCoreTup rhss in constT $ do wild <- newIdH (show nm) (exprType rhs) return $ mkSmallTupleCase vs body wild rhs else fail $ "the following bound variables are used in subsequent bindings: " ++ showVars (toList used) -- Others -- let v = E1 in E2 E3 <=> (let v = E1 in E2) E3 -- let v = E1 in E2 E3 <=> E2 (let v = E1 in E3) staticArg :: forall c. (ExtendPath c Crumb, AddBindings c) => Rewrite c HermitM CoreDef staticArg = prefixFailMsg "static-arg failed: " $ do Def f rhs <- idR let (bnds, body) = collectBinders rhs guardMsg (notNull bnds) "rhs is not a function" contextonlyT $ \ c -> do let bodyContext = foldl (flip addLambdaBinding) c bnds callPats <- apply (callsT (var2THName f) (callT >>> arr snd)) bodyContext (ExprCore body) let argExprs = transpose callPats numCalls = length callPats -- ensure argument is present in every call (partial applications boo) (ps,dbnds) = unzip [ (i,b) | (i,b,exprs) <- zip3 [0..] bnds $ argExprs ++ repeat [] , length exprs /= numCalls || isDynamic b exprs ] isDynamic _ [] = False -- all were static, so static isDynamic b ((Var b'):es) | b == b' = isDynamic b es isDynamic b ((Type (TyVarTy v)):es) | b == v = isDynamic b es isDynamic _ _ = True -- not a simple repass, so dynamic wkr <- newIdH (var2String f ++ "'") (exprType (mkCoreLams dbnds body)) let replaceCall :: Monad m => Rewrite c m CoreExpr replaceCall = do (_,exprs) <- callT return $ mkApps (Var wkr) [ e | (p,e) <- zip [0..] exprs, (p::Int) `elem` ps ] ExprCore body' <- apply (callsR (var2THName f) replaceCall) bodyContext (ExprCore body) return $ Def f $ mkCoreLams bnds $ Let (Rec [(wkr, mkCoreLams dbnds body')]) $ mkApps (Var wkr) (varsToCoreExprs dbnds) ------------------------------------------------------------------------------------------------------ testQuery :: MonadCatch m => Rewrite c m Core -> Translate c m Core String testQuery r = f `liftM` testM r where f True = "Rewrite would succeed." f False = "Rewrite would fail." ------------------------------------------------------------------------------------------------------ -- | Push a function through a Case or Let expression. -- Unsafe if the function is not strict. push :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => TH.Name -> Rewrite c HermitM CoreExpr push nm = prefixFailMsg "push failed: " $ do e <- idR case collectArgs e of (Var v,args) -> do guardMsg (nm `cmpTHName2Var` v) $ "cannot find name " ++ show nm guardMsg (not $ null args) $ "no argument for " ++ show nm guardMsg (all isTypeArg $ init args) $ "initial arguments are not type arguments for " ++ show nm case last args of Case {} -> caseFloatArg Let {} -> letFloatArg _ -> fail "argument is not a Case or Let." _ -> fail "no function to match." ------------------------------------------------------------------------------------------------------ -- The types of these can probably be generalised after the Core Parser is generalised. parseCoreExprT :: CoreString -> TranslateH a CoreExpr parseCoreExprT = contextonlyT . parseCore unsafeReplace :: CoreString -> RewriteH CoreExpr unsafeReplace core = translate $ \ c e -> do e' <- parseCore core c guardMsg (eqType (exprType e) (exprType e')) "expression types differ." return e' unsafeReplaceStash :: String -> RewriteH CoreExpr unsafeReplaceStash label = prefixFailMsg "unsafe-replace failed: " $ contextfreeT $ \ e -> do Def _ rhs <- lookupDef label guardMsg (eqType (exprType e) (exprType rhs)) "expression types differ." return rhs ------------------------------------------------------------------------------------------------------ inlineAll :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => [TH.Name] -> Rewrite c HermitM Core inlineAll = innermostR . foldr (\ nm rr -> promoteExprR (inlineName nm) <+ rr) (fail "inline-all: nothing to do") ------------------------------------------------------------------------------------------------------