{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Clash.Normalize.Util
( ConstantSpecInfo(..)
, isConstantArg
, shouldReduce
, alreadyInlined
, addNewInline
, specializeNorm
, isRecursiveBndr
, isClosed
, callGraph
, collectCallGraphUniques
, classifyFunction
, isCheapFunction
, isNonRecursiveGlobalVar
, constantSpecInfo
, normalizeTopLvlBndr
, rewriteExpr
, removedTm
, mkInlineTick
, substWithTyEq
, tvSubstWithTyEq
)
where
import Control.Lens ((&),(+~),(%=),(.=))
import qualified Control.Lens as Lens
import Data.Bifunctor (bimap)
import Data.Either (lefts)
import qualified Data.List as List
import qualified Data.List.Extra as List
import qualified Data.Map as Map
import qualified Data.HashMap.Strict as HashMapS
import qualified Data.HashSet as HashSet
import Data.Text (Text)
import qualified Data.Text as Text
import PrelNames (eqTyConKey)
import Unique (getKey)
import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.FreeVars
(globalIds, hasLocalFreeVars, globalIdOccursIn)
import Clash.Core.Name (Name(nameOcc,nameUniq))
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
(deShadowTerm, extendTvSubst, extendTvSubstList, mkSubst, substTm, substTy,
substId, extendIdSubst)
import Clash.Core.Term
import Clash.Core.TermInfo (isPolyFun, termType)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type
(Type(LitTy, VarTy), LitTy(SymTy), TypeView (..), tyView, undefinedTy,
splitFunForallTy, splitTyConAppM, mkPolyFunTy)
import Clash.Core.Util
(isClockOrReset)
import Clash.Core.Var (Id, TyVar, Var (..), isGlobalId)
import Clash.Core.VarEnv
(VarEnv, emptyInScopeSet, emptyVarEnv, extendVarEnv, extendVarEnvWith,
lookupVarEnv, unionVarEnvWith, unitVarEnv, extendInScopeSetList)
import Clash.Debug (traceIf)
import Clash.Driver.Types (BindingMap, Binding(..), DebugLevel (..))
import {-# SOURCE #-} Clash.Normalize.Strategy (normalization)
import Clash.Normalize.Types
import Clash.Primitives.Util (constantArgs)
import Clash.Rewrite.Types
(RewriteMonad, TransformContext(..), bindings, curFun, dbgLevel, extra,
tcCache)
import Clash.Rewrite.Util
(runRewrite, specialise, mkTmBinderFor, mkDerivedName)
import Clash.Unique
import Clash.Util (SrcSpan, makeCachedU)
isConstantArg
:: Text
-> Int
-> RewriteMonad NormalizeState Bool
isConstantArg "Clash.Explicit.SimIO.mealyIO" i = pure (i == 2 || i == 3)
isConstantArg nm i = do
argMap <- Lens.use (extra.primitiveArgs)
case Map.lookup nm argMap of
Nothing -> do
prims <- Lens.use (extra.primitives)
case extractPrim =<< HashMapS.lookup nm prims of
Nothing ->
pure False
Just p -> do
let m = constantArgs nm p
(extra.primitiveArgs) Lens.%= Map.insert nm m
pure (i `elem` m)
Just m ->
pure (i `elem` m)
shouldReduce
:: Context
-> RewriteMonad NormalizeState Bool
shouldReduce = List.anyM isConstantArg'
where
isConstantArg' (AppArg (Just (nm, _, i))) = isConstantArg nm i
isConstantArg' _ = pure False
alreadyInlined
:: Id
-> Id
-> NormalizeMonad (Maybe Int)
alreadyInlined f cf = do
inlinedHM <- Lens.use inlineHistory
case lookupVarEnv cf inlinedHM of
Nothing -> return Nothing
Just inlined' -> return (lookupVarEnv f inlined')
addNewInline
:: Id
-> Id
-> NormalizeMonad ()
addNewInline f cf =
inlineHistory %= extendVarEnvWith
cf
(unitVarEnv f 1)
(\_ hm -> extendVarEnvWith f 1 (+) hm)
specializeNorm :: NormRewrite
specializeNorm = specialise specialisationCache specialisationHistory specialisationLimit
isClosed :: TyConMap
-> Term
-> Bool
isClosed tcm = not . isPolyFun tcm
isNonRecursiveGlobalVar
:: Term
-> NormalizeSession Bool
isNonRecursiveGlobalVar (collectArgs -> (Var i, _args)) = do
let eIsGlobal = isGlobalId i
eIsRec <- isRecursiveBndr i
return (eIsGlobal && not eIsRec)
isNonRecursiveGlobalVar _ = return False
isRecursiveBndr
:: Id
-> NormalizeSession Bool
isRecursiveBndr f = do
cg <- Lens.use (extra.recursiveComponents)
case lookupVarEnv f cg of
Just isR -> return isR
Nothing -> do
fBodyM <- lookupVarEnv f <$> Lens.use bindings
case fBodyM of
Nothing -> return False
Just b -> do
let isR = f `globalIdOccursIn` bindingTerm b
(extra.recursiveComponents) %= extendVarEnv f isR
return isR
data ConstantSpecInfo =
ConstantSpecInfo
{ csrNewBindings :: [(Id, Term)]
, csrNewTerm :: !Term
, csrFoundConstant :: !Bool
} deriving (Show)
constantCsr :: Term -> ConstantSpecInfo
constantCsr t = ConstantSpecInfo [] t True
bindCsr
:: TransformContext
-> Term
-> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr ctx@(TransformContext is0 _) oldTerm = do
tcm <- Lens.view tcCache
newId <- mkTmBinderFor is0 tcm (mkDerivedName ctx "bindCsr") oldTerm
pure (ConstantSpecInfo
{ csrNewBindings = [(newId, oldTerm)]
, csrNewTerm = Var newId
, csrFoundConstant = False
})
mergeCsrs
:: TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs ctx ticks oldTerm proposedTerm subTerms = do
subCsrs <- snd <$> List.mapAccumLM constantSpecInfoFolder ctx subTerms
let
anyArgsOrResultConstant =
null (lefts subCsrs) || any csrFoundConstant (lefts subCsrs)
if anyArgsOrResultConstant then
let newTerm = proposedTerm (bimap csrNewTerm id <$> subCsrs) in
pure (ConstantSpecInfo
{ csrNewBindings = concatMap csrNewBindings (lefts subCsrs)
, csrNewTerm = mkTicks newTerm ticks
, csrFoundConstant = True
})
else do
bindCsr ctx oldTerm
where
constantSpecInfoFolder
:: TransformContext
-> Either Term Type
-> RewriteMonad NormalizeState (TransformContext, Either ConstantSpecInfo Type)
constantSpecInfoFolder localCtx (Right typ) =
pure (localCtx, Right typ)
constantSpecInfoFolder localCtx@(TransformContext is0 tfCtx) (Left term) = do
specInfo <- constantSpecInfo localCtx term
let newIds = map fst (csrNewBindings specInfo)
let is1 = extendInScopeSetList is0 newIds
pure (TransformContext is1 tfCtx, Left specInfo)
constantSpecInfo
:: TransformContext
-> Term
-> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo ctx e = do
tcm <- Lens.view tcCache
if isClockOrReset tcm (termType tcm e) then
case collectArgs e of
(Prim p, _)
| primName p == "Clash.Transformations.removedArg" ->
pure (constantCsr e)
_ -> bindCsr ctx e
else
case collectArgsTicks e of
(dc@(Data _), args, ticks) ->
mergeCsrs ctx ticks e (mkApps dc) args
(prim@(Prim _), args, ticks) -> do
csr <- mergeCsrs ctx ticks e (mkApps prim) args
if null (csrNewBindings csr) then
pure csr
else
bindCsr ctx e
(Lam _ _, _, _ticks) ->
if hasLocalFreeVars e then
bindCsr ctx e
else
pure (constantCsr e)
(var@(Var f), args, ticks) -> do
(curF, _) <- Lens.use curFun
isNonRecGlobVar <- isNonRecursiveGlobalVar e
if isNonRecGlobVar && f /= curF then do
csr <- mergeCsrs ctx ticks e (mkApps var) args
if null (csrNewBindings csr) then
pure csr
else
bindCsr ctx e
else
bindCsr ctx e
(Literal _,_, _ticks) ->
pure (constantCsr e)
_ ->
bindCsr ctx e
type CallGraph = VarEnv (VarEnv Word)
collectCallGraphUniques :: CallGraph -> HashSet.HashSet Unique
collectCallGraphUniques cg = HashSet.fromList (us0 ++ us1)
where
us0 = keysUniqMap cg
us1 = concatMap keysUniqMap (eltsUniqMap cg)
callGraph
:: BindingMap
-> Id
-> CallGraph
callGraph bndrs rt = go emptyVarEnv (varUniq rt)
where
go cg root
| Nothing <- lookupUniqMap root cg
, Just rootTm <- lookupUniqMap root bndrs =
let used = Lens.foldMapByOf globalIds (unionVarEnvWith (+))
emptyVarEnv (`unitUniqMap` 1) (bindingTerm rootTm)
cg' = extendUniqMap root used cg
in List.foldl' go cg' (keysUniqMap used)
go cg _ = cg
classifyFunction
:: Term
-> TermClassification
classifyFunction = go (TermClassification 0 0 0)
where
go !c (Lam _ e) = go c e
go !c (TyLam _ e) = go c e
go !c (Letrec bs _) = List.foldl' go c (map snd bs)
go !c e@(App {}) = case fst (collectArgs e) of
Prim {} -> c & primitive +~ 1
Var {} -> c & function +~ 1
_ -> c
go !c (Case _ _ alts) = case alts of
(_:_:_) -> c & selection +~ 1
_ -> c
go !c (Tick _ e) = go c e
go c _ = c
isCheapFunction
:: Term
-> Bool
isCheapFunction tm = case classifyFunction tm of
TermClassification {..}
| _function <= 1 -> _primitive <= 0 && _selection <= 0
| _primitive <= 1 -> _function <= 0 && _selection <= 0
| _selection <= 1 -> _function <= 0 && _primitive <= 0
| otherwise -> False
normalizeTopLvlBndr
:: Bool
-> Id
-> Binding
-> NormalizeSession Binding
normalizeTopLvlBndr isTop nm (Binding nm' sp inl tm) = makeCachedU nm (extra.normalized) $ do
tcm <- Lens.view tcCache
let nmS = showPpr (varName nm)
let tm1 = deShadowTerm emptyInScopeSet tm
tm2 = if isTop then substWithTyEq tm1 else tm1
old <- Lens.use curFun
tm3 <- rewriteExpr ("normalization",normalization) (nmS,tm2) (nm',sp)
curFun .= old
let ty' = termType tcm tm3
return (Binding nm'{varType = ty'} sp inl tm3)
substWithTyEq
:: Term
-> Term
substWithTyEq e0 = go [] False [] e0
where
go
:: [TyVar]
-> Bool
-> [Id]
-> Term
-> Term
go tvs changed ids_ (TyLam tv e) = go (tv:tvs) changed ids_ e
go tvs changed ids_ (Lam v e)
| TyConApp (nameUniq -> tcUniq) (tvFirst -> Just (tv, ty)) <- tyView (varType v)
, tcUniq == getKey eqTyConKey
, tv `elem` tvs
= let
subst0 = extendTvSubst (mkSubst emptyInScopeSet) tv ty
subst1 = extendIdSubst subst0 v (removedTm (varType v))
in go (tvs List.\\ [tv]) True (substId subst0 v : ids_) (substTm "substWithTyEq e" subst1 e)
| otherwise = go tvs changed (v:ids_) e
go tvs True ids_ e =
let
e1 = List.foldl' (flip TyLam) e tvs
e2 = List.foldl' (flip Lam) e1 ids_
in e2
go _ False _ _ = e0
tvFirst :: [Type] -> Maybe (TyVar, Type)
tvFirst [_, VarTy tv, ty] = Just (tv, ty)
tvFirst [_, ty, VarTy tv] = Just (tv, ty)
tvFirst _ = Nothing
tvSubstWithTyEq
:: Type
-> Type
tvSubstWithTyEq ty0 = go [] args0
where
(args0,tyRes) = splitFunForallTy ty0
go :: [(TyVar,Type)] -> [Either TyVar Type] -> Type
go eqs (Right arg : args)
| Just (tc,tcArgs) <- splitTyConAppM arg
, nameUniq tc == getKey eqTyConKey
, Just eq <- tvFirst tcArgs
= go (eq:eqs) args
| otherwise = go eqs args
go eqs (Left _tv : args)
= go eqs args
go [] [] = ty0
go eqs [] = substTy subst ty2
where
subst = extendTvSubstList (mkSubst emptyInScopeSet) eqs
args2 = args0 List.\\ (map (Left . fst) eqs)
ty2 = mkPolyFunTy tyRes args2
rewriteExpr :: (String,NormRewrite)
-> (String,Term)
-> (Id, SrcSpan)
-> NormalizeSession Term
rewriteExpr (nrwS,nrw) (bndrS,expr) (nm, sp) = do
curFun .= (nm, sp)
lvl <- Lens.view dbgLevel
let before = showPpr expr
let expr' = traceIf (lvl >= DebugFinal)
(bndrS ++ " before " ++ nrwS ++ ":\n\n" ++ before ++ "\n")
expr
rewritten <- runRewrite nrwS emptyInScopeSet nrw expr'
let after = showPpr rewritten
traceIf (lvl >= DebugFinal)
(bndrS ++ " after " ++ nrwS ++ ":\n\n" ++ after ++ "\n") $
return rewritten
removedTm
:: Type
-> Term
removedTm =
TyApp (Prim (PrimInfo "Clash.Transformations.removedArg" undefinedTy WorkNever))
mkInlineTick :: Id -> TickInfo
mkInlineTick n = NameMod PrefixName (LitTy . SymTy $ toStr n)
where
toStr = Text.unpack . snd . Text.breakOnEnd "." . nameOcc . varName