{-|
  Copyright  :  (C) 2012-2016, University of Twente
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Transformation process for normalization
-}

module Clash.Normalize.Strategy where

import Clash.Normalize.Transformations
import Clash.Normalize.Types
import Clash.Rewrite.Combinators
import Clash.Rewrite.Types
import Clash.Rewrite.Util

-- [Note: bottomup traversal evalConst]
--
-- 2-May-2019: There is a bug in the evaluator where all data constructors are
-- considered lazy, even though their declaration says they have strict fields.
-- This causes some reductions to fail because the term under the constructor is
-- not in WHNF, which is what some of the evaluation rules for certain primitive
-- operations expect. Using a bottom-up traversal works around this bug by
-- ensuring that the values under the constructor are in WHNF.
--
-- Using a bottomup traversal ensures that constants are reduced to NF, even if
-- constructors are lazy, thus ensuring more sensible/smaller generated HDL.

-- | Normalisation transformation
normalization :: NormRewrite
normalization :: NormRewrite
normalization = NormRewrite
rmDeadcode NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
constantPropagation NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
etaTL NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
rmUnusedExpr NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
anf NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
rmDeadcode NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                NormRewrite
bindConst NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
letTL NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
evalConst NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
cse NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
cleanup NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
recLetRec
  where
    etaTL :: NormRewrite
etaTL      = String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "etaTL" HasCallStack => NormRewrite
NormRewrite
etaExpansionTL NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "applicationPropagation" HasCallStack => NormRewrite
NormRewrite
appPropFast)
    anf :: NormRewrite
anf        = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "nonRepANF" HasCallStack => NormRewrite
NormRewrite
nonRepANF) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "ANF" HasCallStack => NormRewrite
NormRewrite
makeANF NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "caseCon" HasCallStack => NormRewrite
NormRewrite
caseCon)
    letTL :: NormRewrite
letTL      = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "topLet" HasCallStack => NormRewrite
NormRewrite
topLet)
    recLetRec :: NormRewrite
recLetRec  = String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "recToLetRec" HasCallStack => NormRewrite
NormRewrite
recToLetRec
    rmUnusedExpr :: NormRewrite
rmUnusedExpr = NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "removeUnusedExpr" HasCallStack => NormRewrite
NormRewrite
removeUnusedExpr)
    rmDeadcode :: NormRewrite
rmDeadcode = NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "deadcode" HasCallStack => NormRewrite
NormRewrite
deadCode)
    bindConst :: NormRewrite
bindConst  = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "bindConstantVar" HasCallStack => NormRewrite
NormRewrite
bindConstantVar)
    -- See [Note] bottomup traversal evalConst:
    evalConst :: NormRewrite
evalConst  = NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "evalConst" HasCallStack => NormRewrite
NormRewrite
reduceConst)
    cse :: NormRewrite
cse        = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "CSE" HasCallStack => NormRewrite
NormRewrite
simpleCSE)
    cleanup :: NormRewrite
cleanup    = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "etaExpandSyn" HasCallStack => NormRewrite
NormRewrite
etaExpandSyn) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "inlineCleanup" HasCallStack => NormRewrite
NormRewrite
inlineCleanup) NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!->
                 NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
innerMost ([(String, NormRewrite)] -> NormRewrite
forall extra. [(String, Rewrite extra)] -> Rewrite extra
applyMany [("caseCon"        , HasCallStack => NormRewrite
NormRewrite
caseCon)
                                      ,("bindConstantVar", HasCallStack => NormRewrite
NormRewrite
bindConstantVar)
                                      ,("letFlat"        , HasCallStack => NormRewrite
NormRewrite
flattenLet)])
                 NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
rmDeadcode NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
letTL


constantPropagation :: NormRewrite
constantPropagation :: NormRewrite
constantPropagation = NormRewrite
inlineAndPropagate NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                     NormRewrite
caseFlattening NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
dec NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
spec NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
dec NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                     NormRewrite
conSpec
  where
    inlineAndPropagate :: NormRewrite
inlineAndPropagate = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR ([(String, NormRewrite)] -> NormRewrite
forall extra. [(String, Rewrite extra)] -> Rewrite extra
applyMany [(String, NormRewrite)]
transPropagateAndInline) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
inlineNR)
    spec :: NormRewrite
spec               = NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR ([(String, NormRewrite)] -> NormRewrite
forall extra. [(String, Rewrite extra)] -> Rewrite extra
applyMany [(String, NormRewrite)]
specTransformations)
    caseFlattening :: NormRewrite
