{-|
  Copyright  :  (C) 2012-2016, University of Twente
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Utility functions used by the normalisation transformations
-}

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Clash.Normalize.Util
 ( ConstantSpecInfo(..)
 , isConstantArg
 , shouldReduce
 , alreadyInlined
 , addNewInline
 , specializeNorm
 , isRecursiveBndr
 , isClosed
 , callGraph
 , collectCallGraphUniques
 , classifyFunction
 , isCheapFunction
 , isNonRecursiveGlobalVar
 , constantSpecInfo
 , normalizeTopLvlBndr
 , rewriteExpr
 , removedTm
 , mkInlineTick
 , substWithTyEq
 , tvSubstWithTyEq
 )
 where

import           Control.Lens            ((&),(+~),(%=),(.=))
import qualified Control.Lens            as Lens
import           Data.Bifunctor          (bimap)
import           Data.Either             (lefts)
import qualified Data.List               as List
import qualified Data.List.Extra         as List
import qualified Data.Map                as Map
import qualified Data.HashMap.Strict     as HashMapS
import qualified Data.HashSet            as HashSet
import           Data.Text               (Text)
import qualified Data.Text as Text

import           PrelNames               (eqTyConKey)
import           Unique                  (getKey)

import           Clash.Annotations.Primitive (extractPrim)
import           Clash.Core.FreeVars
  (globalIds, hasLocalFreeVars, globalIdOccursIn)
import           Clash.Core.Name         (Name(nameOcc,nameUniq))
import           Clash.Core.Pretty       (showPpr)
import           Clash.Core.Subst
  (deShadowTerm, extendTvSubst, extendTvSubstList, mkSubst, substTm, substTy,
   substId, extendIdSubst)
import           Clash.Core.Term
import           Clash.Core.TermInfo     (isPolyFun, termType)
import           Clash.Core.TyCon        (TyConMap)
import           Clash.Core.Type
  (Type(LitTy, VarTy), LitTy(SymTy), TypeView (..), tyView, undefinedTy,
   splitFunForallTy, splitTyConAppM, mkPolyFunTy)
import           Clash.Core.Util
  (isClockOrReset)
import           Clash.Core.Var          (Id, TyVar, Var (..), isGlobalId)
import           Clash.Core.VarEnv
  (VarEnv, emptyInScopeSet, emptyVarEnv, extendVarEnv, extendVarEnvWith,
   lookupVarEnv, unionVarEnvWith, unitVarEnv, extendInScopeSetList)
import           Clash.Debug             (traceIf)
import           Clash.Driver.Types      (BindingMap, Binding(..), DebugLevel (..))
import {-# SOURCE #-} Clash.Normalize.Strategy (normalization)
import           Clash.Normalize.Types
import           Clash.Primitives.Util   (constantArgs)
import           Clash.Rewrite.Types
  (RewriteMonad, TransformContext(..), bindings, curFun, dbgLevel, extra,
   tcCache)
import           Clash.Rewrite.Util
  (runRewrite, specialise, mkTmBinderFor, mkDerivedName)
import           Clash.Unique
import           Clash.Util              (SrcSpan, makeCachedU)

-- | Determine if argument should reduce to a constant given a primitive and
-- an argument number. Caches results.
isConstantArg
  :: Text
  -- ^ Primitive name
  -> Int
  -- ^ Argument number
  -> RewriteMonad NormalizeState Bool
  -- ^ Yields @DontCare@ for if given primitive name is not found, if the
  -- argument does not exist, or if the argument was not mentioned by the
  -- blackbox.
isConstantArg :: Text -> Int -> RewriteMonad NormalizeState Bool
isConstantArg Text
"Clash.Explicit.SimIO.mealyIO" Int
i = Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3)
isConstantArg Text
nm Int
i = do
  Map Text (Set Int)
argMap <- Getting
  (Map Text (Set Int))
  (RewriteState NormalizeState)
  (Map Text (Set Int))
-> RewriteMonad NormalizeState (Map Text (Set Int))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const (Map Text (Set Int)) NormalizeState)
-> RewriteState NormalizeState
-> Const (Map Text (Set Int)) (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Const (Map Text (Set Int)) NormalizeState)
 -> RewriteState NormalizeState
 -> Const (Map Text (Set Int)) (RewriteState NormalizeState))
-> ((Map Text (Set Int)
     -> Const (Map Text (Set Int)) (Map Text (Set Int)))
    -> NormalizeState -> Const (Map Text (Set Int)) NormalizeState)
-> Getting
     (Map Text (Set Int))
     (RewriteState NormalizeState)
     (Map Text (Set Int))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map Text (Set Int)
 -> Const (Map Text (Set Int)) (Map Text (Set Int)))
-> NormalizeState -> Const (Map Text (Set Int)) NormalizeState
Lens' NormalizeState (Map Text (Set Int))
primitiveArgs)
  case Text -> Map Text (Set Int) -> Maybe (Set Int)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
nm Map Text (Set Int)
argMap of
    Maybe (Set Int)
Nothing -> do
      -- Constant args not yet calculated, or primitive does not exist
      CompiledPrimMap
prims <- Getting
  CompiledPrimMap (RewriteState NormalizeState) CompiledPrimMap
-> RewriteMonad NormalizeState CompiledPrimMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const CompiledPrimMap NormalizeState)
-> RewriteState NormalizeState
-> Const CompiledPrimMap (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Const CompiledPrimMap NormalizeState)
 -> RewriteState NormalizeState
 -> Const CompiledPrimMap (RewriteState NormalizeState))
-> ((CompiledPrimMap -> Const CompiledPrimMap CompiledPrimMap)
    -> NormalizeState -> Const CompiledPrimMap NormalizeState)
-> Getting
     CompiledPrimMap (RewriteState NormalizeState) CompiledPrimMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(CompiledPrimMap -> Const CompiledPrimMap CompiledPrimMap)
