{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016     , Myrtle Software Ltd,
                    2017     , Google Inc.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Utilities for rewriting: e.g. inlining, specialisation, etc.
-}

{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE CPP                      #-}
{-# LANGUAGE LambdaCase               #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE OverloadedStrings        #-}
{-# LANGUAGE Rank2Types               #-}
{-# LANGUAGE FlexibleContexts         #-}
{-# LANGUAGE TemplateHaskell          #-}
{-# LANGUAGE ViewPatterns             #-}

{-# OPTIONS_GHC -Wno-unused-imports #-}

module Clash.Rewrite.Util where

import           Control.DeepSeq
import           Control.Exception           (throw)
import           Control.Lens
  (Lens', (%=), (+=), (^.), _3, _4, _Left)
import qualified Control.Lens                as Lens
import qualified Control.Monad               as Monad
#if !MIN_VERSION_base(4,13,0)
import           Control.Monad.Fail          (MonadFail)
#endif
import qualified Control.Monad.State.Strict  as State
import qualified Control.Monad.Writer        as Writer
import           Data.Bifunctor              (bimap)
import           Data.Coerce                 (coerce)
import           Data.Functor.Const          (Const (..))
import           Data.List                   (group, sort)
import qualified Data.Map                    as Map
import           Data.Maybe                  (catMaybes,isJust,mapMaybe)
import qualified Data.Monoid                 as Monoid
import qualified Data.Set                    as Set
import qualified Data.Set.Lens               as Lens
import           Data.Text                   (Text)
import qualified Data.Text                   as Text

#ifdef HISTORY
import           Data.Binary                 (encode)
import qualified Data.ByteString             as BS
import qualified Data.ByteString.Lazy        as BL
import           System.IO.Unsafe            (unsafePerformIO)
#endif

import           BasicTypes                  (InlineSpec (..))

import           Clash.Core.DataCon          (dcExtTyVars)
import           Clash.Core.FreeVars
  (freeLocalVars, hasLocalFreeVars, localIdDoesNotOccurIn, localIdOccursIn,
   typeFreeVars, termFreeVars', freeIds)
import           Clash.Core.Name
import           Clash.Core.Pretty           (showPpr)
import           Clash.Core.Subst
  (aeqTerm, aeqType, extendIdSubst, mkSubst, substTm)
import           Clash.Core.Term
  (LetBinding, Pat (..), Term (..), CoreContext (..), Context, PrimInfo (..),
   TmName, WorkInfo (..), TickInfo, collectArgs, collectArgsTicks)
import           Clash.Core.TyCon
  (TyConMap, tyConDataCons)
import           Clash.Core.Type             (KindOrType, Type (..),
                                              TypeView (..), coreView1,
                                              normalizeType,
                                              typeKind, tyView, isPolyFunTy)
import           Clash.Core.Util
  (isPolyFun, mkAbstraction, mkApps, mkLams, mkTicks,
   mkTmApps, mkTyApps, mkTyLams, termType, dataConInstArgTysE, isClockOrReset)
import           Clash.Core.Var
  (Id, IdScope (..), TyVar, Var (..), isLocalId, mkGlobalId, mkLocalId, mkTyVar)
import           Clash.Core.VarEnv
  (InScopeSet, VarEnv, elemVarSet, extendInScopeSetList, mkInScopeSet,
   notElemVarEnv, uniqAway, uniqAway')
import           Clash.Driver.Types
  (DebugLevel (..), BindingMap)
import           Clash.Netlist.Util          (representableType)
import           Clash.Rewrite.Types
import           Clash.Unique
import           Clash.Util

-- | Lift an action working in the '_extra' state to the 'RewriteMonad'
zoomExtra :: State.State extra a
          -> RewriteMonad extra a
zoomExtra :: State extra a -> RewriteMonad extra a
zoomExtra m :: State extra a
m = (RewriteEnv
 -> RewriteState extra -> Any -> (a, RewriteState extra, Any))
-> RewriteMonad extra a
forall extra a.
(RewriteEnv
 -> RewriteState extra -> Any -> (a, RewriteState extra, Any))
-> RewriteMonad extra a
R (\_ s :: RewriteState extra
s w :: Any
w -> case State extra a -> extra -> (a, extra)
forall s a. State s a -> s -> (a, s)
State.runState State extra a
m (RewriteState extra
s RewriteState extra
-> Getting extra (RewriteState extra) extra -> extra
forall s a. s -> Getting a s a -> a
^. Getting extra (RewriteState extra) extra
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra) of
                            (a :: a
a,s' :: extra
s') -> (a
a,RewriteState extra
s {_extra :: extra
_extra = extra
s'},Any
w))

-- | Some transformations might erroneously introduce shadowing. For example,
-- a transformation might result in:
--
--   let a = ...
--       b = ...
--       a = ...
--
-- where the last 'a', shadows the first, while Clash assumes that this can't
-- happen. This function finds those constructs and a list of found duplicates.
--
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows =
  \case
    Var {}      -> []
    Data {}     -> []
    Literal {}  -> []
    Prim {}     -> []
    Lam _ t :: Term
t     -> Term -> [[Id]]
findAccidentialShadows Term
t
    TyLam _ t :: Term
t   -> Term -> [[Id]]
findAccidentialShadows Term
t
    App t1 :: Term
t1 t2 :: Term
t2   -> (Term -> [[Id]]) -> [Term] -> [[Id]]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Term -> [[Id]]
findAccidentialShadows [Term
t1, Term
t2]
    TyApp t :: Term
t _   -> Term -> [[Id]]
findAccidentialShadows Term
t
    Cast t :: Term
t _ _  -> Term -> [[Id]]
findAccidentialShadows Term
t
    Tick _ t :: Term
t    -> Term -> [[Id]]
findAccidentialShadows Term
t
    Case t :: Term
t _ as :: [Alt]
as ->
      (Alt -> [[Id]]) -> [Alt] -> [[Id]]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat -> [[Id]]
findInPat (Pat -> [[Id]]) -> (Alt -> Pat) -> Alt -> [[Id]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
as [[Id]] -> [[Id]] -> [[Id]]
forall a. [a] -> [a] -> [a]
++
        (Term -> [[Id]]) -> [Term] -> [[Id]]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Term -> [[Id]]
findAccidentialShadows (Term
t Term -> [Term] -> [Term]
forall a. a -> [a] -> [a]
: (Alt -> Term) -> [Alt] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Alt -> Term
forall a b. (a, b) -> b
snd [Alt]
as)
    Letrec bs :: [LetBinding]
bs t :: Term
t ->
      [Id] -> [[Id]]
findDups ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bs) [[Id]] -> [[Id]] -> [[Id]]
forall a. [a] -> [a] -> [a]
++ Term -> [[Id]]
findAccidentialShadows Term
t

 where
  findInPat :: Pat -> [[Id]]
  findInPat :: Pat -> [[Id]]
findInPat (LitPat _)        = []
  findInPat (Pat
DefaultPat)      = []
  findInPat (DataPat _ _ ids :: [Id]
ids) = [Id] -> [[Id]]
findDups [Id]
ids

  findDups :: [Id] -> [[Id]]
  findDups :: [Id] -> [[Id]]
findDups ids :: [Id]
ids = ([Id] -> Bool) -> [[Id]] -> [[Id]]
forall a. (a -> Bool) -> [a] -> [a]
filter ((1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<) (Int -> Bool) -> ([Id] -> Int) -> [Id] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Id] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) ([Id] -> [[Id]]
forall a. Eq a => [a] -> [[a]]
group ([Id] -> [Id]
forall a. Ord a => [a] -> [a]
sort [Id]
ids))


-- | Record if a transformation is successfully applied
apply
  :: String
  -- ^ Name of the transformation
  -> Rewrite extra
  -- ^ Transformation to be applied
  -> Rewrite extra
apply :: String -> Rewrite extra -> Rewrite extra
apply = \s :: String
s rewrite :: Rewrite extra
rewrite ctx :: TransformContext
ctx expr0 :: Term
expr0 -> do
  DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad extra DebugLevel
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
  (expr1 :: Term
expr1,anyChanged :: Any
anyChanged) <- RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (Rewrite extra
rewrite TransformContext
ctx Term
expr0)
  let hasChanged :: Bool
hasChanged = Any -> Bool
Monoid.getAny Any
anyChanged
      !expr2 :: Term
expr2     = if Bool
hasChanged then Term
expr1 else Term
expr0
  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when Bool
hasChanged ((Int -> Identity Int)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) Int
transformCounter ((Int -> Identity Int)
 -> RewriteState extra -> Identity (RewriteState extra))
-> Int -> RewriteMonad extra ()
forall s (m :: * -> *) a.
(MonadState s m, Num a) =>
ASetter' s a -> a -> m ()
+= 1)
#ifdef HISTORY
  -- NB: When HISTORY is on, emit binary data holding the recorded rewrite steps
  Monad.when hasChanged $ do
    (curBndr, _) <- Lens.use curFun
    let !_ = unsafePerformIO
             $ BS.appendFile "history.dat"
             $ BL.toStrict
             $ encode RewriteStep
                 { t_ctx    = tfContext ctx
                 , t_name   = s
                 , t_bndrS  = showPpr (varName curBndr)
                 , t_before = expr0
                 , t_after  = expr1
                 }
    return ()
#endif
  if DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Eq a => a -> a -> Bool
== DebugLevel
DebugNone
    then Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr2
    else DebugLevel
-> String -> Term -> Bool -> Term -> RewriteMonad extra Term
forall extra.
DebugLevel
-> String -> Term -> Bool -> Term -> RewriteMonad extra Term
applyDebug DebugLevel
lvl String
s Term
expr0 Bool
hasChanged Term
expr2
{-# INLINE apply #-}

applyDebug
  :: DebugLevel
  -- ^ The current debugging level
  -> String
  -- ^ Name of the transformation
  -> Term
  -- ^ Original expression
  -> Bool
  -- ^ Whether the rewrite indicated change
  -> Term
  -- ^ New expression
  -> RewriteMonad extra Term
applyDebug :: DebugLevel
-> String -> Term -> Bool -> Term -> RewriteMonad extra Term
applyDebug lvl :: DebugLevel
lvl name :: String
name exprOld :: Term
exprOld hasChanged :: Bool
hasChanged exprNew :: Term
exprNew =
 Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugAll) ("Trying: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ " on:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before) (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ do
  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
> DebugLevel
DebugNone Bool -> Bool -> Bool
&& Bool
hasChanged) (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$ do
    TyConMap
tcm                  <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let beforeTy :: Type
beforeTy          = TyConMap -> Term -> Type
termType TyConMap
tcm Term
exprOld
        beforeFV :: Set (Var a)
beforeFV          = Getting (Set (Var a)) Term (Var a) -> Term -> Set (Var a)
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set (Var a)) Term (Var a)
forall a. Fold Term (Var a)
freeLocalVars Term
exprOld
        afterTy :: Type
afterTy           = TyConMap -> Term -> Type
termType TyConMap
tcm Term
exprNew
        afterFV :: Set (Var a)
afterFV           = Getting (Set (Var a)) Term (Var a) -> Term -> Set (Var a)
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set (Var a)) Term (Var a)
forall a. Fold Term (Var a)
freeLocalVars Term
exprNew
        newFV :: Bool
newFV             = Bool -> Bool
not (Set (Var Any)
forall a. Set (Var a)
afterFV Set (Var Any) -> Set (Var Any) -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`Set.isSubsetOf` Set (Var Any)
forall a. Set (Var a)
beforeFV)
        accidentalShadows :: [[Id]]
accidentalShadows = Term -> [[Id]]
findAccidentialShadows Term
exprNew

    Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when Bool
newFV (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$
            String -> RewriteMonad extra ()
forall a. HasCallStack => String -> a
error ( [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ $(curLoc)
                           , "Error when applying rewrite ", String
name
                           , " to:\n" , String
before
                           , "\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n"
                           , "It introduces free variables."
                           , "\nBefore: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Var Any] -> String
forall p. PrettyPrec p => p -> String
showPpr (Set (Var Any) -> [Var Any]
forall a. Set a -> [a]
Set.toList Set (Var Any)
forall a. Set (Var a)
beforeFV)
                           , "\nAfter: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Var Any] -> String
forall p. PrettyPrec p => p -> String
showPpr (Set (Var Any) -> [Var Any]
forall a. Set a -> [a]
Set.toList Set (Var Any)
forall a. Set (Var a)
afterFV)
                           ]
                  )
    Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when (Bool -> Bool
not ([[Id]] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[Id]]
accidentalShadows)) (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$
      String -> RewriteMonad extra ()
forall a. HasCallStack => String -> a
error ( [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ $(curLoc)
                     , "Error when applying rewrite ", String
name
                     , " to:\n" , String
before
                     , "\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n"
                     , "It accidentally creates shadowing let/case-bindings:\n"
                     , " ", [[Id]] -> String
forall p. PrettyPrec p => p -> String
showPpr [[Id]]
accidentalShadows, "\n"
                     , "This usually means that a transformation did not extend "
                     , "or incorrectly extended its InScopeSet before applying a "
                     , "substitution."
                     ])

    Bool -> String -> RewriteMonad extra () -> RewriteMonad extra ()
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugAll Bool -> Bool -> Bool
&& (Type
beforeTy Type -> Type -> Bool
`aeqType` Type
afterTy))
            ( [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ $(curLoc)
                     , "Error when applying rewrite ", String
name
                     , " to:\n" , String
before
                     , "\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n"
                     , "Changes type from:\n", Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
beforeTy
                     , "\nto:\n", Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
afterTy
                     ]
            ) (() -> RewriteMonad extra ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasChanged Bool -> Bool -> Bool
&& Bool -> Bool
not (Term
exprOld Term -> Term -> Bool
`aeqTerm` Term
exprNew)) (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$
    String -> RewriteMonad extra ()
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra ())
-> String -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Expression changed without notice(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++  "): before"
                      String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\nafter:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after

  Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugName Bool -> Bool -> Bool
&& Bool
hasChanged) String
name (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
    Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied Bool -> Bool -> Bool
&& Bool
hasChanged) ("Changes when applying rewrite to:\n"
                      String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n") (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
      Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugAll Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasChanged) ("No changes when applying rewrite "
                        String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ " to:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n") (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
        Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
exprNew
 where
  before :: String
before = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
exprOld
  after :: String
after  = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
exprNew

-- | Perform a transformation on a Term
runRewrite
  :: String
  -- ^ Name of the transformation
  -> InScopeSet
  -> Rewrite extra
  -- ^ Transformation to perform
  -> Term
  -- ^ Term to transform
  -> RewriteMonad extra Term
runRewrite :: String
-> InScopeSet -> Rewrite extra -> Term -> RewriteMonad extra Term
runRewrite name :: String
name is :: InScopeSet
is rewrite :: Rewrite extra
rewrite expr :: Term
expr = String -> Rewrite extra -> Rewrite extra
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
name Rewrite extra
rewrite (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is []) Term
expr

-- | Evaluate a RewriteSession to its inner monad.
runRewriteSession :: RewriteEnv
                  -> RewriteState extra
                  -> RewriteMonad extra a
                  -> a
runRewriteSession :: RewriteEnv -> RewriteState extra -> RewriteMonad extra a -> a
runRewriteSession r :: RewriteEnv
r s :: RewriteState extra
s m :: RewriteMonad extra a
m = Bool -> String -> a -> a
forall a. Bool -> String -> a -> a
traceIf Bool
True ("Clash: Applied " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                                        Int -> String
forall a. Show a => a -> String
show (RewriteState extra
s' RewriteState extra -> Getting Int (RewriteState extra) Int -> Int
forall s a. s -> Getting a s a -> a
^. Getting Int (RewriteState extra) Int
forall extra1. Lens' (RewriteState extra1) Int
transformCounter) String -> String -> String
forall a. [a] -> [a] -> [a]
++
                                        " transformations")
                                  a
a
  where
    (a :: a
a,s' :: RewriteState extra
s',_) = RewriteMonad extra a
-> RewriteEnv -> RewriteState extra -> (a, RewriteState extra, Any)
forall extra a.
RewriteMonad extra a
-> RewriteEnv -> RewriteState extra -> (a, RewriteState extra, Any)
runR RewriteMonad extra a
m RewriteEnv
r RewriteState extra
s

-- | Notify that a transformation has changed the expression
setChanged :: RewriteMonad extra ()
setChanged :: RewriteMonad extra ()
setChanged = Any -> RewriteMonad extra ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
Writer.tell (Bool -> Any
Monoid.Any Bool
True)

-- | Identity function that additionally notifies that a transformation has
-- changed the expression
changed :: a -> RewriteMonad extra a
changed :: a -> RewriteMonad extra a
changed val :: a
val = do
  Any -> RewriteMonad extra ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
Writer.tell (Bool -> Any
Monoid.Any Bool
True)
  a -> RewriteMonad extra a
forall (m :: * -> *) a. Monad m => a -> m a
return a
val

closestLetBinder :: Context -> Maybe Id
closestLetBinder :: Context -> Maybe Id
closestLetBinder [] = Maybe Id
forall a. Maybe a
Nothing
closestLetBinder (LetBinding id_ :: Id
id_ _:_) = Id -> Maybe Id
forall a. a -> Maybe a
Just Id
id_
closestLetBinder (_:ctx :: Context
ctx)              = Context -> Maybe Id
closestLetBinder Context
ctx

mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName (TransformContext _ ctx :: Context
ctx) sf :: OccName
sf = case Context -> Maybe Id
closestLetBinder Context
ctx of
  Just id_ :: Id
id_ -> TmName -> OccName -> TmName
forall a. Name a -> OccName -> Name a
appendToName (Id -> TmName
forall a. Var a -> Name a
varName Id
id_) ('_' Char -> OccName -> OccName
`Text.cons` OccName
sf)
  _ -> OccName -> Int -> TmName
forall a. OccName -> Int -> Name a
mkUnsafeInternalName OccName
sf 0

-- | Make a new binder and variable reference for a term
mkTmBinderFor
  :: (Monad m, MonadUnique m, MonadFail m)
  => InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Name a -- ^ Name of the new binder
  -> Term -- ^ Term to bind
  -> m Id
mkTmBinderFor :: InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor is :: InScopeSet
is tcm :: TyConMap
tcm name :: Name a
name e :: Term
e = do
  Left r :: Id
r <- InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor InScopeSet
is TyConMap
tcm Name a
name (Term -> Either Term Type
forall a b. a -> Either a b
Left Term
e)
  Id -> m Id
forall (m :: * -> *) a. Monad m => a -> m a
return Id
r

-- | Make a new binder and variable reference for either a term or a type
mkBinderFor
  :: (Monad m, MonadUnique m, MonadFail m)
  => InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Name a -- ^ Name of the new binder
  -> Either Term Type -- ^ Type or Term to bind
  -> m (Either Id TyVar)
mkBinderFor :: InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor is :: InScopeSet
is tcm :: TyConMap
tcm name :: Name a
name (Left term :: Term
term) = do
  Name a
name' <- InScopeSet -> Name a -> m (Name a)
forall (m :: * -> *) a.
(Monad m, MonadUnique m) =>
InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
name
  let ty :: Type
ty = TyConMap -> Term -> Type
termType TyConMap
tcm Term
term
  Either Id TyVar -> m (Either Id TyVar)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Either Id TyVar
forall a b. a -> Either a b
Left (Type -> TmName -> Id
mkLocalId Type
ty (Name a -> TmName
forall a b. Coercible a b => a -> b
coerce Name a
name')))

mkBinderFor is :: InScopeSet
is tcm :: TyConMap
tcm name :: Name a
name (Right ty :: Type
ty) = do
  Name a
name' <- InScopeSet -> Name a -> m (Name a)
forall (m :: * -> *) a.
(Monad m, MonadUnique m) =>
InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
name
  let ki :: Type
ki = TyConMap -> Type -> Type
typeKind TyConMap
tcm Type
ty
  Either Id TyVar -> m (Either Id TyVar)
forall (m :: * -> *) a. Monad m => a -> m a
return (TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right (Type -> TyName -> TyVar
mkTyVar Type
ki (Name a -> TyName
forall a b. Coercible a b => a -> b
coerce Name a
name')))

-- | Make a new, unique, identifier
mkInternalVar
  :: (Monad m, MonadUnique m)
  => InScopeSet
  -> OccName
  -- ^ Name of the identifier
  -> KindOrType
  -> m Id
mkInternalVar :: InScopeSet -> OccName -> Type -> m Id
mkInternalVar inScope :: InScopeSet
inScope name :: OccName
name ty :: Type
ty = do
  Int
i <- m Int
forall (m :: * -> *). MonadUnique m => m Int
getUniqueM
  let nm :: Name a
nm = OccName -> Int -> Name a
forall a. OccName -> Int -> Name a
mkUnsafeInternalName OccName
name Int
i
  Id -> m Id
forall (m :: * -> *) a. Monad m => a -> m a
return (InScopeSet -> Id -> Id
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
inScope (Type -> TmName -> Id
mkLocalId Type
ty TmName
forall a. Name a
nm))

-- | Inline the binders in a let-binding that have a certain property
inlineBinders
  :: (Term -> LetBinding -> RewriteMonad extra Bool)
  -- ^ Property test
  -> Rewrite extra
inlineBinders :: (Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
inlineBinders condition :: Term -> LetBinding -> RewriteMonad extra Bool
condition (TransformContext inScope0 :: InScopeSet
inScope0 _) expr :: Term
expr@(Letrec xes :: [LetBinding]
xes res :: Term
res) = do
  (replace :: [LetBinding]
replace,others :: [LetBinding]
others) <- (LetBinding -> RewriteMonad extra Bool)
-> [LetBinding] -> RewriteMonad extra ([LetBinding], [LetBinding])
forall (m :: * -> *) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM (Term -> LetBinding -> RewriteMonad extra Bool
condition Term
expr) [LetBinding]
xes
  case [LetBinding]
replace of
    [] -> Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr
    _  -> do
      let inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes)
          (others' :: [LetBinding]
others',res' :: Term
res') = InScopeSet
-> [LetBinding] -> [LetBinding] -> Term -> ([LetBinding], Term)
substituteBinders InScopeSet
inScope1 [LetBinding]
replace [LetBinding]
others Term
res
          newExpr :: Term
newExpr = case [LetBinding]
others' of
                          [] -> Term
res'
                          _  -> [LetBinding] -> Term -> Term
Letrec [LetBinding]
others' Term
res'

      Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
newExpr

inlineBinders _ _ e :: Term
e = Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Determine whether a binder is a join-point created for a complex case
-- expression.
--
-- A join-point is when a local function only occurs in tail-call positions,
-- and when it does, more than once.
isJoinPointIn :: Id   -- ^ 'Id' of the local binder
              -> Term -- ^ Expression in which the binder is bound
              -> Bool
isJoinPointIn :: Id -> Term -> Bool
isJoinPointIn id_ :: Id
id_ e :: Term
e = case Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e of
                      Just n :: Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 1 -> Bool
True
                      _              -> Bool
False

-- | Count the number of (only) tail calls of a function in an expression.
-- 'Nothing' indicates that the function was used in a non-tail call position.
tailCalls :: Id   -- ^ Function to check
          -> Term -- ^ Expression to check it in
          -> Maybe Int
tailCalls :: Id -> Term -> Maybe Int
tailCalls id_ :: Id
id_ = \case
  Var nm :: Id
nm | Id
id_ Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
nm -> Int -> Maybe Int
forall a. a -> Maybe a
Just 1
         | Bool
otherwise -> Int -> Maybe Int
forall a. a -> Maybe a
Just 0
  Lam _ e :: Term
e -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  TyLam _ e :: Term
e -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  App l :: Term
l r :: Term
r  -> case Id -> Term -> Maybe Int
tailCalls Id
id_ Term
r of
                Just 0 -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
l
                _      -> Maybe Int
forall a. Maybe a
Nothing
  TyApp l :: Term
l _ -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
l
  Letrec bs :: [LetBinding]
bs e :: Term
e ->
    let (bsIds :: [Id]
bsIds,bsExprs :: [Term]
bsExprs) = [LetBinding] -> ([Id], [Term])
forall a b. [(a, b)] -> ([a], [b])
unzip [LetBinding]
bs
        bsTls :: [Maybe Int]
bsTls           = (Term -> Maybe Int) -> [Term] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
tailCalls Id
id_) [Term]
bsExprs
        bsIdsUsed :: [Id]
bsIdsUsed       = ((Id, Maybe Int) -> Maybe Id) -> [(Id, Maybe Int)] -> [Id]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(l :: Id
l,r :: Maybe Int
r) -> Id -> Maybe Id
forall (f :: * -> *) a. Applicative f => a -> f a
pure Id
l Maybe Id -> Maybe Int -> Maybe Id
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Maybe Int
r) ([Id] -> [Maybe Int] -> [(Id, Maybe Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bsIds [Maybe Int]
bsTls)
        bsIdsTls :: [Maybe Int]
bsIdsTls        = (Id -> Maybe Int) -> [Id] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
`tailCalls` Term
e) [Id]
bsIdsUsed
        bsCount :: Maybe Int
bsCount         = Int -> Maybe Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Maybe Int) -> ([Int] -> Int) -> [Int] -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Maybe Int) -> [Int] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
bsTls
    in  case ((Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust [Maybe Int]
bsTls) of
          False -> Maybe Int
forall a. Maybe a
Nothing
          True  -> case ((Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==0) ([Int] -> Bool) -> [Int] -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
bsTls) of
            False  -> case (Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust [Maybe Int]
bsIdsTls of
              False -> Maybe Int
forall a. Maybe a
Nothing
              True  -> Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> Int -> Int) -> Maybe Int -> Maybe (Int -> Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Int
bsCount Maybe (Int -> Int) -> Maybe Int -> Maybe Int
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
            True -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  Case scrut :: Term
scrut _ alts :: [Alt]
alts ->
    let scrutTl :: Maybe Int
scrutTl = Id -> Term -> Maybe Int
tailCalls Id
id_ Term
scrut
        altsTl :: [Maybe Int]
altsTl  = (Alt -> Maybe Int) -> [Alt] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
tailCalls Id
id_ (Term -> Maybe Int) -> (Alt -> Term) -> Alt -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Term
forall a b. (a, b) -> b
snd) [Alt]
alts
    in  case Maybe Int
scrutTl of
          Just 0 | (Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Int
forall a. Maybe a
Nothing) [Maybe Int]
altsTl -> Int -> Maybe Int
forall a. a -> Maybe a
Just ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
altsTl))
          _ -> Maybe Int
forall a. Maybe a
Nothing
  _ -> Int -> Maybe Int
forall a. a -> Maybe a
Just 0

-- | Determines whether a function has the following shape:
--
-- > \(w :: Void) -> f a b c
--
-- i.e. is a wrapper around a (partially) applied function 'f', where the
-- introduced argument 'w' is not used by 'f'
isVoidWrapper :: Term -> Bool
isVoidWrapper :: Term -> Bool
isVoidWrapper (Lam bndr :: Id
bndr e :: Term
e@(Term -> (Term, [Either Term Type])
collectArgs -> (Var _,_))) =
  Id
bndr Id -> Term -> Bool
`localIdDoesNotOccurIn` Term
e
isVoidWrapper _ = Bool
False

-- | Substitute the RHS of the first set of Let-binders for references to the
-- first set of Let-binders in: the second set of Let-binders and the additional
-- term
substituteBinders
  :: InScopeSet
  -> [LetBinding]
  -- ^ Let-binders to substitute
  -> [LetBinding]
  -- ^ Let-binders where substitution takes place
  -> Term
  -- ^ Expression where substitution takes place
  -> ([LetBinding],Term)
substituteBinders :: InScopeSet
-> [LetBinding] -> [LetBinding] -> Term -> ([LetBinding], Term)
substituteBinders _ []    others :: [LetBinding]
others res :: Term
res = ([LetBinding]
others,Term
res)
substituteBinders inScope :: InScopeSet
inScope ((bndr :: Id
bndr,val :: Term
val):rest :: [LetBinding]
rest) others :: [LetBinding]
others res :: Term
res =
  InScopeSet
-> [LetBinding] -> [LetBinding] -> Term -> ([LetBinding], Term)
substituteBinders InScopeSet
inScope [LetBinding]
rest' [LetBinding]
others' Term
res'
 where
  subst :: Subst
subst    = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope) Id
bndr Term
val
  selfRef :: Bool
selfRef  = Id
bndr Id -> Term -> Bool
`localIdOccursIn` Term
val
  (res' :: Term
res',rest' :: [LetBinding]
rest',others' :: [LetBinding]
others') = if Bool
selfRef
    then (Term
res,[LetBinding]
rest,(Id
bndr,Term
val)LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
:[LetBinding]
others)
    else ( HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "substituteBindersRes" Subst
subst Term
res
         , (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "substituteBindersRest" Subst
subst)) [LetBinding]
rest
         , (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "substituteBindersOthers" Subst
subst)) [LetBinding]
others
         )

-- | Determine whether a term does any work, i.e. adds to the size of the circuit
isWorkFree
  :: Term
  -> Bool
isWorkFree :: Term -> Bool
isWorkFree (Term -> (Term, [Either Term Type])
collectArgs -> (fun :: Term
fun,args :: [Either Term Type]
args)) = case Term
fun of
  Var i :: Id
i            -> Id -> Bool
forall a. Var a -> Bool
isLocalId Id
i Bool -> Bool -> Bool
&& Bool -> Bool
not (Type -> Bool
isPolyFunTy (Id -> Type
forall a. Var a -> Type
varType Id
i))
  Data {}          -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeArg [Either Term Type]
args
  Literal {}       -> Bool
True
  Prim _ pInfo :: PrimInfo
pInfo -> case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pInfo of
    WorkConstant   -> Bool
True -- We can ignore the arguments, because this
                           -- primitive outputs a constant regardless of its
                           -- arguments
    WorkNever      -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeArg [Either Term Type]
args
    WorkVariable   -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isConstantArg [Either Term Type]
args
    WorkAlways     -> Bool
False -- Things like clock or reset generator always
                            -- perform work
  Lam _ e :: Term
e          -> Term -> Bool
isWorkFree Term
e Bool -> Bool -> Bool
&& (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeArg [Either Term Type]
args
  TyLam _ e :: Term
e        -> Term -> Bool
isWorkFree Term
e Bool -> Bool -> Bool
&& (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeArg [Either Term Type]
args
  Letrec bs :: [LetBinding]
bs e :: Term
e ->
    Term -> Bool
isWorkFree Term
e Bool -> Bool -> Bool
&& (LetBinding -> Bool) -> [LetBinding] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Term -> Bool
isWorkFree (Term -> Bool) -> (LetBinding -> Term) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
bs Bool -> Bool -> Bool
&& (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeArg [Either Term Type]
args
  Case s :: Term
s _ [(_,a :: Term
a)] -> Term -> Bool
isWorkFree Term
s Bool -> Bool -> Bool
&& Term -> Bool
isWorkFree Term
a Bool -> Bool -> Bool
&& (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeArg [Either Term Type]
args
  Cast e :: Term
e _ _       -> Term -> Bool
isWorkFree Term
e Bool -> Bool -> Bool
&& (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeArg [Either Term Type]
args
  _                -> Bool
False
 where
  isWorkFreeArg :: Either Term b -> Bool
isWorkFreeArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isWorkFree (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)
  isConstantArg :: Either Term b -> Bool
isConstantArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)

isFromInt :: Text -> Bool
isFromInt :: OccName -> Bool
isFromInt nm :: OccName
nm = OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.BitVector.fromInteger##" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.BitVector.fromInteger#" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.Index.fromInteger#" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.Signed.fromInteger#" Bool -> Bool -> Bool
||
               OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.Unsigned.fromInteger#"

-- | Determine if a term represents a constant
isConstant :: Term -> Bool
isConstant :: Term -> Bool
isConstant e :: Term
e = case Term -> (Term, [Either Term Type])
collectArgs Term
e of
  (Data _, args :: [Either Term Type]
args)   -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
  (Prim _ _, args :: [Either Term Type]
args) -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
  (Lam _ _, _)     -> Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e)
  (Literal _,_)    -> Bool
True
  _                -> Bool
False

isConstantNotClockReset
  :: Term
  -> RewriteMonad extra Bool
isConstantNotClockReset :: Term -> RewriteMonad extra Bool
isConstantNotClockReset e :: Term
e = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e
  if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy
     then case Term -> (Term, [Either Term Type])
collectArgs Term
e of
        (Prim nm :: OccName
nm _,_) -> Bool -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Transformations.removedArg")
        _ -> Bool -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
     else Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term -> Bool
isConstant Term
e)

-- TODO: Remove function after using WorkInfo in 'isWorkFreeIsh'
isWorkFreeClockOrReset
  :: TyConMap
  -> Term
  -> Maybe Bool
isWorkFreeClockOrReset :: TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrReset tcm :: TyConMap
tcm e :: Term
e =
  let eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e in
  if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy then
    case Term -> (Term, [Either Term Type])
collectArgs Term
e of
      (Prim nm :: OccName
nm _,_) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just (OccName
nm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Transformations.removedArg")
      (Var _, []) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
      _ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
  else
    Maybe Bool
forall a. Maybe a
Nothing

-- | A conservative version of 'isWorkFree'. Is used to determine in 'bindConstantVar'
-- to determine whether an expression can be "bound" (locally inlined). While
-- binding workfree expressions won't result in extra work for the circuit, it
-- might very well cause extra work for Clash. In fact, using 'isWorkFree' in
-- 'bindConstantVar' makes Clash two orders of magnitude slower for some of our
-- test cases.
--
-- In effect, this function is a version of 'isConstant' that also considers
-- references to clocks and resets constant. This allows us to bind
-- HiddenClock(ResetEnable) constructs, allowing Clash to constant spec
-- subconstants - most notably KnownDomain. Doing that enables Clash to
-- eliminate any case-constructs on it.
isWorkFreeIsh
  :: Term
  -> RewriteMonad extra Bool
isWorkFreeIsh :: Term -> RewriteMonad extra Bool
isWorkFreeIsh e :: Term
e = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  case TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrReset TyConMap
tcm Term
e of
    Just b :: Bool
b -> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
b
    Nothing ->
      case Term -> (Term, [Either Term Type])
collectArgs Term
e of
        (Data _, args :: [Either Term Type]
args)   -> (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => (a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeIshArg [Either Term Type]
args
        (Prim _ pInfo :: PrimInfo
pInfo, args :: [Either Term Type]
args) -> case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pInfo of
          WorkAlways     -> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False -- Things like clock or reset generator always
                                       -- perform work
          _              -> (Either Term Type -> RewriteMonad extra Bool)
-> [Either Term Type] -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => (a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> RewriteMonad extra Bool
forall b extra. Either Term b -> RewriteMonad extra Bool
isWorkFreeIshArg [Either Term Type]
args
        (Lam _ _, _)     -> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e))
        (Literal _,_)    -> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
        _                -> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
 where
  isWorkFreeIshArg :: Either Term b -> RewriteMonad extra Bool
isWorkFreeIshArg = (Term -> RewriteMonad extra Bool)
-> (b -> RewriteMonad extra Bool)
-> Either Term b
-> RewriteMonad extra Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> RewriteMonad extra Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFreeIsh (Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> RewriteMonad extra Bool)
-> (b -> Bool) -> b -> RewriteMonad extra Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)

inlineOrLiftBinders
  :: (LetBinding -> RewriteMonad extra Bool)
  -- ^ Property test
  -> (Term -> LetBinding -> RewriteMonad extra Bool)
  -- ^ Test whether to lift or inline
  --
  -- * True: inline
  -- * False: lift
  -> Rewrite extra
inlineOrLiftBinders :: (LetBinding -> RewriteMonad extra Bool)
-> (Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
inlineOrLiftBinders condition :: LetBinding -> RewriteMonad extra Bool
condition inlineOrLift :: Term -> LetBinding -> RewriteMonad extra Bool
inlineOrLift (TransformContext inScope0 :: InScopeSet
inScope0 _) expr :: Term
expr@(Letrec xes :: [LetBinding]
xes res :: Term
res) = do
  (replace :: [LetBinding]
replace,others :: [LetBinding]
others) <- (LetBinding -> RewriteMonad extra Bool)
-> [LetBinding] -> RewriteMonad extra ([LetBinding], [LetBinding])
forall (m :: * -> *) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM LetBinding -> RewriteMonad extra Bool
condition [LetBinding]
xes
  case [LetBinding]
replace of
    [] -> Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr
    _  -> do
      let inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes)
      (doInline :: [LetBinding]
doInline,doLift :: [LetBinding]
doLift) <- (LetBinding -> RewriteMonad extra Bool)
-> [LetBinding] -> RewriteMonad extra ([LetBinding], [LetBinding])
forall (m :: * -> *) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM (Term -> LetBinding -> RewriteMonad extra Bool
inlineOrLift Term
expr) [LetBinding]
replace
      -- We first substitute the binders that we can inline both the binders
      -- that we intend to lift, the other binders, and the body
      let (others' :: [LetBinding]
others',res' :: Term
res')     = InScopeSet
-> [LetBinding] -> [LetBinding] -> Term -> ([LetBinding], Term)
substituteBinders InScopeSet
inScope1 [LetBinding]
doInline ([LetBinding]
doLift [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [LetBinding]
others) Term
res
          (doLift' :: [LetBinding]
doLift',others'' :: [LetBinding]
others'') = Int -> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. Int -> [a] -> ([a], [a])
splitAt ([LetBinding] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetBinding]
doLift) [LetBinding]
others'
      [LetBinding]
doLift'' <- (LetBinding -> RewriteMonad extra LetBinding)
-> [LetBinding] -> RewriteMonad extra [LetBinding]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LetBinding -> RewriteMonad extra LetBinding
forall extra. LetBinding -> RewriteMonad extra LetBinding
liftBinding [LetBinding]
doLift'
      -- We then substitute the lifted binders in the other binders and the body
      let (others3 :: [LetBinding]
others3,res'' :: Term
res'') = InScopeSet
-> [LetBinding] -> [LetBinding] -> Term -> ([LetBinding], Term)
substituteBinders InScopeSet
inScope1 [LetBinding]
doLift'' [LetBinding]
others'' Term
res'
          newExpr :: Term
newExpr = case [LetBinding]
others3 of
                      [] -> Term
res''
                      _  -> [LetBinding] -> Term -> Term
Letrec [LetBinding]
others3 Term
res''
      Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
newExpr

inlineOrLiftBinders _ _ _ e :: Term
e = Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Create a global function for a Let-binding and return a Let-binding where
-- the RHS is a reference to the new global function applied to the free
-- variables of the original RHS
liftBinding :: LetBinding
            -> RewriteMonad extra LetBinding
liftBinding :: LetBinding -> RewriteMonad extra LetBinding
liftBinding (var :: Id
var@Id {varName :: forall a. Var a -> Name a
varName = TmName
idName} ,e :: Term
e) = do
  -- Get all local FVs, excluding the 'idName' from the let-binding
  let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
      unitFV :: Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV v :: Var a
v@(Id {})    = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (UniqSet TyVar
forall a. UniqSet a
emptyUniqSet,Id -> UniqSet Id
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> Id
forall a b. Coercible a b => a -> b
coerce Var a
v))
      unitFV v :: Var a
v@(TyVar {}) = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (TyVar -> UniqSet TyVar
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> TyVar
forall a b. Coercible a b => a -> b
coerce Var a
v),UniqSet Id
forall a. UniqSet a
emptyUniqSet)

      interesting :: Var a -> Bool
      interesting :: Var a -> Bool
interesting Id {idScope :: forall a. Var a -> IdScope
idScope = IdScope
GlobalId} = Bool
False
      interesting v :: Var a
v@(Id {idScope :: forall a. Var a -> IdScope
idScope = IdScope
LocalId}) = Var a -> Int
forall a. Var a -> Int
varUniq Var a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Id -> Int
forall a. Var a -> Int
varUniq Id
var
      interesting _ = Bool
True

      (boundFTVsSet :: UniqSet TyVar
boundFTVsSet,boundFVsSet :: UniqSet Id
boundFVsSet) =
        Const (UniqSet TyVar, UniqSet Id) (Var Any)
-> (UniqSet TyVar, UniqSet Id)
forall a k (b :: k). Const a b -> a
getConst (Getting
  (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
-> (Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any))
-> Term
-> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf ((forall a. Var a -> Bool)
-> Getting
     (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
forall (f :: * -> *) a.
(Contravariant f, Applicative f) =>
(forall a. Var a -> Bool) -> (Var a -> f (Var a)) -> Term -> f Term
termFreeVars' forall a. Var a -> Bool
interesting) Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall a. Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV Term
e)
      boundFTVs :: [TyVar]
boundFTVs = UniqSet TyVar -> [TyVar]
forall a. UniqSet a -> [a]
eltsUniqSet UniqSet TyVar
boundFTVsSet
      boundFVs :: [Id]
boundFVs  = UniqSet Id -> [Id]
forall a. UniqSet a -> [a]
eltsUniqSet UniqSet Id
boundFVsSet

  -- Make a new global ID
  TyConMap
tcm       <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let newBodyTy :: Type
newBodyTy = TyConMap -> Term -> Type
termType TyConMap
tcm (Term -> Type) -> Term -> Type
forall a b. (a -> b) -> a -> b
$ Term -> [TyVar] -> Term
mkTyLams (Term -> [Id] -> Term
mkLams Term
e [Id]
boundFVs) [TyVar]
boundFTVs
  (cf :: Id
cf,sp :: SrcSpan
sp)   <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
  BindingMap
binders <- Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings
  TmName
newBodyNm <-
    BindingMap -> TmName -> RewriteMonad extra TmName
forall (m :: * -> *) a.
(Monad m, MonadUnique m) =>
BindingMap -> Name a -> m (Name a)
cloneNameWithBindingMap
      BindingMap
binders
      (TmName -> OccName -> TmName
forall a. Name a -> OccName -> Name a
appendToName (Id -> TmName
forall a. Var a -> Name a
varName Id
cf) ("_" OccName -> OccName -> OccName
`Text.append` TmName -> OccName
forall a. Name a -> OccName
nameOcc TmName
idName))
  let newBodyId :: Id
newBodyId = Type -> TmName -> Id
mkGlobalId Type
newBodyTy TmName
newBodyNm {nameSort :: NameSort
nameSort = NameSort
Internal}

  -- Make a new expression, consisting of the the lifted function applied to
  -- its free variables
  let newExpr :: Term
newExpr = Term -> [Term] -> Term
mkTmApps
                  (Term -> [Type] -> Term
mkTyApps (Id -> Term
Var Id
newBodyId)
                            ((TyVar -> Type) -> [TyVar] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Type
VarTy [TyVar]
boundFTVs))
                  ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
boundFVs)
      inScope0 :: InScopeSet
inScope0 = VarSet -> InScopeSet
mkInScopeSet (UniqSet Id -> VarSet
forall a b. Coercible a b => a -> b
coerce UniqSet Id
boundFVsSet)
      inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 [Id
var,Id
newBodyId]
  let subst :: Subst
subst    = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope1) Id
var Term
newExpr
      -- Substitute the recursive calls by the new expression
      e' :: Term
e' = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "liftBinding" Subst
subst Term
e
      -- Create a new body that abstracts over the free variables
      newBody :: Term
newBody = Term -> [TyVar] -> Term
mkTyLams (Term -> [Id] -> Term
mkLams Term
e' [Id]
boundFVs) [TyVar]
boundFTVs

  -- Check if an alpha-equivalent global binder already exists
  [(Id, SrcSpan, InlineSpec, Term)]
aeqExisting <- (BindingMap -> [(Id, SrcSpan, InlineSpec, Term)]
forall a. UniqMap a -> [a]
eltsUniqMap (BindingMap -> [(Id, SrcSpan, InlineSpec, Term)])
-> (BindingMap -> BindingMap)
-> BindingMap
-> [(Id, SrcSpan, InlineSpec, Term)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, SrcSpan, InlineSpec, Term) -> Bool)
-> BindingMap -> BindingMap
forall b. (b -> Bool) -> UniqMap b -> UniqMap b
filterUniqMap ((Term -> Term -> Bool
`aeqTerm` Term
newBody) (Term -> Bool)
-> ((Id, SrcSpan, InlineSpec, Term) -> Term)
-> (Id, SrcSpan, InlineSpec, Term)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, SrcSpan, InlineSpec, Term)
-> Getting Term (Id, SrcSpan, InlineSpec, Term) Term -> Term
forall s a. s -> Getting a s a -> a
^. Getting Term (Id, SrcSpan, InlineSpec, Term) Term
forall s t a b. Field4 s t a b => Lens s t a b
_4))) (BindingMap -> [(Id, SrcSpan, InlineSpec, Term)])
-> RewriteMonad extra BindingMap
-> RewriteMonad extra [(Id, SrcSpan, InlineSpec, Term)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings
  case [(Id, SrcSpan, InlineSpec, Term)]
aeqExisting of
    -- If it doesn't, create a new binder
    [] -> do -- Add the created function to the list of global bindings
             (BindingMap -> Identity BindingMap)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings ((BindingMap -> Identity BindingMap)
 -> RewriteState extra -> Identity (RewriteState extra))
-> (BindingMap -> BindingMap) -> RewriteMonad extra ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= TmName
-> (Id, SrcSpan, InlineSpec, Term) -> BindingMap -> BindingMap
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap TmName
newBodyNm
                                    -- We mark this function as internal so that
                                    -- it can be inlined at the very end of
                                    -- the normalisation pipeline as part of the
                                    -- flattening pass. We don't inline
                                    -- right away because we are lifting this
                                    -- function at this moment for a reason!
                                    -- (termination, CSE and DEC oppertunities,
                                    -- ,etc.)
                                    (Id
newBodyId
                                    ,SrcSpan
sp
#if MIN_VERSION_ghc(8,4,1)
                                    ,InlineSpec
NoUserInline
#else
                                    ,EmptyInlineSpec
#endif
                                    ,Term
newBody)
             -- Return the new binder
             LetBinding -> RewriteMonad extra LetBinding
forall (m :: * -> *) a. Monad m => a -> m a
return (Id
var, Term
newExpr)
    -- If it does, use the existing binder
    ((k :: Id
k,_,_,_):_) ->
      let newExpr' :: Term
newExpr' = Term -> [Term] -> Term
mkTmApps
                      (Term -> [Type] -> Term
mkTyApps (Id -> Term
Var Id
k)
                                ((TyVar -> Type) -> [TyVar] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Type
VarTy [TyVar]
boundFTVs))
                      ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
boundFVs)
      in  LetBinding -> RewriteMonad extra LetBinding
forall (m :: * -> *) a. Monad m => a -> m a
return (Id
var, Term
newExpr')

liftBinding _ = String -> RewriteMonad extra LetBinding
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra LetBinding)
-> String -> RewriteMonad extra LetBinding
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "liftBinding: invalid core, expr bound to tyvar"

-- | Ensure that the 'Unique' of a variable does not occur in the 'BindingMap'
uniqAwayBinder
  :: BindingMap
  -> Name a
  -> Name a
uniqAwayBinder :: BindingMap -> Name a -> Name a
uniqAwayBinder binders :: BindingMap
binders nm :: Name a
nm =
  (Int -> Bool) -> Int -> Name a -> Name a
forall a.
(Uniquable a, ClashPretty a) =>
(Int -> Bool) -> Int -> a -> a
uniqAway' (Int -> BindingMap -> Bool
forall b. Int -> UniqMap b -> Bool
`elemUniqMapDirectly` BindingMap
binders) (Name a -> Int
forall a. Name a -> Int
nameUniq Name a
nm) Name a
nm

-- | Make a global function for a name-term tuple
mkFunction
  :: TmName
  -- ^ Name of the function
  -> SrcSpan
  -> InlineSpec
  -> Term
  -- ^ Term bound to the function
  -> RewriteMonad extra Id
  -- ^ Name with a proper unique and the type of the function
mkFunction :: TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
mkFunction bndrNm :: TmName
bndrNm sp :: SrcSpan
sp inl :: InlineSpec
inl body :: Term
body = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let bodyTy :: Type
bodyTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
body
  BindingMap
binders <- Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings
  TmName
bodyNm <- BindingMap -> TmName -> RewriteMonad extra TmName
forall (m :: * -> *) a.
(Monad m, MonadUnique m) =>
BindingMap -> Name a -> m (Name a)
cloneNameWithBindingMap BindingMap
binders TmName
bndrNm
  TmName
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
forall extra.
TmName
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
addGlobalBind TmName
bodyNm Type
bodyTy SrcSpan
sp InlineSpec
inl Term
body
  Id -> RewriteMonad extra Id
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TmName -> Id
mkGlobalId Type
bodyTy TmName
bodyNm)

-- | Add a function to the set of global binders
addGlobalBind
  :: TmName
  -> Type
  -> SrcSpan
  -> InlineSpec
  -> Term
  -> RewriteMonad extra ()
addGlobalBind :: TmName
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
addGlobalBind vNm :: TmName
vNm ty :: Type
ty sp :: SrcSpan
sp inl :: InlineSpec
inl body :: Term
body = do
  let vId :: Id
vId = Type -> TmName -> Id
mkGlobalId Type
ty TmName
vNm
  (Type
ty,Term
body) (Type, Term)
-> ((BindingMap -> Identity BindingMap)
    -> RewriteState extra -> Identity (RewriteState extra))
-> (BindingMap -> Identity BindingMap)
-> RewriteState extra
-> Identity (RewriteState extra)
forall a b. NFData a => a -> b -> b
`deepseq` (BindingMap -> Identity BindingMap)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings ((BindingMap -> Identity BindingMap)
 -> RewriteState extra -> Identity (RewriteState extra))
-> (BindingMap -> BindingMap) -> RewriteMonad extra ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= TmName
-> (Id, SrcSpan, InlineSpec, Term) -> BindingMap -> BindingMap
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap TmName
vNm (Id
vId,SrcSpan
sp,InlineSpec
inl,Term
body)

-- | Create a new name out of the given name, but with another unique. Resulting
-- unique is guaranteed to not be in the given InScopeSet.
cloneNameWithInScopeSet
  :: (Monad m, MonadUnique m)
  => InScopeSet
  -> Name a
  -> m (Name a)
cloneNameWithInScopeSet :: InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet is :: InScopeSet
is nm :: Name a
nm = do
  Int
i <- m Int
forall (m :: * -> *). MonadUnique m => m Int
getUniqueM
  Name a -> m (Name a)
forall (m :: * -> *) a. Monad m => a -> m a
return (InScopeSet -> Name a -> Name a
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
is (Name a -> Int -> Name a
forall a. Uniquable a => a -> Int -> a
setUnique Name a
nm Int
i))

-- | Create a new name out of the given name, but with another unique. Resulting
-- unique is guaranteed to not be in the given BindingMap.
cloneNameWithBindingMap
  :: (Monad m, MonadUnique m)
  => BindingMap
  -> Name a
  -> m (Name a)
cloneNameWithBindingMap :: BindingMap -> Name a -> m (Name a)
cloneNameWithBindingMap binders :: BindingMap
binders nm :: Name a
nm = do
  Int
i <- m Int
forall (m :: * -> *). MonadUnique m => m Int
getUniqueM
  Name a -> m (Name a)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Int -> Bool) -> Int -> Name a -> Name a
forall a.
(Uniquable a, ClashPretty a) =>
(Int -> Bool) -> Int -> a -> a
uniqAway' (Int -> BindingMap -> Bool
forall b. Int -> UniqMap b -> Bool
`elemUniqMapDirectly` BindingMap
binders) Int
i (Name a -> Int -> Name a
forall a. Uniquable a => a -> Int -> a
setUnique Name a
nm Int
i))

{-# INLINE isUntranslatable #-}
-- | Determine if a term cannot be represented in hardware
isUntranslatable
  :: Bool
  -- ^ String representable
  -> Term
  -> RewriteMonad extra Bool
isUntranslatable :: Bool -> Term -> RewriteMonad extra Bool
isUntranslatable stringRepresentable :: Bool
stringRepresentable tm :: Term
tm = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Type -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Type
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                             RewriteMonad
  extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra CustomReprs
-> RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad extra CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                             RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra (TyConMap -> Type -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
stringRepresentable
                             RewriteMonad extra (TyConMap -> Type -> Bool)
-> RewriteMonad extra TyConMap -> RewriteMonad extra (Type -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TyConMap -> RewriteMonad extra TyConMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure TyConMap
tcm
                             RewriteMonad extra (Type -> Bool)
-> RewriteMonad extra Type -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RewriteMonad extra Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyConMap -> Term -> Type
termType TyConMap
tcm Term
tm))

{-# INLINE isUntranslatableType #-}
-- | Determine if a type cannot be represented in hardware
isUntranslatableType
  :: Bool
  -- ^ String representable
  -> Type
  -> RewriteMonad extra Bool
isUntranslatableType :: Bool -> Type -> RewriteMonad extra Bool
isUntranslatableType stringRepresentable :: Bool
stringRepresentable ty :: Type
ty =
  Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Type -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Type
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                             RewriteMonad
  extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra CustomReprs
-> RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad extra CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                             RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra (TyConMap -> Type -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
stringRepresentable
                             RewriteMonad extra (TyConMap -> Type -> Bool)
-> RewriteMonad extra TyConMap -> RewriteMonad extra (Type -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
                             RewriteMonad extra (Type -> Bool)
-> RewriteMonad extra Type -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RewriteMonad extra Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty)

-- | Make a binder that should not be referenced
mkWildValBinder
  :: (Monad m, MonadUnique m)
  => InScopeSet
  -> Type
  -> m Id
mkWildValBinder :: InScopeSet -> Type -> m Id
mkWildValBinder is :: InScopeSet
is = InScopeSet -> OccName -> Type -> m Id
forall (m :: * -> *).
(Monad m, MonadUnique m) =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
is "wild"

-- | Make a case-decomposition that extracts a field out of a (Sum-of-)Product type
mkSelectorCase
  :: HasCallStack
  => (Functor m, Monad m, MonadUnique m)
  => String -- ^ Name of the caller of this function
  -> InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Term -- ^ Subject of the case-composition
  -> Int -- n'th DataCon
  -> Int -- n'th field
  -> m Term
mkSelectorCase :: String -> InScopeSet -> TyConMap -> Term -> Int -> Int -> m Term
mkSelectorCase caller :: String
caller inScope :: InScopeSet
inScope tcm :: TyConMap
tcm scrut :: Term
scrut dcI :: Int
dcI fieldI :: Int
fieldI = Type -> m Term
forall (m :: * -> *). (Monad m, MonadUnique m) => Type -> m Term
go (TyConMap -> Term -> Type
termType TyConMap
tcm Term
scrut)
  where
    go :: Type -> m Term
go (TyConMap -> Type -> Maybe Type
coreView1 TyConMap
tcm -> Just ty' :: Type
ty') = Type -> m Term
go Type
ty'
    go scrutTy :: Type
scrutTy@(Type -> TypeView
tyView -> TyConApp tc :: TyConName
tc args :: [Type]
args) =
      case TyCon -> [DataCon]
tyConDataCons (TyConMap -> TyConName -> TyCon
forall a b. (HasCallStack, Uniquable a) => UniqMap b -> a -> b
lookupUniqMap' TyConMap
tcm TyConName
tc) of
        [] -> String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(curLoc) ("TyCon has no DataCons: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TyConName -> String
forall a. Show a => a -> String
show TyConName
tc String -> String -> String
forall a. [a] -> [a] -> [a]
++ " " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TyConName -> String
forall p. PrettyPrec p => p -> String
showPpr TyConName
tc) Type
scrutTy
        dcs :: [DataCon]
dcs | Int
dcI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> [DataCon] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DataCon]
dcs -> String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(curLoc) "DC index exceeds max" Type
scrutTy
            | Bool
otherwise -> do
          let dc :: DataCon
dc = String -> [DataCon] -> Int -> DataCon
forall a. String -> [a] -> Int -> a
indexNote ($(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "No DC with tag: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Int
dcIInt -> Int -> Int
forall a. Num a => a -> a -> a
-1)) [DataCon]
dcs (Int
dcIInt -> Int -> Int
forall a. Num a => a -> a -> a
-1)
          let (Just fieldTys :: [Type]
fieldTys) = HasCallStack =>
InScopeSet -> TyConMap -> DataCon -> [Type] -> Maybe [Type]
InScopeSet -> TyConMap -> DataCon -> [Type] -> Maybe [Type]
dataConInstArgTysE InScopeSet
inScope TyConMap
tcm DataCon
dc [Type]
args
          if Int
fieldI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
fieldTys
            then String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(curLoc) "Field index exceed max" Type
scrutTy
            else do
              [Id]
wildBndrs <- (Type -> m Id) -> [Type] -> m [Id]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet -> Type -> m Id
forall (m :: * -> *).
(Monad m, MonadUnique m) =>
InScopeSet -> Type -> m Id
mkWildValBinder InScopeSet
inScope) [Type]
fieldTys
              let ty :: Type
ty = String -> [Type] -> Int -> Type
forall a. String -> [a] -> Int -> a
indexNote ($(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "No DC field#: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
fieldI) [Type]
fieldTys Int
fieldI
              Id
selBndr <- InScopeSet -> OccName -> Type -> m Id
forall (m :: * -> *).
(Monad m, MonadUnique m) =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
inScope "sel" Type
ty
              let bndrs :: [Id]
bndrs  = Int -> [Id] -> [Id]
forall a. Int -> [a] -> [a]
take Int
fieldI [Id]
wildBndrs [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ [Id
selBndr] [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ Int -> [Id] -> [Id]
forall a. Int -> [a] -> [a]
drop (Int
fieldIInt -> Int -> Int
forall a. Num a => a -> a -> a
+1) [Id]
wildBndrs
                  pat :: Pat
pat    = DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc (DataCon -> [TyVar]
dcExtTyVars DataCon
dc) [Id]
bndrs
                  retVal :: Term
retVal = Term -> Type -> [Alt] -> Term
Case Term
scrut Type
ty [ (Pat
pat, Id -> Term
Var Id
selBndr) ]
              Term -> m Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
retVal
    go scrutTy :: Type
scrutTy = String -> String -> Type -> m Term
forall p a. PrettyPrec p => String -> String -> p -> a
cantCreate $(curLoc) ("Type of subject is not a datatype: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
scrutTy) Type
scrutTy

    cantCreate :: String -> String -> p -> a
cantCreate loc :: String
loc info :: String
info scrutTy :: p
scrutTy = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
loc String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Can't create selector " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (String, Int, Int) -> String
forall a. Show a => a -> String
show (String
caller,Int
dcI,Int
fieldI) String -> String -> String
forall a. [a] -> [a] -> [a]
++ " for: (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
scrut String -> String -> String
forall a. [a] -> [a] -> [a]
++ " :: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ p -> String
forall p. PrettyPrec p => p -> String
showPpr p
scrutTy String -> String -> String
forall a. [a] -> [a] -> [a]
++ ")\nAdditional info: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
info

-- | Specialise an application on its argument
specialise :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id) -- ^ Lens into previous specialisations
           -> Lens' extra (VarEnv Int) -- ^ Lens into the specialisation history
           -> Lens' extra Int -- ^ Lens into the specialisation limit
           -> Rewrite extra
specialise :: Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int) -> Lens' extra Int -> Rewrite extra
specialise specMapLbl :: Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl specHistLbl :: Lens' extra (VarEnv Int)
specHistLbl specLimitLbl :: Lens' extra Int
specLimitLbl ctx :: TransformContext
ctx e :: Term
e = case Term
e of
  (TyApp e1 :: Term
e1 ty :: Type
ty) -> Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
forall extra.
Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl Lens' extra (VarEnv Int)
specHistLbl Lens' extra Int
specLimitLbl TransformContext
ctx Term
e (Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e1) (Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty)
  (App e1 :: Term
e1 e2 :: Term
e2)   -> Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
forall extra.
Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl Lens' extra (VarEnv Int)
specHistLbl Lens' extra Int
specLimitLbl TransformContext
ctx Term
e (Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e1) (Term -> Either Term Type
forall a b. a -> Either a b
Left  Term
e2)
  _             -> Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Specialise an application on its argument
specialise' :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id) -- ^ Lens into previous specialisations
            -> Lens' extra (VarEnv Int) -- ^ Lens into specialisation history
            -> Lens' extra Int -- ^ Lens into the specialisation limit
            -> TransformContext -- Transformation context
            -> Term -- ^ Original term
            -> (Term, [Either Term Type], [TickInfo]) -- ^ Function part of the term, split into root and applied arguments
            -> Either Term Type -- ^ Argument to specialize on
            -> RewriteMonad extra Term
specialise' :: Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' specMapLbl :: Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl specHistLbl :: Lens' extra (VarEnv Int)
specHistLbl specLimitLbl :: Lens' extra Int
specLimitLbl (TransformContext is0 :: InScopeSet
is0 _) e :: Term
e (Var f :: Id
f, args :: [Either Term Type]
args, ticks :: [TickInfo]
ticks) specArgIn :: Either Term Type
specArgIn = do
  DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad extra DebugLevel
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel

  -- Don't specialise TopEntities
  VarSet
topEnts <- Getting VarSet RewriteEnv VarSet -> RewriteMonad extra VarSet
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
  if Id
f Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`elemVarSet` VarSet
topEnts
  then Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugNone) ("Not specialising TopEntity: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f)) (Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e)
  else do -- NondecreasingIndentation

  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache

  let specArg :: Either Term Type
specArg = (Term -> Term)
-> (Type -> Type) -> Either Term Type -> Either Term Type
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm) (TyConMap -> Type -> Type
normalizeType TyConMap
tcm) Either Term Type
specArgIn
      -- Create binders and variable references for free variables in 'specArg'
      -- (specBndrsIn,specVars) :: ([Either Id TyVar], [Either Term Type])
      (specBndrsIn :: [Either Id TyVar]
specBndrsIn,specVars :: [Either Term Type]
specVars) = Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars Either Term Type
specArg
      argLen :: Int
argLen  = [Either Term Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Type]
args
      specBndrs :: [Either Id TyVar]
      specBndrs :: [Either Id TyVar]
specBndrs = (Either Id TyVar -> Either Id TyVar)
-> [Either Id TyVar] -> [Either Id TyVar]
forall a b. (a -> b) -> [a] -> [b]
map (ASetter (Either Id TyVar) (Either Id TyVar) Id Id
-> (Id -> Id) -> Either Id TyVar -> Either Id TyVar
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
Lens.over ASetter (Either Id TyVar) (Either Id TyVar) Id Id
forall a c b. Prism (Either a c) (Either b c) a b
_Left (TyConMap -> Id -> Id
normalizeId TyConMap
tcm)) [Either Id TyVar]
specBndrsIn
      specAbs :: Either Term Type
      specAbs :: Either Term Type
specAbs = (Term -> Either Term Type)
-> (Type -> Either Term Type)
-> Either Term Type
-> Either Term Type
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Term -> Either Term Type
forall a b. a -> Either a b
Left (Term -> Either Term Type)
-> (Term -> Term) -> Term -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term -> [Either Id TyVar] -> Term
`mkAbstraction` [Either Id TyVar]
specBndrs)) (Type -> Either Term Type
forall a b. b -> Either a b
Right (Type -> Either Term Type)
-> (Type -> Type) -> Type -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall a. a -> a
id) Either Term Type
specArg
  -- Determine if 'f' has already been specialized on (a type-normalized) 'specArg'
  Maybe Id
specM <- (Id, Int, Either Term Type)
-> Map (Id, Int, Either Term Type) Id -> Maybe Id
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Id
f,Int
argLen,Either Term Type
specAbs) (Map (Id, Int, Either Term Type) Id -> Maybe Id)
-> RewriteMonad extra (Map (Id, Int, Either Term Type) Id)
-> RewriteMonad extra (Maybe Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (Map (Id, Int, Either Term Type) Id)
  (RewriteState extra)
  (Map (Id, Int, Either Term Type) Id)
-> RewriteMonad extra (Map (Id, Int, Either Term Type) Id)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((extra -> Const (Map (Id, Int, Either Term Type) Id) extra)
-> RewriteState extra
-> Const (Map (Id, Int, Either Term Type) Id) (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Const (Map (Id, Int, Either Term Type) Id) extra)
 -> RewriteState extra
 -> Const (Map (Id, Int, Either Term Type) Id) (RewriteState extra))
-> ((Map (Id, Int, Either Term Type) Id
     -> Const
          (Map (Id, Int, Either Term Type) Id)
          (Map (Id, Int, Either Term Type) Id))
    -> extra -> Const (Map (Id, Int, Either Term Type) Id) extra)
-> Getting
     (Map (Id, Int, Either Term Type) Id)
     (RewriteState extra)
     (Map (Id, Int, Either Term Type) Id)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map (Id, Int, Either Term Type) Id
 -> Const
      (Map (Id, Int, Either Term Type) Id)
      (Map (Id, Int, Either Term Type) Id))
-> extra -> Const (Map (Id, Int, Either Term Type) Id) extra
Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl)
  case Maybe Id
specM of
    -- Use previously specialized function
    Just f' :: Id
f' ->
      Bool
-> String -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugApplied)
        ("Using previous specialization of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f) String -> String -> String
forall a. [a] -> [a] -> [a]
++ " on " String -> String -> String
forall a. [a] -> [a] -> [a]
++
          ((Term -> String) -> (Type -> String) -> Either Term Type -> String
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> String
forall p. PrettyPrec p => p -> String
showPpr Type -> String
forall p. PrettyPrec p => p -> String
showPpr) Either Term Type
specAbs String -> String -> String
forall a. [a] -> [a] -> [a]
++ ": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f')) (RewriteMonad extra Term -> RewriteMonad extra Term)
-> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$
        Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad extra Term)
-> Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
f') [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
specVars)
    -- Create new specialized function
    Nothing -> do
      -- Determine if we can specialize f
      Maybe (Id, SrcSpan, InlineSpec, Term)
bodyMaybe <- (BindingMap -> Maybe (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad extra BindingMap
-> RewriteMonad extra (Maybe (Id, SrcSpan, InlineSpec, Term))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TmName -> BindingMap -> Maybe (Id, SrcSpan, InlineSpec, Term)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap (Id -> TmName
forall a. Var a -> Name a
varName Id
f)) (RewriteMonad extra BindingMap
 -> RewriteMonad extra (Maybe (Id, SrcSpan, InlineSpec, Term)))
-> RewriteMonad extra BindingMap
-> RewriteMonad extra (Maybe (Id, SrcSpan, InlineSpec, Term))
forall a b. (a -> b) -> a -> b
$ Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings
      case Maybe (Id, SrcSpan, InlineSpec, Term)
bodyMaybe of
        Just (_,sp :: SrcSpan
sp,inl :: InlineSpec
inl,bodyTm :: Term
bodyTm) -> do
          -- Determine if we see a sequence of specialisations on a growing argument
          Maybe Int
specHistM <- Id -> VarEnv Int -> Maybe Int
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap Id
f (VarEnv Int -> Maybe Int)
-> RewriteMonad extra (VarEnv Int)
-> RewriteMonad extra (Maybe Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (VarEnv Int) (RewriteState extra) (VarEnv Int)
-> RewriteMonad extra (VarEnv Int)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((extra -> Const (VarEnv Int) extra)
-> RewriteState extra -> Const (VarEnv Int) (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Const (VarEnv Int) extra)
 -> RewriteState extra -> Const (VarEnv Int) (RewriteState extra))
-> ((VarEnv Int -> Const (VarEnv Int) (VarEnv Int))
    -> extra -> Const (VarEnv Int) extra)
-> Getting (VarEnv Int) (RewriteState extra) (VarEnv Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Int -> Const (VarEnv Int) (VarEnv Int))
-> extra -> Const (VarEnv Int) extra
Lens' extra (VarEnv Int)
specHistLbl)
          Int
specLim   <- Getting Int (RewriteState extra) Int -> RewriteMonad extra Int
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((extra -> Const Int extra)
-> RewriteState extra -> Const Int (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra ((extra -> Const Int extra)
 -> RewriteState extra -> Const Int (RewriteState extra))
-> ((Int -> Const Int Int) -> extra -> Const Int extra)
-> Getting Int (RewriteState extra) Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Const Int Int) -> extra -> Const Int extra
Lens' extra Int
specLimitLbl)
          if Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
specLim) Maybe Int
specHistM
            then ClashException -> RewriteMonad extra Term
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException
                        SrcSpan