caseFlattening     = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "caseFlat" HasCallStack => NormRewrite
NormRewrite
caseFlat))
    dec :: NormRewrite
dec                = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "DEC" HasCallStack => NormRewrite
NormRewrite
disjointExpressionConsolidation))
    conSpec :: NormRewrite
conSpec            = NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "constantSpec" HasCallStack => NormRewrite
NormRewrite
constantSpec)

    transPropagateAndInline :: [(String,NormRewrite)]
    transPropagateAndInline :: [(String, NormRewrite)]
transPropagateAndInline =
      [ ("applicationPropagation", HasCallStack => NormRewrite
NormRewrite
appPropFast          )
      , ("bindConstantVar"       , HasCallStack => NormRewrite
NormRewrite
bindConstantVar      )
      , ("caseLet"               , HasCallStack => NormRewrite
NormRewrite
caseLet              )
      , ("caseCase"              , HasCallStack => NormRewrite
NormRewrite
caseCase             )
      , ("caseCon"               , HasCallStack => NormRewrite
NormRewrite
caseCon              )
      , ("elemExistentials"      , HasCallStack => NormRewrite
NormRewrite
elemExistentials     )
      , ("caseElemNonReachable"  , HasCallStack => NormRewrite
NormRewrite
caseElemNonReachable )
      , ("removeUnusedExpr"      , HasCallStack => NormRewrite
NormRewrite
removeUnusedExpr     )
      -- These transformations can safely be applied in a top-down traversal as
      -- they themselves check whether the to-be-inlined binder is recursive or not.
      , ("inlineWorkFree"  , HasCallStack => NormRewrite
NormRewrite
inlineWorkFree)
      , ("inlineSmall"     , HasCallStack => NormRewrite
NormRewrite
inlineSmall)
      , ("bindOrLiftNonRep", HasCallStack => NormRewrite
NormRewrite
inlineOrLiftNonRep) -- See: [Note] bindNonRep before liftNonRep
                                                 -- See: [Note] bottom-up traversal for liftNonRep
      , ("reduceNonRepPrim", HasCallStack => NormRewrite
NormRewrite
reduceNonRepPrim)


      , ("caseCast"        , HasCallStack => NormRewrite
NormRewrite
caseCast)
      , ("letCast"         , HasCallStack => NormRewrite
NormRewrite
letCast)
      , ("splitCastWork"   , HasCallStack => NormRewrite
NormRewrite
splitCastWork)
      , ("argCastSpec"     , HasCallStack => NormRewrite
NormRewrite
argCastSpec)
      , ("inlineCast"      , HasCallStack => NormRewrite
NormRewrite
inlineCast)
      , ("eliminateCastCast",HasCallStack => NormRewrite
NormRewrite
eliminateCastCast)
      ]

    -- InlineNonRep cannot be applied in a top-down traversal, as the non-representable
    -- binder might be recursive. The idea is, is that if the recursive
    -- non-representable binder is inlined once, we can get rid of the recursive
    -- aspect using the case-of-known-constructor
    inlineNR :: NormRewrite
    inlineNR :: NormRewrite
inlineNR = NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "inlineNonRep" HasCallStack => NormRewrite
NormRewrite
inlineNonRep)

    specTransformations :: [(String,NormRewrite)]
    specTransformations :: [(String, NormRewrite)]
specTransformations =
      [ ("typeSpec"    , HasCallStack => NormRewrite
NormRewrite
typeSpec)
      , ("nonRepSpec"  , HasCallStack => NormRewrite
NormRewrite
nonRepSpec)
      ]

