{-# LANGUAGE PatternGuards, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
module Cryptol.Transform.MonoValues (rewModule) where
import Cryptol.ModuleSystem.Name
(SupplyT,liftSupply,Supply,mkDeclared,NameSource(..))
import Cryptol.Parser.Position (emptyRange)
import Cryptol.TypeCheck.AST hiding (splitTApp)
import Cryptol.TypeCheck.TypeMap
import Cryptol.Utils.Ident (ModName)
import Data.List(sortBy,groupBy)
import Data.Either(partitionEithers)
import Data.Map (Map)
import MonadLib hiding (mapM)
import Prelude ()
import Prelude.Compat
newtype RewMap' a = RM (Map Name (TypesMap (Map Int a)))
type RewMap = RewMap' Name
instance TrieMap RewMap' (Name,[Type],Int) where
emptyTM = RM emptyTM
nullTM (RM m) = nullTM m
lookupTM (x,ts,n) (RM m) = do tM <- lookupTM x m
tP <- lookupTM ts tM
lookupTM n tP
alterTM (x,ts,n) f (RM m) = RM (alterTM x f1 m)
where
f1 Nothing = do a <- f Nothing
return (insertTM ts (insertTM n a emptyTM) emptyTM)
f1 (Just tM) = Just (alterTM ts f2 tM)
f2 Nothing = do a <- f Nothing
return (insertTM n a emptyTM)
f2 (Just pM) = Just (alterTM n f pM)
unionTM f (RM a) (RM b) = RM (unionTM (unionTM (unionTM f)) a b)
toListTM (RM m) = [ ((x,ts,n),y) | (x,tM) <- toListTM m
, (ts,pM) <- toListTM tM
, (n,y) <- toListTM pM ]
mapMaybeWithKeyTM f (RM m) =
RM (mapWithKeyTM (\qn tm ->
mapWithKeyTM (\tys is ->
mapMaybeWithKeyTM (\i a -> f (qn,tys,i) a) is) tm) m)
rewModule :: Supply -> Module -> (Module,Supply)
rewModule s m = runM body (mName m) s
where
body = do ds <- mapM (rewDeclGroup emptyTM) (mDecls m)
return m { mDecls = ds }
type M = ReaderT RO (SupplyT Id)
type RO = ModName
newName :: M Name
newName =
do ns <- ask
liftSupply (mkDeclared ns SystemName "$mono" Nothing emptyRange)
newTopOrLocalName :: M Name
newTopOrLocalName = newName
inLocal :: M a -> M a
inLocal = id
rewE :: RewMap -> Expr -> M Expr
rewE rews = go
where
tryRewrite (EVar x,tps,n) =
do y <- lookupTM (x,tps,n) rews
return (EVar y)
tryRewrite _ = Nothing
go expr =
case expr of
ETApp e t -> case tryRewrite (splitTApp expr 0) of
Nothing -> ETApp <$> go e <*> return t
Just yes -> return yes
EProofApp e -> case tryRewrite (splitTApp e 1) of
Nothing -> EProofApp <$> go e
Just yes -> return yes
EList es t -> EList <$> mapM go es <*> return t
ETuple es -> ETuple <$> mapM go es
ERec fs -> ERec <$> traverse go fs
ESel e s -> ESel <$> go e <*> return s
ESet e s v -> ESet <$> go e <*> return s <*> go v
EIf e1 e2 e3 -> EIf <$> go e1 <*> go e2 <*> go e3
EComp len t e mss -> EComp len t <$> go e <*> mapM (mapM (rewM rews)) mss
EVar _ -> return expr
ETAbs x e -> ETAbs x <$> go e
EApp e1 e2 -> EApp <$> go e1 <*> go e2
EAbs x t e -> EAbs x t <$> go e
EProofAbs x e -> EProofAbs x <$> go e
EWhere e dgs -> EWhere <$> go e <*> inLocal
(mapM (rewDeclGroup rews) dgs)
rewM :: RewMap -> Match -> M Match
rewM rews ma =
case ma of
From x len t e -> From x len t <$> rewE rews e
Let d -> Let <$> rewD rews d
rewD :: RewMap -> Decl -> M Decl
rewD rews d = do e <- rewDef rews (dDefinition d)
return d { dDefinition = e }
rewDef :: RewMap -> DeclDef -> M DeclDef
rewDef rews (DExpr e) = DExpr <$> rewE rews e
rewDef _ DPrim = return DPrim
rewDeclGroup :: RewMap -> DeclGroup -> M DeclGroup
rewDeclGroup rews dg =
case dg of
NonRecursive d -> NonRecursive <$> rewD rews d
Recursive ds ->
do let (leave,rew) = partitionEithers (map consider ds)
rewGroups = groupBy sameTParams
$ sortBy compareTParams rew
ds1 <- mapM (rewD rews) leave
ds2 <- mapM rewSame rewGroups
return $ Recursive (ds1 ++ concat ds2)
where
sameTParams (_,tps1,x,_) (_,tps2,y,_) = tps1 == tps2 && x == y
compareTParams (_,tps1,x,_) (_,tps2,y,_) = compare (x,tps1) (y,tps2)
consider d =
case dDefinition d of
DPrim -> Left d
DExpr e -> let (tps,props,e') = splitTParams e
in if not (null tps) && notFun e'
then Right (d, tps, props, e')
else Left d
rewSame ds =
do new <- forM ds $ \(d,_,_,e) ->
do x <- newName
return (d, x, e)
let (_,tps,props,_) : _ = ds
tys = map (TVar . tpVar) tps
proofNum = length props
addRew (d,x,_) = insertTM (dName d,tys,proofNum) x
newRews = foldr addRew rews new
newDs <- forM new $ \(d,newN,e) ->
do e1 <- rewE newRews e
return ( d
, d { dName = newN
, dSignature = (dSignature d)
{ sVars = [], sProps = [] }
, dDefinition = DExpr e1
}
)
case newDs of
[(f,f')] ->
return [ f { dDefinition =
let newBody = EVar (dName f')
newE = EWhere newBody
[ Recursive [f'] ]
in DExpr $ foldr ETAbs
(foldr EProofAbs newE props) tps
}
]
_ -> do tupName <- newTopOrLocalName
let (polyDs,monoDs) = unzip newDs
tupAr = length monoDs
addTPs = flip (foldr ETAbs) tps
. flip (foldr EProofAbs) props
tupD = Decl
{ dName = tupName
, dSignature =
Forall tps props $
TCon (TC (TCTuple tupAr))
(map (sType . dSignature) monoDs)
, dDefinition =
DExpr $
addTPs $
EWhere (ETuple [ EVar (dName d) | d <- monoDs ])
[ Recursive monoDs ]
, dPragmas = []
, dInfix = False
, dFixity = Nothing
, dDoc = Nothing
}
mkProof e _ = EProofApp e
mkFunDef n f =
f { dDefinition =
DExpr $
addTPs $ ESel ( flip (foldl mkProof) props
$ flip (foldl ETApp) tys
$ EVar tupName
) (TupleSel n (Just tupAr))
}
return (tupD : zipWith mkFunDef [ 0 .. ] polyDs)
splitTParams :: Expr -> ([TParam], [Prop], Expr)
splitTParams e = let (tps, e1) = splitWhile splitTAbs e
(props, e2) = splitWhile splitProofAbs e1
in (tps,props,e2)
splitTApp :: Expr -> Int -> (Expr, [Type], Int)
splitTApp (EProofApp e) n = splitTApp e $! (n + 1)
splitTApp e0 n = let (e1,ts) = splitTy e0 []
in (e1, ts, n)
where
splitTy (ETApp e t) ts = splitTy e (t:ts)
splitTy e ts = (e,ts)
notFun :: Expr -> Bool
notFun (EAbs {}) = False
notFun (EProofAbs _ e) = notFun e
notFun _ = True