{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016-2017, Myrtle Software Ltd,
                    2017-2018, Google Inc.,
                    2021-2022, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  Transformations on letrec expressions.
-}

{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module Clash.Normalize.Transformations.Letrec
  ( deadCode
  , flattenLet
  , recToLetRec
  , removeUnusedExpr
  , simpleCSE
  , topLet
  ) where

import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import Control.Monad.Trans.Except (runExcept)
import Control.Monad.Writer (listen)
import Data.Bifunctor (second)
import qualified Data.Either as Either
import qualified Data.HashMap.Lazy as HashMap
import Data.List ((\\))
import qualified Data.List as List
import qualified Data.List.Extra as List
import qualified Data.Monoid as Monoid (Any(..))
import qualified Data.Text as Text
import qualified Data.Text.Extra as Text
import GHC.Stack (HasCallStack)

import Clash.Annotations.BitRepresentation.Deriving (dontApplyInHDL)
import Clash.Sized.Vector as Vec (Vec(Cons), splitAt)

import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.DataCon (DataCon(..))
import Clash.Core.FreeVars (freeLocalIds)
import Clash.Core.HasFreeVars
import Clash.Core.HasType
import Clash.Core.Name (mkUnsafeSystemName, nameOcc)
import Clash.Core.Subst
import Clash.Core.Term
  ( LetBinding, Pat(..), PrimInfo(..), Term(..), collectArgs, collectArgsTicks
  , collectTicks, isLambdaBodyCtx, isTickCtx, mkApps, mkLams, mkTicks, Bind(..)
  , partitionTicks, stripAllTicks)
import Clash.Core.TermInfo (isCon, isLet, isLocalVar, isTick)
import Clash.Core.TyCon (tyConDataCons)
import Clash.Core.Type
  (Type(..), TypeView(..), normalizeType
  , splitFunForallTy, tyView)
import Clash.Core.Util (inverseTopSortLetBindings, mkVec, tyNatSize)
import Clash.Core.Var (isGlobalId)
import Clash.Core.VarEnv
  ( InScopeSet, elemInScopeSet, emptyVarEnv, extendInScopeSetList, lookupVarEnv
  , unionVarEnvWith, unitVarEnv, mkVarSet)
import Clash.Netlist.BlackBox.Types ()
import Clash.Netlist.BlackBox.Util (getUsedArguments)
import Clash.Netlist.Util (splitNormalized)
import Clash.Normalize.Primitives (removedArg)
import Clash.Normalize.Transformations.Reduce (reduceBinders)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Primitives.Types (Primitive(..), UsedArguments(..))
import Clash.Rewrite.Types
  (TransformContext(..), bindings, curFun, tcCache, workFreeBinders, primitives)
import Clash.Rewrite.Util
  (changed, isFromInt, isUntranslatable, mkTmBinderFor, removeUnusedBinders, setChanged)
import Clash.Rewrite.WorkFree
import Clash.Unique (lookupUniqMap)

{- [Note: Name re-creation]
The names of heap bound variables are safely generate with mkUniqSystemId in
Clash.Core.Evaluator.newLetBinding. But only their uniqs end up in the heap,
not the complete names. So we use mkUnsafeSystemName to recreate the same Name.
-}

-- | Remove unused let-bindings
deadCode :: HasCallStack => NormRewrite
deadCode :: NormRewrite
deadCode TransformContext
_ e :: Term
e@(Let Bind Term
binds Term
body) =
  case Bind Term -> Term -> Maybe Term
removeUnusedBinders Bind Term
binds Term
body of
    Just Term
t -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
t
    Maybe Term
Nothing -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
deadCode TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC deadCode #-}

removeUnusedExpr :: HasCallStack => NormRewrite
removeUnusedExpr :: NormRewrite
removeUnusedExpr TransformContext
_ e :: Term
e@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (p :: Term
p@(Prim PrimInfo
pInfo),[Either Term Type]
args,[TickInfo]
ticks)) = do
  Maybe GuardedCompiledPrimitive
bbM <- Text
-> HashMap Text GuardedCompiledPrimitive
-> Maybe GuardedCompiledPrimitive
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HashMap.lookup (PrimInfo -> Text
primName PrimInfo
pInfo) (HashMap Text GuardedCompiledPrimitive
 -> Maybe GuardedCompiledPrimitive)
-> RewriteMonad
     NormalizeState (HashMap Text GuardedCompiledPrimitive)
-> RewriteMonad NormalizeState (Maybe GuardedCompiledPrimitive)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (HashMap Text GuardedCompiledPrimitive)
  RewriteEnv
  (HashMap Text GuardedCompiledPrimitive)
-> RewriteMonad
     NormalizeState (HashMap Text GuardedCompiledPrimitive)
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting
  (HashMap Text GuardedCompiledPrimitive)
  RewriteEnv
  (HashMap Text GuardedCompiledPrimitive)
