{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016     , Myrtle Software Ltd,
                    2017     , Google Inc.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Utilities for rewriting: e.g. inlining, specialisation, etc.
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Rewrite.Util where

import           Control.Monad.Extra         (andM, eitherM)
import           Control.Concurrent.Supply   (splitSupply)
import           Control.DeepSeq
import           Control.Exception           (throw)
import           Control.Lens
  (Lens', (%=), (+=), (^.), _Left)
import qualified Control.Lens                as Lens
import qualified Control.Monad               as Monad
#if !MIN_VERSION_base(4,13,0)
import           Control.Monad.Fail          (MonadFail)
#endif
import qualified Control.Monad.State.Strict  as State
import qualified Control.Monad.Writer        as Writer
import           Data.Bool                   (bool)
import           Data.Bifunctor              (bimap)
import           Data.Coerce                 (coerce)
import           Data.Functor.Const          (Const (..))
import           Data.List                   (group, partition, sort)
import qualified Data.List                   as List
import qualified Data.List.Extra             as List
import           Data.List.Extra             (allM, partitionM)
import qualified Data.Map                    as Map
import           Data.Maybe
  (catMaybes, isJust, mapMaybe, fromMaybe)
import qualified Data.Monoid                 as Monoid
import qualified Data.Set                    as Set
import qualified Data.Set.Lens               as Lens
import qualified Data.Set.Ordered            as OSet
import qualified Data.Set.Ordered.Extra      as OSet
import           Data.Text                   (Text)
import qualified Data.Text                   as Text

#ifdef HISTORY
import           Data.Binary                 (encode)
import qualified Data.ByteString             as BS
import qualified Data.ByteString.Lazy        as BL
import           System.IO.Unsafe            (unsafePerformIO)
#endif

import           BasicTypes                  (InlineSpec (..))

import           Clash.Core.DataCon          (dcExtTyVars)
import           Clash.Core.Evaluator        (whnf')
import           Clash.Core.Evaluator.Types  (PureHeap)
import           Clash.Core.FreeVars
  (freeLocalVars, hasLocalFreeVars, localIdDoesNotOccurIn, localIdOccursIn,
   typeFreeVars, termFreeVars', freeLocalIds, globalIdOccursIn)
import           Clash.Core.Name
import           Clash.Core.Pretty           (showPpr)
import           Clash.Core.Subst
  (substTmEnv, aeqTerm, aeqType, extendIdSubst, mkSubst, substTm)
import           Clash.Core.Term
import           Clash.Core.TermInfo
import           Clash.Core.TyCon
  (TyConMap, tyConDataCons)
import           Clash.Core.Type             (KindOrType, Type (..),
                                              TypeView (..), coreView1,
                                              normalizeType,
                                              typeKind, tyView, isPolyFunTy)
import           Clash.Core.Util
  (dataConInstArgTysE, isClockOrReset, isEnable)
import           Clash.Core.Var
  (Id, IdScope (..), TyVar, Var (..), isLocalId, mkGlobalId, mkLocalId, mkTyVar)
import           Clash.Core.VarEnv
  (InScopeSet, VarEnv, elemVarSet, extendInScopeSetList, mkInScopeSet,
   uniqAway, uniqAway', mapVarEnv, eltsVarEnv, unitVarSet, emptyVarEnv,
   mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv)
import           Clash.Debug
import           Clash.Driver.Types
  (DebugLevel (..), BindingMap, Binding(..))
import           Clash.Netlist.Util          (representableType)
import           Clash.Pretty                (clashPretty, showDoc)
import           Clash.Rewrite.Types
import           Clash.Unique
import           Clash.Util
import qualified Clash.Util.Interpolate as I

-- | Lift an action working in the '_extra' state to the 'RewriteMonad'
zoomExtra :: State.State extra a
          -> RewriteMonad extra a
zoomExtra :: State extra a -> RewriteMonad extra a
zoomExtra State extra a
m = (RewriteEnv
 -> RewriteState extra -> Any -> (a, RewriteState extra, Any))
-> RewriteMonad extra a
forall extra a.
(RewriteEnv
 -> RewriteState extra -> Any -> (a, RewriteState extra, Any))
-> RewriteMonad extra a
R (\RewriteEnv
_ RewriteState extra
s Any
w -> case State extra a -> extra -> (a, extra)
forall s a. State s a -> s -> (a, s)
State.runState State extra a
m (RewriteState extra
s RewriteState extra
-> Getting extra (RewriteState extra) extra -> extra
forall s a. s -> Getting a s a -> a
^. Getting extra (RewriteState extra) extra
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra) of
                            (a
a,extra
s') -> (a
a,RewriteState extra
s {_extra :: extra
_extra = extra
s'},Any
w))

-- | Some transformations might erroneously introduce shadowing. For example,
-- a transformation might result in:
--
--   let a = ...
--       b = ...
--       a = ...
--
-- where the last 'a', shadows the first, while Clash assumes that this can't
-- happen. This function finds those constructs and a list of found duplicates.
--
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows =
  \case
    Var {}      -> []
    Data {}     -> []
    Literal {}  -> []
    Prim {}     -> []
    Lam Id
_ Term
t     -> Term -> [[Id]]
findAccidentialShadows Term
t
    TyLam TyVar
_ Term
t   -> Term -> [[Id]]
findAccidentialShadows Term
t
    App Term
t1 Term
t2   -> (Term -> [[Id]]) -> [Term] -> [[Id]]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap Term -> [[Id]]
findAccidentialShadows [Term
t1, Term
t2]
    TyApp Term
t Type
_   -> Term -> [[Id]]
findAccidentialShadows Term
t
    Cast Term
t Type
_ Type
_  -> Term -> [[Id]]
findAccidentialShadows Term
t
    Tick TickInfo
_ Term
t    -> Term -> [[Id]]
findAccidentialShadows Term
t
    Case Term
t Type
_ [Alt]
as ->
      (Alt -> [[Id]]) -> [Alt] -> [[Id]]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap (Pat -> [[Id]]
findInPat (Pat -> [[Id]]) -> (Alt -> Pat) -> Alt -> [[Id]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
as [[Id]] -> [[Id]] -> [[Id]]
forall a. [a] -> [a] -> [a]
++
        (Term -> [[Id]]) -> [Term] -> [[Id]]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap Term -> [[Id]]
findAccidentialShadows (Term
t Term -> [Term] -> [Term]
forall a. a -> [a] -> [a]
: (Alt -> Term) -> [Alt] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Alt -> Term
forall a b. (a, b) -> b
snd [Alt]
as)
    Letrec [LetBinding]
bs Term
t ->
      [Id] -> [[Id]]
findDups ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bs) [[Id]] -> [[Id]] -> [[Id]]
forall a. [a] -> [a] -> [a]
++ Term -> [[Id]]
findAccidentialShadows Term
t

 where
  findInPat :: Pat -> [[Id]]
  findInPat :: Pat -> [[Id]]
findInPat (LitPat Literal
_)        = []
  findInPat (Pat
DefaultPat)      = []
  findInPat (DataPat DataCon
_ [TyVar]
_ [Id]
ids) = [Id] -> [[Id]]
findDups [Id]
ids

  findDups :: [Id] -> [[Id]]
  findDups :: [Id] -> [[Id]]
findDups [Id]
ids = ([Id] -> Bool) -> [[Id]] -> [[Id]]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<) (Int -> Bool) -> ([Id] -> Int) -> [Id] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Id] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length) ([Id] -> [[Id]]
forall a. Eq a => [a] -> [[a]]
group ([Id] -> [Id]
forall a. Ord a => [a] -> [a]
sort [Id]
ids))


-- | Record if a transformation is successfully applied
apply
  :: String
  -- ^ Name of the transformation
  -> Rewrite extra
  -- ^ Transformation to be applied
  -> Rewrite extra
apply :: String -> Rewrite extra -> Rewrite extra
apply = \String
s Rewrite extra
rewrite TransformContext
ctx Term
expr0 -> do
  DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad extra DebugLevel
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
  Set String
dbgTranss <- Getting (Set String) RewriteEnv (Set String)
-> RewriteMonad extra (Set String)
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (Set String) RewriteEnv (Set String)
Lens' RewriteEnv (Set String)
dbgTransformations
  let isTryLvl :: Bool
isTryLvl = DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Eq a => a -> a -> Bool
== DebugLevel
DebugTry Bool -> Bool -> Bool
|| DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugAll
      isRelevantTrans :: Bool
isRelevantTrans = String
s String -> Set String -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set String
dbgTranss Bool -> Bool -> Bool
|| Set String -> Bool
forall a. Set a -> Bool
Set.null Set String
dbgTranss
  Bool -> String -> RewriteMonad extra () -> RewriteMonad extra ()
forall a. Bool -> String -> a -> a
traceIf (Bool
isTryLvl Bool -> Bool -> Bool
&& Bool
isRelevantTrans) (String
"Trying: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) (() -> RewriteMonad extra ()
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ())

  (Term
expr1,Any
anyChanged) <- RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (Rewrite extra
rewrite TransformContext
ctx Term
expr0)
  let hasChanged :: Bool
hasChanged = Any -> Bool
Monoid.getAny Any
anyChanged
      !expr2 :: Term
expr2     = if Bool
hasChanged then Term
expr1 else Term
expr0
  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when Bool
hasChanged ((Int -> Identity Int)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) Int
transformCounter ((Int -> Identity Int)
 -> RewriteState extra -> Identity (RewriteState extra))
-> Int -> RewriteMonad extra ()
forall s (m :: Type -> Type) a.
(MonadState s m, Num a) =>
ASetter' s a -> a -> m ()
+= Int
1)
#ifdef HISTORY
  -- NB: When HISTORY is on, emit binary data holding the recorded rewrite steps
  Monad.when hasChanged $ do
    (curBndr, _) <- Lens.use curFun
    let !_ = unsafePerformIO
             $ BS.appendFile "history.dat"
             $ BL.toStrict
             $ encode RewriteStep
                 { t_ctx    = tfContext ctx
                 , t_name   = s
                 , t_bndrS  = showPpr (varName curBndr)
                 , t_before = expr0
                 , t_after  = expr1
                 }
    return ()
#endif

  Int
dbgFrom <- Getting Int RewriteEnv Int -> RewriteMonad extra Int
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting Int RewriteEnv Int
Lens' RewriteEnv Int
dbgTransformationsFrom
  Int
dbgLimit <- Getting Int RewriteEnv Int -> RewriteMonad extra Int
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting Int RewriteEnv Int
Lens' RewriteEnv Int
dbgTransformationsLimit
  let fromLimit :: Maybe (Int, Int)
fromLimit =
        if (Int
dbgFrom, Int
dbgLimit) (Int, Int) -> (Int, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (Int
0, Int
forall a. Bounded a => a
maxBound)
        then Maybe (Int, Int)
forall a. Maybe a
Nothing
        else (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
dbgFrom, Int
dbgLimit)

  if DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Eq a => a -> a -> Bool
== DebugLevel
DebugNone
    then Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr2
    else DebugLevel
-> Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
forall extra.
DebugLevel
-> Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
applyDebug DebugLevel
lvl Set String
dbgTranss Maybe (Int, Int)
fromLimit String
s Term
expr0 Bool
hasChanged Term
expr2
{-# INLINE apply #-}

applyDebug
  :: DebugLevel
  -- ^ The current debugging level
  -> Set.Set String
  -- ^ Transformations to debug
  -> Maybe (Int, Int)
  -- ^ Only print debug information for transformations [n, n+limit]. See flag
  -- documentation of "-fclash-debug-transformations-from" and
  -- "-fclash-debug-transformations-limit"
  -> String
  -- ^ Name of the transformation
  -> Term
  -- ^ Original expression
  -> Bool
  -- ^ Whether the rewrite indicated change
  -> Term
  -- ^ New expression
  -> RewriteMonad extra Term
applyDebug :: DebugLevel
-> Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
applyDebug DebugLevel
lvl Set String
transformations Maybe (Int, Int)
fromLimit String
name Term
exprOld Bool
hasChanged Term
exprNew
  | Just (Int
from, Int
limit) <- Maybe (Int, Int)
fromLimit = do
    Int
nTrans <- Getting Int (RewriteState extra) Int -> RewriteMonad extra Int
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting Int (RewriteState extra) Int
forall extra1. Lens' (RewriteState extra1) Int
transformCounter
    if | Int
nTrans Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
from Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
limit ->
          String -> RewriteMonad extra Term
forall a. HasCallStack => String -> a
error String
"-fclash-debug-transformations-limit exceeded"
       | Int
nTrans Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
from ->
          DebugLevel
-> Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
forall extra.
DebugLevel
-> Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
applyDebug DebugLevel
lvl Set String
transformations Maybe (Int, Int)
forall a. Maybe a
Nothing String
name Term
exprOld Bool
hasChanged Term
exprNew
       | Bool
otherwise ->
          Term -> RewriteMonad extra Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
exprNew

applyDebug DebugLevel
lvl Set String
transformations Maybe (Int, Int)
fromLimit String
name Term
exprOld Bool
hasChanged Term
exprNew
  | Bool -> Bool
not (Set String -> Bool
forall a. Set a -> Bool
Set.null Set String
transformations) =
    let newLvl :: DebugLevel
newLvl = DebugLevel -> DebugLevel -> Bool -> DebugLevel
forall a. a -> a -> Bool -> a
bool DebugLevel
DebugNone DebugLevel
lvl (String
name String -> Set String -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set String
transformations) in
    DebugLevel
-> Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
forall extra.
DebugLevel
-> Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
applyDebug DebugLevel
newLvl Set String
forall a. Set a
Set.empty Maybe (Int, Int)
fromLimit String
name Term
exprOld Bool
hasChanged Term
exprNew

applyDebug DebugLevel
lvl Set String
_transformations Maybe (Int, Int)
_fromLimit String
name Term
exprOld Bool
hasChanged Term
exprNew =
 Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugAll) (String
"Tried: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" on:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before) (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ do
  Int
nTrans <- Int -> Int
forall a. Enum a => a -> a
pred (Int -> Int) -> RewriteMonad extra Int -> RewriteMonad extra Int
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting Int (RewriteState extra) Int -> RewriteMonad extra Int
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting Int (RewriteState extra) Int
forall extra1. Lens' (RewriteState extra1) Int
transformCounter
  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
> DebugLevel
DebugNone Bool -> Bool -> Bool
&& Bool
hasChanged) (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$ do
    TyConMap
tcm                  <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let beforeTy :: Type
beforeTy          = TyConMap -> Term -> Type
termType TyConMap
tcm Term
exprOld
        beforeFV :: Set (Var a)
beforeFV          = Getting (Set (Var a)) Term (Var a) -> Term -> Set (Var a)
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set (Var a)) Term (Var a)
forall a. Fold Term (Var a)
freeLocalVars Term
exprOld
        afterTy :: Type
afterTy           = TyConMap -> Term -> Type
termType TyConMap
tcm Term
exprNew
        afterFV :: Set (Var a)
afterFV           = Getting (Set (Var a)) Term (Var a) -> Term -> Set (Var a)
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set (Var a)) Term (Var a)
forall a. Fold Term (Var a)
freeLocalVars Term
exprNew
        newFV :: Bool
newFV             = Bool -> Bool
not (Set (Var Any)
forall a. Set (Var a)
afterFV Set (Var Any) -> Set (Var Any) -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`Set.isSubsetOf` Set (Var Any)
forall a. Set (Var a)
beforeFV)
        accidentalShadows :: [[Id]]
accidentalShadows = Term -> [[Id]]
findAccidentialShadows Term
exprNew

    Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when Bool
newFV (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$
            String -> RewriteMonad extra ()
forall a. HasCallStack => String -> a
error ( [String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ $(String
curLoc)
                           , String
"Error when applying rewrite ", String
name
                           , String
" to:\n" , String
before
                           , String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
                           , String
"It introduces free variables."
                           , String
"\nBefore: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Var Any] -> String
forall p. PrettyPrec p => p -> String
showPpr (Set (Var Any) -> [Var Any]
forall a. Set a -> [a]
Set.toList Set (Var Any)
forall a. Set (Var a)
beforeFV)
                           , String
"\nAfter: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Var Any] -> String
forall p. PrettyPrec p => p -> String
showPpr (Set (Var Any) -> [Var Any]
forall a. Set a -> [a]
Set.toList Set (Var Any)
forall a. Set (Var a)
afterFV)
                           ]
                  )
    Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (Bool -> Bool
not ([[Id]] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [[Id]]
accidentalShadows)) (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$
      String -> RewriteMonad extra ()
forall a. HasCallStack => String -> a
error ( [String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ $(String
curLoc)
                     , String
"Error when applying rewrite ", String
name
                     , String
" to:\n" , String
before
                     , String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
                     , String
"It accidentally creates shadowing let/case-bindings:\n"
                     , String
" ", [[Id]] -> String
forall p. PrettyPrec p => p -> String
showPpr [[Id]]
accidentalShadows, String
"\n"
                     , String
"This usually means that a transformation did not extend "
                     , String
"or incorrectly extended its InScopeSet before applying a "
                     , String
"substitution."
                     ])

    Bool -> String -> RewriteMonad extra () -> RewriteMonad extra ()
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied Bool -> Bool -> Bool
&& (Bool -> Bool
not (Type
beforeTy Type -> Type -> Bool
`aeqType` Type
afterTy)))
            ( [String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ $(String
curLoc)
                     , String
"Error when applying rewrite ", String
name
                     , String
" to:\n" , String
before
                     , String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
                     , String
"Changes type from:\n", Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
beforeTy
                     , String
"\nto:\n", Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
afterTy
                     ]
            ) (() -> RewriteMonad extra ()
forall (m :: Type -> Type) a. Monad m => a -> m a
return ())

  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasChanged Bool -> Bool -> Bool
&& Bool -> Bool
not (Term
exprOld Term -> Term -> Bool
`aeqTerm` Term
exprNew)) (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$
    String -> RewriteMonad extra ()
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra ())
-> String -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Expression changed without notice(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++  String
"): before"
                      String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nafter:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after

  Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugName Bool -> Bool -> Bool
&& Bool
hasChanged) (String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" {" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
nTrans String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"}") (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
    Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied Bool -> Bool -> Bool
&& Bool
hasChanged) (String
"Changes when applying rewrite to:\n"
                      String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n") (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
      Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugAll Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasChanged) (String
"No changes when applying rewrite "
                        String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" to:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n") (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
        Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
exprNew
 where
  before :: String
before = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
exprOld
  after :: String
after  = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
exprNew

-- | Perform a transformation on a Term
runRewrite
  :: String
  -- ^ Name of the transformation
  -> InScopeSet
  -> Rewrite extra
  -- ^ Transformation to perform
  -> Term
  -- ^ Term to transform
  -> RewriteMonad extra Term
runRewrite :: String
-> InScopeSet -> Rewrite extra -> Term -> RewriteMonad extra Term
runRewrite String
name InScopeSet
is Rewrite extra
rewrite Term
expr = String -> Rewrite extra -> Rewrite extra
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
name Rewrite extra
rewrite (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is []) Term
expr

-- | Evaluate a RewriteSession to its inner monad.
runRewriteSession :: RewriteEnv
                  -> RewriteState extra
                  -> RewriteMonad extra a
                  -> a
runRewriteSession :: RewriteEnv -> RewriteState extra -> RewriteMonad extra a -> a
runRewriteSession RewriteEnv
r RewriteState extra
s RewriteMonad extra a
m =
  Bool -> String -> a -> a
forall a. Bool -> String -> a -> a
traceIf (RewriteEnv -> DebugLevel
_dbgLevel RewriteEnv
r DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
> DebugLevel
DebugNone)
    (String
"Clash: Applied " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (RewriteState extra
s' RewriteState extra -> Getting Int (RewriteState extra) Int -> Int
forall s a. s -> Getting a s a -> a
^. Getting Int (RewriteState extra) Int
forall extra1. Lens' (RewriteState extra1) Int
transformCounter) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" transformations")
    a
a
  where
    (a
a,RewriteState extra
s',Any
_) = RewriteMonad extra a
-> RewriteEnv -> RewriteState extra -> (a, RewriteState extra, Any)
forall extra a.
RewriteMonad extra a
-> RewriteEnv -> RewriteState extra -> (a, RewriteState extra, Any)
runR RewriteMonad extra a
m RewriteEnv
r RewriteState extra
s

-- | Notify that a transformation has changed the expression
setChanged :: RewriteMonad extra ()
setChanged :: RewriteMonad extra ()
setChanged = Any -> RewriteMonad extra ()
forall w (m :: Type -> Type). MonadWriter w m => w -> m ()
Writer.tell (Bool -> Any
Monoid.Any Bool
True)

-- | Identity function that additionally notifies that a transformation has
-- changed the expression
changed :: a -> RewriteMonad extra a
changed :: a -> RewriteMonad extra a
changed a
val = do
  Any -> RewriteMonad extra ()
forall w (m :: Type -> Type). MonadWriter w m => w -> m ()
Writer.tell (Bool -> Any
Monoid.Any Bool
True)
  a -> RewriteMonad extra a
forall (m :: Type -> Type) a. Monad m => a -> m a
return a
val

closestLetBinder :: Context -> Maybe Id
closestLetBinder :: Context -> Maybe Id
closestLetBinder [] = Maybe Id
forall a. Maybe a
Nothing
closestLetBinder (LetBinding Id
id_ [Id]
_:Context
_) = Id -> Maybe Id
forall a. a -> Maybe a
Just Id
id_
closestLetBinder (CoreContext
_:Context
ctx)              = Context -> Maybe Id
closestLetBinder Context
ctx

mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName (TransformContext InScopeSet
_ Context
ctx) OccName
sf = case Context -> Maybe Id
closestLetBinder Context
ctx of
  Just Id
id_ -> TmName -> OccName -> TmName
forall a. Name a -> OccName -> Name a
appendToName (Id -> TmName
forall a. Var a -> Name a
varName Id
id_) (Char
'_' Char -> OccName -> OccName
`Text.cons` OccName
sf)
  Maybe Id
_ -> OccName -> Int -> TmName
forall a. OccName -> Int -> Name a
mkUnsafeInternalName OccName
sf Int
0

-- | Make a new binder and variable reference for a term
mkTmBinderFor
  :: (MonadUnique m, MonadFail m)
  => InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Name a -- ^ Name of the new binder
  -> Term -- ^ Term to bind
  -> m Id
mkTmBinderFor :: InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is TyConMap
tcm Name a
name Term
e = do
  Left Id
r <- InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
forall (m :: Type -> Type) a.
(MonadUnique m, MonadFail m) =>
InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor InScopeSet
is TyConMap
tcm Name a
name (Term -> Either Term Type
forall a b. a -> Either a b
Left Term
e)
  Id -> m Id
forall (m :: Type -> Type) a. Monad m => a -> m a
return Id
r

-- | Make a new binder and variable reference for either a term or a type
mkBinderFor
  :: (MonadUnique m, MonadFail m)
  => InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Name a -- ^ Name of the new binder
  -> Either Term Type -- ^ Type or Term to bind
  -> m (Either Id TyVar)
mkBinderFor :: InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor InScopeSet
is TyConMap
tcm Name a
name (Left Term
term) = do
  Name a
name' <- InScopeSet -> Name a -> m (Name a)
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
name
  let ty :: Type
ty = TyConMap -> Term -> Type
termType TyConMap
tcm Term
term
  Either Id TyVar -> m (Either Id TyVar)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id -> Either Id TyVar
forall a b. a -> Either a b
Left (Type -> TmName -> Id
mkLocalId Type
ty (Name a -> TmName
coerce Name a
name')))

mkBinderFor InScopeSet
is TyConMap
tcm Name a
name (Right Type
ty) = do
  Name a
name' <- InScopeSet -> Name a -> m (Name a)
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
name
  let ki :: Type
ki = TyConMap -> Type -> Type
typeKind TyConMap
tcm Type
ty
  Either Id TyVar -> m (Either Id TyVar)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right (Type -> TyName -> TyVar
mkTyVar Type
ki (Name a -> TyName
coerce Name a
name')))

-- | Make a new, unique, identifier
mkInternalVar
  :: (MonadUnique m)
  => InScopeSet
  -> OccName
  -- ^ Name of the identifier
  -> KindOrType
  -> m Id
mkInternalVar :: InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
inScope OccName
name Type
ty = do
  Int
i <- m Int
forall (m :: Type -> Type). MonadUnique m => m Int
getUniqueM
  let nm :: Name a
nm = OccName -> Int -> Name a
forall a. OccName -> Int -> Name a
mkUnsafeInternalName OccName
name Int
i
  Id -> m Id
forall (m :: Type -> Type) a. Monad m => a -> m a
return (InScopeSet -> Id -> Id
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
inScope (Type -> TmName -> Id
mkLocalId Type
ty TmName
forall a. Name a
nm))

-- | Inline the binders in a let-binding that have a certain property
inlineBinders
  :: (Term -> LetBinding -> RewriteMonad extra Bool)
  -- ^ Property test
  -> Rewrite extra
inlineBinders :: (Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
inlineBinders Term -> LetBinding -> RewriteMonad extra Bool
condition (TransformContext InScopeSet
inScope0 Context
_) expr :: Term
expr@(Letrec [LetBinding]
xes Term
res) = do
  ([LetBinding]
toInline,[LetBinding]
toKeep) <- (LetBinding -> RewriteMonad extra Bool)
-> [LetBinding] -> RewriteMonad extra ([LetBinding], [LetBinding])
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM (Term -> LetBinding -> RewriteMonad extra Bool
condition Term
expr) [LetBinding]
xes
  case [LetBinding]
toInline of
    [] -> Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr
    [LetBinding]
_  -> do
      let inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes)
          ([LetBinding]
toInlRec,([LetBinding]
toKeep1,Term
res1)) =
            InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> ([LetBinding], ([LetBinding], Term))
substituteBinders InScopeSet
inScope1 [LetBinding]
toInline [LetBinding]
toKeep Term
res
      case [LetBinding]
toInlRec [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [LetBinding]
toKeep1 of
        []   -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
res1
        [LetBinding]
xes1 -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xes1 Term
res1)

inlineBinders Term -> LetBinding -> RewriteMonad extra Bool
_ TransformContext
_ Term
e = Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

-- | Determine whether a binder is a join-point created for a complex case
-- expression.
--
-- A join-point is when a local function only occurs in tail-call positions,
-- and when it does, more than once.
isJoinPointIn :: Id   -- ^ 'Id' of the local binder
              -> Term -- ^ Expression in which the binder is bound
              -> Bool
isJoinPointIn :: Id -> Term -> Bool
isJoinPointIn Id
id_ Term
e = case Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e of
                      Just Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 -> Bool
True
                      Maybe Int
_              -> Bool
False

-- | Count the number of (only) tail calls of a function in an expression.
-- 'Nothing' indicates that the function was used in a non-tail call position.
tailCalls :: Id   -- ^ Function to check
          -> Term -- ^ Expression to check it in
          -> Maybe Int
tailCalls :: Id -> Term -> Maybe Int
tailCalls Id
id_ = \case
  Var Id
nm | Id
id_ Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
nm -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
         | Bool
otherwise -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0
  Lam Id
_ Term
e -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  TyLam TyVar
_ Term
e -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  App Term
l Term
r  -> case Id -> Term -> Maybe Int
tailCalls Id
id_ Term
r of
                Just Int
0 -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
l
                Maybe Int
_      -> Maybe Int
forall a. Maybe a
Nothing
  TyApp Term
l Type
_ -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
l
  Letrec [LetBinding]
bs Term
e ->
    let ([Id]
bsIds,[Term]
bsExprs) = [LetBinding] -> ([Id], [Term])
forall a b. [(a, b)] -> ([a], [b])
unzip [LetBinding]
bs
        bsTls :: [Maybe Int]
bsTls           = (Term -> Maybe Int) -> [Term] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
tailCalls Id
id_) [Term]
bsExprs
        bsIdsUsed :: [Id]
bsIdsUsed       = ((Id, Maybe Int) -> Maybe Id) -> [(Id, Maybe Int)] -> [Id]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Id
l,Maybe Int
r) -> Id -> Maybe Id
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Id
l Maybe Id -> Maybe Int -> Maybe Id
forall (f :: Type -> Type) a b. Applicative f => f a -> f b -> f a
<* Maybe Int
r) ([Id] -> [Maybe Int] -> [(Id, Maybe Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bsIds [Maybe Int]
bsTls)
        bsIdsTls :: [Maybe Int]
bsIdsTls        = (Id -> Maybe Int) -> [Id] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
`tailCalls` Term
e) [Id]
bsIdsUsed
        bsCount :: Maybe Int
bsCount         = Int -> Maybe Int
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Int -> Maybe Int) -> ([Int] -> Int) -> [Int] -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Maybe Int) -> [Int] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
bsTls
    in  case ((Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust [Maybe Int]
bsTls) of
          Bool
False -> Maybe Int
forall a. Maybe a
Nothing
          Bool
True  -> case ((Int -> Bool) -> [Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0) ([Int] -> Bool) -> [Int] -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
bsTls) of
            Bool
False  -> case (Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust [Maybe Int]
bsIdsTls of
              Bool
False -> Maybe Int
forall a. Maybe a
Nothing
              Bool
True  -> Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> Int -> Int) -> Maybe Int -> Maybe (Int -> Int)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Int
bsCount Maybe (Int -> Int) -> Maybe Int -> Maybe Int
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
            Bool
True -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  Case Term
scrut Type
_ [Alt]
alts ->
    let scrutTl :: Maybe Int
scrutTl = Id -> Term -> Maybe Int
tailCalls Id
id_ Term
scrut
        altsTl :: [Maybe Int]
altsTl  = (Alt -> Maybe Int) -> [Alt] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
tailCalls Id
id_ (Term -> Maybe Int) -> (Alt -> Term) -> Alt -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Term
forall a b. (a, b) -> b
snd) [Alt]
alts
    in  case Maybe Int
scrutTl of
          Just Int
0 | (Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Int
forall a. Maybe a
Nothing) [Maybe Int]
altsTl -> Int -> Maybe Int
forall a. a -> Maybe a
Just ([Int] -> Int
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ([Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
altsTl))
          Maybe Int
_ -> Maybe Int
forall a. Maybe a
Nothing
  Term
_ -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0

-- | Determines whether a function has the following shape:
--
-- > \(w :: Void) -> f a b c
--
-- i.e. is a wrapper around a (partially) applied function 'f', where the
-- introduced argument 'w' is not used by 'f'
isVoidWrapper :: Term -> Bool
isVoidWrapper :: Term -> Bool
isVoidWrapper (Lam Id
bndr e :: Term
e@(Term -> (Term, [Either Term Type])
collectArgs -> (Var Id
_,[Either Term Type]
_))) =
  Id
bndr Id -> Term -> Bool
`localIdDoesNotOccurIn` Term
e
isVoidWrapper Term
_ = Bool
False

-- | Inline the first set of binder into the second set of binders and into the
-- body of the original let expression.
substituteBinders
  :: InScopeSet
  -> [LetBinding]
  -- ^ Let-binders to substitute
  -> [LetBinding]
  -- ^ Let-binders where substitution takes place
  -> Term
  -- ^ Body where substitution takes place
  -> ([LetBinding],([LetBinding],Term))
  -- ^
  -- 1. Let-bindings that we wanted to substitute, but turned out to be recursive
  -- 2.1 Let-binders where substitution took place
  -- 2.2 Body where substitution took place
substituteBinders :: InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> ([LetBinding], ([LetBinding], Term))
substituteBinders InScopeSet
inScope [LetBinding]
toInline [LetBinding]
toKeep Term
body =
  let (Subst
subst,[LetBinding]
toInlRec) = Subst -> [LetBinding] -> [LetBinding] -> (Subst, [LetBinding])
go (InScopeSet -> Subst
mkSubst InScopeSet
inScope) [] [LetBinding]
toInline
  in  ( (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substToInlRec" Subst
subst)) [LetBinding]
toInlRec
      , ( (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substToKeep" Subst
subst)) [LetBinding]
toKeep
        , HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substBody" Subst
subst Term
body) )
 where
  go :: Subst -> [LetBinding] -> [LetBinding] -> (Subst, [LetBinding])
go Subst
subst [LetBinding]
inlRec [] = (Subst
subst,[LetBinding]
inlRec)
  go !Subst
subst ![LetBinding]
inlRec ((Id
x,Term
e):[LetBinding]
toInl) =
    let e1 :: Term
e1      = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substInl" Subst
subst Term
e
        substE :: Subst
substE  = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope) Id
x Term
e1
        subst1 :: Subst
subst1  = Subst
subst { substTmEnv :: IdSubstEnv
substTmEnv = (Term -> Term) -> IdSubstEnv -> IdSubstEnv
forall a b. (a -> b) -> VarEnv a -> VarEnv b
mapVarEnv (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substSubst" Subst
substE)
                                                 (Subst -> IdSubstEnv
substTmEnv Subst
subst)}
        subst2 :: Subst
subst2  = Subst -> Id -> Term -> Subst
extendIdSubst Subst
subst1 Id
x Term
e1
    in  if Id
x Id -> Term -> Bool
`localIdOccursIn` Term
e1 then
          Subst -> [LetBinding] -> [LetBinding] -> (Subst, [LetBinding])
go Subst
subst ((Id
x,Term
e1)LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
:[LetBinding]
inlRec) [LetBinding]
toInl
        else
          Subst -> [LetBinding] -> [LetBinding] -> (Subst, [LetBinding])
go Subst
subst2 [LetBinding]
inlRec [LetBinding]
toInl

-- | Lift the first set of binders to the level of global bindings, and substitute
-- these lifted bindings into the second set of binders and the body of the
-- original let expression.
liftAndSubsituteBinders
  :: InScopeSet
  -> [LetBinding]
  -- ^ Let-binders to lift, and substitute the lifted result
  -> [LetBinding]
  -- ^ Lef-binders where substitution takes place
  -> Term
  -- ^ Body where substitution takes place
  -> RewriteMonad extra ([LetBinding],Term)
liftAndSubsituteBinders :: InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> RewriteMonad extra ([LetBinding], Term)
liftAndSubsituteBinders InScopeSet
inScope [LetBinding]
toLift [LetBinding]
toKeep Term
body = do
  Subst
subst <- Subst -> [LetBinding] -> RewriteMonad extra Subst
forall extra. Subst -> [LetBinding] -> RewriteMonad extra Subst
go (InScopeSet -> Subst
mkSubst InScopeSet
inScope) [LetBinding]
toLift
  ([LetBinding], Term) -> RewriteMonad extra ([LetBinding], Term)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ( (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftToKeep" Subst
subst)) [LetBinding]
toKeep
       , HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"keepBody" Subst
subst Term
body
       )
 where
  go :: Subst -> [LetBinding] -> RewriteMonad extra Subst
