{-# LANGUAGE CPP #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Rewrite.Util where
import Control.DeepSeq
import Control.Exception (throw)
import Control.Lens
(Lens', (%=), (+=), (^.), _2, _4, _5, _Left)
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import qualified Control.Monad.State.Strict as State
import qualified Control.Monad.Writer as Writer
import Data.Bifunctor (bimap)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Lazy as HML
import qualified Data.HashMap.Strict as HMS
import qualified Data.HashSet as HashSet
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Maybe (catMaybes,isJust,mapMaybe)
import qualified Data.Monoid as Monoid
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import Unbound.Generics.LocallyNameless
(Fresh, bind, embed, rebind, rec, unbind, unembed, unrec)
import qualified Unbound.Generics.LocallyNameless as Unbound
import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind)
import BasicTypes (InlineSpec (..))
import SrcLoc (SrcSpan)
import Clash.Core.DataCon (dataConInstArgTys)
import Clash.Core.FreeVars (termFreeIds, termFreeTyVars,
typeFreeVars)
import Clash.Core.Name
import Clash.Core.Pretty (showDoc)
import Clash.Core.Subst (substTm)
import Clash.Core.Term
(LetBinding, Pat (..), Term (..), TmName, TmOccName)
import Clash.Core.TyCon
(TyCon, TyConOccName, tyConDataCons)
import Clash.Core.Type (KindOrType, Type (..),
TypeView (..), coreView,
normalizeType,
typeKind, tyView)
import Clash.Core.Util
(Delta, Gamma, collectArgs, isPolyFun, mkAbstraction, mkApps, mkId, mkLams,
mkTmApps, mkTyApps, mkTyLams, mkTyVar, termType)
import Clash.Core.Var (Id, TyVar, Var (..))
import Clash.Driver.Types
(ClashException (..), DebugLevel (..))
import Clash.Netlist.Util (representableType)
import Clash.Rewrite.Types
import Clash.Util
zoomExtra :: State.State extra a
-> RewriteMonad extra a
zoomExtra m = R (\_ s -> case State.runState m (s ^. extra) of
(a,s') -> (a,s {_extra = s'},mempty))
apply :: String
-> Rewrite extra
-> Rewrite extra
apply name rewrite ctx expr = do
lvl <- Lens.view dbgLevel
let before = showDoc expr
(expr', anyChanged) <- traceIf (lvl >= DebugAll) ("Trying: " ++ name ++ " on:\n" ++ before) $ Writer.listen $ rewrite ctx expr
let hasChanged = Monoid.getAny anyChanged
Monad.when hasChanged $ transformCounter += 1
let after = showDoc expr'
let expr'' = if hasChanged then expr' else expr
Monad.when (lvl > DebugNone && hasChanged) $ do
tcm <- Lens.view tcCache
beforeTy <- termType tcm expr
let beforeFTV = Lens.setOf termFreeTyVars expr
beforeFV <- Lens.setOf <$> localFreeIds <*> pure expr
afterTy <- termType tcm expr'
let afterFTV = Lens.setOf termFreeTyVars expr
afterFV <- Lens.setOf <$> localFreeIds <*> pure expr'
let newFV = Set.size afterFTV > Set.size beforeFTV ||
Set.size afterFV > Set.size beforeFV
Monad.when newFV $
error ( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "Changes free variables from: ", show (beforeFTV,beforeFV)
, "\nto: ", show (afterFTV,afterFV)
]
)
traceIf (lvl >= DebugAll && beforeTy /= afterTy)
( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "Changes type from:\n", showDoc beforeTy
, "\nto:\n", showDoc afterTy
]
) (return ())
Monad.when (lvl >= DebugApplied && not hasChanged && expr /= expr') $
error $ $(curLoc) ++ "Expression changed without notice(" ++ name ++ "): before" ++ before ++ "\nafter:\n" ++ after
traceIf (lvl >= DebugName && hasChanged) name $
traceIf (lvl >= DebugApplied && hasChanged) ("Changes when applying rewrite to:\n" ++ before ++ "\nResult:\n" ++ after ++ "\n") $
traceIf (lvl >= DebugAll && not hasChanged) ("No changes when applying rewrite " ++ name ++ " to:\n" ++ after ++ "\n") $
return expr''
runRewrite :: String
-> Rewrite extra
-> Term
-> RewriteMonad extra Term
runRewrite name rewrite expr = apply name rewrite [] expr
runRewriteSession :: RewriteEnv
-> RewriteState extra
-> RewriteMonad extra a
-> a
runRewriteSession r s m = traceIf True ("Applied " ++
show (s' ^. transformCounter) ++
" transformations")
a
where
(a,s',_) = runR m r s
setChanged :: RewriteMonad extra ()
setChanged = Writer.tell (Monoid.Any True)
changed :: a -> RewriteMonad extra a
changed val = do
Writer.tell (Monoid.Any True)
return val
contextEnv :: [CoreContext]
-> (Gamma, Delta)
contextEnv = go HML.empty HML.empty
where
go gamma delta [] = (gamma,delta)
go gamma delta (LetBinding _ ids:ctx) = go gamma' delta ctx
where
gamma' = foldl addToGamma gamma ids
go gamma delta (LetBody ids:ctx) = go gamma' delta ctx
where
gamma' = foldl addToGamma gamma ids
go gamma delta (LamBody lId:ctx) = go gamma' delta ctx
where
gamma' = addToGamma gamma lId
go gamma delta (TyLamBody tv:ctx) = go gamma delta' ctx
where
delta' = addToDelta delta tv
go gamma delta (CaseAlt ids:ctx) = go gamma' delta ctx
where
gamma' = foldl addToGamma gamma ids
go gamma delta (_:ctx) = go gamma delta ctx
addToGamma gamma (Id idName ty) = HML.insert (nameOcc idName) (unembed ty) gamma
addToGamma _ _ = error $ $(curLoc) ++ "Adding TyVar to Gamma"
addToDelta delta (TyVar tvName ki) = HML.insert (nameOcc tvName) (unembed ki) delta
addToDelta _ _ = error $ $(curLoc) ++ "Adding Id to Delta"
closestLetBinder :: [CoreContext] -> Maybe Id
closestLetBinder [] = Nothing
closestLetBinder (LetBinding id_ _:_) = Just id_
closestLetBinder (_:ctx) = closestLetBinder ctx
mkDerivedName :: [CoreContext] -> String -> TmName
mkDerivedName ctx sf = case closestLetBinder ctx of
Just id_ -> appendToName (varName id_) ('_':sf)
_ -> string2InternalName sf
mkEnv :: [CoreContext]
-> RewriteMonad extra (Gamma, Delta)
mkEnv ctx = do
let (gamma,delta) = contextEnv ctx
tsMap <- fmap (HML.map (^. _2)) $ Lens.use bindings
let gamma' = tsMap `HML.union` gamma
return (gamma',delta)
mkTmBinderFor :: (Fresh m, MonadUnique m)
=> HashMap TyConOccName TyCon
-> Name a
-> Term
-> m (Id, Term)
mkTmBinderFor tcm name e = do
(Left r) <- mkBinderFor tcm name (Left e)
return r
mkBinderFor :: (Monad m, MonadUnique m, Fresh m)
=> HashMap TyConOccName TyCon
-> Name a
-> Either Term Type
-> m (Either (Id,Term) (TyVar,Type))
mkBinderFor tcm name (Left term) =
Left <$> (mkInternalVar (coerceName name) =<< termType tcm term)
mkBinderFor tcm name (Right ty) = do
name' <- cloneVar (coerceName name)
let kind = typeKind tcm ty
return $ Right (TyVar name' (embed kind), VarTy kind name')
mkInternalVar :: (Monad m, MonadUnique m)
=> TmName
-> KindOrType
-> m (Id,Term)
mkInternalVar name ty = do
name' <- cloneVar name
return (Id name' (embed ty),Var ty name')
inlineBinders :: (Term -> LetBinding -> RewriteMonad extra Bool)
-> Rewrite extra
inlineBinders condition _ expr@(Letrec b) = do
(xes,res) <- unbind b
let expr' = Letrec (bind xes res)
(replace,others) <- partitionM (condition expr') (unrec xes)
case replace of
[] -> return expr
_ -> do
let (others',res') = substituteBinders replace others res
newExpr = case others' of
[] -> res'
_ -> Letrec (bind (rec others') res')
changed newExpr
inlineBinders _ _ e = return e
isJoinPointIn :: Id
-> Term
-> Bool
isJoinPointIn id_ e = case tailCalls id_ e of
Just n | n > 1 -> True
_ -> False
tailCalls :: Id
-> Term
-> Maybe Int
tailCalls id_ expr = case expr of
Var _ nm | varName id_ == nm -> Just 1
| otherwise -> Just 0
Lam b -> let (_,expr') = unsafeUnbind b
in tailCalls id_ expr'
TyLam b -> let (_,expr') = unsafeUnbind b
in tailCalls id_ expr'
App l r -> case tailCalls id_ r of
Just 0 -> tailCalls id_ l
_ -> Nothing
TyApp l _ -> tailCalls id_ l
Letrec b ->
let (bsR,expr') = unsafeUnbind b
(bsIds,bsExprs) = unzip (unrec bsR)
bsTls = map (tailCalls id_ . unembed) bsExprs
bsIdsUsed = mapMaybe (\(l,r) -> pure l <* r) (zip bsIds bsTls)
bsIdsTls = map (`tailCalls` expr') bsIdsUsed
bsCount = pure . sum $ catMaybes bsTls
in case (all isJust bsTls) of
False -> Nothing
True -> case (all (==0) $ catMaybes bsTls) of
False -> case all isJust bsIdsTls of
False -> Nothing
True -> (+) <$> bsCount <*> tailCalls id_ expr'
True -> tailCalls id_ expr'
Case scrut _ alts ->
let scrutTl = tailCalls id_ scrut
altsTl = map (tailCalls id_ . snd . unsafeUnbind) alts
in case scrutTl of
Just 0 | all (/= Nothing) altsTl -> Just (sum (catMaybes altsTl))
_ -> Nothing
_ -> Just 0
isVoidWrapper :: Term -> Bool
isVoidWrapper (Lam b) = case unsafeUnbind b of
(bndr,e@(collectArgs -> (Var _ _,_)))
-> nameOcc (varName bndr) `notElem` Lens.toListOf termFreeIds e
_ -> False
isVoidWrapper _ = False
substituteBinders :: [LetBinding]
-> [LetBinding]
-> Term
-> ([LetBinding],Term)
substituteBinders [] others res = (others,res)
substituteBinders ((bndr,valE):rest) others res = substituteBinders rest' others' res'
where
val = unembed valE
bndrName = nameOcc (varName bndr)
selfRef = bndrName `elem` Lens.toListOf termFreeIds val
(res',rest',others') = if selfRef
then (res,rest,(bndr,valE):others)
else ( substTm bndrName val res
, map (second ( embed
. substTm bndrName val
. unembed)
) rest
, map (second ( embed
. substTm bndrName val
. unembed)
) others
)
localFreeIds :: (Applicative f, Lens.Contravariant f)
=> RewriteMonad extra ((TmOccName -> f TmOccName) -> Term -> f Term)
localFreeIds = do
globalBndrs <- Lens.use bindings
return ((termFreeIds . Lens.filtered (not . (`HML.member` globalBndrs))))
inlineOrLiftBinders :: (LetBinding -> RewriteMonad extra Bool)
-> (Term -> LetBinding -> RewriteMonad extra Bool)
-> Rewrite extra
inlineOrLiftBinders condition inlineOrLift ctx expr@(Letrec b) = do
(xesR,res) <- unbind b
let xes = unrec xesR
(replace,others) <- partitionM condition xes
case replace of
[] -> return expr
_ -> do
let expr' = Letrec (bind xesR res)
(doInline,doLift) <- partitionM (inlineOrLift expr') replace
let (others',res') = substituteBinders doInline (doLift ++ others) res
(doLift',others'') = splitAt (length doLift) others'
(gamma,delta) <- mkEnv (LetBinding undefined (map fst xes) : ctx)
doLift'' <- mapM (liftBinding gamma delta) doLift'
let (others3,res'') = substituteBinders doLift'' others'' res'
newExpr = case others3 of
[] -> res''
_ -> Letrec (bind (rec others3) res'')
changed newExpr
inlineOrLiftBinders _ _ _ e = return e
liftBinding :: Gamma
-> Delta
-> LetBinding
-> RewriteMonad extra LetBinding
liftBinding gamma delta (Id idName tyE,eE) = do
let e = unembed eE
let localFTVs = map (\nm -> Name Internal nm noSrcSpan)
. List.nub $ Lens.toListOf termFreeTyVars e
localFVs <- map (\nm -> Name Internal nm noSrcSpan) . List.nub <$>
(Lens.toListOf <$> localFreeIds <*> pure e)
let localFTVkinds = map (\k -> HML.lookupDefault (error $ $(curLoc) ++ show k ++ " not found") (nameOcc k) delta) localFTVs
localFVs' = filter ((/= (nameOcc idName)) . nameOcc) localFVs
localFVtys' = map (\k -> HML.lookupDefault (error $ $(curLoc) ++ show k ++ " not found") (nameOcc k) gamma) localFVs'
boundFTVs = zipWith mkTyVar localFTVkinds localFTVs
boundFVs = zipWith mkId localFVtys' localFVs'
tcm <- Lens.view tcCache
newBodyTy <- termType tcm $ mkTyLams (mkLams e boundFVs) boundFTVs
(cf,sp) <- Lens.use curFun
newBodyId <- cloneVar (appendToName cf ("_" ++ name2String idName))
let newExpr = mkTmApps
(mkTyApps (Var newBodyTy newBodyId)
(zipWith VarTy localFTVkinds localFTVs))
(zipWith Var localFVtys' localFVs')
e' = substTm (nameOcc idName) newExpr e
newBody = mkTyLams (mkLams e' boundFVs) boundFTVs
aeqExisting <- (HMS.toList . HMS.filter ((== newBody) . (^. _5))) <$> Lens.use bindings
case aeqExisting of
[] -> do
bindings %= HMS.insert (nameOcc newBodyId)
(newBodyId {nameSort = Internal}
,newBodyTy,sp
#if MIN_VERSION_ghc(8,4,1)
,NoUserInline
#else
,EmptyInlineSpec
#endif
,newBody)
return (Id idName tyE, embed newExpr)
((_,(k,aeqTy,_,_,_)):_) ->
let newExpr' = mkTmApps
(mkTyApps (Var aeqTy k)
(zipWith VarTy localFTVkinds localFTVs))
(zipWith Var localFVtys' localFVs')
in return (Id idName tyE, embed newExpr')
liftBinding _ _ _ = error $ $(curLoc) ++ "liftBinding: invalid core, expr bound to tyvar"
mkFunction
:: TmName
-> SrcSpan
-> InlineSpec
-> Term
-> RewriteMonad extra (TmName,Type)
mkFunction bndr sp inl body = do
tcm <- Lens.view tcCache
bodyTy <- termType tcm body
bodyId <- cloneVar bndr
addGlobalBind bodyId bodyTy sp inl body
return (bodyId,bodyTy)
addGlobalBind
:: TmName
-> Type
-> SrcSpan
-> InlineSpec
-> Term
-> RewriteMonad extra ()
addGlobalBind vId ty sp inl body =
(ty,body) `deepseq` bindings %= HMS.insert (nameOcc vId) (vId,ty,sp,inl,body)
cloneVar
:: (Monad m, MonadUnique m)
=> Name a
-> m (Name a)
cloneVar (Name sort nm loc) = do
i <- toInteger <$> getUniqueM
return (Name sort (Unbound.makeName (Unbound.name2String nm) i) loc)
isLocalVar :: Term
-> RewriteMonad extra Bool
isLocalVar (Var _ name)
= fmap (not . HML.member (nameOcc name))
$ Lens.use bindings
isLocalVar _ = return False
{-# INLINE isUntranslatable #-}
isUntranslatable
:: Bool
-> Term
-> RewriteMonad extra Bool
isUntranslatable stringRepresentable tm = do
tcm <- Lens.view tcCache
not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view allowZero
<*> pure stringRepresentable
<*> pure tcm
<*> termType tcm tm)
{-# INLINE isUntranslatableType #-}
isUntranslatableType
:: Bool
-> Type
-> RewriteMonad extra Bool
isUntranslatableType stringRepresentable ty =
not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view allowZero
<*> pure stringRepresentable
<*> Lens.view tcCache
<*> pure ty)
isLambdaBodyCtx :: CoreContext
-> Bool
isLambdaBodyCtx (LamBody _) = True
isLambdaBodyCtx _ = False
mkWildValBinder :: (Monad m, MonadUnique m)
=> Type
-> m Id
mkWildValBinder = fmap fst . mkInternalVar (string2InternalName "wild")
mkSelectorCase :: (Functor m, Monad m, MonadUnique m, Fresh m)
=> String
-> HashMap TyConOccName TyCon
-> Term
-> Int
-> Int
-> m Term
mkSelectorCase caller tcm scrut dcI fieldI = do
scrutTy <- termType tcm scrut
go scrutTy
where
go (coreView tcm -> Just ty') = go ty'
go scrutTy@(tyView -> TyConApp tc args) =
case tyConDataCons (tcm HMS.! nameOcc tc) of
[] -> cantCreate $(curLoc) ("TyCon has no DataCons: " ++ show tc ++ " " ++ showDoc tc) scrutTy
dcs | dcI > length dcs -> cantCreate $(curLoc) "DC index exceeds max" scrutTy
| otherwise -> do
let dc = indexNote ($(curLoc) ++ "No DC with tag: " ++ show (dcI-1)) dcs (dcI-1)
let (Just fieldTys) = dataConInstArgTys dc args
if fieldI >= length fieldTys
then cantCreate $(curLoc) "Field index exceed max" scrutTy
else do
wildBndrs <- mapM mkWildValBinder fieldTys
let ty = indexNote ($(curLoc) ++ "No DC field#: " ++ show fieldI) fieldTys fieldI
selBndr <- mkInternalVar (string2InternalName "sel") ty
let bndrs = take fieldI wildBndrs ++ [fst selBndr] ++ drop (fieldI+1) wildBndrs
pat = DataPat (embed dc) (rebind [] bndrs)
retVal = Case scrut ty [ bind pat (snd selBndr) ]
return retVal
go scrutTy = cantCreate $(curLoc) ("Type of subject is not a datatype: " ++ showDoc scrutTy) scrutTy
cantCreate loc info scrutTy = error $ loc ++ "Can't create selector " ++ show (caller,dcI,fieldI) ++ " for: (" ++ showDoc scrut ++ " :: " ++ showDoc scrutTy ++ ")\nAdditional info: " ++ info
specialise :: Lens' extra (Map.Map (TmOccName, Int, Either Term Type) (TmName,Type))
-> Lens' extra (HashMap TmOccName Int)
-> Lens' extra Int
-> Rewrite extra
specialise specMapLbl specHistLbl specLimitLbl ctx e = case e of
(TyApp e1 ty) -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgs e1) (Right ty)
(App e1 e2) -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgs e1) (Left e2)
_ -> return e
specialise' :: Lens' extra (Map.Map (TmOccName, Int, Either Term Type) (TmName,Type))
-> Lens' extra (HashMap TmOccName Int)
-> Lens' extra Int
-> [CoreContext]
-> Term
-> (Term, [Either Term Type])
-> Either Term Type
-> RewriteMonad extra Term
specialise' specMapLbl specHistLbl specLimitLbl ctx e (Var _ f, args) specArgIn = do
lvl <- Lens.view dbgLevel
topEnts <- Lens.view topEntities
if nameOcc f `HashSet.member` topEnts
then traceIf (lvl >= DebugNone) ("Not specialising TopEntity: " ++ showDoc f) (return e)
else do
tcm <- Lens.view tcCache
let specArg = bimap (normalizeTermTypes tcm) (normalizeType tcm) specArgIn
(specBndrsIn,specVars) <- specArgBndrsAndVars ctx specArg
let argLen = length args
specBndrs :: [Either Id TyVar]
specBndrs = map (Lens.over _Left (normalizeId tcm)) specBndrsIn
specAbs :: Either Term Type
specAbs = either (Left . (`mkAbstraction` specBndrs)) (Right . id) specArg
specM <- Map.lookup (nameOcc f,argLen,specAbs) <$> Lens.use (extra.specMapLbl)
case specM of
Just (fname,fty) ->
traceIf (lvl >= DebugApplied) ("Using previous specialization of " ++ showDoc (nameOcc f) ++ " on " ++ (either showDoc showDoc) specAbs ++ ": " ++ showDoc fname) $
changed $ mkApps (Var fty fname) (args ++ specVars)
Nothing -> do
bodyMaybe <- fmap (HML.lookup (nameOcc f)) $ Lens.use bindings
case bodyMaybe of
Just (_,_,sp,inl,bodyTm) -> do
specHistM <- HML.lookup (nameOcc f) <$> Lens.use (extra.specHistLbl)
specLim <- Lens.use (extra . specLimitLbl)
if maybe False (> specLim) specHistM
then throw (ClashException
sp
(unlines [ "Hit specialisation limit " ++ show specLim ++ " on function `" ++ showDoc (nameOcc f) ++ "'.\n"
, "The function `" ++ showDoc f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
, "Body of `" ++ showDoc f ++ "':\n" ++ showDoc bodyTm ++ "\n"
, "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showDoc showDoc) specArg
, "Run with '-fclash-spec-limit=N' to increase the specialisation limit to N."
])
Nothing)
else do
(boundArgs,argVars) <- fmap (unzip . map (either (Left *** Left) (Right *** Right))) $
Monad.zipWithM
(mkBinderFor tcm)
(unsafeCollectBndrs bodyTm ++ repeat (string2InternalName "pTS"))
args
(fName,inl',specArg') <- case specArg of
Left a@(collectArgs -> (Var _ g,gArgs)) -> do
polyFun <- isPolyFun tcm a
if polyFun
then do
gTmM <- fmap (HML.lookup (nameOcc g)) $ Lens.use bindings
return (g,maybe inl (^. _4) gTmM, maybe specArg (Left . (`mkApps` gArgs) . (^. _5)) gTmM)
else return (f,inl,specArg)
_ -> return (f,inl,specArg)
let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs)
newf <- mkFunction fName sp inl' newBody
(extra.specHistLbl) %= HML.insertWith (+) (nameOcc f) 1
(extra.specMapLbl) %= Map.insert (nameOcc f,argLen,specAbs) newf
let newExpr = mkApps ((uncurry . flip) Var newf) (args ++ specVars)
newf `deepseq` changed newExpr
Nothing -> return e
where
unsafeCollectBndrs :: Term -> [Name a]
unsafeCollectBndrs =
map (either (coerceName . varName) (coerceName . varName)) . reverse . go []
where
go bs (Lam b) = let (v,e') = unsafeUnbind b in go (Left v:bs) e'
go bs (TyLam b) = let (tv,e') = unsafeUnbind b in go (Right tv:bs) e'
go bs (App e' _) = case go [] e' of
[] -> bs
bs' -> init bs' ++ bs
go bs (TyApp e' _) = case go [] e' of
[] -> bs
bs' -> init bs' ++ bs
go bs _ = bs
specialise' _ _ _ ctx _ (appE,args) (Left specArg) = do
(specBndrs,specVars) <- specArgBndrsAndVars ctx (Left specArg)
let newBody = mkAbstraction specArg specBndrs
existing <- HML.filter ((== newBody) . (^. _5)) <$> Lens.use bindings
newf <- case HML.toList existing of
[] -> do (cf,sp) <- Lens.use curFun
mkFunction (appendToName cf "_specF")
sp
#if MIN_VERSION_ghc(8,4,1)
NoUserInline
#else
EmptyInlineSpec
#endif
newBody
((_,(k,kTy,_,_,_)):_) -> return (k,kTy)
let newArg = Left $ mkApps ((uncurry . flip) Var newf) specVars
let newExpr = mkApps appE (args ++ [newArg])
changed newExpr
specialise' _ _ _ _ e _ _ = return e
normalizeTermTypes :: HashMap TyConOccName TyCon -> Term -> Term
normalizeTermTypes tcm e = case e of
Cast e' ty1 ty2 -> Cast (normalizeTermTypes tcm e') (normalizeType tcm ty1) (normalizeType tcm ty2)
Var ty nm -> Var (normalizeType tcm ty) nm
_ -> e
normalizeId :: HashMap TyConOccName TyCon -> Id -> Id
normalizeId tcm (Id nm (Unbound.Embed ty)) = Id nm (Unbound.Embed $ normalizeType tcm ty)
normalizeId _ tyvar = tyvar
specArgBndrsAndVars :: [CoreContext]
-> Either Term Type
-> RewriteMonad extra ([Either Id TyVar],[Either Term Type])
specArgBndrsAndVars ctx specArg = do
let specFTVs = List.nub $ either (Lens.toListOf termFreeTyVars) (Lens.toListOf typeFreeVars) specArg
specFVs <- List.nub <$> either ((Lens.toListOf <$> localFreeIds <*>) . pure) (const (pure [])) specArg
(gamma,delta) <- mkEnv ctx
let (specTyBndrs,specTyVars) = unzip
$ map (\tv -> let ki = HML.lookupDefault (error $ $(curLoc) ++ show tv ++ " not found") tv delta
tv' = Name Internal tv noSrcSpan
in (Right $ TyVar tv' (embed ki), Right $ VarTy ki tv')) specFTVs
(specTmBndrs,specTmVars) = unzip
$ map (\tm -> let ty = HML.lookupDefault (error $ $(curLoc) ++ show tm ++ " not found") tm gamma
tm' = Name Internal tm noSrcSpan
in (Left $ Id tm' (embed ty), Left $ Var ty tm')) specFVs
return (specTyBndrs ++ specTmBndrs,specTyVars ++ specTmVars)