{-| Copyright : (C) 2012-2016, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com> Transformation process for normalization -} module Clash.Normalize.Strategy where import Clash.Normalize.Transformations import Clash.Normalize.Types import Clash.Rewrite.Combinators import Clash.Rewrite.Types import Clash.Rewrite.Util -- [Note: bottomup traversal evalConst] -- -- 2-May-2019: There is a bug in the evaluator where all data constructors are -- considered lazy, even though their declaration says they have strict fields. -- This causes some reductions to fail because the term under the constructor is -- not in WHNF, which is what some of the evaluation rules for certain primitive -- operations expect. Using a bottom-up traversal works around this bug by -- ensuring that the values under the constructor are in WHNF. -- -- Using a bottomup traversal ensures that constants are reduced to NF, even if -- constructors are lazy, thus ensuring more sensible/smaller generated HDL. -- | Normalisation transformation normalization :: NormRewrite normalization :: NormRewrite normalization = NormRewrite rmDeadcode NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite constantPropagation NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite rmUnusedExpr NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-!-> NormRewrite anf NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-!-> NormRewrite rmDeadcode NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite bindConst NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite letTL NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite evalConst NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-!-> NormRewrite cse NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-!-> NormRewrite cleanup NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite xOptim NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite rmDeadcode NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite cleanup NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite recLetRec NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite splitArgs where anf :: NormRewrite anf = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "nonRepANF" HasCallStack => NormRewrite NormRewrite nonRepANF) NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "ANF" HasCallStack => NormRewrite NormRewrite makeANF NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "caseCon" HasCallStack => NormRewrite NormRewrite caseCon) letTL :: NormRewrite letTL = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownSucR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "topLet" HasCallStack => NormRewrite NormRewrite topLet) recLetRec :: NormRewrite recLetRec = String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "recToLetRec" HasCallStack => NormRewrite NormRewrite recToLetRec rmUnusedExpr :: NormRewrite rmUnusedExpr = NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "removeUnusedExpr" HasCallStack => NormRewrite NormRewrite removeUnusedExpr) rmDeadcode :: NormRewrite rmDeadcode = NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "deadcode" HasCallStack => NormRewrite NormRewrite deadCode) bindConst :: NormRewrite bindConst = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "bindConstantVar" HasCallStack => NormRewrite NormRewrite bindConstantVar) -- See [Note] bottomup traversal evalConst: evalConst :: NormRewrite evalConst = NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "evalConst" HasCallStack => NormRewrite NormRewrite reduceConst) cse :: NormRewrite cse = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "CSE" HasCallStack => NormRewrite NormRewrite simpleCSE) xOptim :: NormRewrite xOptim = NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "xOptimize" HasCallStack => NormRewrite NormRewrite xOptimize) cleanup :: NormRewrite cleanup = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "etaExpandSyn" HasCallStack => NormRewrite NormRewrite etaExpandSyn) NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownSucR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "inlineCleanup" HasCallStack => NormRewrite NormRewrite inlineCleanup) NormRewrite -> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m -> Rewrite m !-> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m innerMost ([(String, NormRewrite)] -> NormRewrite forall extra. [(String, Rewrite extra)] -> Rewrite extra applyMany [(String "caseCon" , HasCallStack => NormRewrite NormRewrite caseCon) ,(String "bindConstantVar", HasCallStack => NormRewrite NormRewrite bindConstantVar) ,(String "letFlat" , HasCallStack => NormRewrite NormRewrite flattenLet)]) NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite rmDeadcode NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite letTL splitArgs :: NormRewrite splitArgs = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "separateArguments" HasCallStack => NormRewrite NormRewrite separateArguments) NormRewrite -> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m -> Rewrite m !-> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "caseCon" HasCallStack => NormRewrite NormRewrite caseCon) constantPropagation :: NormRewrite constantPropagation :: NormRewrite constantPropagation = NormRewrite inlineAndPropagate NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite caseFlattening NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite etaTL NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite dec NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite spec NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite dec NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite conSpec where etaTL :: NormRewrite etaTL = String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "etaTL" HasCallStack => NormRewrite NormRewrite etaExpansionTL NormRewrite -> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m -> Rewrite m !-> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "applicationPropagation" HasCallStack => NormRewrite NormRewrite appPropFast) inlineAndPropagate :: NormRewrite inlineAndPropagate = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m repeatR (NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR ([(String, NormRewrite)] -> NormRewrite forall extra. [(String, Rewrite extra)] -> Rewrite extra applyMany [(String, NormRewrite)] transPropagateAndInline) NormRewrite -> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m >-> NormRewrite inlineNR) spec :: NormRewrite spec = NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR ([(String, NormRewrite)] -> NormRewrite forall extra. [(String, Rewrite extra)] -> Rewrite extra applyMany [(String, NormRewrite)] specTransformations) caseFlattening :: NormRewrite caseFlattening = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m repeatR (NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "caseFlat" HasCallStack => NormRewrite NormRewrite caseFlat)) dec :: NormRewrite dec = NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m repeatR (NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m topdownR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "DEC" HasCallStack => NormRewrite NormRewrite disjointExpressionConsolidation)) conSpec :: NormRewrite conSpec = NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR ((String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "appPropCS" HasCallStack => NormRewrite NormRewrite appPropFast NormRewrite -> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m -> Rewrite m !-> NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "constantSpec" HasCallStack => NormRewrite NormRewrite constantSpec)) NormRewrite -> NormRewrite -> NormRewrite forall m. Rewrite m -> Rewrite m -> Rewrite m >-! String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "constantSpec" HasCallStack => NormRewrite NormRewrite constantSpec) transPropagateAndInline :: [(String,NormRewrite)] transPropagateAndInline :: [(String, NormRewrite)] transPropagateAndInline = [ (String "applicationPropagation", HasCallStack => NormRewrite NormRewrite appPropFast ) , (String "bindConstantVar" , HasCallStack => NormRewrite NormRewrite bindConstantVar ) , (String "caseLet" , HasCallStack => NormRewrite NormRewrite caseLet ) , (String "caseCase" , HasCallStack => NormRewrite NormRewrite caseCase ) , (String "caseCon" , HasCallStack => NormRewrite NormRewrite caseCon ) , (String "elemExistentials" , HasCallStack => NormRewrite NormRewrite elemExistentials ) , (String "caseElemNonReachable" , HasCallStack => NormRewrite NormRewrite caseElemNonReachable ) , (String "removeUnusedExpr" , HasCallStack => NormRewrite NormRewrite removeUnusedExpr ) -- These transformations can safely be applied in a top-down traversal as -- they themselves check whether the to-be-inlined binder is recursive or not. , (String "inlineWorkFree" , HasCallStack => NormRewrite NormRewrite inlineWorkFree) , (String "inlineSmall" , HasCallStack => NormRewrite NormRewrite inlineSmall) , (String "bindOrLiftNonRep", HasCallStack => NormRewrite NormRewrite inlineOrLiftNonRep) -- See: [Note] bindNonRep before liftNonRep -- See: [Note] bottom-up traversal for liftNonRep , (String "reduceNonRepPrim", HasCallStack => NormRewrite NormRewrite reduceNonRepPrim) , (String "caseCast" , HasCallStack => NormRewrite NormRewrite caseCast) , (String "letCast" , HasCallStack => NormRewrite NormRewrite letCast) , (String "splitCastWork" , HasCallStack => NormRewrite NormRewrite splitCastWork) , (String "argCastSpec" , HasCallStack => NormRewrite NormRewrite argCastSpec) , (String "inlineCast" , HasCallStack => NormRewrite NormRewrite inlineCast) , (String "eliminateCastCast",HasCallStack => NormRewrite NormRewrite eliminateCastCast) ] -- InlineNonRep cannot be applied in a top-down traversal, as the non-representable -- binder might be recursive. The idea is, is that if the recursive -- non-representable binder is inlined once, we can get rid of the recursive -- aspect using the case-of-known-constructor inlineNR :: NormRewrite inlineNR :: NormRewrite inlineNR = NormRewrite -> NormRewrite forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR (String -> NormRewrite -> NormRewrite forall extra. String -> Rewrite extra -> Rewrite extra apply String "inlineNonRep" HasCallStack => NormRewrite NormRewrite inlineNonRep) specTransformations :: [(String,NormRewrite)] specTransformations :: [(String, NormRewrite)] specTransformations = [ (String "typeSpec" , HasCallStack => NormRewrite NormRewrite typeSpec) , (String "nonRepSpec" , HasCallStack => NormRewrite NormRewrite nonRepSpec) ] {- [Note] bottom-up traversal for liftNonRep We used to say: "The liftNonRep transformation must be applied in a topDown traversal because of what Clash considers tail calls in its join-point analysis." Consider: > let fail = \x -> ... > in case ... of > A -> let fail1 = \y -> case ... of > X -> fail ... > Y -> ... > in case ... of > P -> fail1 ... > Q -> ... > B -> fail ... under "normal" tail call rules, the local 'fail' functions is not a join-point because it is used in a let-binding. However, we apply "special" tail call rules in Clash. Because 'fail' is used in a TC position within 'fail1', and 'fail1' is only used in a TC position, in Clash, we consider 'tail' also only to be used in a TC position. Now image we apply 'liftNonRep' in a bottom up traversal, we will end up with: > fail1 = \fail y -> case ... of > X -> fail ... > Y -> ... > let fail = \x -> ... > in case ... of > A -> case ... of > P -> fail1 fail ... > Q -> ... > B -> fail ... Suddenly, 'fail' ends up in an argument position, because it occurred as a _locally_ bound variable within 'fail1'. And because of that 'fail' stops being a join-point. However, when we apply 'liftNonRep' in a top down traversal we end up with: > fail = \x -> ... > > fail1 = \y -> case ... of > X -> fail ... > Y -> ... > > let ... > in case ... of > A -> let > in case ... of > P -> fail1 ... > Q -> ... > B -> fail ... and all is well with the world. UPDATE: We can now just perform liftNonRep in a bottom-up traversal again, because liftNonRep no longer checks that if the binding that is lifted is a join-point. However, for this to work, bindNonRep must always have been exhaustively applied before liftNonRep. See also: [Note] bindNonRep before liftNonRep. -} {- [Note] bindNonRep before liftNonRep The combination of liftNonRep and nonRepSpec can lead to non-termination in an unchecked rewrite system (without termination measures in place) on the following: > main = f not > f = \a x -> (a x) && (f a x) nonRepSpec will lead to: > main = f' > f = \a x -> (a x) && (f a x) > f' = (\a x -> (a x) && (f a x)) not then lamApp leads to: > main = f' > f = \a x -> (a x) && (f a x) > f' = let a = not in (\x -> (a x) && (f a x)) then liftNonRep leads to: > main = f' > f = \a x -> (a x) && (f a x) > f' = \x -> (g x) && (f g x) > g = not and nonRepSepc leads to: > main = f' > f = \a x -> (a x) && (f a x) > f' = \x -> (g x) && (f'' g x) > g = not > f'' = (\a x -> (a x) && (f a x)) g This cycle continues indefinitely, as liftNonRep creates a new global variable, which is never alpha-equivalent to the previous global variable introduced by liftNonRep. That is why bindNonRep must always be applied before liftNonRep. When we end up in the situation after lamApp: > main = f' > f = \a x -> (a x) && (f a x) > f' = let a = not in (\x -> (a x) && (f a x)) bindNonRep will now lead to: > main = f' > f = \a x -> (a x) && (f a x) > f' = \x -> (not x) && (f not x) Because `f` has already been specialized on the alpha-equivalent-to-itself `not` function, liftNonRep leads to: > main = f' > f = \a x -> (a x) && (f a x) > f' = \x -> (not x) && (f' x) And there is no non-terminating rewriting cycle. That is why bindNonRep must always be exhaustively applied before we apply liftNonRep. -} -- | Topdown traversal, stops upon first success topdownSucR :: Rewrite extra -> Rewrite extra topdownSucR :: Rewrite extra -> Rewrite extra topdownSucR Rewrite extra r = Rewrite extra r Rewrite extra -> Rewrite extra -> Rewrite extra forall m. Rewrite m -> Rewrite m -> Rewrite m >-! (Rewrite extra -> Rewrite extra forall (m :: Type -> Type). Monad m => Transform m -> Transform m allR (Rewrite extra -> Rewrite extra forall m. Rewrite m -> Rewrite m topdownSucR Rewrite extra r)) {-# INLINE topdownSucR #-} topdownRR :: Rewrite extra -> Rewrite extra topdownRR :: Rewrite extra -> Rewrite extra topdownRR Rewrite extra r = Rewrite extra -> Rewrite extra forall m. Rewrite m -> Rewrite m repeatR (Rewrite extra -> Rewrite extra forall m. Rewrite m -> Rewrite m topdownR Rewrite extra r) {-# INLINE topdownRR #-} innerMost :: Rewrite extra -> Rewrite extra innerMost :: Rewrite extra -> Rewrite extra innerMost = let go :: Rewrite extra -> Rewrite extra go Rewrite extra r = Rewrite extra -> Rewrite extra forall (m :: Type -> Type). Monad m => Transform m -> Transform m bottomupR (Rewrite extra r Rewrite extra -> Rewrite extra -> Rewrite extra forall m. Rewrite m -> Rewrite m -> Rewrite m !-> Rewrite extra -> Rewrite extra forall m. Rewrite m -> Rewrite m innerMost Rewrite extra r) in Rewrite extra -> Rewrite extra forall m. Rewrite m -> Rewrite m go {-# INLINE innerMost #-} applyMany :: [(String,Rewrite extra)] -> Rewrite extra applyMany :: [(String, Rewrite extra)] -> Rewrite extra applyMany = (Rewrite extra -> Rewrite extra -> Rewrite extra) -> [Rewrite extra] -> Rewrite extra forall (t :: Type -> Type) a. Foldable t => (a -> a -> a) -> t a -> a foldr1 Rewrite extra -> Rewrite extra -> Rewrite extra forall (m :: Type -> Type). Monad m => Transform m -> Transform m -> Transform m (>->) ([Rewrite extra] -> Rewrite extra) -> ([(String, Rewrite extra)] -> [Rewrite extra]) -> [(String, Rewrite extra)] -> Rewrite extra forall b c a. (b -> c) -> (a -> b) -> a -> c . ((String, Rewrite extra) -> Rewrite extra) -> [(String, Rewrite extra)] -> [Rewrite extra] forall a b. (a -> b) -> [a] -> [b] map ((String -> Rewrite extra -> Rewrite extra) -> (String, Rewrite extra) -> Rewrite extra forall a b c. (a -> b -> c) -> (a, b) -> c uncurry String -> Rewrite extra -> Rewrite extra forall extra. String -> Rewrite extra -> Rewrite extra apply) {-# INLINE applyMany #-}