sp
                        ([String] -> String
unlines [ "Hit specialisation limit " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
specLim String -> String -> String
forall a. [a] -> [a] -> [a]
++ " on function `" String -> String -> String
forall a. [a] -> [a] -> [a]
++ TmName -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> TmName
forall a. Var a -> Name a
varName Id
f) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "'.\n"
                                 , "The function `" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
f String -> String -> String
forall a. [a] -> [a] -> [a]
++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
                                 , "Body of `" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
f String -> String -> String
forall a. [a] -> [a] -> [a]
++ "':\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
bodyTm String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n"
                                 , "Argument (in position: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
argLen String -> String -> String
forall a. [a] -> [a] -> [a]
++ ") that triggered termination:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((Term -> String) -> (Type -> String) -> Either Term Type -> String
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> String
forall p. PrettyPrec p => p -> String
showPpr Type -> String
forall p. PrettyPrec p => p -> String
showPpr) Either Term Type
specArg
                                 , "Run with '-fclash-spec-limit=N' to increase the specialisation limit to N."
                                 ])
                        Maybe String
forall a. Maybe a
Nothing)
            else do
              let existingNames :: [Name a]
existingNames = Term -> [Name a]
forall a. Term -> [Name a]
collectBndrsMinusApps Term
bodyTm
                  newNames :: [Name a]