-> NormalizeState -> Const CompiledPrimMap NormalizeState
Lens' NormalizeState CompiledPrimMap
primitives)
      case PrimitiveGuard CompiledPrimitive -> Maybe CompiledPrimitive
forall a. PrimitiveGuard a -> Maybe a
extractPrim (PrimitiveGuard CompiledPrimitive -> Maybe CompiledPrimitive)
-> Maybe (PrimitiveGuard CompiledPrimitive)
-> Maybe CompiledPrimitive
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> CompiledPrimMap -> Maybe (PrimitiveGuard CompiledPrimitive)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HashMapS.lookup Text
nm CompiledPrimMap
prims of
        Maybe CompiledPrimitive
Nothing ->
          -- Primitive does not exist:
          Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
        Just CompiledPrimitive
p -> do
          -- Calculate constant arguments:
          let m :: Set Int
m = Text -> CompiledPrimitive -> Set Int
constantArgs Text
nm CompiledPrimitive
p
          ((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Identity NormalizeState)
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> ((Map Text (Set Int) -> Identity (Map Text (Set Int)))
    -> NormalizeState -> Identity NormalizeState)
-> (Map Text (Set Int) -> Identity (Map Text (Set Int)))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map Text (Set Int) -> Identity (Map Text (Set Int)))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (Map Text (Set Int))
primitiveArgs) ((Map Text (Set Int) -> Identity (Map Text (Set Int)))
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> (Map Text (Set Int) -> Map Text (Set Int))
-> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
Lens.%= Text -> Set Int -> Map Text (Set Int) -> Map Text (Set Int)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Text
nm Set Int
m
          Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Int
i Int -> Set Int -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` Set Int
m)
    Just Set Int
m ->
      -- Cached version found
      Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Int
i Int -> Set Int -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` Set Int
m)

-- | Given a list of transformation contexts, determine if any of the contexts
-- indicates that the current arg is to be reduced to a constant / literal.
shouldReduce
  :: Context
  -- ^ ..in the current transformcontext
  -> RewriteMonad NormalizeState Bool
shouldReduce :: Context -> RewriteMonad NormalizeState Bool
shouldReduce = (CoreContext -> RewriteMonad NormalizeState Bool)
-> Context -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
List.anyM CoreContext -> RewriteMonad NormalizeState Bool
isConstantArg'
  where
    isConstantArg' :: CoreContext -> RewriteMonad NormalizeState Bool
isConstantArg' (AppArg (Just (Text
nm, Int
_, Int
i))) = Text -> Int -> RewriteMonad NormalizeState Bool
isConstantArg Text
nm Int
i
    isConstantArg' CoreContext
_ = Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False

-- | Determine if a function is already inlined in the context of the 'NetlistMonad'
alreadyInlined
  :: Id
  -- ^ Function we want to inline
  -> Id
  -- ^ Function in which we want to perform the inlining
  -> NormalizeMonad (Maybe Int)
alreadyInlined :: Id -> Id -> NormalizeMonad (Maybe Int)
alreadyInlined Id
f Id
cf = do
  VarEnv (VarEnv Int)
inlinedHM <- Getting (VarEnv (VarEnv Int)) NormalizeState (VarEnv (VarEnv Int))
-> StateT NormalizeState Identity (VarEnv (VarEnv Int))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv (VarEnv Int)) NormalizeState (VarEnv (VarEnv Int))
Lens' NormalizeState (VarEnv (VarEnv Int))
inlineHistory
  case Id -> VarEnv (VarEnv Int) -> Maybe (VarEnv Int)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
cf VarEnv (VarEnv Int)
inlinedHM of
    Maybe (VarEnv Int)
Nothing       -> Maybe Int -> NormalizeMonad (Maybe Int)
forall (m :: Type -> Type) a. Monad m => a -> m a
return Maybe Int
forall a. Maybe a
Nothing
    Just VarEnv Int
inlined' -> Maybe Int -> NormalizeMonad (Maybe Int)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id -> VarEnv Int -> Maybe Int
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv Int
inlined')

addNewInline
  :: Id
  -- ^ Function we want to inline
  -> Id
  -- ^ Function in which we want to perform the inlining
  -> NormalizeMonad ()
addNewInline :: Id -> Id -> NormalizeMonad ()
addNewInline Id
f Id
cf =
  (VarEnv (VarEnv Int) -> Identity (VarEnv (VarEnv Int)))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (VarEnv (VarEnv Int))
inlineHistory ((VarEnv (VarEnv Int) -> Identity (VarEnv (VarEnv Int)))
 -> NormalizeState -> Identity NormalizeState)
-> (VarEnv (VarEnv Int) -> VarEnv (VarEnv Int))
-> NormalizeMonad ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id
-> VarEnv Int
-> (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv (VarEnv Int)
-> VarEnv (VarEnv Int)
forall b a. Var b -> a -> (a -> a -> a) -> VarEnv a -> VarEnv a
extendVarEnvWith
                     Id
cf
                     (Id -> Int -> VarEnv Int
forall b a. Var b -> a -> VarEnv a
unitVarEnv Id
f Int
1)
                     (\VarEnv Int
_ VarEnv Int
hm -> Id -> Int -> (Int -> Int -> Int) -> VarEnv Int -> VarEnv Int
forall b a. Var b -> a -> (a -> a -> a) -> VarEnv a -> VarEnv a
extendVarEnvWith Id
f Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) VarEnv Int
hm)

-- | Specialize under the Normalization Monad
specializeNorm :: NormRewrite
specializeNorm :: NormRewrite
specializeNorm = Lens' NormalizeState (Map (Id, Int, Either Term Type) Id)
-> Lens' NormalizeState (VarEnv Int)
-> Lens' NormalizeState Int
-> NormRewrite
forall extra.
Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int) -> Lens' extra Int -> Rewrite extra
specialise Lens' NormalizeState (Map (Id, Int, Either Term Type) Id)
specialisationCache Lens' NormalizeState (VarEnv Int)
specialisationHistory Lens' NormalizeState Int
specialisationLimit