Getter RewriteEnv (HashMap Text GuardedCompiledPrimitive)
primitives
  let
    usedArgs0 :: Maybe [Int]
usedArgs0 =
      case Maybe (Maybe CompiledPrimitive) -> Maybe CompiledPrimitive
forall (m :: Type -> Type) a. Monad m => m (m a) -> m a
Monad.join (GuardedCompiledPrimitive -> Maybe CompiledPrimitive
forall a. PrimitiveGuard a -> Maybe a
extractPrim (GuardedCompiledPrimitive -> Maybe CompiledPrimitive)
-> Maybe GuardedCompiledPrimitive
-> Maybe (Maybe CompiledPrimitive)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe GuardedCompiledPrimitive
bbM) of
        Just (BlackBoxHaskell{UsedArguments
usedArguments :: forall a b c d. Primitive a b c d -> UsedArguments
usedArguments :: UsedArguments
usedArguments}) ->
          case UsedArguments
usedArguments of
            UsedArguments [Int]
used -> [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
used
            IgnoredArguments [Int]
ignored -> [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int
0..[Either Term Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Either Term Type]
args Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Int]
ignored)
        Just (BlackBox Text
pNm WorkInfo
_ RenderVoid
_ Bool
_ TemplateKind
_ ()
_ Bool
_ [BlackBoxTemplate]
_ [BlackBoxTemplate]
_ [(Int, Int)]
_ [((Text, Text), BlackBox)]
inc [BlackBox]
r [BlackBox]
ri BlackBox
templ) -> [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$
          if | Text -> Bool
isFromInt Text
pNm -> [Int
0,Int
1,Int
2]
             | PrimInfo -> Text
primName PrimInfo
pInfo Text -> [Text] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [ Name -> Text
forall a. Show a => a -> Text
Text.showt 'dontApplyInHDL
                                     , Name -> Text
forall a. Show a => a -> Text
Text.showt 'Vec.splitAt
                                     ] -> [Int
0,Int
1]
             | Bool
otherwise -> [[Int]] -> [Int]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ (BlackBox -> [Int]) -> [BlackBox] -> [Int]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap BlackBox -> [Int]
getUsedArguments [BlackBox]
r
                                   , (BlackBox -> [Int]) -> [BlackBox] -> [Int]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap BlackBox -> [Int]
getUsedArguments [BlackBox]
ri
                                   , BlackBox -> [Int]
getUsedArguments BlackBox
templ
                                   , (((Text, Text), BlackBox) -> [Int])
-> [((Text, Text), BlackBox)] -> [Int]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap (BlackBox -> [Int]
getUsedArguments (BlackBox -> [Int])
-> (((Text, Text), BlackBox) -> BlackBox)
-> ((Text, Text), BlackBox)
-> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Text, Text), BlackBox) -> BlackBox
forall a b. (a, b) -> b
snd) [((Text, Text), BlackBox)]
inc ]
        Maybe CompiledPrimitive
_ ->
          Maybe [Int]
forall a. Maybe a
Nothing

  case Maybe [Int]
usedArgs0 of
    Maybe [Int]
Nothing ->
      Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
    Just [Int]
usedArgs1 -> do
      TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
      ([Either Term Type]
args1, Any -> Bool
Monoid.getAny -> Bool
hasChanged) <- RewriteMonad NormalizeState [Either Term Type]
-> RewriteMonad NormalizeState ([Either Term Type], Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen (TyConMap
-> Int
-> [Int]
-> [Either Term Type]
-> RewriteMonad NormalizeState [Either Term Type]
forall (t :: Type -> Type) b extra.
Foldable t =>
TyConMap
-> Int
-> t Int
-> [Either Term b]
-> RewriteMonad extra [Either Term b]
go TyConMap
tcm Int
0 [Int]
usedArgs1 [Either Term Type]
args)
      if Bool
hasChanged then
        Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
p [TickInfo]
ticks) [Either Term Type]
args1)
      else
        Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

  where
    arity :: Int
arity = [Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (([Either TyVar Type], Type) -> [Type])
-> ([Either TyVar Type], Type)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either TyVar Type] -> [Type]
forall a b. [Either a b] -> [b]
Either.rights ([Either TyVar Type] -> [Type])
-> (([Either TyVar Type], Type) -> [Either TyVar Type])
-> ([Either TyVar Type], Type)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Either TyVar Type], Type) -> [Either TyVar Type]
forall a b. (a, b) -> a
fst (([Either TyVar Type], Type) -> Int)
-> ([Either TyVar Type], Type) -> Int
forall a b. (a -> b) -> a -> b
$ Type -> ([Either TyVar Type], Type)
splitFunForallTy (PrimInfo -> Type
forall a. HasType a => a -> Type
coreTypeOf PrimInfo
pInfo)

    go :: TyConMap