newNames      = [ OccName -> Int -> Name a
forall a. OccName -> Int -> Name a
mkUnsafeInternalName ("pTS" OccName -> OccName -> OccName
`Text.append` String -> OccName
Text.pack (Int -> String
forall a. Show a => a -> String
show Int
n)) Int
n
                                  | Int
n <- [(0::Int)..]
                                  ]
              -- Make new binders for existing arguments
              (boundArgs :: [Either Id TyVar]
boundArgs,argVars :: [Either Term Type]
argVars) <- ([Either Id TyVar] -> ([Either Id TyVar], [Either Term Type]))
-> RewriteMonad extra [Either Id TyVar]
-> RewriteMonad extra ([Either Id TyVar], [Either Term Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(Either Id TyVar, Either Term Type)]
-> ([Either Id TyVar], [Either Term Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Either Id TyVar, Either Term Type)]
 -> ([Either Id TyVar], [Either Term Type]))
-> ([Either Id TyVar] -> [(Either Id TyVar, Either Term Type)])
-> [Either Id TyVar]
-> ([Either Id TyVar], [Either Term Type])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Either Id TyVar -> (Either Id TyVar, Either Term Type))
-> [Either Id TyVar] -> [(Either Id TyVar, Either Term Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Id -> (Either Id TyVar, Either Term Type))
-> (TyVar -> (Either Id TyVar, Either Term Type))
-> Either Id TyVar
-> (Either Id TyVar, Either Term Type)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Id -> Either Id TyVar
forall a b. a -> Either a b
Left (Id -> Either Id TyVar)
-> (Id -> Either Term Type)
-> Id
-> (Either Id TyVar, Either Term Type)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Term -> Either Term Type
forall a b. a -> Either a b
Left (Term -> Either Term Type)
-> (Id -> Term) -> Id -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Term
Var) (TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right (TyVar -> Either Id TyVar)
-> (TyVar -> Either Term Type)
-> TyVar
-> (Either Id TyVar, Either Term Type)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Type -> Either Term Type
forall a b. b -> Either a b
Right (Type -> Either Term Type)
-> (TyVar -> Type) -> TyVar -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Type
VarTy))) (RewriteMonad extra [Either Id TyVar]
 -> RewriteMonad extra ([Either Id TyVar], [Either Term Type]))