-- | Determine if a term is closed
isClosed :: TyConMap
         -> Term
         -> Bool
isClosed :: TyConMap -> Term -> Bool
isClosed TyConMap
tcm = Bool -> Bool
not (Bool -> Bool) -> (Term -> Bool) -> Term -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Term -> Bool
isPolyFun TyConMap
tcm

-- | Test whether a given term represents a non-recursive global variable
isNonRecursiveGlobalVar
  :: Term
  -> NormalizeSession Bool
isNonRecursiveGlobalVar :: Term -> RewriteMonad NormalizeState Bool
isNonRecursiveGlobalVar (Term -> (Term, [Either Term Type])
collectArgs -> (Var Id
i, [Either Term Type]
_args)) = do
  let eIsGlobal :: Bool
eIsGlobal = Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
i
  Bool
eIsRec    <- Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
i
  Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bool
eIsGlobal Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
eIsRec)
isNonRecursiveGlobalVar Term
_ = Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False

-- | Assert whether a name is a reference to a recursive binder.
isRecursiveBndr
  :: Id
  -> NormalizeSession Bool
isRecursiveBndr :: Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
f = do
  VarEnv Bool
cg <- Getting (VarEnv Bool) (RewriteState NormalizeState) (VarEnv Bool)
-> RewriteMonad NormalizeState (VarEnv Bool)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> RewriteState NormalizeState
-> Const (VarEnv Bool) (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Const (VarEnv Bool) NormalizeState)
 -> RewriteState NormalizeState
 -> Const (VarEnv Bool) (RewriteState NormalizeState))
-> ((VarEnv Bool -> Const (VarEnv Bool) (VarEnv Bool))
    -> NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> Getting
     (VarEnv Bool) (RewriteState NormalizeState) (VarEnv Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Bool -> Const (VarEnv Bool) (VarEnv Bool))
-> NormalizeState -> Const (VarEnv Bool) NormalizeState
Lens' NormalizeState (VarEnv Bool)
recursiveComponents)
  case Id -> VarEnv Bool -> Maybe Bool
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv Bool
cg of
    Just Bool
isR -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
isR
    Maybe Bool
Nothing -> do
      Maybe Binding
fBodyM <- Id -> VarEnv Binding -> Maybe Binding
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f (VarEnv Binding -> Maybe Binding)
-> RewriteMonad NormalizeState (VarEnv Binding)
-> RewriteMonad NormalizeState (Maybe Binding)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (VarEnv Binding) (RewriteState NormalizeState) (VarEnv Binding)
-> RewriteMonad NormalizeState (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
  (VarEnv Binding) (RewriteState NormalizeState) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
      case Maybe Binding
fBodyM of
        Maybe Binding
Nothing -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False
        Just Binding
b -> do
          -- There are no global mutually-recursive functions, only self-recursive
          -- ones, so checking whether 'f' is part of the free variables of the
          -- body of 'f' is sufficient.
          let isR :: Bool
isR = Id
f Id -> Term -> Bool
`globalIdOccursIn` Binding -> Term
bindingTerm Binding
b
          ((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Identity NormalizeState)
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> ((VarEnv Bool -> Identity (VarEnv Bool))
    -> NormalizeState -> Identity NormalizeState)
-> (VarEnv Bool -> Identity (VarEnv Bool))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Bool -> Identity (VarEnv Bool))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (VarEnv Bool)
recursiveComponents) ((VarEnv Bool -> Identity (VarEnv Bool))
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> (VarEnv Bool -> VarEnv Bool) -> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id -> Bool -> VarEnv Bool -> VarEnv Bool
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
f Bool
isR
          Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
isR

data ConstantSpecInfo =
  ConstantSpecInfo
    { ConstantSpecInfo -> [(Id, Term)]
csrNewBindings :: [(Id, Term)]
    -- ^ New let-bindings to be created for all the non-constants found
    , ConstantSpecInfo -> Term
csrNewTerm :: !Term
    -- ^ A term where all the non-constant constructs are replaced by variable
    -- references (found in 'csrNewBindings')
    , ConstantSpecInfo -> Bool
csrFoundConstant :: !Bool
    -- ^ Whether the algorithm found a constant at all. (If it didn't, it's no
    -- use creating any new let-bindings!)
    } deriving (Int -> ConstantSpecInfo -> ShowS
[ConstantSpecInfo] -> ShowS
ConstantSpecInfo -> String
(Int -> ConstantSpecInfo -> ShowS)
-> (ConstantSpecInfo -> String)
-> ([ConstantSpecInfo] -> ShowS)
-> Show ConstantSpecInfo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConstantSpecInfo] -> ShowS
$cshowList :: [ConstantSpecInfo] -> ShowS
show :: ConstantSpecInfo -> String
$cshow :: ConstantSpecInfo -> String
showsPrec :: Int -> ConstantSpecInfo -> ShowS
$cshowsPrec :: Int -> ConstantSpecInfo -> ShowS
Show)

-- | Indicate term is fully constant (don't bind anything)
constantCsr :: Term -> ConstantSpecInfo
constantCsr :: Term -> ConstantSpecInfo
constantCsr Term
t = [(Id, Term)] -> Term -> Bool -> ConstantSpecInfo
ConstantSpecInfo [] Term
t Bool
True

-- | Bind given term to a new variable and indicate that it's fully non-constant
bindCsr
  :: TransformContext
  -> Term
  -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr :: TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Term
oldTerm = do
  -- TODO: Seems like the need to put global ids in scope has been made obsolete
  -- TODO: by a recent change in Clash. Investigate whether this is true.
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  Id
newId <- InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type) a.
(MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is0 TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx Text
"bindCsr") Term
oldTerm
  ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (ConstantSpecInfo :: [(Id, Term)] -> Term -> Bool -> ConstantSpecInfo
ConstantSpecInfo
    { csrNewBindings :: [(Id, Term)]
csrNewBindings = [(Id
newId, Term
oldTerm)]
    , csrNewTerm :: Term
csrNewTerm = Id -> Term
Var Id
newId
    , csrFoundConstant :: Bool
csrFoundConstant = Bool
False
    })