-> Int
-> t Int
-> [Either Term b]
-> RewriteMonad extra [Either Term b]
go TyConMap
_ Int
_ t Int
_ [] = [Either Term b] -> RewriteMonad extra [Either Term b]
forall (m :: Type -> Type) a. Monad m => a -> m a
return []
    go TyConMap
tcm !Int
n t Int
used (Right b
ty:[Either Term b]
args') = do
      [Either Term b]
args'' <- TyConMap
-> Int
-> t Int
-> [Either Term b]
-> RewriteMonad extra [Either Term b]
go TyConMap
tcm Int
n t Int
used [Either Term b]
args'
      [Either Term b] -> RewriteMonad extra [Either Term b]
forall (m :: Type -> Type) a. Monad m => a -> m a
return (b -> Either Term b
forall a b. b -> Either a b
Right b
ty Either Term b -> [Either Term b] -> [Either Term b]
forall a. a -> [a] -> [a]
: [Either Term b]
args'')
    go TyConMap
tcm !Int
n t Int
used (Left Term
tm : [Either Term b]
args') = do
      [Either Term b]
args'' <- TyConMap
-> Int
-> t Int
-> [Either Term b]
-> RewriteMonad extra [Either Term b]
go TyConMap
tcm (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) t Int
used [Either Term b]
args'
      case Term
tm of
        TyApp (Prim PrimInfo
p0) Type
_
          | PrimInfo -> Text
primName PrimInfo
p0 Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Text
forall a. Show a => a -> Text
Text.showt 'removedArg
          -> [Either Term b] -> RewriteMonad extra [Either Term b]
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> Either Term b
forall a b. a -> Either a b
Left Term
tm Either Term b -> [Either Term b] -> [Either Term b]
forall a. a -> [a] -> [a]
: [Either Term b]
args'')
        Term
_ -> do
          let ty :: Type
ty = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
tm
              p' :: Term
p' = Term -> Type -> Term
TyApp (PrimInfo -> Term
Prim PrimInfo
removedArg) Type
ty
          if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
arity Bool -> Bool -> Bool
&& Int
n Int -> t Int -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`notElem` t Int
used
             then [Either Term b] -> RewriteMonad extra [Either Term b]
forall a extra. a -> RewriteMonad extra a
changed (Term -> Either Term b
forall a b. a -> Either a b
Left Term
p' Either Term b -> [Either Term b] -> [Either Term b]
forall a. a -> [a] -> [a]
: [Either Term b]
args'')
             else [Either Term b] -> RewriteMonad extra [Either Term b]
forall (m :: Type -> Type) a. Monad m => a -> m a
return  (Term -> Either Term b
forall a b. a -> Either a b
Left Term
tm Either Term b -> [Either Term b] -> [Either Term b]
forall a. a -> [a] -> [a]
: [Either Term b]
args'')

removeUnusedExpr TransformContext
_ e :: Term
e@(Case Term
_ Type
_ [(DataPat DataCon
_ [] [Id]
xs,Term
altExpr)]) =
  if [Id] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet [Id]
xs VarSet -> Term -> Bool
forall a. HasFreeVars a => VarSet -> a -> Bool
`disjointFreeVars` Term
altExpr
     then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altExpr
     else Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

-- Replace any expression that creates a Vector of size 0 within the application
-- of the Cons constructor, by the Nil constructor.
removeUnusedExpr TransformContext
_ e :: Term
e@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Data DataCon
dc, [Either Term Type
_,Right Type
aTy,Right Type
nTy,Either Term Type
_,Left Term
a,Left Term
nil],[TickInfo]
ticks))
  | Name DataCon -> Text
forall a. Name a -> Text
nameOcc (DataCon -> Name DataCon
dcName DataCon
dc) Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Text
forall a. Show a => a -> Text
Text.showt 'Vec.Cons
  = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
    case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Type -> Except String Integer
tyNatSize TyConMap
tcm Type
nTy) of
      Right Integer
0
        | (Term
con, [Either Term Type]
_) <- Term -> (Term, [Either Term Type])
collectArgs Term
nil
        , Bool -> Bool
not (Term -> Bool
isCon Term
con)
        -> let eTy :: Type
eTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
e
               (TyConApp TyConName
vecTcNm [Type]
_) = Type -> TypeView
tyView Type
eTy
               (Just TyCon
vecTc) = TyConName -> TyConMap -> Maybe TyCon
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap TyConName
vecTcNm TyConMap
tcm
               [DataCon
nilCon,DataCon
consCon] = TyCon -> [DataCon]
tyConDataCons TyCon
vecTc
               v :: Term
v = Term -> [TickInfo] -> Term
mkTicks (DataCon -> DataCon -> Type -> Integer -> [Term] -> Term
mkVec DataCon
nilCon DataCon
consCon Type
aTy Integer
1 [Term
a]) [TickInfo]
ticks
           in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