go Subst
subst [] = Subst -> RewriteMonad extra Subst
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Subst
subst
  go !Subst
subst ((Id
x,Term
e):[LetBinding]
inl) = do
    let e1 :: Term
e1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftInl" Subst
subst Term
e
    (Id
_,Term
e2) <- LetBinding -> RewriteMonad extra LetBinding
forall extra. LetBinding -> RewriteMonad extra LetBinding
liftBinding (Id
x,Term
e1)
    let substE :: Subst
substE = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope) Id
x Term
e2
        subst1 :: Subst
subst1 = Subst
subst { substTmEnv :: IdSubstEnv
substTmEnv = (Term -> Term) -> IdSubstEnv -> IdSubstEnv
forall a b. (a -> b) -> VarEnv a -> VarEnv b
mapVarEnv (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftSubst" Subst
substE)
                                                (Subst -> IdSubstEnv
substTmEnv Subst
subst) }
        subst2 :: Subst
subst2 = Subst -> Id -> Term -> Subst
extendIdSubst Subst
subst1 Id
x Term
e2
    if Id
x Id -> Term -> Bool
`localIdOccursIn` Term
e2 then do
      (Id
_,SrcSpan
sp) <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
      ClashException -> RewriteMonad extra Subst
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
sp [I.i|
        Internal error: inlineOrLiftBInders failed on:

        #{showPpr (x,e)}

        creating a self-recursive let-binding:

        #{showPpr (x,e2)}

        given the already built subtitution:

        #{showDoc (clashPretty (substTmEnv subst))}
      |] Maybe String