mergeCsrs
  :: TransformContext
  -> [TickInfo]
  -- ^ Ticks to wrap around proposed new term
  -> Term
  -- ^ "Old" term
  -> ([Either Term Type] -> Term)
  -- ^ Proposed new term in case any constants were found
  -> [Either Term Type]
  -- ^ Subterms
  -> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs :: TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
oldTerm [Either Term Type] -> Term
proposedTerm [Either Term Type]
subTerms = do
  [Either ConstantSpecInfo Type]
subCsrs <- (TransformContext, [Either ConstantSpecInfo Type])
-> [Either ConstantSpecInfo Type]
forall a b. (a, b) -> b
snd ((TransformContext, [Either ConstantSpecInfo Type])
 -> [Either ConstantSpecInfo Type])
-> RewriteMonad
     NormalizeState (TransformContext, [Either ConstantSpecInfo Type])
-> RewriteMonad NormalizeState [Either ConstantSpecInfo Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (TransformContext
 -> Either Term Type
 -> RewriteMonad
      NormalizeState (TransformContext, Either ConstantSpecInfo Type))
-> TransformContext
-> [Either Term Type]
-> RewriteMonad
     NormalizeState (TransformContext, [Either ConstantSpecInfo Type])
forall (m :: Type -> Type) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
List.mapAccumLM TransformContext
-> Either Term Type
-> RewriteMonad
     NormalizeState (TransformContext, Either ConstantSpecInfo Type)
constantSpecInfoFolder TransformContext
ctx [Either Term Type]
subTerms

  -- If any arguments are constant (and hence can be constant specced), a new
  -- term is created with these constants left in, but variable parts let-bound.
  -- There's one edge case: whenever a term has _no_ arguments. This happens for
  -- constructors without fields, or -depending on their WorkInfo- primitives
  -- without args. We still set 'csrFoundConstant', because we know the newly
  -- proposed term will be fully constant.
  let
    anyArgsOrResultConstant :: Bool
anyArgsOrResultConstant =
      [ConstantSpecInfo] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ([Either ConstantSpecInfo Type] -> [ConstantSpecInfo]
forall a b. [Either a b] -> [a]
lefts [Either ConstantSpecInfo Type]
subCsrs) Bool -> Bool -> Bool
|| (ConstantSpecInfo -> Bool) -> [ConstantSpecInfo] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any ConstantSpecInfo -> Bool
csrFoundConstant ([Either ConstantSpecInfo Type] -> [ConstantSpecInfo]
forall a b. [Either a b] -> [a]
lefts [Either ConstantSpecInfo Type]
subCsrs)

  if Bool
anyArgsOrResultConstant then
    let newTerm :: Term
newTerm = [Either Term Type] -> Term
proposedTerm ((ConstantSpecInfo -> Term)
-> (Type -> Type)
-> Either ConstantSpecInfo Type
-> Either Term Type
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ConstantSpecInfo -> Term
csrNewTerm Type -> Type
forall a. a -> a
id (Either ConstantSpecInfo Type -> Either Term Type)
-> [Either ConstantSpecInfo Type] -> [Either Term Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Either ConstantSpecInfo Type]
subCsrs)  in
    ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (ConstantSpecInfo :: [(Id, Term)] -> Term -> Bool -> ConstantSpecInfo
ConstantSpecInfo
      { csrNewBindings :: [(Id, Term)]
csrNewBindings = (ConstantSpecInfo -> [(Id, Term)])
-> [ConstantSpecInfo] -> [(Id, Term)]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ([Either ConstantSpecInfo Type] -> [ConstantSpecInfo]
forall a b. [Either a b] -> [a]
lefts [Either ConstantSpecInfo Type]
subCsrs)
      , csrNewTerm :: Term
csrNewTerm = Term -> [TickInfo] -> Term
mkTicks Term
newTerm [TickInfo]
ticks
      , csrFoundConstant :: Bool
csrFoundConstant = Bool
True
      })
  else do
    -- No constructs were found to be constant, so we might as well refer to the
    -- whole thing with a new let-binding (instead of creating a number of
    -- "smaller" let-bindings)
    TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
oldTerm

 where
  constantSpecInfoFolder
    :: TransformContext
    -> Either Term Type
    -> RewriteMonad NormalizeState (TransformContext, Either ConstantSpecInfo Type)
  constantSpecInfoFolder :: TransformContext
-> Either Term Type
-> RewriteMonad
     NormalizeState (TransformContext, Either ConstantSpecInfo Type)
constantSpecInfoFolder TransformContext
localCtx (Right Type
typ) =
    (TransformContext, Either ConstantSpecInfo Type)
-> RewriteMonad
     NormalizeState (TransformContext, Either ConstantSpecInfo Type)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TransformContext
localCtx, Type -> Either ConstantSpecInfo Type
forall a b. b -> Either a b
Right Type
typ)
  constantSpecInfoFolder localCtx :: TransformContext
localCtx@(TransformContext InScopeSet
is0 Context
tfCtx) (Left Term
term) = do
    ConstantSpecInfo
specInfo <- TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo TransformContext
localCtx Term
term
    let newIds :: [Id]
newIds = ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst (ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
specInfo)
    let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
newIds
    (TransformContext, Either ConstantSpecInfo Type)
-> RewriteMonad
     NormalizeState (TransformContext, Either ConstantSpecInfo Type)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is1 Context