v
      Either String Integer
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

removeUnusedExpr TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC removeUnusedExpr #-}

-- | Flatten's letrecs after `inlineCleanup`
--
-- `inlineCleanup` sometimes exposes additional possibilities for `caseCon`,
-- which then introduces let-bindings in what should be ANF. This transformation
-- flattens those nested let-bindings again.
--
-- NB: must only be called in the cleaning up phase.
flattenLet :: HasCallStack => NormRewrite
flattenLet :: NormRewrite
flattenLet ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) (Letrec [LetBinding]
binds0 body0 :: Term
body0@Letrec{}) = do
  -- deshadow binds1, so binds0 and binds1 don't conflict when merged
  let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
binds0)
      Letrec [LetBinding]
binds1 Term
body1 = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
is1 Term
body0

  RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
  HasCallStack => NormRewrite
NormRewrite
flattenLet TransformContext
ctx{tfInScope :: InScopeSet
tfInScope=InScopeSet
is1} ([LetBinding] -> Term -> Term
Letrec ([LetBinding]
binds0 [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. Semigroup a => a -> a -> a
<> [LetBinding]
binds1) Term
body1)

flattenLet (TransformContext InScopeSet
is0 Context
_) (Letrec [LetBinding]
binds Term
body) = do
  let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
binds)
      bodyOccs :: VarEnv Int
bodyOccs = Fold Term Id
-> (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv Int
-> (Id -> VarEnv Int)
-> Term
-> VarEnv Int
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf
                   Fold Term Id
freeLocalIds ((Int -> Int -> Int) -> VarEnv Int -> VarEnv Int -> VarEnv Int
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
                   VarEnv Int
forall a. VarEnv a
emptyVarEnv (Id -> Int -> VarEnv Int
forall b a. Var b -> a -> VarEnv a
`unitVarEnv` (Int
1 :: Int))
                   Term
body
  (InScopeSet
is2,[LetBinding]
binds1) <- ([[LetBinding]] -> [LetBinding])
-> (InScopeSet, [[LetBinding]]) -> (InScopeSet, [LetBinding])
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [[LetBinding]] -> [LetBinding]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat ((InScopeSet, [[LetBinding]]) -> (InScopeSet, [LetBinding]))
-> RewriteMonad NormalizeState (InScopeSet, [[LetBinding]])
-> RewriteMonad NormalizeState (InScopeSet, [LetBinding])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (InScopeSet
 -> LetBinding
 -> RewriteMonad NormalizeState (InScopeSet, [LetBinding]))
-> InScopeSet
-> [LetBinding]
-> RewriteMonad NormalizeState (InScopeSet, [[LetBinding]])
forall (m :: Type -> Type) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
List.mapAccumLM InScopeSet
-> LetBinding
-> RewriteMonad NormalizeState (InScopeSet, [LetBinding])
go InScopeSet
is1 [LetBinding]
binds
  BindingMap
bndrs <- Getting BindingMap (RewriteState NormalizeState) BindingMap
-> RewriteMonad NormalizeState BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
  Bool
e1WorkFree <-
    case [LetBinding]
binds1 of
      [(Id
_,Term
e1)] -> Lens' (RewriteState NormalizeState) (VarEnv Bool)
-> BindingMap -> Term -> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree forall extra. Lens' (RewriteState extra) (VarEnv Bool)
Lens' (RewriteState NormalizeState) (VarEnv Bool)
workFreeBinders BindingMap
bndrs Term
e1
      [LetBinding]
_ -> Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (String -> Bool
forall a. HasCallStack => String -> a
error String
"flattenLet: unreachable")
  case [LetBinding]
binds1 of
    -- inline binders into the body when there's only a single binder, and only
    -- if that binder doesn't perform any work or is only used once in the body
    [(Id
id1,Term
e1)] | Just Int
occ <- Id -> VarEnv Int -> Maybe Int
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
id1 VarEnv Int
bodyOccs, Bool
e1WorkFree Bool -> Bool -> Bool
|| Int
occ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 ->
      if Id
id1 Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
`elemFreeVars` Term
e1
         -- Except when the binder is recursive!
         then Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
binds1 Term
body)
         else let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is2) Id
id1 Term
e1
              in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"flattenLet" Subst
subst Term
body)
    [LetBinding]
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
binds1 Term
body)
  where
    go :: InScopeSet -> LetBinding -> NormalizeSession (InScopeSet,[LetBinding])
    go :: InScopeSet
-> LetBinding
-> RewriteMonad NormalizeState (InScopeSet, [LetBinding])
go InScopeSet
isN (Id
id1,Term -> (Term, [TickInfo])
collectTicks -> (Letrec [LetBinding]
binds1 Term
body1,[TickInfo]
ticks)) = do
      let bs1 :: [Id]