forall a. Maybe a
Nothing)
    else
      Subst -> [LetBinding] -> RewriteMonad extra Subst
go Subst
subst2 [LetBinding]
inl

-- | Determines whether a global binder is work free. Errors if binder does
-- not exist.
isWorkFreeBinder :: HasCallStack => Id -> RewriteMonad extra Bool
isWorkFreeBinder :: Id -> RewriteMonad extra Bool
isWorkFreeBinder Id
bndr =
  Id
-> Lens' (RewriteState extra) (UniqMap Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra Bool
forall s (m :: Type -> Type) k v.
(MonadState s m, Uniquable k) =>
k -> Lens' s (UniqMap v) -> m v -> m v
makeCachedU Id
bndr forall extra1. Lens' (RewriteState extra1) (UniqMap Bool)
Lens' (RewriteState extra) (UniqMap Bool)
workFreeBinders (RewriteMonad extra Bool -> RewriteMonad extra Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall a b. (a -> b) -> a -> b
$ do
    Maybe Binding
bExprM <- Id -> VarEnv Binding -> Maybe Binding
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
bndr (VarEnv Binding -> Maybe Binding)
-> RewriteMonad extra (VarEnv Binding)
-> RewriteMonad extra (Maybe Binding)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
    case Maybe Binding
bExprM of
      Maybe Binding
Nothing -> String -> RewriteMonad extra Bool
forall a. HasCallStack => String -> a
error (String
"isWorkFreeBinder: couldn't find binder: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
bndr)
      Just (Binding -> Term
bindingTerm -> Term
t) ->
        if Id
bndr Id -> Term -> Bool
`globalIdOccursIn` Term
t
        then Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
        else Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
t

-- | Determine whether a term does any work, i.e. adds to the size of the circuit
isWorkFree
  :: Term
  -> RewriteMonad extra Bool
isWorkFree :: Term -> RewriteMonad extra Bool
isWorkFree (Term -> (Term, [Either Term Type])
collectArgs -> (Term
fun,[Either Term Type]
args)) = case Term
fun of
  Var Id
i ->
    if | Type -> Bool
isPolyFunTy (Id -> Type
forall a. Var a -> Type
varType Id
i) -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
       | Id -> Bool
forall a. Var a -> Bool
isLocalId Id
i -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
       | Bool
otherwise -> [RewriteMonad extra Bool] -> RewriteMonad extra Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Id -> RewriteMonad extra Bool
forall extra. HasCallStack => Id -> RewriteMonad extra Bool
isWorkFreeBinder Id
i, (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args]
  Data {} -> (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args
  Literal {} -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
  Prim PrimInfo
pInfo -> case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pInfo of
    -- We can ignore the arguments, because this primitive outputs a constant
    -- regardless of its arguments
    WorkInfo
WorkConstant -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
    WorkInfo
WorkNever -> (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args
    WorkInfo
WorkVariable -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isConstantArg [Either Term Type]
args)
    -- Things like clock or reset generator always perform work
    WorkInfo
WorkAlways -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
  Lam Id
_ Term
e -> [RewriteMonad extra Bool] -> RewriteMonad extra Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
e, (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args]
  TyLam TyVar
_ Term
e -> [RewriteMonad extra Bool] -> RewriteMonad extra Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
e, (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args]
  Letrec [LetBinding]
bs Term
e ->
    [RewriteMonad extra Bool] -> RewriteMonad extra Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
e, (LetBinding -> RewriteMonad extra Bool)
-> [LetBinding] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM (Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree (Term -> RewriteMonad extra Bool)
-> (LetBinding -> Term) -> LetBinding -> RewriteMonad extra Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
bs, (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args]
  Case Term
s Type
_ [(Pat
_,Term
a)] ->
    [RewriteMonad extra Bool] -> RewriteMonad extra Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
s, Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
a, (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args]
  Cast Term
e Type
_ Type
_ ->
    [RewriteMonad extra Bool] -> RewriteMonad extra Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
e, (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeArg [Either Term Type]
args]
  Term
_ ->
    Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
 where
  isWorkFreeArg :: Either Term b -> RewriteMonad extra Bool
isWorkFreeArg Either Term b
e = (Term -> RewriteMonad extra Bool)
-> (b -> RewriteMonad extra Bool)
-> RewriteMonad extra (Either Term b)
-> RewriteMonad extra Bool
forall (m :: Type -> Type) a c b.
Monad m =>
(a -> m c) -> (b -> m c) -> m (Either a b) -> m c
eitherM Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree (Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Bool -> RewriteMonad extra Bool)
-> (b -> Bool) -> b -> RewriteMonad extra Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True) (Either Term b -> RewriteMonad extra (Either Term b)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Either Term b
e)
  isConstantArg :: Either Term b -> Bool
isConstantArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)

isFromInt :: Text -> Bool
isFromInt :: OccName -> Bool
isFromInt OccName
nm = OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Sized.Internal.BitVector.fromInteger##" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Sized.Internal.BitVector.fromInteger#" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Sized.Internal.Index.fromInteger#" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Sized.Internal.Signed.fromInteger#" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Sized.Internal.Unsigned.fromInteger#"

-- | Determine if a term represents a constant
isConstant :: Term -> Bool
isConstant :: Term -> Bool
isConstant Term
e = case Term -> (Term, [Either Term Type])
collectArgs Term
e of
  (Data DataCon
_, [Either Term Type]
args)   -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
  (Prim PrimInfo
_, [Either Term Type]
args) -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
  (Lam Id
_ Term
_, [Either Term Type]
_)     -> Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e)
  (Literal Literal
_,[Either Term Type]
_)    -> Bool
True
  (Term, [Either Term Type])
