{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Rewrite.Util where
import Control.Monad.Extra (andM, eitherM)
import Control.Concurrent.Supply (splitSupply)
import Control.DeepSeq
import Control.Exception (throw)
import Control.Lens
(Lens', (%=), (+=), (^.), _Left)
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail)
#endif
import qualified Control.Monad.State.Strict as State
import qualified Control.Monad.Writer as Writer
import Data.Bool (bool)
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
import Data.Functor.Const (Const (..))
import Data.List (group, partition, sort)
import qualified Data.List as List
import qualified Data.List.Extra as List
import Data.List.Extra (allM, partitionM)
import qualified Data.Map as Map
import Data.Maybe
(catMaybes, isJust, mapMaybe, fromMaybe)
import qualified Data.Monoid as Monoid
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import qualified Data.Set.Ordered as OSet
import qualified Data.Set.Ordered.Extra as OSet
import Data.Text (Text)
import qualified Data.Text as Text
#ifdef HISTORY
import Data.Binary (encode)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import System.IO.Unsafe (unsafePerformIO)
#endif
import BasicTypes (InlineSpec (..))
import Clash.Core.DataCon (dcExtTyVars)
import Clash.Core.Evaluator (whnf')
import Clash.Core.Evaluator.Types (PureHeap)
import Clash.Core.FreeVars
(freeLocalVars, hasLocalFreeVars, localIdDoesNotOccurIn, localIdOccursIn,
typeFreeVars, termFreeVars', freeLocalIds, globalIdOccursIn)
import Clash.Core.Name
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
(substTmEnv, aeqTerm, aeqType, extendIdSubst, mkSubst, substTm)
import Clash.Core.Term
import Clash.Core.TermInfo
import Clash.Core.TyCon
(TyConMap, tyConDataCons)
import Clash.Core.Type (KindOrType, Type (..),
TypeView (..), coreView1,
normalizeType,
typeKind, tyView, isPolyFunTy)
import Clash.Core.Util
(dataConInstArgTysE, isClockOrReset, isEnable)
import Clash.Core.Var
(Id, IdScope (..), TyVar, Var (..), isLocalId, mkGlobalId, mkLocalId, mkTyVar)
import Clash.Core.VarEnv
(InScopeSet, VarEnv, elemVarSet, extendInScopeSetList, mkInScopeSet,
uniqAway, uniqAway', mapVarEnv, eltsVarEnv, unitVarSet, emptyVarEnv,
mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv)
import Clash.Debug
import Clash.Driver.Types
(DebugLevel (..), BindingMap, Binding(..))
import Clash.Netlist.Util (representableType)
import Clash.Pretty (clashPretty, showDoc)
import Clash.Rewrite.Types
import Clash.Unique
import Clash.Util
import qualified Clash.Util.Interpolate as I
zoomExtra :: State.State extra a
-> RewriteMonad extra a
zoomExtra m = R (\_ s w -> case State.runState m (s ^. extra) of
(a,s') -> (a,s {_extra = s'},w))
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows =
\case
Var {} -> []
Data {} -> []
Literal {} -> []
Prim {} -> []
Lam _ t -> findAccidentialShadows t
TyLam _ t -> findAccidentialShadows t
App t1 t2 -> concatMap findAccidentialShadows [t1, t2]
TyApp t _ -> findAccidentialShadows t
Cast t _ _ -> findAccidentialShadows t
Tick _ t -> findAccidentialShadows t
Case t _ as ->
concatMap (findInPat . fst) as ++
concatMap findAccidentialShadows (t : map snd as)
Letrec bs t ->
findDups (map fst bs) ++ findAccidentialShadows t
where
findInPat :: Pat -> [[Id]]
findInPat (LitPat _) = []
findInPat (DefaultPat) = []
findInPat (DataPat _ _ ids) = findDups ids
findDups :: [Id] -> [[Id]]
findDups ids = filter ((1 <) . length) (group (sort ids))
apply
:: String
-> Rewrite extra
-> Rewrite extra
apply = \s rewrite ctx expr0 -> do
lvl <- Lens.view dbgLevel
dbgTranss <- Lens.view dbgTransformations
let isTryLvl = lvl == DebugTry || lvl >= DebugAll
isRelevantTrans = s `Set.member` dbgTranss || Set.null dbgTranss
traceIf (isTryLvl && isRelevantTrans) ("Trying: " ++ s) (pure ())
(expr1,anyChanged) <- Writer.listen (rewrite ctx expr0)
let hasChanged = Monoid.getAny anyChanged
!expr2 = if hasChanged then expr1 else expr0
Monad.when hasChanged (transformCounter += 1)
#ifdef HISTORY
Monad.when hasChanged $ do
(curBndr, _) <- Lens.use curFun
let !_ = unsafePerformIO
$ BS.appendFile "history.dat"
$ BL.toStrict
$ encode RewriteStep
{ t_ctx = tfContext ctx
, t_name = s
, t_bndrS = showPpr (varName curBndr)
, t_before = expr0
, t_after = expr1
}
return ()
#endif
dbgFrom <- Lens.view dbgTransformationsFrom
dbgLimit <- Lens.view dbgTransformationsLimit
let fromLimit =
if (dbgFrom, dbgLimit) == (0, maxBound)
then Nothing
else Just (dbgFrom, dbgLimit)
if lvl == DebugNone
then return expr2
else applyDebug lvl dbgTranss fromLimit s expr0 hasChanged expr2
{-# INLINE apply #-}
applyDebug
:: DebugLevel
-> Set.Set String
-> Maybe (Int, Int)
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
applyDebug lvl transformations fromLimit name exprOld hasChanged exprNew
| Just (from, limit) <- fromLimit = do
nTrans <- Lens.use transformCounter
if | nTrans - from > limit ->
error "-fclash-debug-transformations-limit exceeded"
| nTrans > from ->
applyDebug lvl transformations Nothing name exprOld hasChanged exprNew
| otherwise ->
pure exprNew
applyDebug lvl transformations fromLimit name exprOld hasChanged exprNew
| not (Set.null transformations) =
let newLvl = bool DebugNone lvl (name `Set.member` transformations) in
applyDebug newLvl Set.empty fromLimit name exprOld hasChanged exprNew
applyDebug lvl _transformations _fromLimit name exprOld hasChanged exprNew =
traceIf (lvl >= DebugAll) ("Tried: " ++ name ++ " on:\n" ++ before) $ do
nTrans <- pred <$> Lens.use transformCounter
Monad.when (lvl > DebugNone && hasChanged) $ do
tcm <- Lens.view tcCache
let beforeTy = termType tcm exprOld
beforeFV = Lens.setOf freeLocalVars exprOld
afterTy = termType tcm exprNew
afterFV = Lens.setOf freeLocalVars exprNew
newFV = not (afterFV `Set.isSubsetOf` beforeFV)
accidentalShadows = findAccidentialShadows exprNew
Monad.when newFV $
error ( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "It introduces free variables."
, "\nBefore: " ++ showPpr (Set.toList beforeFV)
, "\nAfter: " ++ showPpr (Set.toList afterFV)
]
)
Monad.when (not (null accidentalShadows)) $
error ( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "It accidentally creates shadowing let/case-bindings:\n"
, " ", showPpr accidentalShadows, "\n"
, "This usually means that a transformation did not extend "
, "or incorrectly extended its InScopeSet before applying a "
, "substitution."
])
traceIf (lvl >= DebugApplied && (not (beforeTy `aeqType` afterTy)))
( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "Changes type from:\n", showPpr beforeTy
, "\nto:\n", showPpr afterTy
]
) (return ())
Monad.when (lvl >= DebugApplied && not hasChanged && not (exprOld `aeqTerm` exprNew)) $
error $ $(curLoc) ++ "Expression changed without notice(" ++ name ++ "): before"
++ before ++ "\nafter:\n" ++ after
traceIf (lvl >= DebugName && hasChanged) (name <> " {" <> show nTrans <> "}") $
traceIf (lvl >= DebugApplied && hasChanged) ("Changes when applying rewrite to:\n"
++ before ++ "\nResult:\n" ++ after ++ "\n") $
traceIf (lvl >= DebugAll && not hasChanged) ("No changes when applying rewrite "
++ name ++ " to:\n" ++ after ++ "\n") $
return exprNew
where
before = showPpr exprOld
after = showPpr exprNew
runRewrite
:: String
-> InScopeSet
-> Rewrite extra
-> Term
-> RewriteMonad extra Term
runRewrite name is rewrite expr = apply name rewrite (TransformContext is []) expr
runRewriteSession :: RewriteEnv
-> RewriteState extra
-> RewriteMonad extra a
-> a
runRewriteSession r s m =
traceIf (_dbgLevel r > DebugNone)
("Clash: Applied " ++ show (s' ^. transformCounter) ++ " transformations")
a
where
(a,s',_) = runR m r s
setChanged :: RewriteMonad extra ()
setChanged = Writer.tell (Monoid.Any True)
changed :: a -> RewriteMonad extra a
changed val = do
Writer.tell (Monoid.Any True)
return val
closestLetBinder :: Context -> Maybe Id
closestLetBinder [] = Nothing
closestLetBinder (LetBinding id_ _:_) = Just id_
closestLetBinder (_:ctx) = closestLetBinder ctx
mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName (TransformContext _ ctx) sf = case closestLetBinder ctx of
Just id_ -> appendToName (varName id_) ('_' `Text.cons` sf)
_ -> mkUnsafeInternalName sf 0
mkTmBinderFor
:: (MonadUnique m, MonadFail m)
=> InScopeSet
-> TyConMap
-> Name a
-> Term
-> m Id
mkTmBinderFor is tcm name e = do
Left r <- mkBinderFor is tcm name (Left e)
return r
mkBinderFor
:: (MonadUnique m, MonadFail m)
=> InScopeSet
-> TyConMap
-> Name a
-> Either Term Type
-> m (Either Id TyVar)
mkBinderFor is tcm name (Left term) = do
name' <- cloneNameWithInScopeSet is name
let ty = termType tcm term
return (Left (mkLocalId ty (coerce name')))
mkBinderFor is tcm name (Right ty) = do
name' <- cloneNameWithInScopeSet is name
let ki = typeKind tcm ty
return (Right (mkTyVar ki (coerce name')))
mkInternalVar
:: (MonadUnique m)
=> InScopeSet
-> OccName
-> KindOrType
-> m Id
mkInternalVar inScope name ty = do
i <- getUniqueM
let nm = mkUnsafeInternalName name i
return (uniqAway inScope (mkLocalId ty nm))
inlineBinders
:: (Term -> LetBinding -> RewriteMonad extra Bool)
-> Rewrite extra
inlineBinders condition (TransformContext inScope0 _) expr@(Letrec xes res) = do
(toInline,toKeep) <- partitionM (condition expr) xes
case toInline of
[] -> return expr
_ -> do
let inScope1 = extendInScopeSetList inScope0 (map fst xes)
(toInlRec,(toKeep1,res1)) =
substituteBinders inScope1 toInline toKeep res
case toInlRec ++ toKeep1 of
[] -> changed res1
xes1 -> changed (Letrec xes1 res1)
inlineBinders _ _ e = return e
isJoinPointIn :: Id
-> Term
-> Bool
isJoinPointIn id_ e = case tailCalls id_ e of
Just n | n > 1 -> True
_ -> False
tailCalls :: Id
-> Term
-> Maybe Int
tailCalls id_ = \case
Var nm | id_ == nm -> Just 1
| otherwise -> Just 0
Lam _ e -> tailCalls id_ e
TyLam _ e -> tailCalls id_ e
App l r -> case tailCalls id_ r of
Just 0 -> tailCalls id_ l
_ -> Nothing
TyApp l _ -> tailCalls id_ l
Letrec bs e ->
let (bsIds,bsExprs) = unzip bs
bsTls = map (tailCalls id_) bsExprs
bsIdsUsed = mapMaybe (\(l,r) -> pure l <* r) (zip bsIds bsTls)
bsIdsTls = map (`tailCalls` e) bsIdsUsed
bsCount = pure . sum $ catMaybes bsTls
in case (all isJust bsTls) of
False -> Nothing
True -> case (all (==0) $ catMaybes bsTls) of
False -> case all isJust bsIdsTls of
False -> Nothing
True -> (+) <$> bsCount <*> tailCalls id_ e
True -> tailCalls id_ e
Case scrut _ alts ->
let scrutTl = tailCalls id_ scrut
altsTl = map (tailCalls id_ . snd) alts
in case scrutTl of
Just 0 | all (/= Nothing) altsTl -> Just (sum (catMaybes altsTl))
_ -> Nothing
_ -> Just 0
isVoidWrapper :: Term -> Bool
isVoidWrapper (Lam bndr e@(collectArgs -> (Var _,_))) =
bndr `localIdDoesNotOccurIn` e
isVoidWrapper _ = False
substituteBinders
:: InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> ([LetBinding],([LetBinding],Term))
substituteBinders inScope toInline toKeep body =
let (subst,toInlRec) = go (mkSubst inScope) [] toInline
in ( map (second (substTm "substToInlRec" subst)) toInlRec
, ( map (second (substTm "substToKeep" subst)) toKeep
, substTm "substBody" subst body) )
where
go subst inlRec [] = (subst,inlRec)
go !subst !inlRec ((x,e):toInl) =
let e1 = substTm "substInl" subst e
substE = extendIdSubst (mkSubst inScope) x e1
subst1 = subst { substTmEnv = mapVarEnv (substTm "substSubst" substE)
(substTmEnv subst)}
subst2 = extendIdSubst subst1 x e1
in if x `localIdOccursIn` e1 then
go subst ((x,e1):inlRec) toInl
else
go subst2 inlRec toInl
liftAndSubsituteBinders
:: InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> RewriteMonad extra ([LetBinding],Term)
liftAndSubsituteBinders inScope toLift toKeep body = do
subst <- go (mkSubst inScope) toLift
pure ( map (second (substTm "liftToKeep" subst)) toKeep
, substTm "keepBody" subst body
)
where
go subst [] = pure subst
go !subst ((x,e):inl) = do
let e1 = substTm "liftInl" subst e
(_,e2) <- liftBinding (x,e1)
let substE = extendIdSubst (mkSubst inScope) x e2
subst1 = subst { substTmEnv = mapVarEnv (substTm "liftSubst" substE)
(substTmEnv subst) }
subst2 = extendIdSubst subst1 x e2
if x `localIdOccursIn` e2 then do
(_,sp) <- Lens.use curFun
throw (ClashException sp [I.i|
Internal error: inlineOrLiftBInders failed on:
#{showPpr (x,e)}
creating a self-recursive let-binding:
#{showPpr (x,e2)}
given the already built subtitution:
#{showDoc (clashPretty (substTmEnv subst))}
|] Nothing)
else
go subst2 inl
isWorkFreeBinder :: HasCallStack => Id -> RewriteMonad extra Bool
isWorkFreeBinder bndr =
makeCachedU bndr workFreeBinders $ do
bExprM <- lookupVarEnv bndr <$> Lens.use bindings
case bExprM of
Nothing -> error ("isWorkFreeBinder: couldn't find binder: " ++ showPpr bndr)
Just (bindingTerm -> t) ->
if bndr `globalIdOccursIn` t
then pure False
else isWorkFree t
isWorkFree
:: Term
-> RewriteMonad extra Bool
isWorkFree (collectArgs -> (fun,args)) = case fun of
Var i ->
if | isPolyFunTy (varType i) -> pure False
| isLocalId i -> pure True
| otherwise -> andM [isWorkFreeBinder i, allM isWorkFreeArg args]
Data {} -> allM isWorkFreeArg args
Literal {} -> pure True
Prim pInfo -> case primWorkInfo pInfo of
WorkConstant -> pure True
WorkNever -> allM isWorkFreeArg args
WorkVariable -> pure (all isConstantArg args)
WorkAlways -> pure False
Lam _ e -> andM [isWorkFree e, allM isWorkFreeArg args]
TyLam _ e -> andM [isWorkFree e, allM isWorkFreeArg args]
Letrec bs e ->
andM [isWorkFree e, allM (isWorkFree . snd) bs, allM isWorkFreeArg args]
Case s _ [(_,a)] ->
andM [isWorkFree s, isWorkFree a, allM isWorkFreeArg args]
Cast e _ _ ->
andM [isWorkFree e, allM isWorkFreeArg args]
_ ->
pure False
where
isWorkFreeArg e = eitherM isWorkFree (pure . const True) (pure e)
isConstantArg = either isConstant (const True)
isFromInt :: Text -> Bool
isFromInt nm = nm == "Clash.Sized.Internal.BitVector.fromInteger##" ||
nm == "Clash.Sized.Internal.BitVector.fromInteger#" ||
nm == "Clash.Sized.Internal.Index.fromInteger#" ||
nm == "Clash.Sized.Internal.Signed.fromInteger#" ||
nm == "Clash.Sized.Internal.Unsigned.fromInteger#"
isConstant :: Term -> Bool
isConstant e = case collectArgs e of
(Data _, args) -> all (either isConstant (const True)) args
(Prim _, args) -> all (either isConstant (const True)) args
(Lam _ _, _) -> not (hasLocalFreeVars e)
(Literal _,_) -> True
_ -> False
isConstantNotClockReset
:: Term
-> RewriteMonad extra Bool
isConstantNotClockReset e = do
tcm <- Lens.view tcCache
let eTy = termType tcm e
if isClockOrReset tcm eTy
then case collectArgs e of
(Prim p,_) -> return (primName p == "Clash.Transformations.removedArg")
_ -> return False
else pure (isConstant e)
isWorkFreeClockOrResetOrEnable
:: TyConMap
-> Term
-> Maybe Bool
isWorkFreeClockOrResetOrEnable tcm e =
let eTy = termType tcm e in
if isClockOrReset tcm eTy || isEnable tcm eTy then
case collectArgs e of
(Prim p,_) -> Just (primName p == "Clash.Transformations.removedArg")
(Var _, []) -> Just True
(Data _, []) -> Just True
(Literal _,_) -> Just True
_ -> Just False
else
Nothing
isWorkFreeIsh
:: Term
-> RewriteMonad extra Bool
isWorkFreeIsh e = do
tcm <- Lens.view tcCache
case isWorkFreeClockOrResetOrEnable tcm e of
Just b -> pure b
Nothing ->
case collectArgs e of
(Data _, args) -> allM isWorkFreeIshArg args
(Prim pInfo, args) -> case primWorkInfo pInfo of
WorkAlways -> pure False
WorkVariable -> pure (all isConstantArg args)
_ -> allM isWorkFreeIshArg args
(Lam _ _, _) -> pure (not (hasLocalFreeVars e))
(Literal _,_) -> pure True
_ -> pure False
where
isWorkFreeIshArg = either isWorkFreeIsh (pure . const True)
isConstantArg = either isConstant (const True)
inlineOrLiftBinders
:: (LetBinding -> RewriteMonad extra Bool)
-> (Term -> LetBinding -> Bool)
-> Rewrite extra
inlineOrLiftBinders condition inlineOrLift (TransformContext inScope0 _) e@(Letrec bndrs body) = do
(toReplace,toKeep) <- partitionM condition bndrs
case toReplace of
[] -> return e
_ -> do
let inScope1 = extendInScopeSetList inScope0 (map fst bndrs)
let (toInline,toLift) = partition (inlineOrLift e) toReplace
let (toLiftExtra,(toReplace1,body1)) =
substituteBinders inScope1 toInline (toLift ++ toKeep) body
(toLift1,toKeep1) = splitAt (length toLift) toReplace1
(toKeep2,body2) <- liftAndSubsituteBinders inScope1
(toLiftExtra ++ toLift1)
toKeep1 body1
case toKeep2 of
[] -> changed body2
_ -> changed (Letrec toKeep2 body2)
inlineOrLiftBinders _ _ _ e = return e
liftBinding :: LetBinding
-> RewriteMonad extra LetBinding
liftBinding (var@Id {varName = idName} ,e) = do
let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
unitFV v@(Id {}) = Const (emptyUniqSet,unitUniqSet (coerce v))
unitFV v@(TyVar {}) = Const (unitUniqSet (coerce v),emptyUniqSet)
interesting :: Var a -> Bool
interesting Id {idScope = GlobalId} = False
interesting v@(Id {idScope = LocalId}) = varUniq v /= varUniq var
interesting _ = True
(boundFTVsSet,boundFVsSet) =
getConst (Lens.foldMapOf (termFreeVars' interesting) unitFV e)
boundFTVs = eltsUniqSet boundFTVsSet
boundFVs = eltsUniqSet boundFVsSet
tcm <- Lens.view tcCache
let newBodyTy = termType tcm $ mkTyLams (mkLams e boundFVs) boundFTVs
(cf,sp) <- Lens.use curFun
binders <- Lens.use bindings
newBodyNm <-
cloneNameWithBindingMap
binders
(appendToName (varName cf) ("_" `Text.append` nameOcc idName))
let newBodyId = mkGlobalId newBodyTy newBodyNm {nameSort = Internal}
let newExpr = mkTmApps
(mkTyApps (Var newBodyId)
(map VarTy boundFTVs))
(map Var boundFVs)
inScope0 = mkInScopeSet (coerce boundFVsSet)
inScope1 = extendInScopeSetList inScope0 [var,newBodyId]
let subst = extendIdSubst (mkSubst inScope1) var newExpr
e' = substTm "liftBinding" subst e
newBody = mkTyLams (mkLams e' boundFVs) boundFTVs
aeqExisting <- (eltsUniqMap . filterUniqMap ((`aeqTerm` newBody) . bindingTerm)) <$> Lens.use bindings
case aeqExisting of
[] -> do
bindings %= extendUniqMap newBodyNm
(Binding
newBodyId
sp
#if MIN_VERSION_ghc(8,4,1)
NoUserInline
#else
EmptyInlineSpec
#endif
newBody)
return (var, newExpr)
(b:_) ->
let newExpr' = mkTmApps
(mkTyApps (Var $ bindingId b)
(map VarTy boundFTVs))
(map Var boundFVs)
in return (var, newExpr')
liftBinding _ = error $ $(curLoc) ++ "liftBinding: invalid core, expr bound to tyvar"
uniqAwayBinder
:: BindingMap
-> Name a
-> Name a
uniqAwayBinder binders nm =
uniqAway' (`elemUniqMapDirectly` binders) (nameUniq nm) nm
mkFunction
:: TmName
-> SrcSpan
-> InlineSpec
-> Term
-> RewriteMonad extra Id
mkFunction bndrNm sp inl body = do
tcm <- Lens.view tcCache
let bodyTy = termType tcm body
binders <- Lens.use bindings
bodyNm <- cloneNameWithBindingMap binders bndrNm
addGlobalBind bodyNm bodyTy sp inl body
return (mkGlobalId bodyTy bodyNm)
addGlobalBind
:: TmName
-> Type
-> SrcSpan
-> InlineSpec
-> Term
-> RewriteMonad extra ()
addGlobalBind vNm ty sp inl body = do
let vId = mkGlobalId ty vNm
(ty,body) `deepseq` bindings %= extendUniqMap vNm (Binding vId sp inl body)
cloneNameWithInScopeSet
:: (MonadUnique m)
=> InScopeSet
-> Name a
-> m (Name a)
cloneNameWithInScopeSet is nm = do
i <- getUniqueM
return (uniqAway is (setUnique nm i))
cloneNameWithBindingMap
:: (MonadUnique m)
=> BindingMap
-> Name a
-> m (Name a)
cloneNameWithBindingMap binders nm = do
i <- getUniqueM
return (uniqAway' (`elemUniqMapDirectly` binders) i (setUnique nm i))
{-# INLINE isUntranslatable #-}
isUntranslatable
:: Bool
-> Term
-> RewriteMonad extra Bool
isUntranslatable stringRepresentable tm = do
tcm <- Lens.view tcCache
not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure stringRepresentable
<*> pure tcm
<*> pure (termType tcm tm))
{-# INLINE isUntranslatableType #-}
isUntranslatableType
:: Bool
-> Type
-> RewriteMonad extra Bool
isUntranslatableType stringRepresentable ty =
not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure stringRepresentable
<*> Lens.view tcCache
<*> pure ty)
mkWildValBinder
:: (MonadUnique m)
=> InScopeSet
-> Type
-> m Id
mkWildValBinder is = mkInternalVar is "wild"
mkSelectorCase
:: HasCallStack
=> (Functor m, MonadUnique m)
=> String
-> InScopeSet
-> TyConMap
-> Term
-> Int
-> Int
-> m Term
mkSelectorCase caller inScope tcm scrut dcI fieldI = go (termType tcm scrut)
where
go (coreView1 tcm -> Just ty') = go ty'
go scrutTy@(tyView -> TyConApp tc args) =
case tyConDataCons (lookupUniqMap' tcm tc) of
[] -> cantCreate $(curLoc) ("TyCon has no DataCons: " ++ show tc ++ " " ++ showPpr tc) scrutTy
dcs | dcI > length dcs -> cantCreate $(curLoc) "DC index exceeds max" scrutTy
| otherwise -> do
let dc = indexNote ($(curLoc) ++ "No DC with tag: " ++ show (dcI-1)) dcs (dcI-1)
let (Just fieldTys) = dataConInstArgTysE inScope tcm dc args
if fieldI >= length fieldTys
then cantCreate $(curLoc) "Field index exceed max" scrutTy
else do
wildBndrs <- mapM (mkWildValBinder inScope) fieldTys
let ty = indexNote ($(curLoc) ++ "No DC field#: " ++ show fieldI) fieldTys fieldI
selBndr <- mkInternalVar inScope "sel" ty
let bndrs = take fieldI wildBndrs ++ [selBndr] ++ drop (fieldI+1) wildBndrs
pat = DataPat dc (dcExtTyVars dc) bndrs
retVal = Case scrut ty [ (pat, Var selBndr) ]
return retVal
go scrutTy = cantCreate $(curLoc) ("Type of subject is not a datatype: " ++ showPpr scrutTy) scrutTy
cantCreate loc info scrutTy = error $ loc ++ "Can't create selector " ++ show (caller,dcI,fieldI) ++ " for: (" ++ showPpr scrut ++ " :: " ++ showPpr scrutTy ++ ")\nAdditional info: " ++ info
specialise :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> Rewrite extra
specialise specMapLbl specHistLbl specLimitLbl ctx e = case e of
(TyApp e1 ty) -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgsTicks e1) (Right ty)
(App e1 e2) -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgsTicks e1) (Left e2)
_ -> return e
specialise' :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' specMapLbl specHistLbl specLimitLbl (TransformContext is0 _) e (Var f, args, ticks) specArgIn = do
lvl <- Lens.view dbgLevel
tcm <- Lens.view tcCache
topEnts <- Lens.view topEntities
if f `elemVarSet` topEnts
then do
case specArgIn of
Left _ -> traceIf (lvl >= DebugNone) ("Not specializing TopEntity: " ++ showPpr (varName f)) (return e)
Right tyArg -> traceIf (lvl >= DebugApplied) ("Dropping type application on TopEntity: " ++ showPpr (varName f) ++ "\ntype:\n" ++ showPpr tyArg) $
let newVarTy = piResultTy tcm (varType f) tyArg
in changed (mkApps (mkTicks (Var f{varType = newVarTy}) ticks) args)
else do
let specArg = bimap (normalizeTermTypes tcm) (normalizeType tcm) specArgIn
(specBndrsIn,specVars) = specArgBndrsAndVars specArg
argLen = length args
specBndrs :: [Either Id TyVar]
specBndrs = map (Lens.over _Left (normalizeId tcm)) specBndrsIn
specAbs :: Either Term Type
specAbs = either (Left . (`mkAbstraction` specBndrs)) (Right . id) specArg
specM <- Map.lookup (f,argLen,specAbs) <$> Lens.use (extra.specMapLbl)
case specM of
Just f' ->
traceIf (lvl >= DebugApplied)
("Using previous specialization of " ++ showPpr (varName f) ++ " on " ++
(either showPpr showPpr) specAbs ++ ": " ++ showPpr (varName f')) $
changed $ mkApps (mkTicks (Var f') ticks) (args ++ specVars)
Nothing -> do
bodyMaybe <- fmap (lookupUniqMap (varName f)) $ Lens.use bindings
case bodyMaybe of
Just (Binding _ sp inl bodyTm) -> do
specHistM <- lookupUniqMap f <$> Lens.use (extra.specHistLbl)
specLim <- Lens.use (extra . specLimitLbl)
if maybe False (> specLim) specHistM
then throw (ClashException
sp
(unlines [ "Hit specialisation limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n"
, "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
, "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n"
, "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg
, "Run with '-fclash-spec-limit=N' to increase the specialisation limit to N."
])
Nothing)
else do
let existingNames = collectBndrsMinusApps bodyTm
newNames = [ mkUnsafeInternalName ("pTS" `Text.append` Text.pack (show n)) n
| n <- [(0::Int)..]
]
(boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $
Monad.zipWithM
(mkBinderFor is0 tcm)
(existingNames ++ newNames)
args
(fId,inl',specArg') <- case specArg of
Left a@(collectArgsTicks -> (Var g,gArgs,_gTicks)) -> if isPolyFun tcm a
then do
gTmM <- fmap (lookupUniqMap (varName g)) $ Lens.use bindings
return (g,maybe inl bindingSpec gTmM, maybe specArg (Left . (`mkApps` gArgs) . bindingTerm) gTmM)
else return (f,inl,specArg)
_ -> return (f,inl,specArg)
let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs)
newf <- mkFunction (varName fId) sp inl' newBody
(extra.specHistLbl) %= extendUniqMapWith f 1 (+)
(extra.specMapLbl) %= Map.insert (f,argLen,specAbs) newf
let newExpr = mkApps (mkTicks (Var newf) ticks) (args ++ specVars)
newf `deepseq` changed newExpr
Nothing -> return e
where
collectBndrsMinusApps :: Term -> [Name a]
collectBndrsMinusApps = reverse . go []
where
go bs (Lam v e') = go (coerce (varName v):bs) e'
go bs (TyLam tv e') = go (coerce (varName tv):bs) e'
go bs (App e' _) = case go [] e' of
[] -> bs
bs' -> init bs' ++ bs
go bs (TyApp e' _) = case go [] e' of
[] -> bs
bs' -> init bs' ++ bs
go bs _ = bs
specialise' _ _ _ _ctx _ (appE,args,ticks) (Left specArg) = do
let (specBndrs,specVars) = specArgBndrsAndVars (Left specArg)
newBody = mkAbstraction specArg specBndrs
existing <- filterUniqMap ((`aeqTerm` newBody) . bindingTerm) <$> Lens.use bindings
newf <- case eltsUniqMap existing of
[] -> do (cf,sp) <- Lens.use curFun
mkFunction (appendToName (varName cf) "_specF")
sp
#if MIN_VERSION_ghc(8,4,1)
NoUserInline
#else
EmptyInlineSpec
#endif
newBody
(b:_) -> return (bindingId b)
let newArg = Left $ mkApps (Var newf) specVars
let newExpr = mkApps (mkTicks appE ticks) (args ++ [newArg])
changed newExpr
specialise' _ _ _ _ e _ _ = return e
normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes tcm e = case e of
Cast e' ty1 ty2 -> Cast (normalizeTermTypes tcm e') (normalizeType tcm ty1) (normalizeType tcm ty2)
Var v -> Var (normalizeId tcm v)
_ -> e
normalizeId :: TyConMap -> Id -> Id
normalizeId tcm v@(Id {}) = v {varType = normalizeType tcm (varType v)}
normalizeId _ tyvar = tyvar
specArgBndrsAndVars
:: Either Term Type
-> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars specArg =
let unitFV :: Var a -> Const (OSet.OLSet TyVar, OSet.OLSet Id) (Var a)
unitFV v@(Id {}) = Const (mempty, coerce (OSet.singleton (coerce v)))
unitFV v@(TyVar {}) = Const (coerce (OSet.singleton (coerce v)), mempty)
(specFTVs,specFVs) = case specArg of
Left tm -> (OSet.toListL *** OSet.toListL) . getConst $
Lens.foldMapOf freeLocalVars unitFV tm
Right ty -> (eltsUniqSet (Lens.foldMapOf typeFreeVars unitUniqSet ty),[] :: [Id])
specTyBndrs = map Right specFTVs
specTmBndrs = map Left specFVs
specTyVars = map (Right . VarTy) specFTVs
specTmVars = map (Left . Var) specFVs
in (specTyBndrs ++ specTmBndrs,specTyVars ++ specTmVars)
whnfRW
:: Bool
-> TransformContext
-> Term
-> Rewrite extra
-> RewriteMonad extra Term
whnfRW isSubj ctx@(TransformContext is0 _) e rw = do
tcm <- Lens.view tcCache
bndrs <- Lens.use bindings
(primEval, primUnwind) <- Lens.view evaluator
ids <- Lens.use uniqSupply
let (ids1,ids2) = splitSupply ids
uniqSupply Lens..= ids2
gh <- Lens.use globalHeap
case whnf' primEval primUnwind bndrs tcm gh ids1 is0 isSubj e of
(!gh1,ph,v) -> do
globalHeap Lens..= gh1
bindPureHeap tcm ph rw ctx v
{-# SCC whnfRW #-}
bindPureHeap
:: TyConMap
-> PureHeap
-> Rewrite extra
-> Rewrite extra
bindPureHeap tcm heap rw ctx0@(TransformContext is0 hist) e = do
(e1, Monoid.getAny -> hasChanged) <- Writer.listen $ rw ctx e
if hasChanged && not (null bndrs) then do
inlineBinders inlineTest ctx0 (Letrec bndrs e1) >>= \case
e2@(Letrec bnders1 e3) ->
pure (fromMaybe e2 (removeUnusedBinders bnders1 e3))
e2 ->
pure e2
else
return e1
where
bndrs = map toLetBinding $ toListUniqMap heap
heapIds = map fst bndrs
is1 = extendInScopeSetList is0 heapIds
ctx = TransformContext is1 (LetBody heapIds : hist)
toLetBinding :: (Unique,Term) -> LetBinding
toLetBinding (uniq,term) = (nm, term)
where
ty = termType tcm term
nm = mkLocalId ty (mkUnsafeSystemName "x" uniq)
inlineTest _ (_, stripTicks -> e_) = isWorkFree e_
removeUnusedBinders
:: [LetBinding]
-> Term
-> Maybe Term
removeUnusedBinders binds body =
case eltsVarEnv used of
[] -> Just body
qqL | not (List.equalLength qqL binds)
-> Just (Letrec qqL body)
| otherwise
-> Nothing
where
bodyFVs = Lens.foldMapOf freeLocalIds unitVarSet body
used = List.foldl' collectUsed emptyVarEnv (eltsVarSet bodyFVs)
bindsEnv = mkVarEnv (map (\(x,e0) -> (x,(x,e0))) binds)
collectUsed env v =
if v `elemVarEnv` env then
env
else
case lookupVarEnv v bindsEnv of
Just (x,e0) ->
let eFVs = Lens.foldMapOf freeLocalIds unitVarSet e0
in List.foldl' collectUsed
(extendVarEnv x (x,e0) env)
(eltsVarSet eFVs)
Nothing -> env