bs1 = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
binds1
      let ([LetBinding]
binds2,Term
body2,InScopeSet
isN1) =
            -- We need to deshadow because we're merging nested let-expressions
            -- into a single let-expression: and within a let-expression, the
            -- bindings are not allowed to shadow each-other. Of course, we
            -- only need to deshadow if any shadowing is happening in the
            -- first place.
            --
            -- This is much better than blindly calling freshenTm, and saves
            -- almost 30% run-time of the normalization phase on some examples.
            if (Id -> Bool) -> [Id] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (Id -> InScopeSet -> Bool
forall a. Var a -> InScopeSet -> Bool
`elemInScopeSet` InScopeSet
isN) [Id]
bs1 then
              let Letrec [LetBinding]
bindsN Term
bodyN = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
isN ([LetBinding] -> Term -> Term
Letrec [LetBinding]
binds1 Term
body1)
              in  ([LetBinding]
bindsN,Term
bodyN,InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
isN ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bindsN))
            else
              ([LetBinding]
binds1,Term
body1,InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
isN [Id]
bs1)
      let bodyOccs :: VarEnv Int
bodyOccs = Fold Term Id
-> (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv Int
-> (Id -> VarEnv Int)
-> Term
-> VarEnv Int
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf
                       Fold Term Id
freeLocalIds ((Int -> Int -> Int) -> VarEnv Int -> VarEnv Int -> VarEnv Int
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
                       VarEnv Int
forall a. VarEnv a
emptyVarEnv (Id -> Int -> VarEnv Int
forall b a. Var b -> a -> VarEnv a
`unitVarEnv` (Int
1 :: Int))
                       Term
body2
          ([TickInfo]
srcTicks,[TickInfo]
nmTicks) = [TickInfo] -> ([TickInfo], [TickInfo])
partitionTicks [TickInfo]
ticks
      BindingMap
bndrs <- Getting BindingMap (RewriteState NormalizeState) BindingMap
-> RewriteMonad NormalizeState BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
      Bool
e2WorkFree <-
        case [LetBinding]
binds2 of
          [(Id
_,Term
e2)] -> Lens' (RewriteState NormalizeState) (VarEnv Bool)
-> BindingMap -> Term -> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree forall extra. Lens' (RewriteState extra) (VarEnv Bool)
Lens' (RewriteState NormalizeState) (VarEnv Bool)
workFreeBinders BindingMap
bndrs Term
e2
          [LetBinding]
_ -> Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (String -> Bool
forall a. HasCallStack => String -> a
error String
"flattenLet: unreachable")
      -- Distribute the name ticks of the let-expression over all the bindings
      (InScopeSet
isN1,) ([LetBinding] -> (InScopeSet, [LetBinding]))
-> ([LetBinding] -> [LetBinding])
-> [LetBinding]
-> (InScopeSet, [LetBinding])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
nmTicks)) ([LetBinding] -> (InScopeSet, [LetBinding]))
-> RewriteMonad NormalizeState [LetBinding]
-> RewriteMonad NormalizeState (InScopeSet, [LetBinding])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> case [LetBinding]
binds2 of
        -- inline binders into the body when there's only a single binder, and
        -- only if that binder doesn't perform any work or is only used once in
        -- the body
        [(Id
id2,Term
e2)] | Just Int
occ <- Id -> VarEnv Int -> Maybe Int
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
id2 VarEnv Int
bodyOccs, Bool
e2WorkFree Bool -> Bool -> Bool
|| Int
occ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 ->
          if Id
id2 Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
`elemFreeVars` Term
e2
             -- Except when the binder is recursive!
             then [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed ([(Id
id2,Term
e2),(Id
id1, Term
body2)])
             else let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
isN1) Id
id2 Term
e2
                  in  [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed [(Id
id1
                               -- Only apply srcTicks to the body
                               ,Term -> [TickInfo] -> Term
mkTicks (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"flattenLetGo" Subst
subst Term
body2)
                                        [TickInfo]
srcTicks)]
        [LetBinding]
bs -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding]
bs [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [(Id
id1
                               -- Only apply srcTicks to the body
                              ,Term -> [TickInfo] -> Term
mkTicks Term
body2 [TickInfo]
srcTicks)])
    go InScopeSet
isN LetBinding
b = (InScopeSet, [LetBinding])
-> RewriteMonad NormalizeState (InScopeSet, [LetBinding])
forall (m :: Type -> Type) a. Monad m => a -> m a
return (InScopeSet
isN,[LetBinding
b])

flattenLet TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC flattenLet #-}

