{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize where
import Control.Concurrent.Supply (Supply)
import Control.Exception (throw)
import qualified Control.Lens as Lens
import Control.Monad (when)
import Control.Monad.State.Strict (State)
import Data.Default (def)
import Data.Either (lefts,partitionEithers)
import qualified Data.IntMap as IntMap
import Data.IntMap.Strict (IntMap)
import Data.List
(intersect, mapAccumL)
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import Data.Text.Prettyprint.Doc (vcat)
import BasicTypes (InlineSpec (..))
import Clash.Annotations.BitRepresentation.Internal
(CustomReprs)
import Clash.Core.Evaluator.Types (PrimStep, PrimUnwind)
import Clash.Core.FreeVars
(freeLocalIds, globalIds, globalIdOccursIn, localIdDoesNotOccurIn)
import Clash.Core.Pretty (PrettyOptions(..), showPpr, showPpr', ppr)
import Clash.Core.Subst
(extendGblSubstList, mkSubst, substTm)
import Clash.Core.Term (Term (..), collectArgsTicks
,mkApps, mkTicks)
import Clash.Core.Type (Type, splitCoreFunForallTy)
import Clash.Core.TyCon
(TyConMap, TyConName)
import Clash.Core.Type (isPolyTy)
import Clash.Core.Var (Id, varName, varType)
import Clash.Core.VarEnv
(VarEnv, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv,
extendVarEnv, lookupVarEnv, mapVarEnv, mapMaybeVarEnv,
mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, unionVarEnv)
import Clash.Debug (traceIf)
import Clash.Driver.Types
(BindingMap, Binding(..), ClashOpts (..), DebugLevel (..))
import Clash.Netlist.Types
(HWMap, FilteredHWType(..))
import Clash.Netlist.Util
(splitNormalized)
import Clash.Normalize.Strategy
import Clash.Normalize.Transformations
(appPropFast, bindConstantVar, caseCon, flattenLet, reduceConst, topLet,
reduceNonRepPrim, removeUnusedExpr, deadCode)
import Clash.Normalize.Types
import Clash.Normalize.Util
import Clash.Primitives.Types (CompiledPrimMap)
import Clash.Rewrite.Combinators ((>->),(!->),repeatR,topdownR)
import Clash.Rewrite.Types
(RewriteEnv (..), RewriteState (..), bindings, dbgLevel, extra,
tcCache, topEntities)
import Clash.Rewrite.Util
(apply, isUntranslatableType, runRewriteSession)
import Clash.Util
import Clash.Util.Interpolate (i)
#ifdef HISTORY
import Data.Binary (encode)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import System.IO.Unsafe (unsafePerformIO)
import Clash.Rewrite.Types (RewriteStep(..))
#endif
runNormalization
:: ClashOpts
-> Supply
-> BindingMap
-> (CustomReprs -> TyConMap -> Type ->
State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> IntMap TyConName
-> (PrimStep, PrimUnwind)
-> CompiledPrimMap
-> VarEnv Bool
-> [Id]
-> NormalizeSession a
-> a
runNormalization opts supply globals typeTrans reprs tcm tupTcm eval primMap rcsMap topEnts
= runRewriteSession rwEnv rwState
where
rwEnv = RewriteEnv
(opt_dbgLevel opts)
(opt_dbgTransformations opts)
(opt_dbgTransformationsFrom opts)
(opt_dbgTransformationsLimit opts)
(opt_aggressiveXOpt opts)
typeTrans
tcm
tupTcm
eval
(mkVarSet topEnts)
reprs
rwState = RewriteState
0
globals
supply
(error $ $(curLoc) ++ "Report as bug: no curFun",noSrcSpan)
0
(IntMap.empty, 0)
emptyVarEnv
normState
normState = NormalizeState
emptyVarEnv
Map.empty
emptyVarEnv
(opt_specLimit opts)
emptyVarEnv
(opt_inlineLimit opts)
(opt_inlineFunctionLimit opts)
(opt_inlineConstantLimit opts)
primMap
Map.empty
rcsMap
(opt_newInlineStrat opts)
(opt_ultra opts)
(opt_inlineWFCacheLimit opts)
normalize
:: [Id]
-> NormalizeSession BindingMap
normalize [] = return emptyVarEnv
normalize top = do
(new,topNormalized) <- unzip <$> mapM normalize' top
newNormalized <- normalize (concat new)
return (unionVarEnv (mkVarEnv topNormalized) newNormalized)
normalize' :: Id -> NormalizeSession ([Id], (Id, Binding))
normalize' nm = do
exprM <- lookupVarEnv nm <$> Lens.use bindings
let nmS = showPpr (varName nm)
case exprM of
Just (Binding nm' sp inl tm) -> do
tcm <- Lens.view tcCache
topEnts <- Lens.view topEntities
let isTop = nm `elemVarSet` topEnts
ty0 = varType nm'
ty1 = if isTop then tvSubstWithTyEq ty0 else ty0
when (isPolyTy ty1) $
let msg = $curLoc ++ [i|
Clash can only normalize monomorphic functions, but this is polymorphic:
#{showPpr' def{displayUniques=False\} nm'}
|]
msgExtra | ty0 == ty1 = Nothing
| otherwise = Just $ [i|
Even after applying type equality constraints it remained polymorphic:
#{showPpr' def{displayUniques=False\} nm'{varType=ty1\}}
|]
in throw (ClashException sp msg msgExtra)
let (args,resTy) = splitCoreFunForallTy tcm ty1
isTopEnt = nm `elemVarSet` topEnts
isFunction = not $ null $ lefts args
resTyRep <- not <$> isUntranslatableType False resTy
if resTyRep
then do
tmNorm <- normalizeTopLvlBndr isTopEnt nm (Binding nm' sp inl tm)
let usedBndrs = Lens.toListOf globalIds (bindingTerm tmNorm)
traceIf (nm `elem` usedBndrs)
(concat [ $(curLoc),"Expr belonging to bndr: ",nmS ," (:: "
, showPpr (varType (bindingId tmNorm))
, ") remains recursive after normalization:\n"
, showPpr (bindingTerm tmNorm) ])
(return ())
prevNorm <- mapVarEnv bindingId <$> Lens.use (extra.normalized)
let toNormalize = filter (`notElemVarSet` topEnts)
$ filter (`notElemVarEnv` (extendVarEnv nm nm prevNorm)) usedBndrs
return (toNormalize,(nm,tmNorm))
else
do
when (isTopEnt || isFunction) $
let msg = $(curLoc) ++ [i|
This bndr has a non-representable return type and can't be normalized:
#{showPpr' def{displayUniques=False\} nm'}
|]
in throw (ClashException sp msg Nothing)
lvl <- Lens.view dbgLevel
traceIf (lvl > DebugNone)
(concat [$(curLoc), "Expr belonging to bndr: ", nmS, " (:: "
, showPpr (varType nm')
, ") has a non-representable return type."
, " Not normalising:\n", showPpr tm] )
(return ([],(nm,(Binding nm' sp inl tm))))
Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found"
checkNonRecursive
:: BindingMap
-> BindingMap
checkNonRecursive norm = case mapMaybeVarEnv go norm of
rcs | nullVarEnv rcs -> norm
rcs -> error $ $(curLoc) ++ "Callgraph after normalization contains following recursive components: "
++ show (vcat [ ppr a <> ppr b
| (a,b) <- eltsVarEnv rcs
])
where
go (Binding nm _ _ tm) =
if nm `globalIdOccursIn` tm
then Just (nm,tm)
else Nothing
cleanupGraph
:: Id
-> BindingMap
-> NormalizeSession BindingMap
cleanupGraph topEntity norm
| Just ct <- mkCallTree [] norm topEntity
= do ctFlat <- flattenCallTree ct
return (mkVarEnv $ snd $ callTreeToList [] ctFlat)
cleanupGraph _ norm = return norm
data CallTree
= CLeaf (Id, Binding)
| CBranch (Id, Binding) [CallTree]
mkCallTree
:: [Id]
-> BindingMap
-> Id
-> Maybe CallTree
mkCallTree visited bindingMap root
| Just rootTm <- lookupVarEnv root bindingMap
= let used = Set.toList $ Lens.setOf globalIds $ (bindingTerm rootTm)
other = Maybe.mapMaybe (mkCallTree (root:visited) bindingMap) (filter (`notElem` visited) used)
in case used of
[] -> Just (CLeaf (root,rootTm))
_ -> Just (CBranch (root,rootTm) other)
mkCallTree _ _ _ = Nothing
stripArgs
:: [Id]
-> [Id]
-> [Either Term Type]
-> Maybe [Either Term Type]
stripArgs _ (_:_) [] = Nothing
stripArgs allIds [] args = if any mentionsId args
then Nothing
else Just args
where
mentionsId t = not $ null (either (Lens.toListOf freeLocalIds) (const []) t
`intersect`
allIds)
stripArgs allIds (id_:ids) (Left (Var nm):args)
| id_ == nm = stripArgs allIds ids args
| otherwise = Nothing
stripArgs _ _ _ = Nothing
flattenNode
:: CallTree
-> NormalizeSession (Either CallTree ((Id,Term),[CallTree]))
flattenNode c@(CLeaf (_,(Binding _ _ NoInline _))) = return (Left c)
flattenNode c@(CLeaf (nm,(Binding _ _ _ e))) = do
isTopEntity <- elemVarSet nm <$> Lens.view topEntities
if isTopEntity then return (Left c) else do
tcm <- Lens.view tcCache
let norm = splitNormalized tcm e
case norm of
Right (ids,[(bId,bExpr)],_) -> do
let (fun,args,ticks) = collectArgsTicks bExpr
case stripArgs ids (reverse ids) (reverse args) of
Just remainder | bId `localIdDoesNotOccurIn` bExpr ->
return (Right ((nm,mkApps (mkTicks fun ticks) (reverse remainder)),[]))
_ -> return (Right ((nm,e),[]))
_ -> return (Right ((nm,e),[]))
flattenNode b@(CBranch (_,(Binding _ _ NoInline _)) _) =
return (Left b)
flattenNode b@(CBranch (nm,(Binding _ _ _ e)) us) = do
isTopEntity <- elemVarSet nm <$> Lens.view topEntities
if isTopEntity then return (Left b) else do
tcm <- Lens.view tcCache
let norm = splitNormalized tcm e
case norm of
Right (ids,[(bId,bExpr)],_) -> do
let (fun,args,ticks) = collectArgsTicks bExpr
case stripArgs ids (reverse ids) (reverse args) of
Just remainder | bId `localIdDoesNotOccurIn` bExpr ->
return (Right ((nm,mkApps (mkTicks fun ticks) (reverse remainder)),us))
_ -> return (Right ((nm,e),us))
_ -> do
newInlineStrat <- Lens.use (extra.newInlineStrategy)
if newInlineStrat || isCheapFunction e
then return (Right ((nm,e),us))
else return (Left b)
flattenCallTree
:: CallTree
-> NormalizeSession CallTree
flattenCallTree c@(CLeaf _) = return c
flattenCallTree (CBranch (nm,(Binding nm' sp inl tm)) used) = do
flattenedUsed <- mapM flattenCallTree used
(newUsed,il_ct) <- partitionEithers <$> mapM flattenNode flattenedUsed
let (toInline,il_used) = unzip il_ct
subst = extendGblSubstList (mkSubst emptyInScopeSet) toInline
newExpr <- case toInline of
[] -> return tm
_ -> do
let tm1 = substTm "flattenCallTree.flattenExpr" subst tm
#ifdef HISTORY
let !_ = unsafePerformIO
$ BS.appendFile "history.dat"
$ BL.toStrict
$ encode RewriteStep
{ t_ctx = []
, t_name = "INLINE"
, t_bndrS = showPpr (varName nm')
, t_before = tm
, t_after = tm1
}
#endif
rewriteExpr ("flattenExpr",flatten) (showPpr nm, tm1) (nm', sp)
let allUsed = newUsed ++ concat il_used
if inl /= NoInline && isCheapFunction newExpr
then do
let (toInline',allUsed') = unzip (map goCheap allUsed)
subst' = extendGblSubstList (mkSubst emptyInScopeSet)
(Maybe.catMaybes toInline')
let tm1 = substTm "flattenCallTree.flattenCheap" subst' newExpr
newExpr' <- rewriteExpr ("flattenCheap",flatten) (showPpr nm, tm1) (nm', sp)
return (CBranch (nm,(Binding nm' sp inl newExpr')) (concat allUsed'))
else return (CBranch (nm,(Binding nm' sp inl newExpr)) allUsed)
where
flatten =
repeatR (topdownR (apply "appPropFast" appPropFast >->
apply "bindConstantVar" bindConstantVar >->
apply "caseCon" caseCon >->
(apply "reduceConst" reduceConst !-> apply "deadcode" deadCode) >->
apply "reduceNonRepPrim" reduceNonRepPrim >->
apply "removeUnusedExpr" removeUnusedExpr >->
apply "flattenLet" flattenLet)) !->
topdownSucR (apply "topLet" topLet)
goCheap c@(CLeaf (nm2,(Binding _ _ inl2 e)))
| inl2 == NoInline = (Nothing ,[c])
| otherwise = (Just (nm2,e),[])
goCheap c@(CBranch (nm2,(Binding _ _ inl2 e)) us)
| inl2 == NoInline = (Nothing, [c])
| otherwise = (Just (nm2,e),us)
callTreeToList :: [Id] -> CallTree -> ([Id], [(Id, Binding)])
callTreeToList visited (CLeaf (nm,bndr))
| nm `elem` visited = (visited,[])
| otherwise = (nm:visited,[(nm,bndr)])
callTreeToList visited (CBranch (nm,bndr) used)
| nm `elem` visited = (visited,[])
| otherwise = (visited',(nm,bndr):(concat others))
where
(visited',others) = mapAccumL callTreeToList (nm:visited) used