{- [Note] bottom-up traversal for liftNonRep
We used to say:

"The liftNonRep transformation must be applied in a topDown traversal because
of what Clash considers tail calls in its join-point analysis."

Consider:

> let fail = \x -> ...
> in  case ... of
>       A -> let fail1 = \y -> case ... of
>                                 X -> fail ...
>                                 Y -> ...
>            in case ... of
>                 P -> fail1 ...
>                 Q -> ...
>       B -> fail ...

under "normal" tail call rules, the local 'fail' functions is not a join-point
because it is used in a let-binding. However, we apply "special" tail call rules
in Clash. Because 'fail' is used in a TC position within 'fail1', and 'fail1' is
only used in a TC position, in Clash, we consider 'tail' also only to be used
in a TC position.

Now image we apply 'liftNonRep' in a bottom up traversal, we will end up with:

> fail1 = \fail y -> case ... of
>   X -> fail ...
>   Y -> ...

> let fail = \x -> ...
> in  case ... of
>       A -> case ... of
>                 P -> fail1 fail ...
>                 Q -> ...
>       B -> fail ...

Suddenly, 'fail' ends up in an argument position, because it occurred as a
_locally_ bound variable within 'fail1'. And because of that 'fail' stops being
a join-point.

However, when we apply 'liftNonRep' in a top down traversal we end up with:

> fail = \x -> ...
>
> fail1 = \y -> case ... of
>   X -> fail ...
>   Y -> ...
>
> let ...
> in  case ... of
>       A -> let
>            in case ... of
>                 P -> fail1 ...
>                 Q -> ...
>       B -> fail ...

and all is well with the world.

UPDATE:
We can now just perform liftNonRep in a bottom-up traversal again, because
liftNonRep no longer checks that if the binding that is lifted is a join-point.
However, for this to work, bindNonRep must always have been exhaustively applied
before liftNonRep. See also: [Note] bindNonRep before liftNonRep.
-}

{- [Note] bindNonRep before liftNonRep
The combination of liftNonRep and nonRepSpec can lead to non-termination in an
unchecked rewrite system (without termination measures in place) on the
following:

> main = f not
> f    = \a x -> (a x) && (f a x)

nonRepSpec will lead to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = (\a x -> (a x) && (f a x)) not

then lamApp leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = let a = not in (\x -> (a x) && (f a x))

then liftNonRep leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (g x) && (f g x)
> g    = not

and nonRepSepc leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (g x) && (f'' g x)
> g    = not
> f''  = (\a x -> (a x) && (f a x)) g

This cycle continues indefinitely, as liftNonRep creates a new global variable,
which is never alpha-equivalent to the previous global variable introduced by
liftNonRep.

That is why bindNonRep must always be applied before liftNonRep. When we end up
in the situation after lamApp:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = let a = not in (\x -> (a x) && (f a x))

bindNonRep will now lead to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (not x) && (f not x)

Because `f` has already been specialized on the alpha-equivalent-to-itself `not`
function, liftNonRep leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (not x) && (f' x)

And there is no non-terminating rewriting cycle.

That is why bindNonRep must always be exhaustively applied before we apply
liftNonRep.
-}

-- | Topdown traversal, stops upon first success
topdownSucR :: Rewrite extra -> Rewrite extra
topdownSucR :: Rewrite extra -> Rewrite extra
topdownSucR r :: Rewrite extra
r = Rewrite extra
r Rewrite extra -> Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m -> Rewrite m
>-! (Rewrite extra -> Rewrite extra
forall (m :: * -> *). Monad m => Transform m -> Transform m
allR (Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m
topdownSucR Rewrite extra
r))
{-# INLINE topdownSucR #-}

topdownRR :: Rewrite extra -> Rewrite extra
topdownRR :: Rewrite extra -> Rewrite extra
topdownRR r :: Rewrite extra
r = Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m
repeatR (Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m
topdownR Rewrite extra
r)
{-# INLINE topdownRR #-}

innerMost :: Rewrite extra -> Rewrite extra
innerMost :: Rewrite extra -> Rewrite extra
innerMost = let go :: Rewrite extra -> Rewrite extra
go r :: Rewrite extra
r = Rewrite extra -> Rewrite extra
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR (Rewrite extra
r Rewrite extra -> Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m
innerMost Rewrite extra
r) in Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m
go
{-# INLINE innerMost #-}

applyMany :: [(String,Rewrite extra)] -> Rewrite extra
applyMany :: [(String, Rewrite extra)] -> Rewrite extra
applyMany = (Rewrite extra -> Rewrite extra -> Rewrite extra)
-> [Rewrite extra] -> Rewrite extra
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Rewrite extra -> Rewrite extra -> Rewrite extra
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
(>->) ([Rewrite extra] -> Rewrite extra)
-> ([(String, Rewrite extra)] -> [Rewrite extra])
-> [(String, Rewrite extra)]
-> Rewrite extra
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Rewrite extra) -> Rewrite extra)
-> [(String, Rewrite extra)] -> [Rewrite extra]
forall a b. (a -> b) -> [a] -> [b]
map ((String -> Rewrite extra -> Rewrite extra)
-> (String, Rewrite extra) -> Rewrite extra
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry String -> Rewrite extra -> Rewrite extra
forall extra. String -> Rewrite extra -> Rewrite extra
apply)
{-# INLINE applyMany #-}