-> RewriteMonad extra [Either Id TyVar]
-> RewriteMonad extra ([Either Id TyVar], [Either Term Type])
forall a b. (a -> b) -> a -> b
$
                                     (Name Any
 -> Either Term Type -> RewriteMonad extra (Either Id TyVar))
-> [Name Any]
-> [Either Term Type]
-> RewriteMonad extra [Either Id TyVar]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM
                                       (InScopeSet
-> TyConMap
-> Name Any
-> Either Term Type
-> RewriteMonad extra (Either Id TyVar)
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor InScopeSet
is0 TyConMap
tcm)
                                       ([Name Any]
forall a. [Name a]
existingNames [Name Any] -> [Name Any] -> [Name Any]
forall a. [a] -> [a] -> [a]
++ [Name Any]
forall a. [Name a]
newNames)
                                       [Either Term Type]
args
              -- Determine name the resulting specialized function, and the
              -- form of the specialized-on argument
              (fId :: Id
fId,inl' :: InlineSpec
inl',specArg' :: Either Term Type
specArg') <- case Either Term Type
specArg of
                Left a :: Term
a@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Var g :: Id
g,gArgs :: [Either Term Type]
gArgs,_gTicks :: [TickInfo]
_gTicks)) -> if TyConMap -> Term -> Bool
isPolyFun TyConMap
tcm Term
a
                    then do
                      -- In case we are specialising on an argument that is a
                      -- global function then we use that function's name as the
                      -- name of the specialized higher-order function.
                      -- Additionally, we will return the body of the global
                      -- function, instead of a variable reference to the
                      -- global function.
                      --
                      -- This will turn things like @mealy g k@ into a new
                      -- binding @g'@ where both the body of @mealy@ and @g@
                      -- are inlined, meaning the state-transition-function
                      -- and the memory element will be in a single function.
                      Maybe (Id, SrcSpan, InlineSpec, Term)