tfCtx, ConstantSpecInfo -> Either ConstantSpecInfo Type
forall a b. a -> Either a b
Left ConstantSpecInfo
specInfo)


-- | Calculate constant spec info. The goal of this function is to analyze a
-- given term and yield a new term that:
--
--  * Leaves all the constant parts as they were.
--  * Has all _variable_ parts replaced by a newly generated identifier.
--
-- The result structure will additionally contain:
--
--  * Whether the function found any constant parts at all
--  * A list of let-bindings binding the aforementioned identifiers with
--    the term they replaced.
--
-- This can be used in functions wanting to constant specialize over
-- partially constant data structures.
constantSpecInfo
  :: TransformContext
  -> Term
  -> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo :: TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo TransformContext
ctx Term
e = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  -- Don't constant spec clocks or resets, they're either:
  --
  --  * A simple wire (Var), therefore not interesting to spec
  --  * A clock/reset generator, and speccing a generator weirds out HDL simulators.
  --
  -- I believe we can remove this special case in the future by looking at the
  -- primitive's workinfo.
  if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm (TyConMap -> Term -> Type
termType TyConMap
tcm Term
e) then
    case Term -> (Term, [Either Term Type])
collectArgs Term
e of
      (Prim PrimInfo
p, [Either Term Type]
_)
        | PrimInfo -> Text
primName PrimInfo
p Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Transformations.removedArg" ->
          ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> ConstantSpecInfo
constantCsr Term
e)
      (Term, [Either Term Type])
_ -> TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
  else
    case Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e of
      (dc :: Term
dc@(Data DataCon
_), [Either Term Type]
args, [TickInfo]
ticks) ->
        TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
e (Term -> [Either Term Type] -> Term
mkApps Term
dc) [Either Term Type]
args

      -- TODO: Work with prim's WorkInfo?
      (prim :: Term
prim@(Prim PrimInfo
_), [Either Term Type]
args, [TickInfo]
ticks) -> do
        ConstantSpecInfo
csr <- TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
e (Term -> [Either Term Type] -> Term
mkApps Term
prim) [Either Term Type]
args
        if [(Id, Term)] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null (ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
csr) then
          ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ConstantSpecInfo
csr
        else
          TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e

      (Lam Id
_ Term
_, [Either Term Type]
_, [TickInfo]
_ticks) ->
        if Term -> Bool
hasLocalFreeVars Term
e then
          TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
        else
          ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> ConstantSpecInfo
constantCsr Term
e)

      (var :: Term
var@(Var Id
f), [Either Term Type]
args, [TickInfo]
ticks) -> do
        (Id
curF, SrcSpan
_) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
        Bool
isNonRecGlobVar <- Term -> RewriteMonad NormalizeState Bool
isNonRecursiveGlobalVar Term
e
        if Bool
isNonRecGlobVar Bool -> Bool -> Bool
&& Id
f Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
/= Id
curF then do
          ConstantSpecInfo
csr <- TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
e (Term -> [Either Term Type] -> Term
mkApps Term
var) [Either Term Type]
args
          if [(Id, Term)] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null (ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
csr) then
            ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ConstantSpecInfo
csr
          else
            TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
        else
          TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e

      (Literal Literal
_,[Either Term Type]
_, [TickInfo]
_ticks) ->
        ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> ConstantSpecInfo
constantCsr Term
e)

      (Term, [Either Term Type], [TickInfo])
_ ->
        TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e

-- | A call graph counts the number of occurrences that a functions 'g' is used
-- in 'f'.
type CallGraph = VarEnv (VarEnv Word)

-- | Collect all binders mentioned in CallGraph into a HashSet
collectCallGraphUniques :: CallGraph -> HashSet.HashSet Unique
collectCallGraphUniques :: CallGraph -> HashSet Int
collectCallGraphUniques CallGraph
cg = [Int] -> HashSet Int
forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList ([Int]
us0 [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int]
us1)
 where
  us0 :: [Int]
us0 = CallGraph -> [Int]
forall a. UniqMap a -> [Int]
keysUniqMap CallGraph
cg
  us1 :: [Int]
us1 = (UniqMap Word -> [Int]) -> [UniqMap Word] -> [Int]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap UniqMap Word -> [Int]
forall a. UniqMap a -> [Int]
keysUniqMap (CallGraph -> [UniqMap Word]
forall a. UniqMap a -> [a]
eltsUniqMap CallGraph
cg)

-- | Create a call graph for a set of global binders, given a root
callGraph
  :: BindingMap
  -> Id
  -> CallGraph
callGraph :: VarEnv Binding -> Id -> CallGraph
callGraph VarEnv Binding
bndrs Id
rt = CallGraph -> Int -> CallGraph
forall a. Num a => UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a)
go CallGraph
forall a. VarEnv a
emptyVarEnv (Id -> Int
forall a. Var a -> Int
varUniq Id
rt)
  where
    go :: UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a)
go UniqMap (VarEnv a)
cg Int
root
      | Maybe (VarEnv a)
Nothing     <- Int -> UniqMap (VarEnv a) -> Maybe (VarEnv a)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap Int
root UniqMap (VarEnv a)
cg
      , Just Binding
rootTm <- Int -> VarEnv Binding -> Maybe Binding
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap Int
root VarEnv Binding
bndrs =
      let used :: VarEnv a
used = Fold Term Id
-> (VarEnv a -> VarEnv a -> VarEnv a)
-> VarEnv a
-> (Id -> VarEnv a)
-> Term
-> VarEnv a
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf Fold Term Id
globalIds ((a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith a -> a -> a
forall a. Num a => a -> a -> a
(+))
                  VarEnv a
forall a. VarEnv a
emptyVarEnv (Id -> a -> VarEnv a
forall a b. Uniquable a => a -> b -> UniqMap b
`unitUniqMap` a
1) (Binding -> Term
bindingTerm Binding
rootTm)
          cg' :: UniqMap (VarEnv a)
cg'  = Int -> VarEnv a -> UniqMap (VarEnv a) -> UniqMap (VarEnv a)
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap Int
root VarEnv a
used UniqMap (VarEnv a)
cg
      in  (UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a))