-- | Turn a  normalized recursive function, where the recursive calls only pass
-- along the unchanged original arguments, into let-recursive function. This
-- means that all recursive calls are replaced by the same variable reference as
-- found in the body of the top-level let-expression.
recToLetRec :: HasCallStack => NormRewrite
recToLetRec :: NormRewrite
recToLetRec (TransformContext InScopeSet
is0 []) Term
e = do
  (Id
fn,SrcSpan
_) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
  TyConMap
tcm    <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  case TyConMap -> Term -> Either String ([Id], [LetBinding], Id)
splitNormalized TyConMap
tcm Term
e of
    Right ([Id]
args,[LetBinding]
bndrs,Id
res) -> do
      let args' :: [Term]
args'             = (Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
args
          ([LetBinding]
toInline,[LetBinding]
others) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition (TyConMap -> Id -> [Term] -> Term -> Bool
eqApp TyConMap
tcm Id
fn [Term]
args' (Term -> Bool) -> (LetBinding -> Term) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
bndrs
          resV :: Term
resV              = Id -> Term
Var Id
res
      case ([LetBinding]
toInline,[LetBinding]
others) of
        (LetBinding
_:[LetBinding]
_,LetBinding
_:[LetBinding]
_) -> do
          let is1 :: InScopeSet
is1          = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ([Id]
args [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bndrs)
          let substsInline :: Subst
substsInline = Subst -> [LetBinding] -> Subst
extendIdSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is1)
                           ([LetBinding] -> Subst) -> [LetBinding] -> Subst
forall a b. (a -> b) -> a -> b
$ (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Term -> Term -> Term
forall a b. a -> b -> a
const Term
resV)) [LetBinding]
toInline
              others' :: [LetBinding]
others'      = (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"recToLetRec" Subst
substsInline))
                                 [LetBinding]
others
          Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> [Id] -> Term
mkLams ([LetBinding] -> Term -> Term
Letrec [LetBinding]
others' Term
resV) [Id]
args
        ([LetBinding], [LetBinding])
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
    Either String ([Id], [LetBinding], Id)
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
  where
    -- This checks whether things are semantically equal. For example, say we
    -- have:
    --
    --   x :: (a, (b, c))
    --
    -- and
    --
    --   y :: (a, (b, c))
    --
    -- If we can determine that 'y' is constructed solely using the
    -- corresponding fields in 'x', then we can say they are semantically
    -- equal. The algorithm below keeps track of what (sub)field it is
    -- constructing, and checks if the field-expression projects the
    -- corresponding (sub)field from the target variable.
    --
    -- TODO: See [Note: Breaks on constants and predetermined equality]
    --
    -- Since 'aeqTerm' now looks at ticks when determining equality, it is
    -- required that all ticks are removed with 'stripAllTicks' to keep the
    -- previous behaviour of this function. If we remove this, most terms will
    -- not be identified as equal.
    eqApp :: TyConMap -> Id -> [Term] -> Term -> Bool
eqApp TyConMap
tcm Id
v [Term]
args (Term -> (Term, [Either Term Type])
collectArgs (Term -> (Term, [Either Term Type]))
-> (Term -> Term) -> Term -> (Term, [Either Term Type])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Term
stripAllTicks -> (Var Id
v',[Either Term Type]
args'))
      | Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
v'
      , Id
v Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
v'
      , let args2 :: [Term]
args2 = [Either Term Type] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Type]
args'
      , [Term] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Term]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Term] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Term]
args2
      = [Bool] -> Bool
forall (t :: Type -> Type). Foldable t => t Bool -> Bool
and ((Term -> Term -> Bool) -> [Term] -> [Term] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (TyConMap -> Term -> Term -> Bool
eqArg TyConMap
tcm) [Term]
args [Term]
args2)
    eqApp TyConMap
_ Id
_ [Term]
_ Term
_ = Bool
False

    eqArg :: TyConMap -> Term -> Term -> Bool
eqArg TyConMap
_ Term
v1 v2 :: Term
v2@Var{}
      = Term
v1 Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
v2
    eqArg TyConMap