gTmM <- (BindingMap -> Maybe (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad extra BindingMap
-> RewriteMonad extra (Maybe (Id, SrcSpan, InlineSpec, Term))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TmName -> BindingMap -> Maybe (Id, SrcSpan, InlineSpec, Term)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap (Id -> TmName
forall a. Var a -> Name a
varName Id
g)) (RewriteMonad extra BindingMap
 -> RewriteMonad extra (Maybe (Id, SrcSpan, InlineSpec, Term)))
-> RewriteMonad extra BindingMap
-> RewriteMonad extra (Maybe (Id, SrcSpan, InlineSpec, Term))
forall a b. (a -> b) -> a -> b
$ Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings
                      (Id, InlineSpec, Either Term Type)
-> RewriteMonad extra (Id, InlineSpec, Either Term Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id
g,InlineSpec
-> ((Id, SrcSpan, InlineSpec, Term) -> InlineSpec)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
-> InlineSpec
forall b a. b -> (a -> b) -> Maybe a -> b
maybe InlineSpec
inl ((Id, SrcSpan, InlineSpec, Term)
-> Getting InlineSpec (Id, SrcSpan, InlineSpec, Term) InlineSpec
-> InlineSpec
forall s a. s -> Getting a s a -> a
^. Getting InlineSpec (Id, SrcSpan, InlineSpec, Term) InlineSpec
forall s t a b. Field3 s t a b => Lens s t a b
_3) Maybe (Id, SrcSpan, InlineSpec, Term)
gTmM, Either Term Type
-> ((Id, SrcSpan, InlineSpec, Term) -> Either Term Type)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
-> Either Term Type
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either Term Type
specArg (Term -> Either Term Type
forall a b. a -> Either a b
Left (Term -> Either Term Type)
-> ((Id, SrcSpan, InlineSpec, Term) -> Term)
-> (Id, SrcSpan, InlineSpec, Term)
-> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term -> [Either Term Type] -> Term
`mkApps` [Either Term Type]
gArgs) (Term -> Term)
-> ((Id, SrcSpan, InlineSpec, Term) -> Term)
-> (Id, SrcSpan, InlineSpec, Term)
-> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, SrcSpan, InlineSpec, Term)
-> Getting Term (Id, SrcSpan, InlineSpec, Term) Term -> Term
forall s a. s -> Getting a s a -> a
^. Getting Term (Id, SrcSpan, InlineSpec, Term) Term
forall s t a b. Field4 s t a b => Lens s t a b
_4)) Maybe (Id, SrcSpan, InlineSpec, Term)
gTmM)
                    else (Id, InlineSpec, Either Term Type)