-> UniqMap (VarEnv a) -> [Int] -> UniqMap (VarEnv a)
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a)
go UniqMap (VarEnv a)
cg' (VarEnv a -> [Int]
forall a. UniqMap a -> [Int]
keysUniqMap VarEnv a
used)
    go UniqMap (VarEnv a)
cg Int
_ = UniqMap (VarEnv a)
cg

-- | Give a "performance/size" classification of a function in normal form.
classifyFunction
  :: Term
  -> TermClassification
classifyFunction :: Term -> TermClassification
classifyFunction = TermClassification -> Term -> TermClassification
go (Int -> Int -> Int -> TermClassification
TermClassification Int
0 Int
0 Int
0)
  where
    go :: TermClassification -> Term -> TermClassification
go !TermClassification
c (Lam Id
_ Term
e)     = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
    go !TermClassification
c (TyLam TyVar
_ Term
e)   = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
    go !TermClassification
c (Letrec [(Id, Term)]
bs Term
_) = (TermClassification -> Term -> TermClassification)
-> TermClassification -> [Term] -> TermClassification
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' TermClassification -> Term -> TermClassification
go TermClassification
c (((Id, Term) -> Term) -> [(Id, Term)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Term
forall a b. (a, b) -> b
snd [(Id, Term)]
bs)
    go !TermClassification
c e :: Term
e@(App {}) = case (Term, [Either Term Type]) -> Term
forall a b. (a, b) -> a
fst (Term -> (Term, [Either Term Type])
collectArgs Term
e) of
      Prim {} -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Int
primitive ((Int -> Identity Int)
 -> TermClassification -> Identity TermClassification)
-> Int -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
      Var {}  -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Int
function ((Int -> Identity Int)
 -> TermClassification -> Identity TermClassification)
-> Int -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
      Term
_ -> TermClassification
c
    go !TermClassification
c (Case Term
_ Type
_ [Alt]
alts) = case [Alt]
alts of
      (Alt
_:Alt
_:[Alt]
_) -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Int
selection  ((Int -> Identity Int)
 -> TermClassification -> Identity TermClassification)
-> Int -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
      [Alt]
_ -> TermClassification
c
    go !TermClassification
c (Tick TickInfo
_ Term
e) = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
    go TermClassification
c Term
_ = TermClassification
c

-- | Determine whether a function adds a lot of hardware or not.
--
-- It is considered expensive when it has 2 or more of the following components:
--
-- * functions
-- * primitives
-- * selections (multiplexers)
isCheapFunction
  :: Term
  -> Bool
isCheapFunction :: Term -> Bool
isCheapFunction Term
tm = case Term -> TermClassification
classifyFunction Term
tm of
  TermClassification {Int
_selection :: TermClassification -> Int
_primitive :: TermClassification -> Int
_function :: TermClassification -> Int
_selection :: Int
_primitive :: Int
_function :: Int
..}
    | Int
_function  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 -> Int
_primitive Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
&& Int
_selection Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
    | Int
_primitive Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 -> Int
_function  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
&& Int
_selection Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
    | Int
_selection Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 -> Int
_function  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
&& Int
_primitive Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
    | Bool
otherwise       -> Bool
False

normalizeTopLvlBndr
  :: Bool
  -> Id
  -> Binding
  -> NormalizeSession Binding
normalizeTopLvlBndr :: Bool -> Id -> Binding -> NormalizeSession Binding
normalizeTopLvlBndr Bool
isTop Id
nm (Binding Id
nm' SrcSpan
sp InlineSpec
inl Term
tm) = Id
-> Lens' (RewriteState NormalizeState) (VarEnv Binding)
-> NormalizeSession Binding
-> NormalizeSession Binding
forall s (m :: Type -> Type) k v.
(MonadState s m, Uniquable k) =>
k -> Lens' s (UniqMap v) -> m v -> m v
makeCachedU Id
nm ((NormalizeState -> f NormalizeState)
-> RewriteState NormalizeState -> f (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> f NormalizeState)
 -> RewriteState NormalizeState -> f (RewriteState NormalizeState))
-> ((VarEnv Binding -> f (VarEnv Binding))
    -> NormalizeState -> f NormalizeState)
-> (VarEnv Binding -> f (VarEnv Binding))
-> RewriteState NormalizeState
-> f (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Binding -> f (VarEnv Binding))
-> NormalizeState -> f NormalizeState
Lens' NormalizeState (VarEnv Binding)
normalized) (NormalizeSession Binding -> NormalizeSession Binding)
-> NormalizeSession Binding -> NormalizeSession Binding
forall a b. (a -> b) -> a -> b
$ do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let nmS :: String
nmS = Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)
  -- We deshadow the term because sometimes GHC gives us
  -- code where a local binder has the same unique as a
  -- global binder, sometimes causing the inliner to go
  -- into a loop. Deshadowing freshens all the bindings
  -- to avoid this.
  let tm1 :: Term
tm1 = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
emptyInScopeSet Term
tm
      tm2 :: Term
tm2 = if Bool
isTop then Term -> Term
substWithTyEq Term
tm1 else Term
tm1
  (Id, SrcSpan)
old <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
  Term