tcm Term
v1 v2 :: Term
v2@(Term -> (Term, [Either Term Type])
collectArgs -> (Data DataCon
_, [Either Term Type]
args'))
      | let t1 :: Type
t1 = TyConMap -> Type -> Type
normalizeType TyConMap
tcm (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
v1)
      , let t2 :: Type
t2 = TyConMap -> Type -> Type
normalizeType TyConMap
tcm (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
v2)
      , Type
t1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t2
      = if Type -> Bool
isClassConstraint Type
t1 then
          -- Class constraints are equal if their types are equal, so we can
          -- take a shortcut here.
          Bool
True
        else
          -- Check whether all arguments to the data constructor are projections
          --
          [Bool] -> Bool
forall (t :: Type -> Type). Foldable t => t Bool -> Bool
and (([Int] -> Term -> Bool) -> [[Int]] -> [Term] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Term -> [Int] -> Term -> Bool
eqDat Term
v1) ((Int -> [Int]) -> [Int] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map Int -> [Int]
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure [Int
0..]) ([Either Term Type] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Type]
args'))
    eqArg TyConMap
_ Term
_ Term
_
      = Bool
False

    -- Recursively check whether a term /e/ is semantically equal to some variable /v/.
    -- Currently it can only assert equality when /e/ is  syntactically equal
    -- to /v/, or is constructed out of projections of /v/, importantly:
    --
    -- [Note: Breaks on constants and predetermined equality]
    -- This function currently breaks if:
    --
    --   * One or more subfields are constants. Constants might have been
    --     inlined for the construction, instead of being a projection of the
    --     target variable.
    --
    --   * One or more subfields are determined to be equal and one is simply
    --     swapped / replaced by the other. For example, say we have
    --     `x :: (a, a)`. If GHC determines that both elements of the tuple will
    --     always be the same, it might replace the (semantically equal to 'x')
    --     construction of `y` with `(fst x, fst x)`.
    --
    eqDat :: Term -> [Int] -> Term -> Bool
    eqDat :: Term -> [Int] -> Term -> Bool
eqDat Term
v [Int]
fTrace (Term -> (Term, [Either Term Type])
collectArgs -> (Data DataCon
_, [Either Term Type]
args)) =
      [Bool] -> Bool
forall (t :: Type -> Type). Foldable t => t Bool -> Bool
and (([Int] -> Term -> Bool) -> [[Int]] -> [Term] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Term -> [Int] -> Term -> Bool
eqDat Term
v) ((Int -> [Int]) -> [Int] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
fTrace) [Int
0..]) ([Either Term Type] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Type]
args))
    eqDat Term
v1 [Int]
fTrace Term
v2 =
      case [Int] -> Term -> Term -> Maybe [Int]
stripProjection ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
fTrace) Term
v1 Term
v2 of
        Just [] -> Bool
True
        Maybe [Int]
_ -> Bool
False

    stripProjection :: [Int] -> Term -> Term -> Maybe [Int]
    stripProjection :: [Int] -> Term -> Term -> Maybe [Int]
stripProjection [Int]
fTrace0 Term
vTarget0 (Case Term
v Type
_ [(DataPat DataCon
_ [TyVar]
_ [Id]
xs, Term
r)]) = do
      -- Get projection made in subject of case:
      [Int]
fTrace1 <- [Int] -> Term -> Term -> Maybe [Int]
stripProjection [Int]
fTrace0 Term
vTarget0 Term
v

      -- Extract projection of this case statement. Subsequent calls to
      -- 'stripProjection' will check if new target is actually used.
      (Int
n, [Int]
fTrace2) <- [Int] -> Maybe (Int, [Int])
forall a. [a] -> Maybe (a, [a])
List.uncons [Int]
fTrace1
      Id
vTarget1 <- [Id] -> Int -> Maybe Id
forall a. [a] -> Int -> Maybe a
List.indexMaybe [Id]
xs Int
n

      [Int] -> Term -> Term -> Maybe [Int]
stripProjection [Int]
fTrace2 (Id -> Term
Var Id
vTarget1) Term
r

    stripProjection [Int]
fTrace (Var Id
sTarget) (Var Id
s) =
      if Id