_                -> Bool
False

isConstantNotClockReset
  :: Term
  -> RewriteMonad extra Bool
isConstantNotClockReset :: Term -> RewriteMonad extra Bool
isConstantNotClockReset Term
e = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e
  if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy
     then case Term -> (Term, [Either Term Type])
collectArgs Term
e of
        (Prim PrimInfo
p,[Either Term Type]
_) -> Bool -> RewriteMonad extra Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return (PrimInfo -> OccName
primName PrimInfo
p OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Transformations.removedArg")
        (Term, [Either Term Type])
_ -> Bool -> RewriteMonad extra Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False
     else Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> Bool
isConstant Term
e)

-- TODO: Remove function after using WorkInfo in 'isWorkFreeIsh'
isWorkFreeClockOrResetOrEnable
  :: TyConMap
  -> Term
  -> Maybe Bool
isWorkFreeClockOrResetOrEnable :: TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrResetOrEnable TyConMap
tcm Term
e =
  let eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e in
  if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy Bool -> Bool -> Bool
|| TyConMap -> Type -> Bool
isEnable TyConMap
tcm Type
eTy then
    case Term -> (Term, [Either Term Type])
collectArgs Term
e of
      (Prim PrimInfo
p,[Either Term Type]
_) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just (PrimInfo -> OccName
primName PrimInfo
p OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Transformations.removedArg")
      (Var Id
_, []) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
      (Data DataCon
_, []) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True -- For Enable True/False
      (Literal Literal
_,[Either Term Type]
_) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
      (Term, [Either Term Type])
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
  else
    Maybe Bool
forall a. Maybe a
Nothing

-- | A conservative version of 'isWorkFree'. Is used to determine in 'bindConstantVar'
-- to determine whether an expression can be "bound" (locally inlined). While
-- binding workfree expressions won't result in extra work for the circuit, it
-- might very well cause extra work for Clash. In fact, using 'isWorkFree' in
-- 'bindConstantVar' makes Clash two orders of magnitude slower for some of our
-- test cases.
--
-- In effect, this function is a version of 'isConstant' that also considers
-- references to clocks and resets constant. This allows us to bind
-- HiddenClock(ResetEnable) constructs, allowing Clash to constant spec
-- subconstants - most notably KnownDomain. Doing that enables Clash to
-- eliminate any case-constructs on it.
isWorkFreeIsh
  :: Term
  -> RewriteMonad extra Bool
isWorkFreeIsh :: Term -> RewriteMonad extra Bool
isWorkFreeIsh Term
e = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  case TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrResetOrEnable TyConMap
tcm Term
e of
    Just Bool
b -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
b
    Maybe Bool
Nothing ->
      case Term -> (Term, [Either Term Type])
collectArgs Term
e of
        (Data DataCon
_, [Either Term Type]
args)   -> (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeIshArg [Either Term Type]
args
        (Prim PrimInfo
pInfo, [Either Term Type]
args) -> case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pInfo of
          WorkInfo
WorkAlways     -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False -- Things like clock or reset generator always
                                       -- perform work
          WorkInfo
WorkVariable   -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isConstantArg [Either Term Type]
args)
          WorkInfo
_              -> (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeIshArg [Either Term Type]
args
        (Lam Id
_ Term
_, [Either Term Type]
_)     -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e))
        (Literal Literal
_,[Either Term Type]
_)    -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
        (Term, [Either Term Type])
_                -> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
 where
  isWorkFreeIshArg :: Either Term b -> RewriteMonad extra Bool
isWorkFreeIshArg = (Term -> RewriteMonad extra Bool)
-> (b -> RewriteMonad extra Bool)
-> Either Term b
-> RewriteMonad extra Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFreeIsh (Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Bool -> RewriteMonad extra Bool)
-> (b -> Bool) -> b -> RewriteMonad extra Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)
  isConstantArg :: Either Term b -> Bool
isConstantArg    = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)

inlineOrLiftBinders
  :: (LetBinding -> RewriteMonad extra Bool)
  -- ^ Property test
  -> (Term -> LetBinding -> Bool)
  -- ^ Test whether to lift or inline
  --
  -- * True: inline
  -- * False: lift
  -> Rewrite extra
inlineOrLiftBinders :: (LetBinding -> RewriteMonad extra Bool)
-> (Term -> LetBinding -> Bool) -> Rewrite extra
inlineOrLiftBinders LetBinding -> RewriteMonad extra Bool
condition Term -> LetBinding -> Bool
inlineOrLift (TransformContext InScopeSet
inScope0 Context
_) e :: Term
e@(Letrec [LetBinding]
bndrs Term
body) = do
  ([LetBinding]
toReplace,[LetBinding]
toKeep) <- (LetBinding -> RewriteMonad extra Bool)
-> [LetBinding] -> RewriteMonad extra ([LetBinding], [LetBinding])
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM LetBinding -> RewriteMonad extra Bool
condition [LetBinding]
bndrs
  case [LetBinding]
toReplace of
    [] -> Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
    [LetBinding]
_  -> do
      let inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bndrs)
      let ([LetBinding]
toInline,[LetBinding]
toLift) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Term -> LetBinding -> Bool
inlineOrLift Term
e) [LetBinding]
toReplace
      -- We first substitute the binders that we can inline both the binders
      -- that we intend to lift, the other binders, and the body
      let ([LetBinding]
toLiftExtra,([LetBinding]
toReplace1,Term
body1)) =
            InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> ([LetBinding], ([LetBinding], Term))
substituteBinders InScopeSet
inScope1 [LetBinding]
toInline ([LetBinding]
toLift [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [LetBinding]
toKeep) Term
body
          ([LetBinding]
toLift1,[LetBinding]
toKeep1) = Int -> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. Int -> [a] -> ([a], [a])
splitAt ([LetBinding] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [LetBinding]
toLift) [LetBinding]
toReplace1
      -- We then substitute the lifted binders in the other binders and the body
      ([LetBinding]
toKeep2,Term
body2) <- InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> RewriteMonad extra ([LetBinding], Term)
forall extra.
InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> RewriteMonad extra ([LetBinding], Term)
liftAndSubsituteBinders InScopeSet
inScope1
                           ([LetBinding]
toLiftExtra [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [LetBinding]
toLift1)
                           [LetBinding]
toKeep1 Term
body1
      case [LetBinding]
toKeep2 of
        [] -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
body2
        [LetBinding]
_  -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
toKeep2 Term
body2)

inlineOrLiftBinders LetBinding -> RewriteMonad extra Bool
_ Term -> LetBinding -> Bool
_ TransformContext
_ Term
e = Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

-- | Create a global function for a Let-binding and return a Let-binding where
-- the RHS is a reference to the new global function applied to the free
-- variables of the original RHS
liftBinding :: LetBinding
            -> RewriteMonad extra LetBinding
liftBinding :: LetBinding -> RewriteMonad extra LetBinding
liftBinding (var :: Id
var@Id {varName :: forall a. Var a -> Name a
varName = TmName
idName} ,Term
e) = do
  -- Get all local FVs, excluding the 'idName' from the let-binding
  let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
      unitFV :: Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV v :: Var a
v@(Id {})    = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (UniqSet TyVar
forall a. UniqSet a
emptyUniqSet,Id -> UniqSet Id
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> Id
coerce Var a
v))
      unitFV v :: Var a
v@(TyVar {}) = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (TyVar -> UniqSet TyVar
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> TyVar
coerce Var a
v),UniqSet Id
forall a. UniqSet a
emptyUniqSet)

      interesting :: Var a -> Bool
      interesting :: Var a -> Bool
interesting Id {idScope :: forall a. Var a -> IdScope
idScope = IdScope
GlobalId} = Bool
False
      interesting v :: Var a
v@(Id {idScope :: forall a. Var a -> IdScope
idScope = IdScope
LocalId}) = Var a -> Int
forall a. Var a -> Int
varUniq Var a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Id -> Int
forall a. Var a -> Int
varUniq Id
var
      interesting Var a
_ = Bool
True

      (UniqSet TyVar
boundFTVsSet,UniqSet Id
boundFVsSet) =
        Const (UniqSet TyVar, UniqSet Id) (Var Any)
-> (UniqSet TyVar, UniqSet Id)
forall a k (b :: k). Const a b -> a
getConst (Getting
  (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
-> (Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any))
-> Term
-> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf ((forall a. Var a -> Bool)
-> Getting
     (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
forall (f :: Type -> Type) a.
(Contravariant f, Applicative f) =>
(forall a. Var a -> Bool) -> (Var a -> f (Var a)) -> Term -> f Term
termFreeVars' forall a. Var a -> Bool
interesting) Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall a. Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV Term
e)
      boundFTVs :: [TyVar]
boundFTVs = UniqSet TyVar -> [TyVar]
forall a. UniqSet a -> [a]
eltsUniqSet UniqSet TyVar
boundFTVsSet
      boundFVs :: [Id]
boundFVs  = UniqSet Id -> [Id]
forall a. UniqSet a -> [a]
eltsUniqSet UniqSet Id
boundFVsSet

  -- Make a new global ID
  TyConMap
tcm       <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let newBodyTy :: Type
newBodyTy = TyConMap -> Term -> Type
termType TyConMap
tcm (Term -> Type) -> Term -> Type
forall a b. (a -> b) -> a -> b
$ Term -> [TyVar] -> Term
mkTyLams (Term -> [Id] -> Term
mkLams Term
e [Id]
boundFVs) [TyVar]
boundFTVs
  (Id
cf,SrcSpan
sp)   <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
  VarEnv Binding
binders <- Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
  TmName
newBodyNm <-
    VarEnv Binding -> TmName -> RewriteMonad extra TmName
forall (m :: Type -> Type) a.
MonadUnique m =>
VarEnv Binding -> Name a -> m (Name a)
cloneNameWithBindingMap
      VarEnv Binding
binders
      (TmName -> OccName -> TmName
forall a. Name a -> OccName -> Name a
appendToName (Id -> TmName
forall a. Var a -> Name a
varName Id
cf) (OccName
"_" OccName -> OccName -> OccName
`Text.append` TmName -> OccName
forall a. Name a -> OccName
nameOcc TmName
idName))
  let newBodyId :: Id
newBodyId = Type -> TmName -> Id
mkGlobalId Type
newBodyTy TmName
newBodyNm {nameSort :: NameSort
nameSort = NameSort
Internal}

  -- Make a new expression, consisting of the the lifted function applied to
  -- its free variables
  let newExpr :: Term
newExpr = Term -> [Term] -> Term
mkTmApps
                  (Term -> [Type] -> Term
mkTyApps (Id -> Term
Var Id
newBodyId)
                            ((TyVar -> Type) -> [TyVar] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Type
VarTy [TyVar]
boundFTVs))
                  ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
boundFVs)
      inScope0 :: InScopeSet
inScope0 = VarSet -> InScopeSet
mkInScopeSet (UniqSet Id -> VarSet
coerce UniqSet Id
boundFVsSet)
      inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 [Id
var,Id
newBodyId]
  let subst :: Subst
subst    = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope1) Id
var Term
newExpr
      -- Substitute the recursive calls by the new expression
      e' :: Term
e' = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftBinding" Subst
subst Term
e
      -- Create a new body that abstracts over the free variables
      newBody :: Term
newBody = Term -> [TyVar] -> Term
mkTyLams (Term -> [Id] -> Term
mkLams Term
e' [Id]
boundFVs) [TyVar]
boundFTVs

  -- Check if an alpha-equivalent global binder already exists
  [Binding]
aeqExisting <- (VarEnv Binding -> [Binding]
forall a. UniqMap a -> [a]
eltsUniqMap (VarEnv Binding -> [Binding])
-> (VarEnv Binding -> VarEnv Binding)
-> VarEnv Binding
-> [Binding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Binding -> Bool) -> VarEnv Binding -> VarEnv Binding
forall b. (b -> Bool) -> UniqMap b -> UniqMap b
filterUniqMap ((Term -> Term -> Bool
`aeqTerm` Term
newBody) (Term -> Bool) -> (Binding -> Term) -> Binding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> Term
bindingTerm)) (VarEnv Binding -> [Binding])
-> RewriteMonad extra (VarEnv Binding)
-> RewriteMonad extra [Binding]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
  case [Binding]
aeqExisting of
    -- If it doesn't, create a new binder
    [] -> do -- Add the created function to the list of global bindings
             (VarEnv Binding -> Identity (VarEnv Binding))
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings ((VarEnv Binding -> Identity (VarEnv Binding))
 -> RewriteState extra -> Identity (RewriteState extra))