tm3 <- (String, NormRewrite)
-> (String, Term) -> (Id, SrcSpan) -> NormalizeSession Term
rewriteExpr (String
"normalization",NormRewrite
normalization) (String
nmS,Term
tm2) (Id
nm',SrcSpan
sp)
  ((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun (((Id, SrcSpan) -> Identity (Id, SrcSpan))
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> (Id, SrcSpan) -> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= (Id, SrcSpan)
old
  let ty' :: Type
ty' = TyConMap -> Term -> Type
termType TyConMap
tcm Term
tm3
  Binding -> NormalizeSession Binding
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id -> SrcSpan -> InlineSpec -> Term -> Binding
Binding Id
nm'{varType :: Type
varType = Type
ty'} SrcSpan
sp InlineSpec
inl Term
tm3)

-- | Turn type equality constraints into substitutions and apply them.
--
-- So given:
--
-- > /\dom . \(eq : dom ~ "System") . \(eta : Signal dom Bool) . eta
--
-- we create the substitution [dom := "System"] and apply it to create:
--
-- > \(eq : "System" ~ "System") . \(eta : Signal "System" Bool) . eta
--
-- __NB:__ Users of this function should ensure it's only applied to TopEntities
substWithTyEq
  :: Term
  -> Term
substWithTyEq :: Term -> Term
substWithTyEq Term
e0 = [TyVar] -> Bool -> [Id] -> Term -> Term
go [] Bool
False [] Term
e0
 where
  go
    :: [TyVar]
    -> Bool
    -> [Id]
    -> Term
    -> Term
  go :: [TyVar] -> Bool -> [Id] -> Term -> Term
go [TyVar]
tvs Bool
changed [Id]
ids_ (TyLam TyVar
tv Term
e) = [TyVar] -> Bool -> [Id] -> Term -> Term
go (TyVar
tvTyVar -> [TyVar] -> [TyVar]
forall a. a -> [a] -> [a]
:[TyVar]
tvs) Bool
changed [Id]
ids_ Term
e
  go [TyVar]
tvs Bool
changed [Id]
ids_ (Lam Id
v Term
e)
    | TyConApp (TyConName -> Int
forall a. Name a -> Int
nameUniq -> Int
tcUniq) ([Type] -> Maybe (TyVar, Type)
tvFirst -> Just (TyVar
tv, Type
ty)) <- Type -> TypeView
tyView (Id -> Type
forall a. Var a -> Type
varType Id
v)
    , Int
tcUniq Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Unique -> Int
getKey Unique
eqTyConKey
    , TyVar
tv TyVar -> [TyVar] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [TyVar]
tvs
    = let
        subst0 :: Subst
subst0 = Subst -> TyVar -> Type -> Subst
extendTvSubst (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet) TyVar
tv Type
ty
        subst1 :: Subst
subst1 = Subst -> Id -> Term -> Subst
extendIdSubst Subst
subst0 Id
v (Type -> Term
removedTm (Id -> Type
forall a. Var a -> Type
varType Id
v))
      in [TyVar] -> Bool -> [Id] -> Term -> Term
go ([TyVar]
tvs [TyVar] -> [TyVar] -> [TyVar]
forall a. Eq a => [a] -> [a] -> [a]
List.\\ [TyVar
tv]) Bool
True (HasCallStack => Subst -> Id -> Id
Subst -> Id -> Id
substId Subst
subst0 Id
v Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
ids_) (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substWithTyEq e" Subst
subst1 Term
e)
    | Bool
otherwise = [TyVar] -> Bool -> [Id] -> Term -> Term
go [TyVar]
tvs Bool
changed (Id
vId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
ids_) Term
e
  go [TyVar]
tvs Bool
True [Id]
ids_ Term
e =
    let
      e1 :: Term
e1 = (Term -> TyVar -> Term) -> Term -> [TyVar] -> Term
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' ((TyVar -> Term -> Term) -> Term -> TyVar -> Term
forall a b c. (a -> b -> c) -> b -> a -> c
flip TyVar -> Term -> Term
TyLam) Term
e [TyVar]
tvs
      e2 :: Term
e2 = (Term -> Id -> Term) -> Term -> [Id] -> Term
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' ((Id -> Term -> Term) -> Term -> Id -> Term
forall a b c. (a -> b -> c) -> b -> a -> c
flip Id -> Term -> Term
Lam) Term
e1 [Id]
ids_
    in Term
e2
  go [TyVar]
_ Bool
False [Id]
_ Term
_ = Term
e0

-- Type equality (~) is symmetrical, so users could write: (dom ~ System) or (System ~ dom)
tvFirst :: [Type] -> Maybe (TyVar, Type)
tvFirst :: [Type] -> Maybe (TyVar, Type)
tvFirst [Type
_, VarTy TyVar
tv, Type
ty] = (TyVar, Type) -> Maybe (TyVar, Type)
forall a. a -> Maybe a
Just (TyVar
tv, Type
ty)
tvFirst [Type
_, Type
ty, VarTy TyVar
tv] = (TyVar, Type) -> Maybe (TyVar, Type)
forall a. a -> Maybe a
Just (TyVar
tv, Type
ty)
tvFirst [Type]
_ = Maybe (TyVar, Type)
forall a. Maybe a
Nothing

-- | The type equivalent of 'substWithTyEq'
tvSubstWithTyEq
  :: Type
  -> Type
tvSubstWithTyEq :: Type -> Type
tvSubstWithTyEq Type
ty0 = [(TyVar, Type)] -> [Either TyVar Type] -> Type
go [] [Either TyVar Type]
args0
 where
  ([Either TyVar Type]
args0,Type
tyRes) = Type -> ([Either TyVar Type], Type)
splitFunForallTy Type
ty0

  go :: [(TyVar,Type)] -> [Either TyVar Type] -> Type
  go :: [(TyVar, Type)] -> [Either TyVar Type] -> Type
go [(TyVar, Type)]
eqs (Right Type
arg : [Either TyVar Type]
args)
    | Just (TyConName
tc,[Type]
tcArgs) <- Type -> Maybe (TyConName, [Type])
splitTyConAppM Type
arg
    , TyConName -> Int
forall a. Name a -> Int
nameUniq TyConName
tc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Unique -> Int
getKey Unique
eqTyConKey
    , Just (TyVar, Type)