-> RewriteMonad extra (Id, InlineSpec, Either Term Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id
f,InlineSpec
inl,Either Term Type
specArg)
                _ -> (Id, InlineSpec, Either Term Type)
-> RewriteMonad extra (Id, InlineSpec, Either Term Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id
f,InlineSpec
inl,Either Term Type
specArg)
              -- Create specialized functions
              let newBody :: Term
newBody = Term -> [Either Id TyVar] -> Term
mkAbstraction (Term -> [Either Term Type] -> Term
mkApps Term
bodyTm ([Either Term Type]
argVars [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type
specArg'])) ([Either Id TyVar]
boundArgs [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. [a] -> [a] -> [a]
++ [Either Id TyVar]
specBndrs)
              Id
newf <- TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
forall extra.
TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
mkFunction (Id -> TmName
forall a. Var a -> Name a
varName Id
fId) SrcSpan
sp InlineSpec
inl' Term
newBody
              -- Remember specialization
              ((extra -> Identity extra)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Identity extra)
 -> RewriteState extra -> Identity (RewriteState extra))
-> ((VarEnv Int -> Identity (VarEnv Int))
    -> extra -> Identity extra)
-> (VarEnv Int -> Identity (VarEnv Int))
-> RewriteState extra
-> Identity (RewriteState extra)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Int -> Identity (VarEnv Int)) -> extra -> Identity extra
Lens' extra (VarEnv Int)
specHistLbl) ((VarEnv Int -> Identity (VarEnv Int))
 -> RewriteState extra -> Identity (RewriteState extra))
