{-# LANGUAGE PatternGuards, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Cryptol.Transform.MonoValues (rewModule) where
import Cryptol.ModuleSystem.Name
import Cryptol.Parser.Position (emptyRange)
import Cryptol.TypeCheck.AST hiding (splitTApp)
import Cryptol.TypeCheck.TypeMap
import Cryptol.Utils.Ident (ModName)
import Data.List(sortBy,groupBy)
import Data.Either(partitionEithers)
import Data.Map (Map)
import MonadLib hiding (mapM)
import Prelude ()
import Prelude.Compat
newtype RewMap' a = RM (Map Name (TypesMap (Map Int a)))
type RewMap = RewMap' Name
instance TrieMap RewMap' (Name,[Type],Int) where
emptyTM = RM emptyTM
nullTM (RM m) = nullTM m
lookupTM (x,ts,n) (RM m) = do tM <- lookupTM x m
tP <- lookupTM ts tM
lookupTM n tP
alterTM (x,ts,n) f (RM m) = RM (alterTM x f1 m)
f1 Nothing = do a <- f Nothing
return (insertTM ts (insertTM n a emptyTM) emptyTM)
f1 (Just tM) = Just (alterTM ts f2 tM)
f2 Nothing = do a <- f Nothing
return (insertTM n a emptyTM)
f2 (Just pM) = Just (alterTM n f pM)
unionTM f (RM a) (RM b) = RM (unionTM (unionTM (unionTM f)) a b)
toListTM (RM m) = [ ((x,ts,n),y) | (x,tM) <- toListTM m
, (ts,pM) <- toListTM tM
, (n,y) <- toListTM pM ]
mapMaybeWithKeyTM f (RM m) =
RM (mapWithKeyTM (\qn tm ->
mapWithKeyTM (\tys is ->
mapMaybeWithKeyTM (\i a -> f (qn,tys,i) a) is) tm) m)
rewModule :: Supply -> Module -> (Module,Supply)
rewModule s m = runM body (mName m) s
body = do ds <- mapM (rewDeclGroup emptyTM) (mDecls m)
return m { mDecls = ds }
type M = ReaderT RO (SupplyT Id)
type RO = ModName
newName :: M Name
newName =
do ns <- ask
liftSupply (mkDeclared ns SystemName "$mono" Nothing emptyRange)
newTopOrLocalName :: M Name
newTopOrLocalName = newName
inLocal :: M a -> M a
inLocal = id
rewE :: RewMap -> Expr -> M Expr
rewE rews = go
tryRewrite (EVar x,tps,n) =
do y <- lookupTM (x,tps,n) rews
return (EVar y)
tryRewrite _ = Nothing
go expr =
case expr of
ETApp e t -> case tryRewrite (splitTApp expr 0) of
Nothing -> ETApp <$> go e <*> return t
Just yes -> return yes
EProofApp e -> case tryRewrite (splitTApp e 1) of
Nothing -> EProofApp <$> go e
Just yes -> return yes
EList es t -> EList <$> mapM go es <*> return t
ETuple es -> ETuple <$> mapM go es
ERec fs -> ERec <$> traverse go fs
ESel e s -> ESel <$> go e <*> return s
ESet e s v -> ESet <$> go e <*> return s <*> go v
EIf e1 e2 e3 -> EIf <$> go e1 <*> go e2 <*> go e3
EComp len t e mss -> EComp len t <$> go e <*> mapM (mapM (rewM rews)) mss
EVar _ -> return expr
ETAbs x e -> ETAbs x <$> go e
EApp e1 e2 -> EApp <$> go e1 <*> go e2
EAbs x t e -> EAbs x t <$> go e
EProofAbs x e -> EProofAbs x <$> go e
EWhere e dgs -> EWhere <$> go e <*> inLocal
(mapM (rewDeclGroup rews) dgs)
rewM :: RewMap -> Match -> M Match
rewM rews ma =
case ma of
From x len t e -> From x len t <$> rewE rews e
Let d -> Let <$> rewD rews d
rewD :: RewMap -> Decl -> M Decl
rewD rews d = do e <- rewDef rews (dDefinition d)
return d { dDefinition = e }
rewDef :: RewMap -> DeclDef -> M DeclDef
rewDef rews (DExpr e) = DExpr <$> rewE rews e
rewDef _ DPrim = return DPrim
rewDeclGroup :: RewMap -> DeclGroup -> M DeclGroup
rewDeclGroup rews dg =
case dg of
NonRecursive d -> NonRecursive <$> rewD rews d
Recursive ds ->
do let (leave,rew) = partitionEithers (map consider ds)
rewGroups = groupBy sameTParams
$ sortBy compareTParams rew
ds1 <- mapM (rewD rews) leave
ds2 <- mapM rewSame rewGroups
return $ Recursive (ds1 ++ concat ds2)
sameTParams (_,tps1,x,_) (_,tps2,y,_) = tps1 == tps2 && x == y
compareTParams (_,tps1,x,_) (_,tps2,y,_) = compare (x,tps1) (y,tps2)
consider d =
case dDefinition d of
DPrim -> Left d
DExpr e -> let (tps,props,e') = splitTParams e
in if not (null tps) && notFun e'
then Right (d, tps, props, e')
else Left d
rewSame ds =
do new <- forM ds $ \(d,_,_,e) ->
do x <- newName
return (d, x, e)
let (_,tps,props,_) : _ = ds
tys = map (TVar . tpVar) tps
proofNum = length props
addRew (d,x,_) = insertTM (dName d,tys,proofNum) x
newRews = foldr addRew rews new
newDs <- forM new $ \(d,newN,e) ->
do e1 <- rewE newRews e
return ( d
, d { dName = newN
, dSignature = (dSignature d)
{ sVars = [], sProps = [] }
, dDefinition = DExpr e1
case newDs of
[(f,f')] ->
return [ f { dDefinition =
let newBody = EVar (dName f')
newE = EWhere newBody
[ Recursive [f'] ]
in DExpr $ foldr ETAbs
(foldr EProofAbs newE props) tps
_ -> do tupName <- newTopOrLocalName
let (polyDs,monoDs) = unzip newDs
tupAr = length monoDs
addTPs = flip (foldr ETAbs) tps
. flip (foldr EProofAbs) props
tupD = Decl
{ dName = tupName
, dSignature =
Forall tps props $
TCon (TC (TCTuple tupAr))
(map (sType . dSignature) monoDs)
, dDefinition =
DExpr $
addTPs $
EWhere (ETuple [ EVar (dName d) | d <- monoDs ])
[ Recursive monoDs ]
, dPragmas = []
, dInfix = False
, dFixity = Nothing
, dDoc = Nothing
mkProof e _ = EProofApp e
mkFunDef n f =
f { dDefinition =
DExpr $
addTPs $ ESel ( flip (foldl mkProof) props
$ flip (foldl ETApp) tys
$ EVar tupName
) (TupleSel n (Just tupAr))
return (tupD : zipWith mkFunDef [ 0 .. ] polyDs)
splitTParams :: Expr -> ([TParam], [Prop], Expr)
splitTParams e = let (tps, e1) = splitWhile splitTAbs e
(props, e2) = splitWhile splitProofAbs e1
in (tps,props,e2)
splitTApp :: Expr -> Int -> (Expr, [Type], Int)
splitTApp (EProofApp e) n = splitTApp e $! (n + 1)
splitTApp e0 n = let (e1,ts) = splitTy e0 []
in (e1, ts, n)
splitTy (ETApp e t) ts = splitTy e (t:ts)
splitTy e ts = (e,ts)
notFun :: Expr -> Bool
notFun (EAbs {}) = False
notFun (EProofAbs _ e) = notFun e
notFun _ = True