module CLaSH.Rewrite.Util where
import Control.DeepSeq
import Control.Lens (Lens', (%=), (+=), (^.))
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import qualified Control.Monad.Reader as Reader
import qualified Control.Monad.State as State
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Writer as Writer
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Lazy as HML
import qualified Data.HashMap.Strict as HMS
import qualified Data.Map as Map
import Data.Maybe (mapMaybe)
import qualified Data.Monoid as Monoid
import qualified Data.Set as Set
import Unbound.LocallyNameless (Collection (..), Fresh, bind,
embed, makeName, name2String,
rebind, rec, string2Name, unbind,
unembed, unrec)
import qualified Unbound.LocallyNameless as Unbound
import Unbound.Util (filterC)
import CLaSH.Core.DataCon (dataConInstArgTys)
import CLaSH.Core.FreeVars (termFreeVars, typeFreeVars, termFreeIds)
import CLaSH.Core.Pretty (showDoc)
import CLaSH.Core.Subst (substTm)
import CLaSH.Core.Term (LetBinding, Pat (..), Term (..),
TmName)
import CLaSH.Core.TyCon (TyCon, TyConName, tyConDataCons)
import CLaSH.Core.Type (KindOrType, TyName, Type (..),
TypeView (..), transparentTy,
typeKind, coreView)
import CLaSH.Core.Util (Delta, Gamma, collectArgs,
mkAbstraction, mkApps, mkId,
mkLams, mkTmApps, mkTyApps,
mkTyLams, mkTyVar, termType)
import CLaSH.Core.Var (Id, TyVar, Var (..))
import CLaSH.Netlist.Util (representableType)
import CLaSH.Rewrite.Types
import CLaSH.Util
liftR :: Monad m => m a -> RewriteMonad m a
liftR m = lift . lift . lift . lift $ m
liftRS :: Monad m => m a -> RewriteSession m a
liftRS m = lift . lift . lift $ m
apply :: (Monad m, Functor m)
=> String
-> Rewrite m
-> Rewrite m
apply name rewrite ctx expr = R $ do
lvl <- Lens.view dbgLevel
let before = showDoc expr
(expr', anyChanged) <- traceIf (lvl >= DebugAll) ("Trying: " ++ name ++ " on:\n" ++ before) $ Writer.listen $ runR $ rewrite ctx expr
let hasChanged = Monoid.getAny anyChanged
Monad.when hasChanged $ transformCounter += 1
let after = showDoc expr'
let expr'' = if hasChanged then expr' else expr
Monad.when (lvl > DebugNone && hasChanged) $ do
tcm <- Lens.use tcCache
beforeTy <- fmap transparentTy $ termType tcm expr
(beforeFTV,beforeFV) <- localFreeVars expr
afterTy <- fmap transparentTy $ termType tcm expr'
(afterFTV,afterFV) <- localFreeVars expr'
let newFV = Set.size afterFTV > Set.size beforeFTV ||
Set.size afterFV > Set.size beforeFV
Monad.when newFV $
error ( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "Changes free variables from: ", show (beforeFTV,beforeFV)
, "\nto: ", show (afterFTV,afterFV)
]
)
traceIf ( beforeTy /= afterTy)
( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "Changes type from:\n", showDoc beforeTy
, "\nto:\n", showDoc afterTy
]
) (return ())
Monad.when (lvl >= DebugApplied && not hasChanged && expr /= expr') $
error $ "Expression changed without notice(" ++ name ++ "): before" ++ before ++ "\nafter:\n" ++ after
traceIf (lvl >= DebugName && hasChanged) name $
traceIf (lvl >= DebugApplied && hasChanged) ("Changes when applying rewrite to:\n" ++ before ++ "\nResult:\n" ++ after ++ "\n") $
traceIf (lvl >= DebugAll && not hasChanged) ("No changes when applying rewrite " ++ name ++ " to:\n" ++ after ++ "\n") $
return expr''
runRewrite :: (Monad m, Functor m)
=> String
-> Rewrite m
-> Term
-> RewriteSession m Term
runRewrite name rewrite expr = do
(expr',_) <- Writer.runWriterT . runR $ apply name rewrite [] expr
return expr'
runRewriteSession :: (Functor m, Monad m)
=> DebugLevel
-> RewriteState
-> RewriteSession m a
-> m a
runRewriteSession lvl st
= Unbound.runFreshMT
. fmap (\(a,s) -> traceIf True ("Applied " ++ show (s ^. transformCounter) ++ " transformations") a)
. (`State.runStateT` st)
. (`Reader.runReaderT` RE lvl)
setChanged :: Monad m => RewriteMonad m ()
setChanged = Writer.tell (Monoid.Any True)
changed :: Monad m => a -> RewriteMonad m a
changed val = do
Writer.tell (Monoid.Any True)
return val
contextEnv :: [CoreContext]
-> (Gamma, Delta)
contextEnv = go HML.empty HML.empty
where
go gamma delta [] = (gamma,delta)
go gamma delta (LetBinding ids:ctx) = go gamma' delta ctx
where
gamma' = foldl addToGamma gamma ids
go gamma delta (LetBody ids:ctx) = go gamma' delta ctx
where
gamma' = foldl addToGamma gamma ids
go gamma delta (LamBody lId:ctx) = go gamma' delta ctx
where
gamma' = addToGamma gamma lId
go gamma delta (TyLamBody tv:ctx) = go gamma delta' ctx
where
delta' = addToDelta delta tv
go gamma delta (CaseAlt ids:ctx) = go gamma' delta ctx
where
gamma' = foldl addToGamma gamma ids
go gamma delta (_:ctx) = go gamma delta ctx
addToGamma gamma (Id idName ty) = HML.insert idName (unembed ty) gamma
addToGamma _ _ = error $ $(curLoc) ++ "Adding TyVar to Gamma"
addToDelta delta (TyVar tvName ki) = HML.insert tvName (unembed ki) delta
addToDelta _ _ = error $ $(curLoc) ++ "Adding Id to Delta"
mkEnv :: (Functor m, Monad m)
=> [CoreContext]
-> RewriteMonad m (Gamma, Delta)
mkEnv ctx = do
let (gamma,delta) = contextEnv ctx
tsMap <- fmap (HML.map fst) $ Lens.use bindings
let gamma' = tsMap `HML.union` gamma
return (gamma',delta)
mkTmBinderFor :: (Functor m, Fresh m, MonadUnique m)
=> HashMap TyConName TyCon
-> String
-> Term
-> m (Id, Term)
mkTmBinderFor tcm name e = do
(Left r) <- mkBinderFor tcm name (Left e)
return r
mkBinderFor :: (Functor m, Monad m, MonadUnique m, Fresh m)
=> HashMap TyConName TyCon
-> String
-> Either Term Type
-> m (Either (Id,Term) (TyVar,Type))
mkBinderFor tcm name (Left term) =
Left <$> (mkInternalVar name =<< termType tcm term)
mkBinderFor tcm name (Right ty) = do
name' <- fmap (makeName name . toInteger) getUniqueM
let kind = typeKind tcm ty
return $ Right (TyVar name' (embed kind), VarTy kind name')
mkInternalVar :: (Functor m, Monad m, MonadUnique m)
=> String
-> KindOrType
-> m (Id,Term)
mkInternalVar name ty = do
name' <- fmap (makeName name . toInteger) getUniqueM
return (Id name' (embed ty),Var ty name')
inlineBinders :: Monad m
=> (LetBinding -> RewriteMonad m Bool)
-> Rewrite m
inlineBinders condition _ expr@(Letrec b) = R $ do
(xes,res) <- unbind b
(replace,others) <- partitionM condition (unrec xes)
case replace of
[] -> return expr
_ -> do
let (others',res') = substituteBinders replace others res
newExpr = case others' of
[] -> res'
_ -> Letrec (bind (rec others') res')
changed newExpr
inlineBinders _ _ e = return e
substituteBinders :: [LetBinding]
-> [LetBinding]
-> Term
-> ([LetBinding],Term)
substituteBinders [] others res = (others,res)
substituteBinders ((bndr,valE):rest) others res = substituteBinders rest' others' res'
where
val = unembed valE
bndrName = varName bndr
selfRef = (bndrName `elem`) . snd $ termFreeVars val
(res',rest',others') = if selfRef
then (res,rest,(bndr,valE):others)
else ( substTm (varName bndr) val res
, map (second ( embed
. substTm bndrName val
. unembed)
) rest
, map (second ( embed
. substTm bndrName val
. unembed)
) others
)
localFreeVars :: (Functor m, Monad m, Collection c)
=> Term
-> RewriteMonad m (c TyName,c TmName)
localFreeVars term = do
globalBndrs <- Lens.use bindings
let (tyFVs,tmFVs) = termFreeVars term
return ( tyFVs
, filterC
$ cmap (\v -> if v `HML.member` globalBndrs
then Nothing
else Just v
) tmFVs
)
liftBinders :: (Functor m, Monad m)
=> (LetBinding -> RewriteMonad m Bool)
-> Rewrite m
liftBinders condition ctx expr@(Letrec b) = R $ do
(xes,res) <- unbind b
(replace,others) <- partitionM condition (unrec xes)
case replace of
[] -> return expr
_ -> do
(gamma,delta) <- mkEnv (LetBinding (map fst $ unrec xes) : ctx)
replace' <- mapM (liftBinding gamma delta) replace
let (others',res') = substituteBinders replace' others res
newExpr = case others' of
[] -> res'
_ -> Letrec (bind (rec others') res')
changed newExpr
liftBinders _ _ e = return e
liftBinding :: (Functor m, Monad m)
=> Gamma
-> Delta
-> LetBinding
-> RewriteMonad m LetBinding
liftBinding gamma delta (Id idName tyE,eE) = do
let ty = unembed tyE
e = unembed eE
(localFTVs,localFVs) <- fmap (Set.toList *** Set.toList) $ localFreeVars e
let localFTVkinds = map (\k -> HML.lookupDefault (error $ $(curLoc) ++ show k ++ " not found") k delta) localFTVs
localFVs' = filter (/= idName) localFVs
localFVtys' = map (\k -> HML.lookupDefault (error $ $(curLoc) ++ show k ++ " not found") k gamma) localFVs'
boundFTVs = zipWith mkTyVar localFTVkinds localFTVs
boundFVs = zipWith mkId localFVtys' localFVs'
tcm <- Lens.use tcCache
newBodyTy <- termType tcm $ mkTyLams (mkLams e boundFVs) boundFTVs
newBodyId <- fmap (makeName (name2String idName) . toInteger) getUniqueM
let newExpr = mkTmApps
(mkTyApps (Var newBodyTy newBodyId)
(zipWith VarTy localFTVkinds localFTVs))
(zipWith Var localFVtys' localFVs')
e' = substTm idName newExpr e
newBody = mkTyLams (mkLams e' boundFVs) boundFTVs
bindings %= HMS.insert newBodyId (newBodyTy,newBody)
return (Id idName (embed ty), embed newExpr)
liftBinding _ _ _ = error $ $(curLoc) ++ "liftBinding: invalid core, expr bound to tyvar"
mkFunction :: (Functor m, Monad m)
=> TmName
-> Term
-> RewriteMonad m (TmName,Type)
mkFunction bndr body = do
tcm <- Lens.use tcCache
bodyTy <- termType tcm body
bodyId <- cloneVar bndr
addGlobalBind bodyId bodyTy body
return (bodyId,bodyTy)
addGlobalBind :: (Functor m, Monad m)
=> TmName
-> Type
-> Term
-> RewriteMonad m ()
addGlobalBind vId ty body = (ty,body) `deepseq` bindings %= HMS.insert vId (ty,body)
cloneVar :: (Functor m, Monad m)
=> TmName
-> RewriteMonad m TmName
cloneVar name = fmap (makeName (name2String name) . toInteger) getUniqueM
isLocalVar :: (Functor m, Monad m)
=> Term
-> RewriteMonad m Bool
isLocalVar (Var _ name)
= fmap (not . HML.member name)
$ Lens.use bindings
isLocalVar _ = return False
isUntranslatable :: (Functor m, Monad m)
=> Term
-> RewriteMonad m Bool
isUntranslatable tm = do
tcm <- Lens.use tcCache
not <$> (representableType <$> Lens.use typeTranslator <*> pure tcm <*> termType tcm tm)
isLambdaBodyCtx :: CoreContext
-> Bool
isLambdaBodyCtx (LamBody _) = True
isLambdaBodyCtx _ = False
mkWildValBinder :: (Functor m, Monad m, MonadUnique m)
=> Type
-> m Id
mkWildValBinder = fmap fst . mkInternalVar "wild"
mkSelectorCase :: (Functor m, Monad m, MonadUnique m, Fresh m)
=> String
-> HashMap TyConName TyCon
-> [CoreContext]
-> Term
-> Int
-> Int
-> m Term
mkSelectorCase caller tcm _ scrut dcI fieldI = do
scrutTy <- termType tcm scrut
let cantCreate loc info = error $ loc ++ "Can't create selector " ++ show (caller,dcI,fieldI) ++ " for: (" ++ showDoc scrut ++ " :: " ++ showDoc scrutTy ++ ")\nAdditional info: " ++ info
case coreView tcm scrutTy of
TyConApp tc args ->
case tyConDataCons (tcm HMS.! tc) of
[] -> cantCreate $(curLoc) ("TyCon has no DataCons: " ++ show tc ++ " " ++ showDoc tc)
dcs | dcI > length dcs -> cantCreate $(curLoc) "DC index exceeds max"
| otherwise -> do
let dc = indexNote ($(curLoc) ++ "No DC with tag: " ++ show (dcI1)) dcs (dcI1)
let fieldTys = dataConInstArgTys dc args
if fieldI >= length fieldTys
then cantCreate $(curLoc) "Field index exceed max"
else do
wildBndrs <- mapM mkWildValBinder fieldTys
selBndr <- mkInternalVar "sel" (indexNote ($(curLoc) ++ "No DC field#: " ++ show fieldI) fieldTys fieldI)
let bndrs = take fieldI wildBndrs ++ [fst selBndr] ++ drop (fieldI+1) wildBndrs
let pat = DataPat (embed dc) (rebind [] bndrs)
let retVal = Case scrut [ bind pat (snd selBndr) ]
return retVal
FunTy _ _ -> do
(id_,var) <- mkInternalVar "selector" scrutTy
return (mkLams var [id_])
(OtherType oTy) -> cantCreate $(curLoc) ("Type of subject is not a datatype: " ++ showDoc oTy)
specialise :: (Functor m, State.MonadState s m)
=> Lens' s (Map.Map (TmName, Int, Either Term Type) (TmName,Type))
-> Lens' s (HashMap TmName Int)
-> Lens' s Int
-> Rewrite m
specialise specMapLbl specHistLbl specLimitLbl ctx e@(TyApp e1 ty) = specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgs e1) (Right ty)
specialise specMapLbl specHistLbl specLimitLbl ctx e@(App e1 e2) = specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgs e1) (Left e2)
specialise _ _ _ _ e = return e
specialise' :: (Functor m, State.MonadState s m)
=> Lens' s (Map.Map (TmName, Int, Either Term Type) (TmName,Type))
-> Lens' s (HashMap TmName Int)
-> Lens' s Int
-> [CoreContext]
-> Term
-> (Term, [Either Term Type])
-> Either Term Type
-> R m Term
specialise' specMapLbl specHistLbl specLimitLbl ctx e (Var _ f, args) specArg = R $ do
lvl <- Lens.view dbgLevel
(specBndrs,specVars) <- specArgBndrsAndVars ctx specArg
let argLen = length args
specAbs = either (Left . (`mkAbstraction` specBndrs)) (Right . id) specArg
specM <- liftR $ fmap (Map.lookup (f,argLen,specAbs))
$ Lens.use specMapLbl
case specM of
Just (fname,fty) ->
traceIf (lvl >= DebugApplied) ("Using previous specialization of " ++ showDoc f ++ " on " ++ (either showDoc showDoc) specAbs ++ ": " ++ showDoc fname) $
changed $ mkApps (Var fty fname) (args ++ specVars)
Nothing -> do
bodyMaybe <- fmap (HML.lookup f) $ Lens.use bindings
case bodyMaybe of
Just (_,bodyTm) -> do
specHistM <- liftR $ fmap (HML.lookup f) (Lens.use specHistLbl)
specLim <- liftR $ Lens.use specLimitLbl
if maybe False (> specLim) specHistM
then fail $ unlines [ "Hit specialisation limit on function `" ++ showDoc f ++ "'.\n"
, "The function `" ++ showDoc f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
, "Body of `" ++ showDoc f ++ "':\n" ++ showDoc bodyTm ++ "\n"
, "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showDoc showDoc) specArg
]
else do
tcm <- Lens.use tcCache
(boundArgs,argVars) <- fmap (unzip . map (either (Left *** Left) (Right *** Right))) $
mapM (mkBinderFor tcm "pTS") args
let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg])) (boundArgs ++ specBndrs)
newf <- mkFunction f newBody
liftR $ specHistLbl %= HML.insertWith (+) f 1
liftR $ specMapLbl %= Map.insert (f,argLen,specAbs) newf
let newExpr = mkApps ((uncurry . flip) Var newf) (args ++ specVars)
newf `deepseq` changed newExpr
Nothing -> return e
specialise' _ _ _ ctx _ (appE,args) (Left specArg) = R $ do
(specBndrs,specVars) <- specArgBndrsAndVars ctx (Left specArg)
let newBody = mkAbstraction specArg specBndrs
newf <- mkFunction (string2Name "specF") newBody
let newArg = Left $ mkApps ((uncurry . flip) Var newf) specVars
let newExpr = mkApps appE (args ++ [newArg])
changed newExpr
specialise' _ _ _ _ e _ _ = return e
specArgBndrsAndVars :: (Functor m, Monad m)
=> [CoreContext]
-> Either Term Type
-> RewriteMonad m ([Either Id TyVar],[Either Term Type])
specArgBndrsAndVars ctx specArg = do
(specFTVs,specFVs) <- fmap (Set.toList *** Set.toList) $
either localFreeVars (pure . (,emptyC) . typeFreeVars) specArg
(gamma,delta) <- mkEnv ctx
let (specTyBndrs,specTyVars) = unzip
$ map (\tv -> let ki = HML.lookupDefault (error $ $(curLoc) ++ show tv ++ " not found") tv delta
in (Right $ TyVar tv (embed ki), Right $ VarTy ki tv)) specFTVs
(specTmBndrs,specTmVars) = unzip
$ map (\tm -> let ty = HML.lookupDefault (error $ $(curLoc) ++ show tm ++ " not found") tm gamma
in (Left $ Id tm (embed ty), Left $ Var ty tm)) specFVs
return (specTyBndrs ++ specTmBndrs,specTyVars ++ specTmVars)
untranslatableFVs :: (Functor m, Monad m)
=> [CoreContext]
-> Term
-> RewriteMonad m Bool
untranslatableFVs ctx tm = do
let (gamma,_) = contextEnv ctx
fvs = termFreeIds tm
vars = mapMaybe (\n -> do fvTy <- HML.lookup n gamma
return (Var fvTy n)
) fvs
or <$> mapM isUntranslatable vars