eq <- [Type] -> Maybe (TyVar, Type)
tvFirst [Type]
tcArgs
    = [(TyVar, Type)] -> [Either TyVar Type] -> Type
go ((TyVar, Type)
eq(TyVar, Type) -> [(TyVar, Type)] -> [(TyVar, Type)]
forall a. a -> [a] -> [a]
:[(TyVar, Type)]
eqs) [Either TyVar Type]
args
    | Bool
otherwise = [(TyVar, Type)] -> [Either TyVar Type] -> Type
go [(TyVar, Type)]
eqs [Either TyVar Type]
args
  go [(TyVar, Type)]
eqs (Left TyVar
_tv : [Either TyVar Type]
args)
    = [(TyVar, Type)] -> [Either TyVar Type] -> Type
go [(TyVar, Type)]
eqs [Either TyVar Type]
args -- drop (ForAll) tv
  go []  [] = Type
ty0 -- no eq constraints, returning original type
  go [(TyVar, Type)]
eqs [] = HasCallStack => Subst -> Type -> Type
Subst -> Type -> Type
substTy Subst
subst Type
ty2
   where
     subst :: Subst
subst = Subst -> [(TyVar, Type)] -> Subst
extendTvSubstList (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet) [(TyVar, Type)]
eqs
     args2 :: [Either TyVar Type]
args2 = [Either TyVar Type]
args0 [Either TyVar Type] -> [Either TyVar Type] -> [Either TyVar Type]
forall a. Eq a => [a] -> [a] -> [a]
List.\\ (((TyVar, Type) -> Either TyVar Type)
-> [(TyVar, Type)] -> [Either TyVar Type]
forall a b. (a -> b) -> [a] -> [b]
map (TyVar -> Either TyVar Type
forall a b. a -> Either a b
Left (TyVar -> Either TyVar Type)
-> ((TyVar, Type) -> TyVar) -> (TyVar, Type) -> Either TyVar Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TyVar, Type) -> TyVar
forall a b. (a, b) -> a
fst) [(TyVar, Type)]
eqs)
     ty2 :: Type
ty2 = Type -> [Either TyVar Type] -> Type
mkPolyFunTy Type
tyRes [Either TyVar Type]
args2

-- | Rewrite a term according to the provided transformation
rewriteExpr :: (String,NormRewrite) -- ^ Transformation to apply
            -> (String,Term)        -- ^ Term to transform
            -> (Id, SrcSpan)        -- ^ Renew current function being rewritten
            -> NormalizeSession Term
rewriteExpr :: (String, NormRewrite)
-> (String, Term) -> (Id, SrcSpan) -> NormalizeSession Term
rewriteExpr (String
nrwS,NormRewrite
nrw) (String
bndrS,Term
expr) (Id
nm, SrcSpan
sp) = do
  ((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun (((Id, SrcSpan) -> Identity (Id, SrcSpan))
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> (Id, SrcSpan) -> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= (Id
nm, SrcSpan
sp)
  DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad NormalizeState DebugLevel
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
  let before :: String
before = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
expr
  let expr' :: Term
expr' = Bool -> String -> Term -> Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugFinal)
                (String
bndrS String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" before " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
nrwS String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
":\n\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
before String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n")
                Term
expr
  Term
rewritten <- String
-> InScopeSet -> NormRewrite -> Term -> NormalizeSession Term
forall extra.
String
-> InScopeSet -> Rewrite extra -> Term -> RewriteMonad extra Term
runRewrite String
nrwS InScopeSet
emptyInScopeSet NormRewrite
nrw Term
expr'
  let after :: String
after = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
rewritten
  Bool -> String -> NormalizeSession Term -> NormalizeSession Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugFinal)
    (String
bndrS String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" after " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
nrwS String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
":\n\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
after String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n") (NormalizeSession Term -> NormalizeSession Term)
-> NormalizeSession Term -> NormalizeSession Term
forall a b. (a -> b) -> a -> b
$
    Term -> NormalizeSession Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
rewritten

removedTm
  :: Type
  -> Term
removedTm :: Type -> Term
removedTm =
  Term -> Type -> Term
TyApp (PrimInfo -> Term
Prim (Text -> Type -> WorkInfo -> PrimInfo
PrimInfo Text
"Clash.Transformations.removedArg" Type
undefinedTy WorkInfo
WorkNever))

-- | A tick to prefix an inlined expression with it's original name.
-- For example, given
--
--     foo = bar  -- ...
--     bar = baz  -- ...
--     baz = quuz -- ...
--
-- if bar is inlined into foo, then the name of the component should contain
-- the name of the inlined component. This tick ensures that the component in
-- foo is called bar_baz instead of just baz.
--
mkInlineTick :: Id -> TickInfo
mkInlineTick :: Id -> TickInfo
mkInlineTick Id
n = NameMod -> Type -> TickInfo
NameMod NameMod
PrefixName (LitTy -> Type
LitTy (LitTy -> Type) -> (String -> LitTy) -> String -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> LitTy
SymTy (String -> Type) -> String -> Type
forall a b. (a -> b) -> a -> b
$ Id -> String
forall a. Var a -> String
toStr Id
n)
 where
  toStr :: Var a -> String
toStr = Text -> String
Text.unpack (Text -> String) -> (Var a -> Text) -> Var a -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text, Text) -> Text
forall a b. (a, b) -> b
snd ((Text, Text) -> Text) -> (Var a -> (Text, Text)) -> Var a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> (Text, Text)
Text.breakOnEnd Text
"." (Text -> (Text, Text)) -> (Var a -> Text) -> Var a -> (Text, Text)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name a -> Text
forall a. Name a -> Text
nameOcc (Name a -> Text) -> (Var a -> Name a) -> Var a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var a -> Name a
forall a. Var a -> Name a
varName