-> (VarEnv Binding -> VarEnv Binding) -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= TmName -> Binding -> VarEnv Binding -> VarEnv Binding
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap TmName
newBodyNm
                                    -- We mark this function as internal so that
                                    -- it can be inlined at the very end of
                                    -- the normalisation pipeline as part of the
                                    -- flattening pass. We don't inline
                                    -- right away because we are lifting this
                                    -- function at this moment for a reason!
                                    -- (termination, CSE and DEC oppertunities,
                                    -- ,etc.)
                                    (Id -> SrcSpan -> InlineSpec -> Term -> Binding
Binding
                                      Id
newBodyId
                                      SrcSpan
sp
#if MIN_VERSION_ghc(8,4,1)
                                      InlineSpec
NoUserInline
#else
                                      EmptyInlineSpec
#endif
                                      Term
newBody)
             -- Return the new binder
             LetBinding -> RewriteMonad extra LetBinding
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
var, Term
newExpr)
    -- If it does, use the existing binder
    (Binding
b:[Binding]
_) ->
      let newExpr' :: Term
newExpr' = Term -> [Term] -> Term
mkTmApps
                      (Term -> [Type] -> Term
mkTyApps (Id -> Term
Var (Id -> Term) -> Id -> Term
forall a b. (a -> b) -> a -> b
$ Binding -> Id
bindingId Binding
b)
                                ((TyVar -> Type) -> [TyVar] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Type
VarTy [TyVar]
boundFTVs))
                      ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
boundFVs)
      in  LetBinding -> RewriteMonad extra LetBinding
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
var, Term
newExpr')

liftBinding LetBinding
_ = String -> RewriteMonad extra LetBinding
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra LetBinding)
-> String -> RewriteMonad extra LetBinding
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"liftBinding: invalid core, expr bound to tyvar"

-- | Ensure that the 'Unique' of a variable does not occur in the 'BindingMap'
uniqAwayBinder
  :: BindingMap
  -> Name a
  -> Name a
uniqAwayBinder :: VarEnv Binding -> Name a -> Name a
uniqAwayBinder VarEnv Binding
binders Name a
nm =
  (Int -> Bool) -> Int -> Name a -> Name a
forall a.
(Uniquable a, ClashPretty a) =>
(Int -> Bool) -> Int -> a -> a
uniqAway' (Int -> VarEnv Binding -> Bool
forall b. Int -> UniqMap b -> Bool
`elemUniqMapDirectly` VarEnv Binding
binders) (Name a -> Int
forall a. Name a -> Int
nameUniq Name a
nm) Name a
nm

-- | Make a global function for a name-term tuple
mkFunction
  :: TmName
  -- ^ Name of the function
  -> SrcSpan
  -> InlineSpec
  -> Term
  -- ^ Term bound to the function
  -> RewriteMonad extra Id
  -- ^ Name with a proper unique and the type of the function
mkFunction :: TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
mkFunction TmName
bndrNm SrcSpan
sp InlineSpec
inl Term
body = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let bodyTy :: Type
bodyTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
body
  VarEnv Binding
binders <- Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
  TmName
bodyNm <- VarEnv Binding -> TmName -> RewriteMonad extra TmName
forall (m :: Type -> Type) a.
MonadUnique m =>
VarEnv Binding -> Name a -> m (Name a)
cloneNameWithBindingMap VarEnv Binding
binders TmName
bndrNm
  TmName
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
forall extra.
TmName
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
addGlobalBind TmName
bodyNm Type
bodyTy SrcSpan
sp InlineSpec
inl Term
body
  Id -> RewriteMonad extra Id
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Type -> TmName -> Id
mkGlobalId Type
bodyTy TmName
bodyNm)

-- | Add a function to the set of global binders
addGlobalBind
  :: TmName
  -> Type
  -> SrcSpan
  -> InlineSpec
  -> Term
  -> RewriteMonad extra ()
addGlobalBind :: TmName
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
addGlobalBind TmName
vNm Type
ty SrcSpan
sp InlineSpec
inl Term
body = do
  let vId :: Id
vId = Type -> TmName -> Id
mkGlobalId Type
ty TmName
vNm
  (Type
ty,Term
body) (Type, Term)
-> ((VarEnv Binding -> Identity (VarEnv Binding))
    -> RewriteState extra -> Identity (RewriteState extra))
-> (VarEnv Binding -> Identity (VarEnv Binding))
-> RewriteState extra
-> Identity (RewriteState extra)
forall a b. NFData a => a -> b -> b
`deepseq` (VarEnv Binding -> Identity (VarEnv Binding))
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings ((VarEnv Binding -> Identity (VarEnv Binding))
 -> RewriteState extra -> Identity (RewriteState extra))
-> (VarEnv Binding -> VarEnv Binding) -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= TmName -> Binding -> VarEnv Binding -> VarEnv Binding
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap TmName
vNm (Id -> SrcSpan -> InlineSpec -> Term -> Binding
Binding Id
vId SrcSpan
sp InlineSpec
inl Term
body)

-- | Create a new name out of the given name, but with another unique. Resulting
-- unique is guaranteed to not be in the given InScopeSet.
cloneNameWithInScopeSet
  :: (MonadUnique m)
  => InScopeSet
  -> Name a
  -> m (Name a)
cloneNameWithInScopeSet :: InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
nm = do
  Int
i <- m Int
forall (m :: Type -> Type). MonadUnique m => m Int
getUniqueM
  Name a -> m (Name a)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (InScopeSet -> Name a -> Name a
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
is (Name a -> Int -> Name a
forall a. Uniquable a => a -> Int -> a
setUnique Name a
nm Int
i))

-- | Create a new name out of the given name, but with another unique. Resulting
-- unique is guaranteed to not be in the given BindingMap.
cloneNameWithBindingMap
  :: (MonadUnique m)
  => BindingMap
  -> Name a
  -> m (Name a)
cloneNameWithBindingMap :: VarEnv Binding -> Name a -> m (Name a)
cloneNameWithBindingMap VarEnv Binding
binders Name a
nm = do
  Int
i <- m Int
forall (m :: Type -> Type). MonadUnique m => m Int
getUniqueM
  Name a -> m (Name a)
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Int -> Bool) -> Int -> Name a -> Name a
forall a.
(Uniquable a, ClashPretty a) =>
(Int -> Bool) -> Int -> a -> a
uniqAway' (Int -> VarEnv Binding -> Bool
forall b. Int -> UniqMap b -> Bool
`elemUniqMapDirectly` VarEnv Binding
binders) Int
i (Name a -> Int -> Name a
forall a. Uniquable a => a -> Int -> a
setUnique Name a
nm Int
i))

{-# INLINE isUntranslatable #-}
-- | Determine if a term cannot be represented in hardware
isUntranslatable
  :: Bool
  -- ^ String representable
  -> Term
  -> RewriteMonad extra Bool
isUntranslatable :: Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
stringRepresentable Term
tm = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Type -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Type
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                             RewriteMonad
  extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra CustomReprs
-> RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad extra CustomReprs
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                             RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra (TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
stringRepresentable
                             RewriteMonad extra (TyConMap -> Type -> Bool)
-> RewriteMonad extra TyConMap -> RewriteMonad extra (Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> TyConMap -> RewriteMonad extra TyConMap
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure TyConMap
tcm
                             RewriteMonad extra (Type -> Bool)
-> RewriteMonad extra Type -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> RewriteMonad extra Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TyConMap -> Term -> Type
termType TyConMap
tcm Term
tm))

{-# INLINE isUntranslatableType #-}
-- | Determine if a type cannot be represented in hardware
isUntranslatableType
  :: Bool
  -- ^ String representable
  -> Type
  -> RewriteMonad extra Bool
isUntranslatableType :: Bool -> Type -> RewriteMonad extra Bool
isUntranslatableType Bool
stringRepresentable Type
ty =
  Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Type -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Type
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                             RewriteMonad
  extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra CustomReprs
-> RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad extra CustomReprs
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                             RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra (TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
stringRepresentable
                             RewriteMonad extra (TyConMap -> Type -> Bool)
-> RewriteMonad extra TyConMap -> RewriteMonad extra (Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
                             RewriteMonad extra (Type -> Bool)
-> RewriteMonad extra Type -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> RewriteMonad extra Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty)

-- | Make a binder that should not be referenced
mkWildValBinder
  :: (MonadUnique m)
  => InScopeSet
  -> Type
  -> m Id
mkWildValBinder :: InScopeSet -> Type -> m Id
mkWildValBinder InScopeSet
is = InScopeSet -> OccName -> Type -> m Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
is OccName
"wild"

-- | Make a case-decomposition that extracts a field out of a (Sum-of-)Product type
mkSelectorCase
  :: HasCallStack
  => (Functor m, MonadUnique m)
  => String -- ^ Name of the caller of this function
  -> InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Term -- ^ Subject of the case-composition
  -> Int -- n'th DataCon
  -> Int -- n'th field
  -> m Term
mkSelectorCase :: String -> InScopeSet -> TyConMap -> Term -> Int -> Int -> m Term
mkSelectorCase String
caller InScopeSet
inScope TyConMap
tcm Term
scrut Int
dcI Int
fieldI = Type -> m Term
forall (m :: Type -> Type). MonadUnique m => Type -> m Term
go (TyConMap -> Term -> Type
termType TyConMap
tcm Term
scrut)
  where
    go :: Type -> m Term
go (TyConMap -> Type -> Maybe Type
coreView1 TyConMap
tcm -> Just Type
ty') = Type -> m Term
go Type
ty'
    go scrutTy :: Type
scrutTy@(Type -> TypeView
tyView -> TyConApp TyConName
tc [Type]
args) =
      case TyCon -> [DataCon]
tyConDataCons (TyConMap -> TyConName -> TyCon
forall a b. (HasCallStack, Uniquable a) => UniqMap b -> a -> b
lookupUniqMap' TyConMap
tcm TyConName
tc) of
        [] -> String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(String
curLoc) (String
"TyCon has no DataCons: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TyConName -> String
forall a. Show a => a -> String
show TyConName
tc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TyConName -> String
forall p. PrettyPrec p => p -> String
showPpr TyConName
tc) Type
scrutTy
        [DataCon]
dcs | Int
dcI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> [DataCon] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [DataCon]
dcs -> String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(String
curLoc) String
"DC index exceeds max" Type
scrutTy
            | Bool
otherwise -> do
          let dc :: DataCon
dc = String -> [DataCon] -> Int -> DataCon
forall a. HasCallStack => String -> [a] -> Int -> a
indexNote ($(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"No DC with tag: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Int
dcIInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) [DataCon]
dcs (Int
dcIInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
          let (Just [Type]
fieldTys) = HasCallStack =>
InScopeSet -> TyConMap -> DataCon -> [Type] -> Maybe [Type]
InScopeSet -> TyConMap -> DataCon -> [Type] -> Maybe [Type]
dataConInstArgTysE InScopeSet
inScope TyConMap
tcm DataCon
dc [Type]
args
          if Int
fieldI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Type]
fieldTys
            then String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(String
curLoc) String
"Field index exceed max" Type
scrutTy
            else do
              [Id]
wildBndrs <- (Type -> m Id) -> [Type] -> m [Id]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet -> Type -> m Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Type -> m Id
mkWildValBinder InScopeSet
inScope) [Type]
fieldTys
              let ty :: Type
ty = String -> [Type] -> Int -> Type
forall a. HasCallStack => String -> [a] -> Int -> a
indexNote ($(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"No DC field#: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
fieldI) [Type]
fieldTys Int
fieldI
              Id
selBndr <- InScopeSet -> OccName -> Type -> m Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
inScope OccName
"sel" Type
ty
              let bndrs :: [Id]
bndrs  = Int -> [Id] -> [Id]
forall a. Int -> [a] -> [a]
take Int
fieldI [Id]
wildBndrs [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ [Id
selBndr] [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ Int -> [Id] -> [Id]
forall a. Int -> [a] -> [a]
drop (Int
fieldIInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Id]
wildBndrs
                  pat :: Pat
pat    = DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc (DataCon -> [TyVar]
dcExtTyVars DataCon
dc) [Id]
bndrs
                  retVal :: Term
retVal = Term -> Type -> [Alt] -> Term
Case Term
scrut Type
ty [ (Pat
pat, Id -> Term
Var Id
selBndr) ]
              Term -> m Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
retVal
    go Type
scrutTy = String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(String
curLoc) (String
"Type of subject is not a datatype: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
scrutTy) Type
scrutTy

    cantCreate :: String -> String -> p -> a
cantCreate String
loc String
info p
scrutTy = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
loc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Can't create selector " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (String, Int, Int) -> String
forall a. Show a => a -> String
show (String
caller,Int
dcI,Int
fieldI) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" for: (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
scrut String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" :: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ p -> String
forall p. PrettyPrec p => p -> String
showPpr p
scrutTy String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")\nAdditional info: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
info

-- | Specialise an application on its argument
specialise :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id) -- ^ Lens into previous specialisations
           -> Lens' extra (VarEnv Int) -- ^ Lens into the specialisation history
           -> Lens' extra Int -- ^ Lens into the specialisation limit
           -> Rewrite extra
specialise :: Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int) -> Lens' extra Int -> Rewrite extra
specialise Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl Lens' extra (VarEnv Int)
specHistLbl Lens' extra Int
specLimitLbl TransformContext
ctx Term
e = case Term
e of
  (TyApp Term
e1 Type
ty) -> Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
forall extra.
Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl Lens' extra (VarEnv Int)
specHistLbl Lens' extra Int
specLimitLbl TransformContext
ctx Term
e (Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e1) (Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty)
  (App Term
e1 Term
e2)   -> Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
forall extra.
Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl Lens' extra (VarEnv Int)
specHistLbl Lens' extra Int
specLimitLbl TransformContext
ctx Term
e (Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e1) (Term -> Either Term Type
forall a b. a -> Either a b
Left  Term
e2)
  Term
_             -> Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

-- | Specialise an application on its argument
specialise' :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id) -- ^ Lens into previous specialisations
            -> Lens' extra (VarEnv Int) -- ^ Lens into specialisation history
            -> Lens' extra Int -- ^ Lens into the specialisation limit
            -> TransformContext -- Transformation context
            -> Term -- ^ Original term
            -> (Term, [Either Term Type], [TickInfo]) -- ^ Function part of the term, split into root and applied arguments
            -> Either Term Type -- ^ Argument to specialize on
            -> RewriteMonad extra Term
specialise' :: Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl Lens' extra (VarEnv Int)
specHistLbl Lens' extra Int
specLimitLbl (TransformContext InScopeSet
is0 Context
_) Term
e (Var Id
f, [Either Term Type]
args, [TickInfo]
ticks) Either Term Type
specArgIn = do
  DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad extra DebugLevel
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache

  -- Don't specialise TopEntities
  VarSet
topEnts <- Getting VarSet RewriteEnv VarSet -> RewriteMonad extra VarSet
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
  if Id
f Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`elemVarSet` VarSet
topEnts
  then do
    case Either Term Type
specArgIn of
      Left Term
_ -> Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugNone) (String
"Not specializing TopEntity: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f)) (Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e)
      Right Type
tyArg -> Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied) (String
"Dropping type application on TopEntity: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\ntype:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
tyArg) (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
        -- TopEntities aren't allowed to be semantically polymorphic.
        -- But using type equality constraints they may be syntactically polymorphic.
        -- > topEntity :: forall dom . (dom ~ "System") => Signal dom Bool -> Signal dom Bool
        -- The TyLam's in the body will have been removed by 'Clash.Normalize.Util.substWithTyEq'.
        -- So we drop the TyApp ("specialising" on it) and change the varType to match.
        let newVarTy :: Type
newVarTy = HasCallStack => TyConMap -> Type -> Type -> Type
TyConMap -> Type -> Type -> Type
piResultTy TyConMap
tcm (Id -> Type
forall a. Var a -> Type
varType Id
f) Type
tyArg
        in  Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
f{varType :: Type
varType = Type
newVarTy}) [TickInfo]
ticks) [Either Term Type]
args)
  else do -- NondecreasingIndentation

  let specArg :: Either Term Type
specArg = (Term -> Term)
-> (Type -> Type) -> Either Term Type -> Either Term Type
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm) (TyConMap -> Type -> Type
normalizeType TyConMap
tcm) Either Term Type
specArgIn
      -- Create binders and variable references for free variables in 'specArg'
      -- (specBndrsIn,specVars) :: ([Either Id TyVar], [Either Term Type])
      ([Either Id TyVar]
specBndrsIn,[Either Term Type]
specVars) = Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars Either Term Type
specArg
      argLen :: Int
argLen  = [Either Term Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Either Term Type]
args
      specBndrs :: [Either Id TyVar]
      specBndrs :: [Either Id TyVar]
specBndrs = (Either Id TyVar -> Either Id TyVar)
-> [Either Id TyVar] -> [Either Id TyVar]
forall a b. (a -> b) -> [a] -> [b]
map (ASetter (Either Id TyVar) (Either Id TyVar) Id Id
-> (Id -> Id) -> Either Id TyVar -> Either Id TyVar
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
Lens.over ASetter (Either Id TyVar) (Either Id TyVar) Id Id
forall a c b. Prism (Either a c) (Either b c) a b
_Left (TyConMap -> Id -> Id
normalizeId TyConMap
tcm)) [Either Id TyVar]
specBndrsIn
      specAbs :: Either Term Type
      specAbs :: Either Term Type
specAbs = (Term -> Either Term Type)
-> (Type -> Either Term Type)
-> Either Term Type
-> Either Term Type
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Term -> Either Term Type
forall a b. a -> Either a b
Left (Term -> Either Term Type)
-> (Term -> Term) -> Term -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term -> [Either Id TyVar] -> Term
`mkAbstraction` [Either Id TyVar]
specBndrs)) (Type -> Either Term Type
forall a b. b -> Either a b
Right (Type -> Either Term Type)
-> (Type -> Type) -> Type -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall a. a -> a
id) Either Term Type
specArg
  -- Determine if 'f' has already been specialized on (a type-normalized) 'specArg'
  Maybe Id