-> (VarEnv Int -> VarEnv Int) -> RewriteMonad extra ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id -> Int -> (Int -> Int -> Int) -> VarEnv Int -> VarEnv Int
forall a b.
Uniquable a =>
a -> b -> (b -> b -> b) -> UniqMap b -> UniqMap b
extendUniqMapWith Id
f 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)
              ((extra -> Identity extra)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((extra -> Identity extra)
 -> RewriteState extra -> Identity (RewriteState extra))
-> ((Map (Id, Int, Either Term Type) Id
     -> Identity (Map (Id, Int, Either Term Type) Id))
    -> extra -> Identity extra)
-> (Map (Id, Int, Either Term Type) Id
    -> Identity (Map (Id, Int, Either Term Type) Id))
-> RewriteState extra
-> Identity (RewriteState extra)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map (Id, Int, Either Term Type) Id
 -> Identity (Map (Id, Int, Either Term Type) Id))
-> extra -> Identity extra
Lens' extra (Map (Id, Int, Either Term Type) Id)
specMapLbl)  ((Map (Id, Int, Either Term Type) Id
  -> Identity (Map (Id, Int, Either Term Type) Id))
 -> RewriteState extra -> Identity (RewriteState extra))
-> (Map (Id, Int, Either Term Type) Id
    -> Map (Id, Int, Either Term Type) Id)
-> RewriteMonad extra ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= (Id, Int, Either Term Type)
-> Id
-> Map (Id, Int, Either Term Type) Id
-> Map (Id, Int, Either Term Type) Id
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Id
f,Int
argLen,Either Term Type
specAbs) Id
newf
              -- use specialized function
              let newExpr :: Term
newExpr = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
newf) [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
specVars)
              Id
newf Id -> RewriteMonad extra Term -> RewriteMonad extra Term
forall a b. NFData a => a -> b -> b
`deepseq` Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
newExpr
        Nothing -> Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    collectBndrsMinusApps :: Term -> [Name a]
    collectBndrsMinusApps :: Term -> [Name a]
collectBndrsMinusApps = [Name a] -> [Name a]
forall a. [a] -> [a]
reverse ([Name a] -> [Name a]) -> (Term -> [Name a]) -> Term -> [Name a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name a] -> Term -> [Name a]
forall a.
(Coercible a TmName, Coercible a TyName) =>
[a] -> Term -> [a]
go []
      where
        go :: [a] -> Term -> [a]
go bs :: [a]
bs (Lam v :: Id
v e' :: Term
e')    = [a] -> Term -> [a]
go (TmName -> a
forall a b. Coercible a b => a -> b
coerce (Id -> TmName
forall a. Var a -> Name a
varName Id
v)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs)  Term
e'
        go bs :: [a]
bs (TyLam tv :: TyVar
tv e' :: Term
e') = [a] -> Term -> [a]
go (TyName -> a
forall a b. Coercible a b => a -> b
coerce (TyVar -> TyName
forall a. Var a -> Name a
varName TyVar
tv)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs) Term
e'
        go bs :: [a]
bs (App e' :: Term
e' _) = case [a] -> Term -> [a]
go [] Term
e' of
          []  -> [a]
bs
          bs' :: [a]
bs' -> [a] -> [a]
forall a. [a] -> [a]
init [a]
bs' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
bs
        go bs :: [a]
bs (TyApp e' :: Term
e' _) = case [a] -> Term -> [a]
go [] Term
e' of
          []  -> [a]
bs
          bs' :: [a]
bs' -> [a] -> [a]
forall a. [a] -> [a]
init [a]
bs' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
bs
        go bs :: [a]
bs _ = [a]
bs

specialise' _ _ _ _ctx :: TransformContext
_ctx _ (appE :: Term
appE,args :: [Either Term Type]
args,ticks :: [TickInfo]
ticks) (Left specArg :: Term
specArg) = do
  -- Create binders and variable references for free variables in 'specArg'
  let (specBndrs :: [Either Id TyVar]
specBndrs,specVars :: [Either Term Type]
specVars) = Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars (Term -> Either Term Type
forall a b. a -> Either a b
Left Term
specArg)
  -- Create specialized function
      newBody :: Term
newBody = Term -> [Either Id TyVar] -> Term
mkAbstraction Term
specArg [Either Id TyVar]
specBndrs
  -- See if there's an existing binder that's alpha-equivalent to the
  -- specialized function
  BindingMap
existing <- ((Id, SrcSpan, InlineSpec, Term) -> Bool)
-> BindingMap -> BindingMap
forall b. (b -> Bool) -> UniqMap b -> UniqMap b
filterUniqMap ((Term -> Term -> Bool
`aeqTerm` Term
newBody) (Term -> Bool)
-> ((Id, SrcSpan, InlineSpec, Term) -> Term)
-> (Id, SrcSpan, InlineSpec, Term)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, SrcSpan, InlineSpec, Term)
-> Getting Term (Id, SrcSpan, InlineSpec, Term) Term -> Term
forall s a. s -> Getting a s a -> a
^. Getting Term (Id, SrcSpan, InlineSpec, Term) Term
forall s t a b. Field4 s t a b => Lens s t a b
_4)) (BindingMap -> BindingMap)
-> RewriteMonad extra BindingMap -> RewriteMonad extra BindingMap
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra1. Lens' (RewriteState extra1) BindingMap
bindings
  -- Create a new function if an alpha-equivalent binder doesn't exist
  Id
newf <- case BindingMap -> [(Id, SrcSpan, InlineSpec, Term)]
forall a. UniqMap a -> [a]
eltsUniqMap BindingMap
existing of
    [] -> do (cf :: Id
cf,sp :: SrcSpan
sp) <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
             TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
forall extra.
TmName -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
mkFunction (TmName -> OccName -> TmName
forall a. Name a -> OccName -> Name a
appendToName (Id -> TmName
forall a. Var a -> Name a
varName Id
cf) "_specF")
                        SrcSpan
sp
#if MIN_VERSION_ghc(8,4,1)
                        InlineSpec
NoUserInline
#else
                        EmptyInlineSpec
#endif
                        Term
newBody
    ((k :: Id
k,_,_,_):_) -> Id -> RewriteMonad extra Id
forall (m :: * -> *) a. Monad m => a -> m a
return Id
k
  -- Create specialized argument
  let newArg :: Either Term b
newArg  = Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> Term -> Either Term b
forall a b. (a -> b) -> a -> b
$ Term -> [Either Term Type] -> Term
mkApps (Id -> Term
Var Id
newf) [Either Term Type]
specVars
  -- Use specialized argument
  let newExpr :: Term
newExpr = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
appE [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type
forall b. Either Term b
newArg])
  Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
newExpr

specialise' _ _ _ _ e :: Term
e _ _ = Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes tcm :: TyConMap
tcm e :: Term
e = case Term
e of
  Cast e' :: Term
e' ty1 :: Type
ty1 ty2 :: Type
ty2 -> Term -> Type -> Type -> Term
Cast (TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm Term
e') (TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
ty1) (TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
ty2)
  Var v :: Id
v -> Id -> Term
Var (TyConMap -> Id -> Id
normalizeId TyConMap
tcm Id
v)
  -- TODO other terms?
  _ -> Term
e

normalizeId :: TyConMap -> Id -> Id
normalizeId :: TyConMap -> Id -> Id
normalizeId tcm :: TyConMap
tcm v :: Id
v@(Id {}) = Id
v {varType :: Type
varType = TyConMap -> Type -> Type
normalizeType TyConMap
tcm (Id -> Type
forall a. Var a -> Type
varType Id
v)}
normalizeId _   tyvar :: Id
tyvar     = Id
tyvar


-- | Create binders and variable references for free variables in 'specArg'
specArgBndrsAndVars
  :: Either Term Type
  -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars :: Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars specArg :: Either Term Type
specArg =
  let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
      unitFV :: Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV v :: Var a
v@(Id {}) = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (UniqSet TyVar
forall a. UniqSet a
emptyUniqSet,Id -> UniqSet Id
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> Id
forall a b. Coercible a b => a -> b
coerce Var a
v))
      unitFV v :: Var a
v@(TyVar {}) = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (TyVar -> UniqSet TyVar
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> TyVar
forall a b. Coercible a b => a -> b
coerce Var a
v),UniqSet Id
forall a. UniqSet a
emptyUniqSet)

      (specFTVs :: [TyVar]
specFTVs,specFVs :: [Id]
specFVs) = case Either Term Type
specArg of
        Left tm :: Term
tm  -> (UniqSet TyVar -> [TyVar]
forall a. UniqSet a -> [a]
eltsUniqSet (UniqSet TyVar -> [TyVar])
-> (UniqSet Id -> [Id])
-> (UniqSet TyVar, UniqSet Id)
-> ([TyVar], [Id])
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** UniqSet Id -> [Id]
forall a. UniqSet a -> [a]
eltsUniqSet) ((UniqSet TyVar, UniqSet Id) -> ([TyVar], [Id]))
-> (Const (UniqSet TyVar, UniqSet Id) (Var Any)
    -> (UniqSet TyVar, UniqSet Id))
-> Const (UniqSet TyVar, UniqSet Id) (Var Any)
-> ([TyVar], [Id])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Const (UniqSet TyVar, UniqSet Id) (Var Any)
-> (UniqSet TyVar, UniqSet Id)
forall a k (b :: k). Const a b -> a
getConst (Const (UniqSet TyVar, UniqSet Id) (Var Any) -> ([TyVar], [Id]))
-> Const (UniqSet TyVar, UniqSet Id) (Var Any) -> ([TyVar], [Id])
forall a b. (a -> b) -> a -> b
$
                    Getting
  (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
-> (Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any))
-> Term
-> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting
  (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
forall a. Fold Term (Var a)
freeLocalVars Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall a. Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV Term
tm
        Right ty :: Type
ty -> (UniqSet TyVar -> [TyVar]
forall a. UniqSet a -> [a]
eltsUniqSet (Getting (UniqSet TyVar) Type TyVar
-> (TyVar -> UniqSet TyVar) -> Type -> UniqSet TyVar
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet TyVar) Type TyVar
Fold Type TyVar
typeFreeVars TyVar -> UniqSet TyVar
forall a. Uniquable a => a -> UniqSet a
unitUniqSet Type
ty),[] :: [Id])

      specTyBndrs :: [Either a TyVar]
specTyBndrs = (TyVar -> Either a TyVar) -> [TyVar] -> [Either a TyVar]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Either a TyVar
forall a b. b -> Either a b
Right [TyVar]
specFTVs
      specTmBndrs :: [Either Id b]
specTmBndrs = (Id -> Either Id b) -> [Id] -> [Either Id b]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Either Id b
forall a b. a -> Either a b
Left  [Id]
specFVs

      specTyVars :: [Either a Type]
specTyVars  = (TyVar -> Either a Type) -> [TyVar] -> [Either a Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Either a Type
forall a b. b -> Either a b
Right (Type -> Either a Type)
-> (TyVar -> Type) -> TyVar -> Either a Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Type
VarTy) [TyVar]
specFTVs
      specTmVars :: [Either Term b]
specTmVars  = (Id -> Either Term b) -> [Id] -> [Either Term b]
forall a b. (a -> b) -> [a] -> [b]
map (Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> (Id -> Term) -> Id -> Either Term b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Term
Var) [Id]
specFVs

  in  ([Either Id TyVar]
forall a. [Either a TyVar]
specTyBndrs [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. [a] -> [a] -> [a]
++ [Either Id TyVar]
forall b. [Either Id b]
specTmBndrs,[Either Term Type]
forall a. [Either a Type]
specTyVars [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
forall b. [Either Term b]
specTmVars)