{-# Language FlexibleInstances, PatternGuards #-}
module Cryptol.ModuleSystem.InstantiateModule
( instantiateModule
) where
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import qualified Data.Map as Map
import MonadLib(ReaderT,runReaderT,ask)
import Cryptol.Parser.Position(Located(..))
import Cryptol.ModuleSystem.Name
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst(listParamSubst, apSubst)
import Cryptol.Utils.Ident(ModName,modParamIdent)
instantiateModule :: FreshM m =>
Module ->
ModName ->
Map TParam Type ->
Map Name Expr ->
m ([Located Prop], Module)
instantiateModule func newName tpMap vpMap =
runReaderT newName $
do let oldVpNames = Map.keys vpMap
newVpNames <- mapM freshParamName (Map.keys vpMap)
let vpNames = Map.fromList (zip oldVpNames newVpNames)
env <- computeEnv func tpMap vpNames
let rnMp :: Inst a => (a -> Name) -> Map Name a -> Map Name a
rnMp f m = Map.fromList [ (f x, x) | a <- Map.elems m
, let x = inst env a ]
renamedExports = inst env (mExports func)
renamedTySyns = rnMp tsName (mTySyns func)
renamedNewtypes = rnMp ntName (mNewtypes func)
renamedPrimTys = rnMp atName (mPrimTypes func)
su = listParamSubst (Map.toList (tyParamMap env))
goals = map (fmap (apSubst su)) (mParamConstraints func)
let renamedDecls = inst env (mDecls func)
paramDecls = map (mkParamDecl su vpNames) (Map.toList vpMap)
return ( goals
, Module
{ mName = newName
, mExports = renamedExports
, mImports = mImports func
, mTySyns = renamedTySyns
, mNewtypes = renamedNewtypes
, mPrimTypes = renamedPrimTys
, mParamTypes = Map.empty
, mParamConstraints = []
, mParamFuns = Map.empty
, mDecls = paramDecls ++ renamedDecls
} )
where
mkParamDecl su vpNames (x,e) =
NonRecursive Decl
{ dName = Map.findWithDefault (error "OOPS") x vpNames
, dSignature = apSubst su
$ mvpType
$ Map.findWithDefault (error "UUPS") x (mParamFuns func)
, dDefinition = DExpr e
, dPragmas = []
, dInfix = False
, dFixity = Nothing
, dDoc = Nothing
}
class Defines t where
defines :: t -> Set Name
instance Defines t => Defines [t] where
defines = Set.unions . map defines
instance Defines Decl where
defines = Set.singleton . dName
instance Defines DeclGroup where
defines d =
case d of
NonRecursive x -> defines x
Recursive x -> defines x
type InstM = ReaderT ModName
freshenName :: FreshM m => Name -> InstM m Name
freshenName x =
do m <- ask
let sys = case nameInfo x of
Declared _ s -> s
_ -> UserName
liftSupply (mkDeclared m sys (nameIdent x) (nameFixity x) (nameLoc x))
freshParamName :: FreshM m => Name -> InstM m Name
freshParamName x =
do m <- ask
let newName = modParamIdent (nameIdent x)
liftSupply (mkDeclared m UserName newName (nameFixity x) (nameLoc x))
computeEnv :: FreshM m =>
Module ->
Map TParam Type ->
Map Name Name ->
InstM m Env
computeEnv m tpMap vpMap =
do tss <- mapM freshTy (Map.toList (mTySyns m))
nts <- mapM freshTy (Map.toList (mNewtypes m))
let tnMap = Map.fromList (tss ++ nts)
defHere <- mapM mkVParam (Set.toList (defines (mDecls m)))
let fnMap = Map.union vpMap (Map.fromList defHere)
return Env { funNameMap = fnMap
, tyNameMap = tnMap
, tyParamMap = tpMap
}
where
freshTy (x,_) = do y <- freshenName x
return (x,y)
mkVParam x = do y <- freshenName x
return (x,y)
data Env = Env
{ funNameMap :: Map Name Name
, tyNameMap :: Map Name Name
, tyParamMap :: Map TParam Type
} deriving Show
class Inst t where
inst :: Env -> t -> t
instance Inst a => Inst [a] where
inst env = map (inst env)
instance Inst Expr where
inst env = go
where
go expr =
case expr of
EVar x -> case Map.lookup x (funNameMap env) of
Just y -> EVar y
_ -> expr
EList xs t -> EList (inst env xs) (inst env t)
ETuple es -> ETuple (inst env es)
ERec xs -> ERec (fmap go xs)
ESel e s -> ESel (go e) s
ESet e x v -> ESet (go e) x (go v)
EIf e1 e2 e3 -> EIf (go e1) (go e2) (go e3)
EComp t1 t2 e mss -> EComp (inst env t1) (inst env t2)
(go e)
(inst env mss)
ETAbs t e -> ETAbs t (go e)
ETApp e t -> ETApp (go e) (inst env t)
EApp e1 e2 -> EApp (go e1) (go e2)
EAbs x t e -> EAbs x (inst env t) (go e)
EProofAbs p e -> EProofAbs (inst env p) (go e)
EProofApp e -> EProofApp (go e)
EWhere e ds -> EWhere (go e) (inst env ds)
instance Inst DeclGroup where
inst env dg =
case dg of
NonRecursive d -> NonRecursive (inst env d)
Recursive ds -> Recursive (inst env ds)
instance Inst DeclDef where
inst env d =
case d of
DPrim -> DPrim
DExpr e -> DExpr (inst env e)
instance Inst Decl where
inst env d = d { dSignature = inst env (dSignature d)
, dDefinition = inst env (dDefinition d)
, dName = Map.findWithDefault (dName d) (dName d)
(funNameMap env)
}
instance Inst Match where
inst env m =
case m of
From x t1 t2 e -> From x (inst env t1) (inst env t2) (inst env e)
Let d -> Let (inst env d)
instance Inst Schema where
inst env s = s { sProps = inst env (sProps s)
, sType = inst env (sType s)
}
instance Inst Type where
inst env ty =
case ty of
TCon tc ts -> TCon (inst env tc) (inst env ts)
TVar tv ->
case tv of
TVBound tp | Just t <- Map.lookup tp (tyParamMap env) -> t
_ -> ty
TUser x ts t -> TUser y (inst env ts) (inst env t)
where y = Map.findWithDefault x x (tyNameMap env)
TRec fs -> TRec (fmap (inst env) fs)
instance Inst TCon where
inst env tc =
case tc of
TC x -> TC (inst env x)
_ -> tc
instance Inst TC where
inst env tc =
case tc of
TCNewtype x -> TCNewtype (inst env x)
TCAbstract x -> TCAbstract (inst env x)
_ -> tc
instance Inst UserTC where
inst env (UserTC x t) = UserTC y t
where y = Map.findWithDefault x x (tyNameMap env)
instance Inst (ExportSpec Name) where
inst env es = ExportSpec { eTypes = Set.map instT (eTypes es)
, eBinds = Set.map instV (eBinds es)
}
where instT x = Map.findWithDefault x x (tyNameMap env)
instV x = Map.findWithDefault x x (funNameMap env)
instance Inst TySyn where
inst env ts = TySyn { tsName = instTyName env x
, tsParams = tsParams ts
, tsConstraints = inst env (tsConstraints ts)
, tsDef = inst env (tsDef ts)
, tsDoc = tsDoc ts
}
where x = tsName ts
instance Inst Newtype where
inst env nt = Newtype { ntName = instTyName env x
, ntParams = ntParams nt
, ntConstraints = inst env (ntConstraints nt)
, ntFields = [ (f,inst env t) | (f,t) <- ntFields nt ]
, ntDoc = ntDoc nt
}
where x = ntName nt
instance Inst AbstractType where
inst env a = AbstractType { atName = instTyName env (atName a)
, atKind = atKind a
, atCtrs = case atCtrs a of
(xs,ps) -> (xs, inst env ps)
, atFixitiy = atFixitiy a
, atDoc = atDoc a
}
instTyName :: Env -> Name -> Name
instTyName env x = Map.findWithDefault x x (tyNameMap env)