sTarget Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
s then [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
fTrace else Maybe [Int]
forall a. Maybe a
Nothing

    stripProjection [Int]
_fTrace Term
_vTarget Term
_v =
      Maybe [Int]
forall a. Maybe a
Nothing

recToLetRec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC recToLetRec #-}

isClassConstraint :: Type -> Bool
isClassConstraint :: Type -> Bool
isClassConstraint (Type -> TypeView
tyView -> TyConApp TyConName
nm0 [Type]
_) =
  if -- Constraint tuple:
     | Text
"GHC.Classes.(%" Text -> Text -> Bool
`Text.isInfixOf` Text
nm1 -> Bool
True
     -- Constraint class:
     | Text
"C:" Text -> Text -> Bool
`Text.isInfixOf` Text
nm2 -> Bool
True
     | Bool
otherwise -> Bool
False
 where
  nm1 :: Text
nm1 = TyConName -> Text
forall a. Name a -> Text
nameOcc TyConName
nm0
  nm2 :: Text
nm2 = (Text, Text) -> Text
forall a b. (a, b) -> b
snd (Text -> Text -> (Text, Text)
Text.breakOnEnd Text
"." Text
nm1)

isClassConstraint Type
_ = Bool
False

-- | Simplified CSE, only works on let-bindings, does an inverse topological
-- sort of the let-bindings and then works from top to bottom
--
-- XXX: Check whether inverse top-sort followed by single traversal removes as
-- many binders as the previous "apply-until-fixpoint" approach in the presence
-- of recursive groups in the let-bindings. If not but just for checking whether
-- changes to transformation affect the eventual size of the circuit, it would
-- be really helpful if we tracked circuit size in the regression/test suite.
-- On the two examples that were tested, Reducer and PipelinesViaFolds, this new
-- version of CSE removed the same amount of let-binders.
simpleCSE :: HasCallStack => NormRewrite
simpleCSE :: NormRewrite
simpleCSE (TransformContext InScopeSet
is0 Context
_) term :: Term
term@Letrec{} = do
  let Letrec [LetBinding]
bndrs Term
body = HasCallStack => Term -> Term
Term -> Term
inverseTopSortLetBindings Term
term
  let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bndrs)
  ((Subst
subst,[LetBinding]
bndrs1), Any
change) <- RewriteMonad NormalizeState (Subst, [LetBinding])
-> RewriteMonad NormalizeState ((Subst, [LetBinding]), Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen (RewriteMonad NormalizeState (Subst, [LetBinding])
 -> RewriteMonad NormalizeState ((Subst, [LetBinding]), Any))
-> RewriteMonad NormalizeState (Subst, [LetBinding])
-> RewriteMonad NormalizeState ((Subst, [LetBinding]), Any)
forall a b. (a -> b) -> a -> b
$ Subst
-> [LetBinding]
-> [LetBinding]
-> RewriteMonad NormalizeState (Subst, [LetBinding])
reduceBinders (InScopeSet -> Subst
mkSubst InScopeSet
is1) [] [LetBinding]
bndrs
  -- TODO: check whether a substitution over the body is enough, the reason I'm
  -- doing a substitution over the the binders as well is that I don't know in
  -- what order a recursive group shows up in a inverse topological sort.
  -- Depending on the order and forgetting to apply the substitution over the
  -- let-bindings might lead to the introduction of free variables.
  --
  -- NB: don't apply the substitution to the entire let-expression, and that
  -- would rename the let-bindings because they've been added to the InScopeSet
  -- of the substitution.
  if Any -> Bool
Monoid.getAny Any
change
     then
       let bndrs2 :: [LetBinding]
bndrs2 = (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"simpleCSE.bndrs" Subst
subst)) [LetBinding]
bndrs1
           body1 :: Term
body1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"simpleCSE.body" Subst
subst Term
body
        in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
bndrs2 Term
body1)
     else
       Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
term

simpleCSE TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC simpleCSE #-}

-- | Ensure that top-level lambda's eventually bind a let-expression of which
-- the body is a variable-reference.
topLet :: HasCallStack => NormRewrite
topLet :: NormRewrite
topLet (TransformContext InScopeSet
is0 Context
ctx) Term
e
  | (CoreContext -> Bool) -> Context -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (\CoreContext
c -> CoreContext -> Bool
isLambdaBodyCtx CoreContext
c Bool -> Bool -> Bool
|| CoreContext -> Bool
isTickCtx CoreContext
c) Context
ctx Bool -> Bool -> Bool
&& Bool -> Bool
not (Term -> Bool
isLet Term
e) Bool -> Bool -> Bool
&& Bool -> Bool
not (Term -> Bool
isTick Term
e)
  = do
  Bool
untranslatable <- Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
e
  if Bool
untranslatable
    then Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
    else do TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
            Id
argId <- InScopeSet
-> TyConMap -> Name Any -> Term -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is0 TyConMap
tcm (Text -> Int -> Name Any
forall a. Text -> Int -> Name a
mkUnsafeSystemName Text
"result" Int
0) Term
e
            Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
argId Term
e) (Id -> Term
Var Id
argId))

topLet (TransformContext InScopeSet
is0 Context
ctx) e :: Term
e@(Letrec [LetBinding]
binds Term
body)
  | (CoreContext -> Bool) -> Context -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (\CoreContext
c -> CoreContext -> Bool
isLambdaBodyCtx CoreContext
c Bool -> Bool -> Bool
|| CoreContext -> Bool
isTickCtx CoreContext
c) Context
ctx
  = do
    let localVar :: Bool
localVar = Term -> Bool
isLocalVar Term
body
    Bool
untranslatable <- Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
body
    if Bool
localVar Bool -> Bool -> Bool
|| Bool
untranslatable
      then Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
      else do
        TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
        let is2 :: InScopeSet
is2 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
binds)
        Id
argId <- InScopeSet
-> TyConMap -> Name Any -> Term -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is2 TyConMap
tcm (Text -> Int -> Name Any
forall a. Text -> Int -> Name a
mkUnsafeSystemName Text
"result" Int
0) Term
body

        -- TODO We would like this to be
        --
        -- Let binds (Let (NonRec argId body) (Var argId))
        --
        -- but this makes tests/shouldwork/SimIO/Test00.hs fail.
        Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec ([LetBinding]
binds [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [(Id
argId, Term
body)]) (Id -> Term
Var Id
argId))

topLet TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC topLet #-}