{-| Copyright : (C) 2017-2022, Google Inc., 2021-2022, QBayLogic B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. Call-by-need evaluator based on the evaluator described in: Maximilian Bolingbroke, Simon Peyton Jones, "Supercompilation by evaluation", Haskell '10, Baltimore, Maryland, USA. -} {-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} module Clash.GHC.Evaluator where import Prelude hiding (lookup) import Control.Concurrent.Supply (Supply, freshId) import Data.Either (lefts,rights) import Data.List (foldl',mapAccumL) import qualified Data.Primitive.ByteArray as BA import qualified Data.Text as Text #if MIN_VERSION_base(4,15,0) import GHC.Num.Integer (Integer (..)) #else import GHC.Integer.GMP.Internals (Integer (..), BigNat (..)) #endif import Clash.Core.DataCon import Clash.Core.Evaluator.Types import Clash.Core.HasFreeVars import Clash.Core.HasType import Clash.Core.Literal import Clash.Core.Name import Clash.Core.Pretty import Clash.Core.Subst import Clash.Core.Term import Clash.Core.TyCon import Clash.Core.Type import Clash.Core.Util import Clash.Core.Var import Clash.Core.VarEnv import Clash.Debug import qualified Clash.Normalize.Primitives as NP (removedArg, undefined, undefinedX) import Clash.Unique import Clash.Util (curLoc) import Clash.GHC.Evaluator.Primitive evaluator :: Evaluator evaluator = Evaluator { step = ghcStep , unwind = ghcUnwind , primStep = ghcPrimStep , primUnwind = ghcPrimUnwind } {- [Note: forcing special primitives] Clash uses the `whnf` function in two places (for now): 1. The case-of-known-constructor transformation 2. The reduceConstant transformation The first transformation is needed to reach the required normal form. The second transformation is more of cleanup transformation, so non-essential. Normally, `whnf` would force the evaluation of all primitives, which is needed in the `case-of-known-constructor` transformation. However, there are some primitives which we want to leave unevaluated in the `reduceConstant` transformation. Such primitives are: - Primitives such as `Clash.Sized.Vector.transpose`, `Clash.Sized.Vector.map`, etc. that do not reduce to an expression in normal form. Where the `reduceConstant` transformation is supposed to be normal-form preserving. - Primitives such as `GHC.Int.I8#`, `GHC.Word.W32#`, etc. which seem like wrappers around a 64-bit literal, but actually perform truncation to the desired bit-size. This is why the Primitive Evaluator gets a flag telling whether it should evaluate these special primitives. -} stepVar :: Id -> Step stepVar i m _ | Just e <- heapLookup LocalId i m = go LocalId e | Just e <- heapLookup GlobalId i m , isGlobalId i = go GlobalId e | otherwise = Nothing where go s e = let term = deShadowTerm (mScopeNames m) (tickExpr e) in Just . setTerm term . stackPush (Update s i) $ heapDelete s i m -- Removing the heap-bound value on a force ensures we do not get stuck on -- expressions such as: "let x = x in x" tickExpr = Tick (NameMod PrefixName (LitTy . SymTy $ toStr i)) unQualName = snd . Text.breakOnEnd "." toStr = Text.unpack . unQualName . flip Text.snoc '_' . nameOcc . varName stepData :: DataCon -> Step stepData dc = ghcUnwind (DC dc []) stepLiteral :: Literal -> Step stepLiteral l = ghcUnwind (Lit l) stepPrim :: PrimInfo -> Step stepPrim pInfo m tcm | primName pInfo == "GHC.Prim.realWorld#" = ghcUnwind (PrimVal pInfo [] []) m tcm | otherwise = case fst $ splitFunForallTy (primType pInfo) of [] -> ghcPrimStep tcm (forcePrims m) pInfo [] [] m tys -> newBinder tys (Prim pInfo) m tcm stepLam :: Id -> Term -> Step stepLam x e = ghcUnwind (Lambda x e) stepTyLam :: TyVar -> Term -> Step stepTyLam x e = ghcUnwind (TyLambda x e) stepApp :: Term -> Term -> Step stepApp x y m tcm = case term of Data dc -> let tys = fst $ splitFunForallTy (dcType dc) in case compare (length args) (length tys) of EQ -> ghcUnwind (DC dc args) m tcm LT -> newBinder tys' (App x y) m tcm GT -> error "Overapplied DC" Prim p -> let tys = fst $ splitFunForallTy (primType p) in case compare (length args) (length tys) of EQ -> case lefts args of -- We make boolean conjunction and disjunction extra lazy by -- deferring the evaluation of the arguments during the evaluation -- of the primop rule. -- -- This allows us to implement: -- -- x && True --> x -- x && False --> False -- x || True --> True -- x || False --> x -- -- even when that 'x' is _|_. This makes the evaluation -- rule lazier than the actual Haskel implementations which -- are strict in the first argument and lazy in the second. [a0, a1] | primName p `elem` ["GHC.Classes.&&","GHC.Classes.||"] -> let (m0,i) = newLetBinding tcm m a0 (m1,j) = newLetBinding tcm m0 a1 in ghcPrimStep tcm (forcePrims m) p [] [Suspend (Var i), Suspend (Var j)] m1 (e':es) | primName p `elem` (undefinedXPrims ++ undefinedPrims) -- The above primitives are (bottoming) values, whose arguments -- are never used anywhere in the rest of the compiler. So -- instead of pushing a PrimApply frame on the stack to evaluate -- those arguments, we instead just unwind the stack with the -- primitive value and leave its arguments in an unevaluated -- state (Suspend). -> ghcUnwind (PrimVal p (rights args) (map Suspend (e':es))) m tcm | otherwise -> Just . setTerm e' $ stackPush (PrimApply p (rights args) [] es) m _ -> error "internal error" LT -> newBinder tys' (App x y) m tcm GT -> let (m0, n) = newLetBinding tcm m y in Just . setTerm x $ stackPush (Apply n) m0 _ -> let (m0, n) = newLetBinding tcm m y in Just . setTerm x $ stackPush (Apply n) m0 where (term, args, _) = collectArgsTicks (App x y) tys' = fst . splitFunForallTy . inferCoreTypeOf tcm $ App x y stepTyApp :: Term -> Type -> Step stepTyApp x ty m tcm = case term of Data dc -> let tys = fst $ splitFunForallTy (dcType dc) in case compare (length args) (length tys) of EQ -> ghcUnwind (DC dc args) m tcm LT -> newBinder tys' (TyApp x ty) m tcm GT -> error "Overapplied DC" Prim p -> let tys = fst $ splitFunForallTy (primType p) in case compare (length args) (length tys) of EQ -> case lefts args of [] | primName p `elem` fmap primName [ NP.removedArg , NP.undefined , NP.undefinedX ] -> ghcUnwind (PrimVal p (rights args) []) m tcm | otherwise -> ghcPrimStep tcm (forcePrims m) p (rights args) [] m (e':es) -> Just . setTerm e' $ stackPush (PrimApply p (rights args) [] es) m LT -> newBinder tys' (TyApp x ty) m tcm GT -> Just . setTerm x $ stackPush (Instantiate ty) m _ -> Just . setTerm x $ stackPush (Instantiate ty) m where (term, args, _) = collectArgsTicks (TyApp x ty) tys' = fst . splitFunForallTy . inferCoreTypeOf tcm $ TyApp x ty stepLet :: Bind Term -> Term -> Step stepLet (NonRec i b) x m _ = Just (allocate [(i,b)] x m) stepLet (Rec bs) x m _ = Just (allocate bs x m) stepCase :: Term -> Type -> [Alt] -> Step stepCase scrut ty alts m _ = Just . setTerm scrut $ stackPush (Scrutinise ty alts) m -- TODO Support stepwise evaluation of casts. -- stepCast :: Term -> Type -> Type -> Step stepCast _ _ _ _ _ = flip trace Nothing $ unlines [ "WARNING: " <> $(curLoc) <> "Clash can't symbolically evaluate casts" , "Please file an issue at https://github.com/clash-lang/clash-compiler/issues" ] stepTick :: TickInfo -> Term -> Step stepTick tick x m _ = Just . setTerm x $ stackPush (Tickish tick) m -- | Small-step operational semantics. -- ghcStep :: Step ghcStep m = case mTerm m of Var i -> stepVar i m Data dc -> stepData dc m Literal l -> stepLiteral l m Prim p -> stepPrim p m Lam v x -> stepLam v x m TyLam v x -> stepTyLam v x m App x y -> stepApp x y m TyApp x ty -> stepTyApp x ty m Let bs x -> stepLet bs x m Case s ty as -> stepCase s ty as m Cast x a b -> stepCast x a b m Tick t x -> stepTick t x m -- | Take a list of types or type variables and create a lambda / type lambda -- for each one around the given term. -- newBinder :: [Either TyVar Type] -> Term -> Step newBinder tys x m tcm = let (s', iss', x') = mkAbstr (mSupply m, mScopeNames m, x) tys m' = m { mSupply = s', mScopeNames = iss', mTerm = x' } in ghcStep m' tcm where mkAbstr = foldr go where go (Left tv) (s', iss', e') = (s', iss', TyLam tv (TyApp e' (VarTy tv))) go (Right ty) (s', iss', e') = let ((s'', _), n) = mkUniqSystemId (s', iss') ("x", ty) in (s'', iss' ,Lam n (App e' (Var n))) newLetBinding :: TyConMap -> Machine -> Term -> (Machine, Id) newLetBinding tcm m e | Var v <- e , heapContains LocalId v m = (m, v) | otherwise = let m' = heapInsert LocalId id_ e m in (m' { mSupply = ids', mScopeNames = is1 }, id_) where ty = inferCoreTypeOf tcm e ((ids', is1), id_) = mkUniqSystemId (mSupply m, mScopeNames m) ("x", ty) -- | Unwind the stack by 1 ghcUnwind :: Unwind ghcUnwind v m tcm = do (m', kf) <- stackPop m go kf m' where go (Update s x) = return . update s x v go (Apply x) = return . apply tcm v x go (Instantiate ty) = return . instantiate tcm v ty go (PrimApply p tys vs tms) = ghcPrimUnwind tcm p tys vs v tms go (Scrutinise altTy as) = return . scrutinise v altTy as go (Tickish _) = return . setTerm (valToTerm v) -- | Update the Heap with the evaluated term update :: IdScope -> Id -> Value -> Machine -> Machine update s x (valToTerm -> term) = setTerm term . heapInsert s x term -- | Apply a value to a function apply :: TyConMap -> Value -> Id -> Machine -> Machine apply _tcm (Lambda x' e) x m = setTerm (substTm "Evaluator.apply" subst e) m where subst = extendIdSubst subst0 x' (Var x) subst0 = mkSubst $ extendInScopeSet (mScopeNames m) x apply tcm pVal@(PrimVal (PrimInfo{primType}) tys vs) x m | isUndefinedXPrimVal pVal = setTerm (TyApp (Prim NP.undefinedX) ty) m | isUndefinedPrimVal pVal = setTerm (TyApp (Prim NP.undefined) ty) m where ty = piResultTys tcm primType (tys ++ map (inferCoreTypeOf tcm . valToTerm) vs ++ [varType x]) apply _ v _ m = error $ "Evaluator.apply: Not a lambda: " ++ show v ++ "\n" ++ show m -- | Instantiate a type-abstraction instantiate :: TyConMap -> Value -> Type -> Machine -> Machine instantiate _tcm (TyLambda x e) ty m = setTerm (substTm "Evaluator.instantiate1" subst e) m where subst = extendTvSubst subst0 x ty subst0 = mkSubst iss0 iss0 = mkInScopeSet (freeVarsOf e <> freeVarsOf ty) -- The evaluator is setup in such a way that under normal conditions anything -- of type 'forall a . ty' must be a ty-lambda. -- -- However, sometimes we evaluate to an error /value/. When this happens, -- instead of doing a regural type substitition we: -- -- 1. Calculate the 'forall a . ty' type of the error value -- 2. Substitute the 'a' by the applied type. -- 3. Create a new error value of the shape: 'undefined @substituted_type' -- Where this particular 'undefined' has type 'forall a . a'. We destinquish -- between error values throwing X exceptions and other error values, and -- create appropriate error values that we return. We make this distinctions -- in onder to enable conversion of X-exception throwing code to undefined -- bitvectors. instantiate tcm pVal@(PrimVal (PrimInfo{primType}) tys es) ty m | isUndefinedXPrimVal pVal = setTerm (TyApp (Prim NP.undefinedX) primType1) m | isUndefinedPrimVal pVal = setTerm (TyApp (Prim NP.undefined) primType1) m where esTys = map (inferCoreTypeOf tcm) es -- Calculate the type of: prim @ty0 .. @tyN e0 .. eN @ty -- -- This combines the above-mentioned step 1 and 2 primType1 = piResultTys tcm primType (tys ++ esTys ++ [ty]) instantiate _ p _ _ = error $ "Evaluator.instantiate: Not a tylambda: " ++ show p -- | Evaluate a case-expression scrutinise :: Value -> Type -> [Alt] -> Machine -> Machine scrutinise v _altTy [] m = setTerm (valToTerm v) m -- [Note: empty case expressions] -- -- Clash does not have empty case-expressions; instead, empty case-expressions -- are used to indicate that the `whnf` function was called the context of a -- case-expression, which means certain special primitives must be forced. -- See also [Note: forcing special primitives] scrutinise (Lit l) _altTy alts m = case alts of (DefaultPat, altE):alts1 -> setTerm (go altE alts1) m _ -> let term = go (error $ "Evaluator.scrutinise: no match " <> showPpr (Case (valToTerm (Lit l)) (ConstTy Arrow) alts)) alts in setTerm term m where go def [] = def go _ ((LitPat l1,altE):_) | l1 == l = altE go _ ((DataPat dc [] [x],altE):_) | IntegerLiteral l1 <- l , Just patE <- case dcTag dc of 1 | l1 >= ((-2)^(63::Int)) && l1 < 2^(63::Int) -> Just (IntLiteral l1) 2 | l1 >= (2^(63::Int)) -> #if MIN_VERSION_base(4,15,0) let !(IP ba0) = l1 #else let !(Jp# !(BN# ba0)) = l1 #endif ba1 = BA.ByteArray ba0 in Just (ByteArrayLiteral ba1) 3 | l1 < ((-2)^(63::Int)) -> #if MIN_VERSION_base(4,15,0) let !(IN ba0) = l1 #else let !(Jn# !(BN# ba0)) = l1 #endif ba1 = BA.ByteArray ba0 in Just (ByteArrayLiteral ba1) _ -> Nothing = let inScope = freeVarsOf altE subst0 = mkSubst (mkInScopeSet inScope) subst1 = extendIdSubst subst0 x (Literal patE) in substTm "Evaluator.scrutinise" subst1 altE | NaturalLiteral l1 <- l , Just patE <- case dcTag dc of 1 | l1 >= 0 && l1 < 2^(64::Int) -> Just (WordLiteral l1) 2 | l1 >= (2^(64::Int)) -> #if MIN_VERSION_base(4,15,0) let !(IP ba0) = l1 #else let !(Jp# !(BN# ba0)) = l1 #endif ba1 = BA.ByteArray ba0 in Just (ByteArrayLiteral ba1) _ -> Nothing = let inScope = freeVarsOf altE subst0 = mkSubst (mkInScopeSet inScope) subst1 = extendIdSubst subst0 x (Literal patE) in substTm "Evaluator.scrutinise" subst1 altE go def (_:alts1) = go def alts1 scrutinise (DC dc xs) _altTy alts m | altE:_ <- [substInAlt altDc tvs pxs xs altE | (DataPat altDc tvs pxs,altE) <- alts, altDc == dc ] ++ [altE | (DefaultPat,altE) <- alts ] = setTerm altE m scrutinise v@(PrimVal p _ vs) altTy alts m | isUndefinedXPrimVal v = setTerm (TyApp (Prim NP.undefinedX) altTy) m | isUndefinedPrimVal v = setTerm (TyApp (Prim NP.undefined) altTy) m | any (\case {(LitPat {},_) -> True; _ -> False}) alts = case alts of ((DefaultPat,altE):alts1) -> setTerm (go altE alts1) m _ -> let term = go (error $ "Evaluator.scrutinise: no match " <> showPpr (Case (valToTerm v) (ConstTy Arrow) alts)) alts in setTerm term m where go def [] = def go _ ((LitPat l1,altE):_) | l1 == l = altE go def (_:alts1) = go def alts1 l = case primName p of "Clash.Sized.Internal.BitVector.fromInteger##" | [Lit (WordLiteral 0), Lit l0] <- vs -> l0 "Clash.Sized.Internal.BitVector.fromInteger#" | [_,Lit (NaturalLiteral 0),Lit l0] <- vs -> l0 "Clash.Sized.Internal.Index.fromInteger#" | [_,Lit l0] <- vs -> l0 "Clash.Sized.Internal.Signed.fromInteger#" | [_,Lit l0] <- vs -> l0 "Clash.Sized.Internal.Unsigned.fromInteger#" | [_,Lit l0] <- vs -> l0 _ -> error ("scrutinise: " ++ showPpr (Case (valToTerm v) (ConstTy Arrow) alts)) scrutinise v _altTy alts _ = error ("scrutinise: " ++ showPpr (Case (valToTerm v) (ConstTy Arrow) alts)) substInAlt :: DataCon -> [TyVar] -> [Id] -> [Either Term Type] -> Term -> Term substInAlt dc tvs xs args e = substTm "Evaluator.substInAlt" subst e where tys = rights args tms = lefts args substTyMap = zip tvs (drop (length (dcUnivTyVars dc)) tys) substTmMap = zip xs tms inScope = freeVarsOf tys `unionVarSet` freeVarsOf (e:tms) subst = extendTvSubstList (extendIdSubstList subst0 substTmMap) substTyMap subst0 = mkSubst (mkInScopeSet inScope) -- | Allocate let-bindings on the heap allocate :: [LetBinding] -> Term -> Machine -> Machine allocate xes e m = m { mHeapLocal = extendVarEnvList (mHeapLocal m) xes' , mSupply = ids' , mScopeNames = isN , mTerm = e' } where xNms = fmap fst xes is1 = extendInScopeSetList (mScopeNames m) xNms (ids', s) = mapAccumL (letSubst (mHeapLocal m)) (mSupply m) xNms (nms, s') = unzip s isN = extendInScopeSetList is1 nms subst = extendIdSubstList subst0 s' subst0 = mkSubst (foldl' extendInScopeSet is1 nms) xes' = zip nms (fmap (substTm "Evaluator.allocate0" subst . snd) xes) e' = substTm "Evaluator.allocate1" subst e -- | Create a unique name and substitution for a let-binder letSubst :: PureHeap -> Supply -> Id -> (Supply, (Id, (Id, Term))) letSubst h acc id0 = let (acc',id1) = mkUniqueHeapId h acc id0 in (acc',(id1,(id0,Var id1))) where mkUniqueHeapId :: PureHeap -> Supply -> Id -> (Supply, Id) mkUniqueHeapId h' ids x = maybe (ids', x') (const $ mkUniqueHeapId h' ids' x) (lookupVarEnv x' h') where (i,ids') = freshId ids x' = modifyVarName (`setUnique` i) x