module Cryptol.Transform.Specialize
where
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.TypeMap
import Cryptol.TypeCheck.Subst
import qualified Cryptol.ModuleSystem as M
import qualified Cryptol.ModuleSystem.Env as M
import qualified Cryptol.ModuleSystem.Monad as M
import Cryptol.ModuleSystem.Name
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import MonadLib hiding (mapM)
import Prelude ()
import Prelude.Compat
type SpecCache = Map Name (Decl, TypesMap (Name, Maybe Decl))
type SpecT m a = StateT SpecCache (M.ModuleT m) a
type SpecM a = SpecT IO a
runSpecT :: SpecCache -> SpecT m a -> M.ModuleT m (a, SpecCache)
runSpecT s m = runStateT s m
liftSpecT :: Monad m => M.ModuleT m a -> SpecT m a
liftSpecT m = lift m
getSpecCache :: Monad m => SpecT m SpecCache
getSpecCache = get
setSpecCache :: Monad m => SpecCache -> SpecT m ()
setSpecCache = set
modifySpecCache :: Monad m => (SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache = modify
modify :: StateM m s => (s -> s) -> m ()
modify f = get >>= (set . f)
specialize :: Expr -> M.ModuleCmd Expr
specialize expr (ev,modEnv) = run $ do
let extDgs = allDeclGroups modEnv
let (tparams, expr') = destETAbs expr
spec' <- specializeEWhere expr' extDgs
return (foldr ETAbs spec' tparams)
where
run = M.runModuleT (ev,modEnv) . fmap fst . runSpecT Map.empty
specializeExpr :: Expr -> SpecM Expr
specializeExpr expr =
case expr of
EList es t -> EList <$> traverse specializeExpr es <*> pure t
ETuple es -> ETuple <$> traverse specializeExpr es
ERec fs -> ERec <$> traverse (traverseSnd specializeExpr) fs
ESel e s -> ESel <$> specializeExpr e <*> pure s
ESet e s v -> ESet <$> specializeExpr e <*> pure s <*> specializeExpr v
EIf e1 e2 e3 -> EIf <$> specializeExpr e1 <*> specializeExpr e2 <*> specializeExpr e3
EComp len t e mss -> EComp len t <$> specializeExpr e <*> traverse (traverse specializeMatch) mss
EVar {} -> specializeConst expr
ETAbs t e -> do
cache <- getSpecCache
setSpecCache Map.empty
e' <- specializeExpr e
setSpecCache cache
return (ETAbs t e')
ETApp {} -> specializeConst expr
EApp e1 e2 -> EApp <$> specializeExpr e1 <*> specializeExpr e2
EAbs qn t e -> EAbs qn t <$> specializeExpr e
EProofAbs p e -> EProofAbs p <$> specializeExpr e
EProofApp {} -> specializeConst expr
EWhere e dgs -> specializeEWhere e dgs
specializeMatch :: Match -> SpecM Match
specializeMatch (From qn l t e) = From qn l t <$> specializeExpr e
specializeMatch (Let decl)
| null (sVars (dSignature decl)) = return (Let decl)
| otherwise = fail "unimplemented: specializeMatch Let unimplemented"
withDeclGroups :: [DeclGroup] -> SpecM a
-> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
withDeclGroups dgs action = do
origCache <- getSpecCache
let decls = concatMap groupDecls dgs
let newCache = Map.fromList [ (dName d, (d, emptyTM)) | d <- decls ]
let savedCache = Map.intersection origCache newCache
setSpecCache (Map.union newCache origCache)
result <- action
let splitDecl :: Decl -> SpecM [Decl]
splitDecl d = do
~(Just (_, tm)) <- Map.lookup (dName d) <$> getSpecCache
return (catMaybes $ map (snd . snd) $ toListTM tm)
let splitDeclGroup :: DeclGroup -> SpecM [DeclGroup]
splitDeclGroup (Recursive ds) = do
ds' <- concat <$> traverse splitDecl ds
if null ds'
then return []
else return [Recursive ds']
splitDeclGroup (NonRecursive d) = map NonRecursive <$> splitDecl d
dgs' <- concat <$> traverse splitDeclGroup dgs
newCache' <- flip Map.intersection newCache <$> getSpecCache
let nameTable = fmap (fmap fst . snd) newCache'
modifySpecCache (Map.union savedCache . flip Map.difference newCache)
return (result, dgs', nameTable)
specializeEWhere :: Expr -> [DeclGroup] -> SpecM Expr
specializeEWhere e dgs = do
(e', dgs', _) <- withDeclGroups dgs (specializeExpr e)
return $ if null dgs'
then e'
else EWhere e' dgs'
specializeDeclGroups :: [DeclGroup] -> SpecM ([DeclGroup], Map Name (TypesMap Name))
specializeDeclGroups dgs = do
let decls = concatMap groupDecls dgs
let isMonoType s = null (sVars s) && null (sProps s)
let monos = [ EVar (dName d) | d <- decls, isMonoType (dSignature d) ]
(_, dgs', names) <- withDeclGroups dgs $ mapM specializeExpr monos
return (dgs', names)
specializeConst :: Expr -> SpecM Expr
specializeConst e0 = do
let (e1, n) = destEProofApps e0
let (e2, ts) = destETApps e1
case e2 of
EVar qname ->
do cache <- getSpecCache
case Map.lookup qname cache of
Nothing -> return e0
Just (decl, tm) ->
case lookupTM ts tm of
Just (qname', _) -> return (EVar qname')
Nothing -> do
qname' <- freshName qname ts
sig' <- instantiateSchema ts n (dSignature decl)
modifySpecCache (Map.adjust (fmap (insertTM ts (qname', Nothing))) qname)
rhs' <- case dDefinition decl of
DExpr e -> do e' <- specializeExpr =<< instantiateExpr ts n e
return (DExpr e')
DPrim -> return DPrim
let decl' = decl { dName = qname', dSignature = sig', dDefinition = rhs' }
modifySpecCache (Map.adjust (fmap (insertTM ts (qname', Just decl'))) qname)
return (EVar qname')
_ -> return e0
destEProofApps :: Expr -> (Expr, Int)
destEProofApps = go 0
where
go n (EProofApp e) = go (n + 1) e
go n e = (e, n)
destETApps :: Expr -> (Expr, [Type])
destETApps = go []
where
go ts (ETApp e t) = go (t : ts) e
go ts e = (e, ts)
destEProofAbs :: Expr -> ([Prop], Expr)
destEProofAbs = go []
where
go ps (EProofAbs p e) = go (p : ps) e
go ps e = (ps, e)
destETAbs :: Expr -> ([TParam], Expr)
destETAbs = go []
where
go ts (ETAbs t e) = go (t : ts) e
go ts e = (ts, e)
freshName :: Name -> [Type] -> SpecM Name
freshName n _ =
case nameInfo n of
Declared ns s -> liftSupply (mkDeclared ns s ident fx loc)
Parameter -> liftSupply (mkParameter ident loc)
where
fx = nameFixity n
ident = nameIdent n
loc = nameLoc n
instantiateSchema :: [Type] -> Int -> Schema -> SpecM Schema
instantiateSchema ts n (Forall params props ty)
| length params /= length ts = fail "instantiateSchema: wrong number of type arguments"
| length props /= n = fail "instantiateSchema: wrong number of prop arguments"
| otherwise = return $ Forall [] [] (apSubst sub ty)
where sub = listParamSubst (zip params ts)
instantiateExpr :: [Type] -> Int -> Expr -> SpecM Expr
instantiateExpr [] 0 e = return e
instantiateExpr [] n (EProofAbs _ e) = instantiateExpr [] (n - 1) e
instantiateExpr (t : ts) n (ETAbs param e) =
instantiateExpr ts n (apSubst (singleSubst (tpVar param) t) e)
instantiateExpr _ _ _ = fail "instantiateExpr: wrong number of type/proof arguments"
allDeclGroups :: M.ModuleEnv -> [DeclGroup]
allDeclGroups =
concatMap mDecls
. M.loadedModules
traverseSnd :: Functor f => (b -> f c) -> (a, b) -> f (a, c)
traverseSnd f (x, y) = (,) x <$> f y