{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
#include "../../ClashDebug.h"
module Clash.Core.Subst
(
TvSubst (..)
, TvSubstEnv
, extendTvSubst
, extendTvSubstList
, substTy
, substTyWith
, substTyInVar
, substGlobalsInExistentials
, substInExistentials
, substInExistentialsList
, Subst (..)
, mkSubst
, mkTvSubst
, extendInScopeId
, extendInScopeIdList
, extendIdSubst
, extendIdSubstList
, extendGblSubstList
, substTm
, maybeSubstTm
, substAlt
, substId
, deShadowTerm
, deShadowAlt
, freshenTm
, deshadowLetExpr
, aeqType
, aeqTerm
)
where
import Data.Coerce (coerce)
import Data.Text.Prettyprint.Doc
import qualified Data.List as List
import qualified Data.List.Extra as List
import Data.Ord (comparing)
import Clash.Core.FreeVars
(noFreeVarsOfType, localFVsOfTerms, tyFVsOfTypes)
import Clash.Core.Pretty (ppr, fromPpr)
import Clash.Core.Term
(LetBinding, Pat (..), Term (..), TickInfo (..), PrimInfo(primName))
import Clash.Core.Type (Type (..))
import Clash.Core.VarEnv
import Clash.Core.Var (Id, Var (..), TyVar, isGlobalId)
import Clash.Debug (debugIsOn)
import Clash.Unique
import Clash.Util
import Clash.Pretty
type TvSubstEnv = VarEnv Type
data TvSubst
= TvSubst InScopeSet
TvSubstEnv
instance ClashPretty TvSubst where
clashPretty (TvSubst ins tenv) =
brackets $ sep [ "TvSubst"
, nest 2 ("In scope:" <+> clashPretty ins)
, nest 2 ("Type env:" <+> clashPretty tenv)]
type IdSubstEnv = VarEnv Term
data Subst
= Subst
{ substInScope :: InScopeSet
, substTmEnv :: IdSubstEnv
, substTyEnv :: TvSubstEnv
, substGblEnv :: IdSubstEnv
}
emptySubst
:: Subst
emptySubst = Subst emptyInScopeSet emptyVarEnv emptyVarEnv emptyVarEnv
mkSubst
:: InScopeSet
-> Subst
mkSubst is = Subst is emptyVarEnv emptyVarEnv emptyVarEnv
mkTvSubst
:: InScopeSet
-> VarEnv Type
-> Subst
mkTvSubst is env = Subst is emptyVarEnv env emptyVarEnv
zipTvSubst
:: [TyVar]
-> [Type]
-> Subst
zipTvSubst tvs tys
| debugIsOn
, not (List.equalLength tvs tys)
= pprTrace "zipTvSubst" (ppr tvs <> line <> ppr tys) emptySubst
| otherwise
= Subst (mkInScopeSet (tyFVsOfTypes tys)) emptyVarEnv tenv emptyVarEnv
where
tenv = zipTyEnv tvs tys
zipTyEnv
:: [TyVar]
-> [Type]
-> VarEnv Type
zipTyEnv tvs tys = mkVarEnv (List.zipEqual tvs tys)
extendIdSubst
:: Subst
-> Id
-> Term
-> Subst
extendIdSubst (Subst is env tenv genv) i e =
Subst is (extendVarEnv i e env) tenv genv
extendIdSubstList
:: Subst
-> [(Id,Term)]
-> Subst
extendIdSubstList (Subst is env tenv genv) es =
Subst is (extendVarEnvList env es) tenv genv
extendGblSubstList
:: Subst
-> [(Id,Term)]
-> Subst
extendGblSubstList (Subst is env tenv genv) es =
Subst is env tenv (extendVarEnvList genv es)
extendTvSubst
:: Subst
-> TyVar
-> Type
-> Subst
extendTvSubst (Subst is env tenv genv) tv t =
Subst is env (extendVarEnv tv t tenv) genv
extendTvSubstList
:: Subst
-> [(TyVar, Type)]
-> Subst
extendTvSubstList (Subst is env tenv genv) ts =
Subst is env (extendVarEnvList tenv ts) genv
extendInScopeId
:: Subst
-> Id
-> Subst
extendInScopeId (Subst inScope env tenv genv) id' =
Subst inScope' env' tenv genv
where
inScope' = extendInScopeSet inScope id'
env' = delVarEnv env id'
extendInScopeIdList
:: Subst
-> [Id]
-> Subst
extendInScopeIdList (Subst inScope env tenv genv) ids =
Subst inScope' env' tenv genv
where
inScope' = extendInScopeSetList inScope ids
env' = delVarEnvList env ids
substTy
:: HasCallStack
=> Subst
-> Type
-> Type
substTy (Subst inScope _ tvS _) ty
| nullVarEnv tvS
= ty
| otherwise
= checkValidSubst s' [ty] (substTy' s' ty)
where
s' = TvSubst inScope tvS
substTyInVar
:: HasCallStack
=> Subst
-> Var a
-> Var a
substTyInVar subst tyVar =
tyVar { varType = (substTy subst (varType tyVar)) }
substTyUnchecked
:: HasCallStack
=> TvSubst
-> Type
-> Type
substTyUnchecked subst@(TvSubst _ tvS) ty
| nullVarEnv tvS
= ty
| otherwise
= substTy' subst ty
substGlobalsInExistentials
:: HasCallStack
=> InScopeSet
-> [TyVar]
-> [(TyVar, Type)]
-> [TyVar]
substGlobalsInExistentials is exts substs0 = result
where
iss = scanl extendInScopeSet is exts
substs1 = map (\is_ -> extendTvSubstList (mkSubst is_) substs0) iss
result = zipWith substTyInVar substs1 exts
substInExistentialsList
:: HasCallStack
=> InScopeSet
-> [TyVar]
-> [(TyVar, Type)]
-> [TyVar]
substInExistentialsList is exts substs =
foldl (substInExistentials is) exts substs
substInExistentials
:: HasCallStack
=> InScopeSet
-> [TyVar]
-> (TyVar, Type)
-> [TyVar]
substInExistentials is exts subst@(typeVar, _type) =
case List.elemIndices typeVar exts of
[] ->
substGlobalsInExistentials is exts [subst]
(last -> i) ->
take (i+1) exts ++ substGlobalsInExistentials is (drop (i+1) exts) [subst]
checkValidSubst
:: HasCallStack
=> TvSubst
-> [Type]
-> a
-> a
checkValidSubst subst@(TvSubst inScope tenv) tys a =
WARN( not (isValidSubst subst),
"inScope" <+> clashPretty inScope <> line <>
"tenv" <+> clashPretty tenv <> line <>
"tenvFVs" <+> clashPretty (tyFVsOfTypes tenv) <> line <>
"tys" <+> fromPpr tys)
WARN( not tysFVsInSope,
"inScope" <+> clashPretty inScope <> line <>
"tenv" <+> clashPretty tenv <> line <>
"tys" <+> fromPpr tys <> line <>
"needsInScope" <+> clashPretty needsInScope)
a
where
needsInScope = foldrWithUnique (\k _ s -> delVarSetByKey k s)
(tyFVsOfTypes tys)
tenv
tysFVsInSope = needsInScope `varSetInScope` inScope
isValidSubst
:: TvSubst
-> Bool
isValidSubst (TvSubst inScope tenv) = tenvFVs `varSetInScope` inScope
where
tenvFVs = tyFVsOfTypes tenv
substTy'
:: HasCallStack
=> TvSubst
-> Type
-> Type
substTy' subst = go where
go = \case
VarTy tv -> substTyVar subst tv
ForAllTy tv ty -> case substTyVarBndr subst tv of
(subst', tv') -> ForAllTy tv' (substTy' subst' ty)
AppTy fun arg -> AppTy (go fun) (go arg)
ty -> ty
substTyVar
:: TvSubst
-> TyVar
-> Type
substTyVar (TvSubst _ tenv) tv = case lookupVarEnv tv tenv of
Just ty -> ty
_ -> VarTy tv
substTyVarBndr
:: TvSubst
-> TyVar
-> (TvSubst, TyVar)
substTyVarBndr subst@(TvSubst inScope tenv) oldVar =
ASSERT2( no_capture, clashPretty oldVar <> line
<> clashPretty newVar <> line
<> clashPretty subst )
(TvSubst (inScope `extendInScopeSet` newVar) newEnv, newVar)
where
newEnv | noChange = delVarEnv tenv oldVar
| otherwise = extendVarEnv oldVar (VarTy newVar) tenv
no_capture = not (newVar `elemVarSet` tyFVsOfTypes tenv)
oldKi = varType oldVar
noKindChange = noFreeVarsOfType oldKi
noChange = noKindChange && (newVar == oldVar)
newVar | noKindChange = uniqAway inScope oldVar
| otherwise = uniqAway inScope
(oldVar {varType = substTyUnchecked subst oldKi})
maybeSubstTm
:: HasCallStack
=> Doc ()
-> Maybe Subst
-> Term
-> Term
maybeSubstTm _doc Nothing = id
maybeSubstTm doc (Just s) = substTm doc s
substTm
:: HasCallStack
=> Doc ()
-> Subst
-> Term
-> Term
substTm doc subst = go where
go = \case
Var v -> lookupIdSubst (doc <> line <> "subsTm") subst v
Lam v e -> case substIdBndr subst v of
(subst',v') -> Lam v' (substTm doc subst' e)
TyLam v e -> case substTyVarBndr' subst v of
(subst',v') -> TyLam v' (substTm doc subst' e)
App l r -> App (go l) (go r)
TyApp l r -> TyApp (go l) (substTy subst r)
Letrec bs e -> case substBind doc subst bs of
(subst',bs') -> Letrec bs' (substTm doc subst' e)
Case subj ty alts -> Case (go subj) (substTy subst ty) (map goAlt alts)
Cast e t1 t2 -> Cast (go e) (substTy subst t1) (substTy subst t2)
Tick tick e -> Tick (goTick tick) (go e)
tm -> tm
goAlt (pat,alt) = case pat of
DataPat dc tvs ids -> case List.mapAccumL substTyVarBndr' subst tvs of
(subst1,tvs') -> case List.mapAccumL substIdBndr subst1 ids of
(subst2,ids') -> (DataPat dc tvs' ids',substTm doc subst2 alt)
_ -> (pat,go alt)
goTick t@(SrcSpan _) = t
goTick (NameMod m ty) = NameMod m (substTy subst ty)
goTick t@DeDup = t
goTick t@NoDeDup = t
substAlt
:: HasCallStack
=> Doc ()
-> Subst
-> (Pat, Term)
-> (Pat, Term)
substAlt doc subst (pat,alt) = case pat of
DataPat dc tvs ids -> case List.mapAccumL substTyVarBndr' subst tvs of
(subst1,tvs1) -> case List.mapAccumL substIdBndr subst1 ids of
(subst2,ids1) -> (DataPat dc tvs1 ids1,substTm doc subst2 alt)
_ -> (pat, substTm doc subst alt)
substId
:: HasCallStack
=> Subst
-> Id
-> Id
substId subst oldId = snd $ substIdBndr subst oldId
lookupIdSubst
:: HasCallStack
=> Doc ()
-> Subst
-> Id
-> Term
lookupIdSubst doc (Subst inScope tmS _ genv) v
| isGlobalId v = case lookupVarEnv v genv of
Just e -> e
_ -> Var v
| Just e <- lookupVarEnv v tmS = e
| Just v'@(Id {}) <- lookupInScope inScope v = Var (coerce v')
| otherwise = WARN(True, "Subst.lookupIdSubst" <+> doc <+> fromPpr v)
Var v
substIdBndr
:: HasCallStack
=> Subst
-> Id
-> (Subst,Id)
substIdBndr subst@(Subst inScope env tenv genv) oldId =
(Subst (inScope `extendInScopeSet` newId) newEnv tenv genv, newId)
where
id1 = uniqAway inScope oldId
newId | noTypeChange = id1
| otherwise = id1 {varType = substTy subst (varType id1)}
oldTy = varType oldId
noTypeChange = nullVarEnv tenv || noFreeVarsOfType oldTy
newEnv | noChange = delVarEnv env oldId
| otherwise = extendVarEnv oldId (Var newId) env
noChange = id1 == oldId
substTyVarBndr'
:: HasCallStack
=> Subst
-> TyVar
-> (Subst,TyVar)
substTyVarBndr' (Subst inScope tmS tyS tgS) tv =
case substTyVarBndr (TvSubst inScope tyS) tv of
(TvSubst inScope' tyS',tv') -> (Subst inScope' tmS tyS' tgS, tv')
substBind
:: HasCallStack
=> Doc ()
-> Subst
-> [LetBinding]
-> (Subst,[LetBinding])
substBind doc subst xs =
(subst',zip bndrs' rhss')
where
(bndrs,rhss) = unzip xs
(subst',bndrs') = List.mapAccumL substIdBndr subst bndrs
rhss' = map (substTm ("substBind" <+> doc) subst') rhss
substTyWith
:: HasCallStack
=> [TyVar]
-> [Type]
-> Type
-> Type
substTyWith tvs tys =
ASSERT( List.equalLength tvs tys )
substTy (zipTvSubst tvs tys)
deShadowTerm
:: HasCallStack
=> InScopeSet
-> Term
-> Term
deShadowTerm is e = substTm "deShadowTerm" (mkSubst is) e
deShadowAlt ::
HasCallStack =>
InScopeSet ->
(Pat, Term) ->
(Pat, Term)
deShadowAlt is = substAlt "deShadowAlt" (mkSubst is)
deshadowLetExpr
:: HasCallStack
=> InScopeSet
-> [LetBinding]
-> Term
-> ([LetBinding],Term)
deshadowLetExpr is bs e =
case substBind "deshadowLetBindings" (mkSubst is) bs of
(s1,bs1) -> (bs1, substTm "deShadowLetBody" s1 e)
freshenTm
:: InScopeSet
-> Term
-> (InScopeSet, Term)
freshenTm is0 = go (mkSubst is0) where
go subst0 = \case
Var v -> (substInScope subst0, lookupIdSubst "freshenTm" subst0 v)
Lam v e -> case substIdBndr subst0 v of
(subst1,v') -> case go subst1 e of
(is2,e') -> (is2, Lam v' e')
TyLam v e -> case substTyVarBndr' subst0 v of
(subst1,v') -> case go subst1 e of
(is2,e') -> (is2,TyLam v' e')
App l r -> case go subst0 l of
(is1,l') -> case go subst0 {substInScope = is1} r of
(is2,r') -> (is2, App l' r')
TyApp l r -> case go subst0 l of
(is1,l') -> (is1, TyApp l' (substTy subst0 r))
Letrec bs e -> case goBind subst0 bs of
(subst1,bs') -> case go subst1 e of
(is2,e') -> (is2,Letrec bs' e')
Case subj ty alts -> case go subst0 subj of
(is1,subj') -> case List.mapAccumL (\isN -> goAlt subst0 {substInScope = isN}) is1 alts of
(is2,alts') -> (is2, Case subj' (substTy subst0 ty) alts')
Cast e t1 t2 -> case go subst0 e of
(is1, e') -> (is1, Cast e' (substTy subst0 t1) (substTy subst0 t2))
Tick tick e -> case go subst0 e of
(is1, e') -> (is1, Tick (goTick subst0 tick) e')
tm -> (substInScope subst0, tm)
goBind subst0 xs =
let (bndrs,rhss) = unzip xs
(subst1,bndrs') = List.mapAccumL substIdBndr subst0 bndrs
(is2,rhss') = List.mapAccumL (\isN -> go subst1 {substInScope = isN})
(substInScope subst1)
rhss
in (subst1 {substInScope = is2},zip bndrs' rhss')
goAlt subst0 (pat,alt) = case pat of
DataPat dc tvs ids -> case List.mapAccumL substTyVarBndr' subst0 tvs of
(subst1,tvs') -> case List.mapAccumL substIdBndr subst1 ids of
(subst2,ids') -> case go subst2 alt of
(is3,alt') -> (is3,(DataPat dc tvs' ids',alt'))
_ -> case go subst0 alt of
(is1,alt') -> (is1,(pat,alt'))
goTick subst0 (NameMod m ty) = NameMod m (substTy subst0 ty)
goTick _ tick = tick
aeqType
:: Type
-> Type
-> Bool
aeqType t1 t2 = acmpType' rnEnv t1 t2 == EQ
where
rnEnv = mkRnEnv (mkInScopeSet (tyFVsOfTypes [t1,t2]))
acmpType
:: Type
-> Type
-> Ordering
acmpType t1 t2 = acmpType' (mkRnEnv inScope) t1 t2
where
inScope = mkInScopeSet (tyFVsOfTypes [t1,t2])
acmpType'
:: RnEnv
-> Type
-> Type
-> Ordering
acmpType' = go
where
go env (VarTy tv1) (VarTy tv2) = compare (rnOccLTy env tv1) (rnOccRTy env tv2)
go _ (ConstTy c1) (ConstTy c2) = compare c1 c2
go env (ForAllTy tv1 t1) (ForAllTy tv2 t2) =
go env (varType tv1) (varType tv2) `thenCompare` go (rnTyBndr env tv1 tv2) t1 t2
go env (AppTy s1 t1) (AppTy s2 t2) =
go env s1 s2 `thenCompare` go env t1 t2
go _ (LitTy l1) (LitTy l2) = compare l1 l2
go env (AnnType _ t1) (AnnType _ t2) =
go env t1 t2
go _ t1 t2 = compare (getRank t1) (getRank t2)
getRank :: Type -> Word
getRank (VarTy {}) = 0
getRank (LitTy {}) = 1
getRank (ConstTy {}) = 2
getRank (AnnType {}) = 3
getRank (AppTy {}) = 4
getRank (ForAllTy {}) = 5
aeqTerm
:: Term
-> Term
-> Bool
aeqTerm t1 t2 = aeqTerm' inScope t1 t2
where
inScope = mkInScopeSet (localFVsOfTerms [t1,t2])
aeqTerm'
:: InScopeSet
-> Term
-> Term
-> Bool
aeqTerm' inScope t1 t2 = acmpTerm' inScope t1 t2 == EQ
acmpTerm
:: Term
-> Term
-> Ordering
acmpTerm t1 t2 = acmpTerm' inScope t1 t2
where
inScope = mkInScopeSet (localFVsOfTerms [t1,t2])
acmpTerm'
:: InScopeSet
-> Term
-> Term
-> Ordering
acmpTerm' inScope = go (mkRnEnv inScope)
where
thenCmpTm EQ rel = rel
thenCmpTm rel _ = rel
go env (Var id1) (Var id2) = compare (rnOccLId env id1) (rnOccRId env id2)
go _ (Data dc1) (Data dc2) = compare dc1 dc2
go _ (Literal l1) (Literal l2) = compare l1 l2
go _ (Prim p1) (Prim p2) = comparing primName p1 p2
go env (Lam b1 e1) (Lam b2 e2) =
acmpType' env (varType b1) (varType b2) `thenCompare`
go (rnTmBndr env b1 b2) e1 e2
go env (TyLam b1 e1) (TyLam b2 e2) =
acmpType' env (varType b1) (varType b2) `thenCompare`
go (rnTyBndr env b1 b2) e1 e2
go env (App l1 r1) (App l2 r2) =
go env l1 l2 `thenCompare` go env r1 r2
go env (TyApp l1 r1) (TyApp l2 r2) =
go env l1 l2 `thenCompare` acmpType' env r1 r2
go env (Letrec bs1 e1) (Letrec bs2 e2) =
compare (length bs1) (length bs2) `thenCompare`
foldr thenCmpTm EQ (zipWith (go env') rhs1 rhs2) `thenCompare`
go env' e1 e2
where
(ids1,rhs1) = unzip bs1
(ids2,rhs2) = unzip bs2
env' = rnTmBndrs env ids1 ids2
go env (Case e1 _ a1) (Case e2 _ a2) =
compare (length a1) (length a2) `thenCompare`
go env e1 e2 `thenCompare`
foldr thenCmpTm EQ (zipWith (goAlt env) a1 a2)
go env (Cast e1 l1 r1) (Cast e2 l2 r2) =
go env e1 e2 `thenCompare`
acmpType' env l1 l2 `thenCompare`
acmpType' env r1 r2
go env (Tick _ e1) e2 = go env e1 e2
go env e1 (Tick _ e2) = go env e1 e2
go _ e1 e2 = compare (getRank e1) (getRank e2)
goAlt env (DataPat c1 tvs1 ids1,e1) (DataPat c2 tvs2 ids2,e2) =
compare c1 c2 `thenCompare` go env' e1 e2
where
env' = rnTmBndrs (rnTyBndrs env tvs1 tvs2) ids1 ids2
goAlt env (c1,e1) (c2,e2) =
compare c1 c2 `thenCompare` go env e1 e2
getRank :: Term -> Word
getRank = \case
Var {} -> 0
Data {} -> 1
Literal {} -> 2
Prim {} -> 3
Cast {} -> 4
App {} -> 5
TyApp {} -> 6
Lam {} -> 7
TyLam {} -> 8
Letrec {} -> 9
Case {} -> 10
Tick {} -> 11
thenCompare :: Ordering -> Ordering -> Ordering
thenCompare EQ rel = rel
thenCompare rel _ = rel
instance Eq Type where
(==) = aeqType
instance Ord Type where
compare = acmpType
instance Eq Term where
(==) = aeqTerm
instance Ord Term where
compare = acmpTerm