specM <- (Id, Int, Either Term Type)
-> Map (Id, Int, Either Term Type) Id -> Maybe Id
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Id
f,Int
argLen,Either Term Type
specAbs) (Map (Id, Int, Either Term Type) Id -> Maybe Id)
-> RewriteMonad extra (Map (Id, Int, Either Term Type) Id)
-> RewriteMonad extra (Maybe Id)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (Map (Id, Int, Either Term Type) Id)
  (RewriteState extra)
  (Map (Id, Int, Either Term Type) Id)
-> RewriteMonad extra (Map (Id, Int, Either Term Type) Id)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((extra -> Const (Map (Id, Int, Either Term Type) Id) extra)
-> RewriteState extra
-> Const (Map (Id, Int, Either Term Type) Id) (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Const (Map (Id, Int, Either Term Type) Id) extra)
 -> RewriteState extra
 -> Const (Map (Id, Int, Either Term Type) Id) (RewriteState extra))
-> ((Map (Id, Int, Either Term Type) Id
     -> Const
          (Map (Id, Int, Either Term Type) Id)
          (Map (Id, Int, Either Term Type) Id))
    -> extra -> Const (Map (Id, Int, Either Term Type) Id) extra)
-> Getting
     (Map (Id, Int, Either Term Type) Id)
     (RewriteState extra)
     (Map (Id, Int, Either Term Type) Id)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map (Id, Int, Either Term Type) Id
 -> Const
      (Map (Id, Int, Either Term Type) Id)
      (Map (Id, Int, Either Term Type) Id))
-> extra -> Const (Map (Id, Int, Either Term Type) Id) extra
Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl)
  case Maybe Id
specM of
    -- Use previously specialized function
    Just Id
f' ->
      Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied)
        (String
"Using previous specialization of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" on " String -> String -> String
forall a. [a] -> [a] -> [a]
++
          ((Term -> String) -> (Type -> String) -> Either Term Type -> String
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> String
forall p. PrettyPrec p => p -> String
showPpr Type -> String
forall p. PrettyPrec p => p -> String
showPpr) Either Term Type
specAbs String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f')) (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
        Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad extra Term)
-> Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
f') [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
specVars)
    -- Create new specialized function
    Maybe Id
Nothing -> do
      -- Determine if we can specialize f
      Maybe Binding
bodyMaybe <- (VarEnv Binding -> Maybe Binding)
-> RewriteMonad extra (VarEnv Binding)
-> RewriteMonad extra (Maybe Binding)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (TmName -> VarEnv Binding -> Maybe Binding
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap (Id -> TmName
forall a. Var a -> Name a
varName Id
f)) (RewriteMonad extra (VarEnv Binding)
 -> RewriteMonad extra (Maybe Binding))
-> RewriteMonad extra (VarEnv Binding)
-> RewriteMonad extra (Maybe Binding)
forall a b. (a -> b) -> a -> b
$ Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
      case Maybe Binding
bodyMaybe of
        Just (Binding Id
_ SrcSpan
sp InlineSpec
inl Term
bodyTm) -> do
          -- Determine if we see a sequence of specialisations on a growing argument
          Maybe Int
specHistM <- Id -> VarEnv Int -> Maybe Int
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap Id
f (VarEnv Int -> Maybe Int)
-> RewriteMonad extra (VarEnv Int)
-> RewriteMonad extra (Maybe Int)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (VarEnv Int) (RewriteState extra) (VarEnv Int)
-> RewriteMonad extra (VarEnv Int)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((extra -> Const (VarEnv Int) extra)
-> RewriteState extra -> Const (VarEnv Int) (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Const (VarEnv Int) extra)
 -> RewriteState extra -> Const (VarEnv Int) (RewriteState extra))
-> ((VarEnv Int -> Const (VarEnv Int) (VarEnv Int))
    -> extra -> Const (VarEnv Int) extra)
-> Getting (VarEnv Int) (RewriteState extra) (VarEnv Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Int -> Const (VarEnv Int) (VarEnv Int))
-> extra -> Const (VarEnv Int) extra
Lens' extra (VarEnv Int)
specHistLbl)
          Int
specLim   <- Getting Int (RewriteState extra) Int -> RewriteMonad extra Int
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((extra -> Const Int extra)
-> RewriteState extra -> Const Int (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra ((extra -> Const Int extra)
 -> RewriteState extra -> Const Int (RewriteState extra))
-> ((Int -> Const Int Int) -> extra -> Const Int extra)
-> Getting Int (RewriteState extra) Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Const Int Int) -> extra -> Const Int extra
Lens' extra Int
specLimitLbl)
          if Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
specLim) Maybe Int
specHistM
            then ClashException -> RewriteMonad extra Term
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException
                        SrcSpan
sp
                        ([String] -> String
unlines [ String
"Hit specialisation limit " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
specLim String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" on function `" String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'.\n"
                                 , String
"The function `" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
f String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
                                 , String
"Body of `" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
f String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"':\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
bodyTm String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
                                 , String
"Argument (in position: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
argLen String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
") that triggered termination:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((Term -> String) -> (Type -> String) -> Either Term Type -> String
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> String
forall p. PrettyPrec p => p -> String
showPpr Type -> String
forall p. PrettyPrec p => p -> String
showPpr) Either Term Type
specArg
                                 , String
"Run with '-fclash-spec-limit=N' to increase the specialisation limit to N."
                                 ])
                        Maybe String
forall a. Maybe a
Nothing)
            else do
              let existingNames :: [Name a]
existingNames = Term -> [Name a]
forall a. Term -> [Name a]
collectBndrsMinusApps Term
bodyTm
                  newNames :: [Name a]
newNames      = [ OccName -> Int -> Name a
forall a. OccName -> Int -> Name a
mkUnsafeInternalName (OccName
"pTS" OccName -> OccName -> OccName
`Text.append` String -> OccName
Text.pack (Int -> String
forall a. Show a => a -> String
show Int
n)) Int
n
                                  | Int
n <- [(Int
0::Int)..]
                                  ]
              -- Make new binders for existing arguments
              ([Either Id TyVar]
boundArgs,[Either Term Type]
argVars) <- ([Either Id TyVar] -> ([Either Id TyVar], [Either Term Type]))
-> RewriteMonad extra [Either Id TyVar]
-> RewriteMonad extra ([Either Id TyVar], [Either Term Type])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(Either Id TyVar, Either Term Type)]
-> ([Either Id TyVar], [Either Term Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Either Id TyVar, Either Term Type)]
 -> ([Either Id TyVar], [Either Term Type]))
-> ([Either Id TyVar] -> [(Either Id TyVar, Either Term Type)])
-> [Either Id TyVar]
-> ([Either Id TyVar], [Either Term Type])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Either Id TyVar -> (Either Id TyVar, Either Term Type))
-> [Either Id TyVar] -> [(Either Id TyVar, Either Term Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Id -> (Either Id TyVar, Either Term Type))
-> (TyVar -> (Either Id TyVar, Either Term Type))
-> Either Id TyVar
-> (Either Id TyVar, Either Term Type)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Id -> Either Id TyVar
forall a b. a -> Either a b
Left (Id -> Either Id TyVar)
-> (Id -> Either Term Type)
-> Id
-> (Either Id TyVar, Either Term Type)
forall (a :: Type -> Type -> Type) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Term -> Either Term Type
forall a b. a -> Either a b
Left (Term -> Either Term Type)
-> (Id -> Term) -> Id -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Term
Var) (TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right (TyVar -> Either Id TyVar)
-> (TyVar -> Either Term Type)
-> TyVar
-> (Either Id TyVar, Either Term Type)
forall (a :: Type -> Type -> Type) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Type -> Either Term Type
forall a b. b -> Either a b
Right (Type -> Either Term Type)
-> (TyVar -> Type) -> TyVar -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Type
VarTy))) (RewriteMonad extra [Either Id TyVar]
 -> RewriteMonad extra ([Either Id TyVar], [Either Term Type]))
-> RewriteMonad extra [Either Id TyVar]
-> RewriteMonad extra ([Either Id TyVar], [Either Term Type])
forall a b. (a -> b) -> a -> b
$
                                     (Name Any
 -> Either Term Type -> RewriteMonad extra (Either Id TyVar))
-> [Name Any]
-> [Either Term Type]
-> RewriteMonad extra [Either Id TyVar]
forall (m :: Type -> Type) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM
                                       (InScopeSet
-> TyConMap
-> Name Any
-> Either Term Type
-> RewriteMonad extra (Either Id TyVar)
forall (m :: Type -> Type) a.
(MonadUnique m, MonadFail m) =>
InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor InScopeSet
is0 TyConMap
tcm)
                                       ([Name Any]
forall a. [Name a]
existingNames [Name Any] -> [Name Any] -> [Name Any]
forall a. [a] -> [a] -> [a]
++ [Name Any]
forall a. [Name a]
newNames)
                                       [Either Term Type]
args
              -- Determine name the resulting specialized function, and the
              -- form of the specialized-on argument
              (Id
fId,InlineSpec
inl',Either Term Type
specArg') <- case Either Term Type
specArg of
                Left a :: Term
a@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Var Id
g,[Either Term Type]
gArgs,[TickInfo]
_gTicks)) -> if TyConMap -> Term -> Bool
isPolyFun TyConMap
tcm Term
a
                    then do
                      -- In case we are specialising on an argument that is a
                      -- global function then we use that function's name as the
                      -- name of the specialized higher-order function.
                      -- Additionally, we will return the body of the global
                      -- function, instead of a variable reference to the
                      -- global function.
                      --
                      -- This will turn things like @mealy g k@ into a new
                      -- binding @g'@ where both the body of @mealy@ and @g@
                      -- are inlined, meaning the state-transition-function
                      -- and the memory element will be in a single function.
                      Maybe Binding
gTmM <- (VarEnv Binding -> Maybe Binding)
-> RewriteMonad extra (VarEnv Binding)
-> RewriteMonad extra (Maybe Binding)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (TmName -> VarEnv Binding -> Maybe Binding
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap (Id -> TmName
forall a. Var a -> Name a
varName Id
g)) (RewriteMonad extra (VarEnv Binding)
 -> RewriteMonad extra (Maybe Binding))
-> RewriteMonad extra (VarEnv Binding)
-> RewriteMonad extra (Maybe Binding)
forall a b. (a -> b) -> a -> b
$ Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
                      (Id, InlineSpec, Either Term Type)
-> RewriteMonad extra (Id, InlineSpec, Either Term Type)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
g,InlineSpec
-> (Binding -> InlineSpec) -> Maybe Binding -> InlineSpec
forall b a. b -> (a -> b) -> Maybe a -> b
maybe InlineSpec
inl Binding -> InlineSpec
bindingSpec Maybe Binding
gTmM, Either Term Type
-> (Binding -> Either Term Type)
-> Maybe Binding
-> Either Term Type
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either Term Type
specArg (Term -> Either Term Type
forall a b. a -> Either a b
Left (Term -> Either Term Type)
-> (Binding -> Term) -> Binding -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term -> [Either Term Type] -> Term
`mkApps` [Either Term Type]
gArgs) (Term -> Term) -> (Binding -> Term) -> Binding -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> Term
bindingTerm) Maybe Binding
gTmM)
                    else (Id, InlineSpec, Either Term Type)
-> RewriteMonad extra (Id, InlineSpec, Either Term Type)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
f,InlineSpec
inl,Either Term Type
specArg)
                Either Term Type
_ -> (Id, InlineSpec, Either Term Type)
-> RewriteMonad extra (Id, InlineSpec, Either Term Type)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
f,InlineSpec
inl,Either Term Type
specArg)
              -- Create specialized functions
              let newBody :: Term
newBody = Term -> [Either Id TyVar] -> Term
mkAbstraction (Term -> [Either Term Type] -> Term
mkApps Term
bodyTm ([Either Term Type]
argVars [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type
specArg'])) ([Either Id TyVar]
boundArgs [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. [a] -> [a] -> [a]
++ [Either Id TyVar]
specBndrs)
              Id
newf <- TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
forall extra.
TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
mkFunction (Id -> TmName
forall a. Var a -> Name a
varName Id
fId) SrcSpan
sp InlineSpec
inl' Term
newBody
              -- Remember specialization
              ((extra -> Identity extra)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Identity extra)
 -> RewriteState extra -> Identity (RewriteState extra))
-> ((VarEnv Int -> Identity (VarEnv Int))
    -> extra -> Identity extra)
-> (VarEnv Int -> Identity (VarEnv Int))
-> RewriteState extra
-> Identity (RewriteState extra)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Int -> Identity (VarEnv Int)) -> extra -> Identity extra
Lens' extra (VarEnv Int)
specHistLbl) ((VarEnv Int -> Identity (VarEnv Int))
 -> RewriteState extra -> Identity (RewriteState extra))
-> (VarEnv Int -> VarEnv Int) -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id -> Int -> (Int -> Int -> Int) -> VarEnv Int -> VarEnv Int
forall a b.
Uniquable a =>
a -> b -> (b -> b -> b) -> UniqMap b -> UniqMap b
extendUniqMapWith Id
f Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)
              ((extra -> Identity extra)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Identity extra)
 -> RewriteState extra -> Identity (RewriteState extra))
-> ((Map (Id, Int, Either Term Type) Id
     -> Identity (Map (Id, Int, Either Term Type) Id))
    -> extra -> Identity extra)
-> (Map (Id, Int, Either Term Type) Id
    -> Identity (Map (Id, Int, Either Term Type) Id))
-> RewriteState extra
-> Identity (RewriteState extra)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map (Id, Int, Either Term Type) Id
 -> Identity (Map (Id, Int, Either Term Type) Id))
-> extra -> Identity extra
Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl)  ((Map (Id, Int, Either Term Type) Id
  -> Identity (Map (Id, Int, Either Term Type) Id))
 -> RewriteState extra -> Identity (RewriteState extra))
-> (Map (Id, Int, Either Term Type) Id
    -> Map (Id, Int, Either Term Type) Id)
-> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= (Id, Int, Either Term Type)
-> Id
-> Map (Id, Int, Either Term Type) Id
-> Map (Id, Int, Either Term Type) Id
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Id
f,Int
argLen,Either Term Type
specAbs) Id
newf
              -- use specialized function
              let newExpr :: Term
newExpr = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
newf) [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
specVars)
              Id
newf Id -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. NFData a => a -> b -> b
`deepseq` Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
newExpr
        Maybe Binding
Nothing -> Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
  where
    collectBndrsMinusApps :: Term -> [Name a]
    collectBndrsMinusApps :: Term -> [Name a]
collectBndrsMinusApps = [Name a] -> [Name a]
forall a. [a] -> [a]
reverse ([Name a] -> [Name a]) -> (Term -> [Name a]) -> Term -> [Name a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name a] -> Term -> [Name a]
forall a.
(Coercible a TmName, Coercible a TyName) =>
[a] -> Term -> [a]
go []
      where
        go :: [a] -> Term -> [a]
go [a]
bs (Lam Id
v Term
e')    = [a] -> Term -> [a]
go (TmName -> a
coerce (Id -> TmName
forall a. Var a -> Name a
varName Id
v)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs)  Term
e'
        go [a]
bs (TyLam TyVar
tv Term
e') = [a] -> Term -> [a]
go (TyName -> a
coerce (TyVar -> TyName
forall a. Var a -> Name a
varName TyVar
tv)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs) Term
e'
        go [a]
bs (App Term
e' Term
_) = case [a] -> Term -> [a]
go [] Term
e' of
          []  -> [a]
bs
          [a]
bs' -> [a] -> [a]
forall a. [a] -> [a]
init [a]
bs' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
bs
        go [a]
bs (TyApp Term
e' Type
_) = case [a] -> Term -> [a]
go [] Term
e' of
          []  -> [a]
bs
          [a]
bs' -> [a] -> [a]
forall a. [a] -> [a]
init [a]
bs' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
bs
        go [a]
bs Term
_ = [a]
bs

specialise' Lens' extra (Map (Id, Int, Either Term Type) Id)
_ Lens' extra (VarEnv Int)
_ Lens' extra Int
_ TransformContext
_ctx Term
_ (Term
appE,[Either Term Type]
args,[TickInfo]
ticks) (Left Term
specArg) = do
  -- Create binders and variable references for free variables in 'specArg'
  let ([Either Id TyVar]
specBndrs,[Either Term Type]
specVars) = Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars (Term -> Either Term Type
forall a b. a -> Either a b
Left Term
specArg)
  -- Create specialized function
      newBody :: Term
newBody = Term -> [Either Id TyVar] -> Term
mkAbstraction Term
specArg [Either Id TyVar]
specBndrs
  -- See if there's an existing binder that's alpha-equivalent to the
  -- specialized function
  VarEnv Binding
existing <- (Binding -> Bool) -> VarEnv Binding -> VarEnv Binding
forall b. (b -> Bool) -> UniqMap b -> UniqMap b
filterUniqMap ((Term -> Term -> Bool
`aeqTerm` Term
newBody) (Term -> Bool) -> (Binding -> Term) -> Binding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> Term
bindingTerm) (VarEnv Binding -> VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
  -- Create a new function if an alpha-equivalent binder doesn't exist
  Id
newf <- case VarEnv Binding -> [Binding]
forall a. UniqMap a -> [a]
eltsUniqMap VarEnv Binding
existing of
    [] -> do (Id
cf,SrcSpan
sp) <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
             TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
forall extra.
TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
mkFunction (TmName -> OccName -> TmName
forall a. Name a -> OccName -> Name a
appendToName (Id -> TmName
forall a. Var a -> Name a
varName Id
cf) OccName
"_specF")
                        SrcSpan
sp
#if MIN_VERSION_ghc(8,4,1)
                        InlineSpec
NoUserInline
#else
                        EmptyInlineSpec
#endif
                        Term
newBody
    (Binding
b:[Binding]
_) -> Id -> RewriteMonad extra Id
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Binding -> Id
bindingId Binding
b)
  -- Create specialized argument
  let newArg :: Either Term b
newArg  = Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> Term -> Either Term b
forall a b. (a -> b) -> a -> b
$ Term -> [Either Term Type] -> Term
mkApps (Id -> Term
Var Id
newf) [Either Term Type]
specVars
  -- Use specialized argument
  let newExpr :: Term
newExpr = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
appE [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type
forall b. Either Term b
newArg])
  Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
newExpr

specialise' Lens' extra (Map (Id, Int, Either Term Type) Id)
_ Lens' extra (VarEnv Int)
_ Lens' extra Int
_ TransformContext
_ Term
e (Term, [Either Term Type], [TickInfo])
_ Either Term Type
_ = Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm Term
e = case Term
e of
  Cast Term
e' Type
ty1 Type
ty2 -> Term -> Type -> Type -> Term
Cast (TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm Term
e') (TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
ty1) (TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
ty2)
  Var Id
v -> Id -> Term
Var (TyConMap -> Id -> Id
normalizeId TyConMap
tcm Id
v)
  -- TODO other terms?
  Term
_ -> Term
e

normalizeId :: TyConMap -> Id -> Id
normalizeId :: TyConMap -> Id -> Id
normalizeId TyConMap
tcm v :: Id
v@(Id {}) = Id
v {varType :: Type
varType = TyConMap -> Type -> Type
normalizeType TyConMap
tcm (Id -> Type
forall a. Var a -> Type
varType Id
v)}
normalizeId TyConMap
_   Id
tyvar     = Id
tyvar

-- Note [Collect free-variables in an insertion-ordered set]
--
-- In order for the specialization cache to work, 'specArgBndrsAndVars' should
-- yield (alpha equivalent) results for the same specialization. While collecting
-- free variables in a given term or type it should therefore keep a stable
-- ordering based on the order in which it finds free vars. To see why,
-- consider the following two pseudo-code calls to 'specialise':
--
--     specialise {f ('a', x[123], y[456])}
--     specialise {f ('b', x[456], y[123])}
--
-- Collecting the binders in a VarSet would yield the following (unique ordered)
-- sets:
--
--     {x[123], y[456]}
--     {y[123], x[456]}
--
-- ..and therefore breaking specializing caching. We now track them in insert-
-- ordered sets, yielding:
--
--     {x[123], y[456]}
--     {x[456], y[123]}
--

-- | Create binders and variable references for free variables in 'specArg'
specArgBndrsAndVars
  :: Either Term Type
  -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars :: Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars Either Term Type
specArg =
  -- See Note [Collect free-variables in an insertion-ordered set]
  let unitFV :: Var a -> Const (OSet.OLSet TyVar, OSet.OLSet Id) (Var a)
      unitFV :: Var a -> Const (OLSet TyVar, OLSet Id) (Var a)
unitFV v :: Var a
v@(Id {}) = (OLSet TyVar, OLSet Id) -> Const (OLSet TyVar, OLSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (OLSet TyVar
forall a. Monoid a => a
mempty, OSet Id -> OLSet Id
coerce (Id -> OSet Id
forall a. a -> OSet a
OSet.singleton (Var a -> Id
coerce Var a
v)))
      unitFV v :: Var a
v@(TyVar {}) = (OLSet TyVar, OLSet Id) -> Const (OLSet TyVar, OLSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (OSet TyVar -> OLSet TyVar
coerce (TyVar -> OSet TyVar
forall a. a -> OSet a
OSet.singleton (Var a -> TyVar
coerce Var a
v)), OLSet Id
forall a. Monoid a => a
mempty)

      ([TyVar]
specFTVs,[Id]
specFVs) = case Either Term Type
specArg of
        Left Term
tm  -> (OLSet TyVar -> [TyVar]
forall a. OLSet a -> [a]
OSet.toListL (OLSet TyVar -> [TyVar])
-> (OLSet Id -> [Id]) -> (OLSet TyVar, OLSet Id) -> ([TyVar], [Id])
forall (a :: Type -> Type -> Type) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** OLSet Id -> [Id]
forall a. OLSet a -> [a]
OSet.toListL) ((OLSet TyVar, OLSet Id) -> ([TyVar], [Id]))
-> (Const (OLSet TyVar, OLSet Id) (Var Any)
    -> (OLSet TyVar, OLSet Id))
-> Const (OLSet TyVar, OLSet Id) (Var Any)
-> ([TyVar], [Id])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Const (OLSet TyVar, OLSet Id) (Var Any) -> (OLSet TyVar, OLSet Id)
forall a k (b :: k). Const a b -> a
getConst (Const (OLSet TyVar, OLSet Id) (Var Any) -> ([TyVar], [Id]))
-> Const (OLSet TyVar, OLSet Id) (Var Any) -> ([TyVar], [Id])
forall a b. (a -> b) -> a -> b
$
                    Getting (Const (OLSet TyVar, OLSet Id) (Var Any)) Term (Var Any)
-> (Var Any -> Const (OLSet TyVar, OLSet Id) (Var Any))
-> Term
-> Const (OLSet TyVar, OLSet Id) (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (Const (OLSet TyVar, OLSet Id) (Var Any)) Term (Var Any)
forall a. Fold Term (Var a)
freeLocalVars Var Any -> Const (OLSet TyVar, OLSet Id) (Var Any)
forall a. Var a -> Const (OLSet TyVar, OLSet Id) (Var a)
unitFV Term
tm
        Right Type
ty -> (UniqSet TyVar -> [TyVar]
forall a. UniqSet a -> [a]
eltsUniqSet (Getting (UniqSet TyVar) Type TyVar
-> (TyVar -> UniqSet TyVar) -> Type -> UniqSet TyVar
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet TyVar) Type TyVar
Fold Type TyVar
typeFreeVars TyVar -> UniqSet TyVar
forall a. Uniquable a => a -> UniqSet a
unitUniqSet Type
ty),[] :: [Id])

      specTyBndrs :: [Either a TyVar]
specTyBndrs = (TyVar -> Either a TyVar) -> [TyVar] -> [Either a TyVar]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Either a TyVar
forall a b. b -> Either a b
Right [TyVar]
specFTVs
      specTmBndrs :: [Either Id b]
specTmBndrs = (Id -> Either Id b) -> [Id] -> [Either Id b]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Either Id b
forall a b. a -> Either a b
Left  [Id]
specFVs

      specTyVars :: [Either a Type]
specTyVars  = (TyVar -> Either a Type) -> [TyVar] -> [Either a Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Either a Type
forall a b. b -> Either a b
Right (Type -> Either a Type)
-> (TyVar -> Type) -> TyVar -> Either a Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Type
VarTy) [TyVar]
specFTVs
      specTmVars :: [Either Term b]
specTmVars  = (Id -> Either Term b) -> [Id] -> [Either Term b]
forall a b. (a -> b) -> [a] -> [b]
map (Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> (Id -> Term) -> Id -> Either Term b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Term
Var) [Id]
specFVs

  in  ([Either Id TyVar]
forall a. [Either a TyVar]
specTyBndrs [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. [a] -> [a] -> [a]
++ [Either Id TyVar]
forall b. [Either Id b]
specTmBndrs,[Either Term Type]
forall a. [Either a Type]
specTyVars [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
forall b. [Either Term b]
specTmVars)

-- | Evaluate an expression to weak-head normal form (WHNF), and apply a
-- transformation on the expression in WHNF.
whnfRW
  :: Bool
  -- ^ Whether the expression we're reducing to WHNF is the subject of a
  -- case expression.
  -> TransformContext
  -> Term
  -> Rewrite extra
  -> RewriteMonad extra Term
whnfRW :: Bool
-> TransformContext
-> Term
-> Rewrite extra
-> RewriteMonad extra Term
whnfRW Bool
isSubj ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Term
e Rewrite extra
rw = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  VarEnv Binding
bndrs <- Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
-> RewriteMonad extra (VarEnv Binding)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (VarEnv Binding) (RewriteState extra) (VarEnv Binding)
forall extra1. Lens' (RewriteState extra1) (VarEnv Binding)
bindings
  (PrimStep
primEval, PrimUnwind
primUnwind) <- Getting (PrimStep, PrimUnwind) RewriteEnv (PrimStep, PrimUnwind)
-> RewriteMonad extra (PrimStep, PrimUnwind)
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (PrimStep, PrimUnwind) RewriteEnv (PrimStep, PrimUnwind)
Lens' RewriteEnv (PrimStep, PrimUnwind)
evaluator
  Supply
ids <- Getting Supply (RewriteState extra) Supply
-> RewriteMonad extra Supply
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting Supply (RewriteState extra) Supply
forall extra1. Lens' (RewriteState extra1) Supply
uniqSupply
  let (Supply
ids1,Supply
ids2) = Supply -> (Supply, Supply)
splitSupply Supply
ids
  (Supply -> Identity Supply)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) Supply
uniqSupply ((Supply -> Identity Supply)
 -> RewriteState extra -> Identity (RewriteState extra))
-> Supply -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= Supply
ids2
  PrimHeap
gh <- Getting PrimHeap (RewriteState extra) PrimHeap
-> RewriteMonad extra PrimHeap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting PrimHeap (RewriteState extra) PrimHeap
forall extra1. Lens' (RewriteState extra1) PrimHeap
globalHeap
  case PrimStep
-> PrimUnwind
-> VarEnv Binding
-> TyConMap
-> PrimHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (PrimHeap, IdSubstEnv, Term)
whnf' PrimStep
primEval PrimUnwind
primUnwind VarEnv Binding
bndrs TyConMap
tcm PrimHeap
gh Supply
ids1 InScopeSet
is0 Bool
isSubj Term
e of
    (!PrimHeap
gh1,IdSubstEnv
ph,Term
v) -> do
      (PrimHeap -> Identity PrimHeap)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) PrimHeap
globalHeap ((PrimHeap -> Identity PrimHeap)
 -> RewriteState extra -> Identity (RewriteState extra))
-> PrimHeap -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= PrimHeap
gh1
      TyConMap -> IdSubstEnv -> Rewrite extra -> Rewrite extra
forall extra.
TyConMap -> IdSubstEnv -> Rewrite extra -> Rewrite extra
bindPureHeap TyConMap
tcm IdSubstEnv
ph Rewrite extra
rw TransformContext
ctx Term
v
{-# SCC whnfRW #-}

-- | Binds variables on the PureHeap over the result of the rewrite
--
-- To prevent unnecessary rewrites only do this when rewrite changed something.
bindPureHeap
  :: TyConMap
  -> PureHeap
  -> Rewrite extra
  -> Rewrite extra
bindPureHeap :: TyConMap -> IdSubstEnv -> Rewrite extra -> Rewrite extra
bindPureHeap TyConMap
tcm IdSubstEnv
heap Rewrite extra
rw ctx0 :: TransformContext
ctx0@(TransformContext InScopeSet
is0 Context
hist) Term
e = do
  (Term
e1, Any -> Bool
Monoid.getAny -> Bool
hasChanged) <- RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad extra Term -> RewriteMonad extra (Term, Any))
-> RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite extra
rw TransformContext
ctx Term
e
  if Bool
hasChanged Bool -> Bool -> Bool
&& Bool -> Bool
not ([LetBinding] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [LetBinding]
bndrs) then do
    -- The evaluator results are post-processed with two operations:
    --
    --   1. Inline work free binders. We've seen cases in the wild† where the
    --      evaluator (or rather, 'bindPureHeap') would let-bind work-free
    --      binders that were crucial for eliminating case constructs. If these
    --      case constructs were used in a self-referential (but terminating)
    --      manner, Clash would get stuck in an infinite loop. The proper
    --      solution would be to use 'isWorkFree', instead of 'isWorkFreeIsh',
    --      in 'bindConstantVar' such that these work free constructs would get
    --      inlined again. However, this incurs a great performance penalty so
    --      we opt to prevent the evaluator from introducing this situation in
    --      the first place.
    --
    --      I'd like to stress that this is not a proper solution though, as GHC
    --      might produce a similar situation. We plan on properly solving this
    --      by eliminating the current lift/bind/eval strategy, instead replacing
    --      it by a partial evaluator‡.
    --
    --   2. Remove any unused let-bindings. Similar to (1), we risk Clash getting
    --      stuck in an infinite loop if we don't remove unused (eliminated by
    --      evaluation!) binders.
    --
    -- † https://github.com/clash-lang/clash-compiler/pull/1354#issuecomment-635430374
    -- ‡ https://www.microsoft.com/en-us/research/wp-content/uploads/2016/07/supercomp-by-eval.pdf
    (Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
forall extra.
(Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
inlineBinders Term -> LetBinding -> RewriteMonad extra Bool
forall p extra. p -> LetBinding -> RewriteMonad extra Bool
inlineTest TransformContext
ctx0 ([LetBinding] -> Term -> Term
Letrec [LetBinding]
bndrs Term
e1) RewriteMonad extra Term
-> (Term -> RewriteMonad extra Term) -> RewriteMonad extra Term
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      e2 :: Term
e2@(Letrec [LetBinding]
bnders1 Term
e3) ->
        Term -> RewriteMonad extra Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> Maybe Term -> Term
forall a. a -> Maybe a -> a
fromMaybe Term
e2 ([LetBinding] -> Term -> Maybe Term
removeUnusedBinders [LetBinding]
bnders1 Term
e3))
      Term
e2 ->
        Term -> RewriteMonad extra Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
e2
  else
    Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e1
  where
    bndrs :: [LetBinding]
bndrs = ((Int, Term) -> LetBinding) -> [(Int, Term)] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Term) -> LetBinding
toLetBinding ([(Int, Term)] -> [LetBinding]) -> [(Int, Term)] -> [LetBinding]
forall a b. (a -> b) -> a -> b
$ IdSubstEnv -> [(Int, Term)]
forall a. UniqMap a -> [(Int, a)]
toListUniqMap IdSubstEnv
heap
    heapIds :: [Id]
heapIds = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bndrs
    is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
heapIds
    ctx :: TransformContext
ctx = InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is1 ([Id] -> CoreContext
LetBody [Id]
heapIds CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
hist)

    toLetBinding :: (Unique,Term) -> LetBinding
    toLetBinding :: (Int, Term) -> LetBinding
toLetBinding (Int
uniq,Term
term) = (Id
nm, Term
term)
      where
        ty :: Type
ty = TyConMap -> Term -> Type
termType TyConMap
tcm Term
term
        nm :: Id
nm = Type -> TmName -> Id
mkLocalId Type
ty (OccName -> Int -> TmName
forall a. OccName -> Int -> Name a
mkUnsafeSystemName OccName
"x" Int
uniq) -- See [Note: Name re-creation]

    inlineTest :: p -> LetBinding -> RewriteMonad extra Bool
inlineTest p
_ (Id
i, Term -> Term
stripTicks -> Term
e_) =
      if Term -> Bool
isLocalVar Term
e_ then
        -- Don't inline `let x = x in x`, it throws  us in an infinite loop
        Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Id
i Id -> Term -> Bool
`localIdDoesNotOccurIn` Term
e_)
      else
        Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFree Term
e_

-- | Remove unused binders in given let-binding. Returns /Nothing/ if no unused
-- binders were found.
removeUnusedBinders
  :: [LetBinding]
  -> Term
  -> Maybe Term
removeUnusedBinders :: [LetBinding] -> Term -> Maybe Term
removeUnusedBinders [LetBinding]
binds Term
body =
  case VarEnv LetBinding -> [LetBinding]
forall a. UniqMap a -> [a]
eltsVarEnv VarEnv LetBinding
used of
    [] -> Term -> Maybe Term
forall a. a -> Maybe a
Just Term
body
    [LetBinding]
qqL | Bool -> Bool
not ([LetBinding] -> [LetBinding] -> Bool
forall a b. [a] -> [b] -> Bool
List.equalLength [LetBinding]
qqL [LetBinding]
binds)
        -> Term -> Maybe Term
forall a. a -> Maybe a
Just ([LetBinding] -> Term -> Term
Letrec [LetBinding]
qqL Term
body)
        | Bool
otherwise
        -> Maybe Term
forall a. Maybe a
Nothing
 where
  bodyFVs :: VarSet
bodyFVs = Getting VarSet Term Id -> (Id -> VarSet) -> Term -> VarSet
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting VarSet Term Id
Fold Term Id
freeLocalIds Id -> VarSet
forall a. Var a -> VarSet
unitVarSet Term
body
  used :: VarEnv LetBinding
used = (VarEnv LetBinding -> Var Any -> VarEnv LetBinding)
-> VarEnv LetBinding -> [Var Any] -> VarEnv LetBinding
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' VarEnv LetBinding -> Var Any -> VarEnv LetBinding
collectUsed VarEnv LetBinding
forall a. VarEnv a
emptyVarEnv (VarSet -> [Var Any]
eltsVarSet VarSet
bodyFVs)
  bindsEnv :: VarEnv LetBinding
bindsEnv = [(Id, LetBinding)] -> VarEnv LetBinding
forall a b. [(Var a, b)] -> VarEnv b
mkVarEnv ((LetBinding -> (Id, LetBinding))
-> [LetBinding] -> [(Id, LetBinding)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Id
x,Term
e0) -> (Id
x,(Id
x,Term
e0))) [LetBinding]
binds)

  collectUsed :: VarEnv LetBinding -> Var Any -> VarEnv LetBinding
collectUsed VarEnv LetBinding
env Var Any
v =
    if Var Any
v Var Any -> VarEnv LetBinding -> Bool
forall a b. Var a -> VarEnv b -> Bool
`elemVarEnv` VarEnv LetBinding
env then
      VarEnv LetBinding
env
    else
      case Var Any -> VarEnv LetBinding -> Maybe LetBinding
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Var Any
v VarEnv LetBinding
bindsEnv of
        Just (Id
x,Term
e0) ->
          let eFVs :: VarSet
eFVs = Getting VarSet Term Id -> (Id -> VarSet) -> Term -> VarSet
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting VarSet Term Id
Fold Term Id
freeLocalIds Id -> VarSet
forall a. Var a -> VarSet
unitVarSet Term
e0
          in  (VarEnv LetBinding -> Var Any -> VarEnv LetBinding)
-> VarEnv LetBinding -> [Var Any] -> VarEnv LetBinding
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' VarEnv LetBinding -> Var Any -> VarEnv LetBinding
collectUsed
                          (Id -> LetBinding -> VarEnv LetBinding -> VarEnv LetBinding
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
x (Id
x,Term
e0) VarEnv LetBinding
env)
                          (VarSet -> [Var Any]
eltsVarSet VarSet
eFVs)
        Maybe LetBinding
Nothing -> VarEnv LetBinding
env