module Cryptol.Transform.AddModParams (addModParams) where
import Data.Map ( Map )
import qualified Data.Map as Map
import Data.Set ( Set )
import qualified Data.Set as Set
import Data.Either(partitionEithers)
import Data.List(find,sortBy)
import Data.Ord(comparing)
import Cryptol.TypeCheck.AST
import Cryptol.Parser.Position(thing)
import Cryptol.ModuleSystem.Name(toParamInstName,asParamName,nameIdent
,paramModRecParam)
import Cryptol.Utils.Ident(paramInstModName)
import Cryptol.Utils.RecordMap(recordFromFields)
addModParams :: Module -> Either [Name] Module
addModParams m =
case getParams m of
Left errs -> Left errs
Right ps ->
let toInst = Set.unions ( Map.keysSet (mTySyns m)
: Map.keysSet (mNewtypes m)
: map defs (mDecls m)
)
inp = (toInst, ps { pTypeConstraints = inst inp (pTypeConstraints ps) })
in Right m { mName = paramInstModName (mName m)
, mTySyns = fixMap inp (mTySyns m)
, mNewtypes = fixMap inp (mNewtypes m)
, mDecls = fixUp inp (mDecls m)
, mParamTypes = Map.empty
, mParamConstraints = []
, mParamFuns = Map.empty
}
defs :: DeclGroup -> Set Name
defs dg =
case dg of
Recursive ds -> Set.fromList (map dName ds)
NonRecursive d -> Set.singleton (dName d)
fixUp :: (AddParams a, Inst a) => Inp -> a -> a
fixUp i = addParams (snd i) . inst i
fixMap :: (AddParams a, Inst a) => Inp -> Map Name a -> Map Name a
fixMap i m =
Map.fromList [ (toParamInstName x, fixUp i a) | (x,a) <- Map.toList m ]
data Params = Params
{ pTypes :: [TParam]
, pTypeConstraints :: [Prop]
, pFuns :: [(Name,Type)]
}
getParams :: Module -> Either [Name] Params
getParams m
| null errs =
let ps = Params { pTypes = map rnTP
$ sortBy (comparing mtpNumber)
$ Map.elems
$ mParamTypes m
, pTypeConstraints = map thing (mParamConstraints m)
, pFuns = oks
}
in Right ps
| otherwise = Left errs
where
(errs,oks) = partitionEithers (map checkFunP (Map.toList (mParamFuns m)))
checkFunP (x,s) = case isMono (mvpType s) of
Just t -> Right (asParamName x, t)
Nothing -> Left x
rnTP tp = mtpParam tp { mtpName = asParamName (mtpName tp) }
class AddParams a where
addParams :: Params -> a -> a
instance AddParams a => AddParams [a] where
addParams ps = map (addParams ps)
instance AddParams Schema where
addParams ps s = s { sVars = pTypes ps ++ sVars s
, sProps = pTypeConstraints ps ++ sProps s
, sType = addParams ps (sType s)
}
instance AddParams Type where
addParams ps t
| null (pFuns ps) = t
| otherwise = tFun (paramRecTy ps) t
instance AddParams Expr where
addParams ps e = foldr ETAbs withProps (pTypes ps ++ as)
where (as,rest1) = splitWhile splitTAbs e
(bs,rest2) = splitWhile splitProofAbs rest1
withProps = foldr EProofAbs withArgs (pTypeConstraints ps ++ bs)
withArgs
| null (pFuns ps) = rest2
| otherwise = EAbs paramModRecParam (paramRecTy ps) rest2
instance AddParams DeclGroup where
addParams ps dg =
case dg of
Recursive ds -> Recursive (addParams ps ds)
NonRecursive d -> NonRecursive (addParams ps d)
instance AddParams Decl where
addParams ps d =
case dDefinition d of
DPrim -> d
DExpr e -> d { dSignature = addParams ps (dSignature d)
, dDefinition = DExpr (addParams ps e)
, dName = toParamInstName (dName d)
}
instance AddParams TySyn where
addParams ps ts = ts { tsParams = pTypes ps ++ tsParams ts
, tsConstraints = pTypeConstraints ps ++ tsConstraints ts
, tsName = toParamInstName (tsName ts)
}
instance AddParams Newtype where
addParams ps nt = nt { ntParams = pTypes ps ++ ntParams nt
, ntConstraints = pTypeConstraints ps ++ ntConstraints nt
, ntName = toParamInstName (ntName nt)
}
class Inst a where
inst :: Inp -> a -> a
type Inp = (Set Name, Params)
paramRecTy :: Params -> Type
paramRecTy ps = tRec (recordFromFields [ (nameIdent x, t) | (x,t) <- pFuns ps ])
nameInst :: Inp -> Name -> [Type] -> Int -> Expr
nameInst (_,ps) x ts prfs
| null (pFuns ps) = withProofs
| otherwise = EApp withProofs (EVar paramModRecParam)
where
withProofs = iterate EProofApp withTys !!
(length (pTypeConstraints ps) + prfs)
withTys = foldl ETApp (EVar (toParamInstName x))
(map (TVar . tpVar) (pTypes ps) ++ ts)
instTyParams :: Inp -> [Type]
instTyParams (_,ps) = map (TVar . tpVar) (pTypes ps)
needsInst :: Inp -> Name -> Bool
needsInst (xs,_) x = Set.member x xs
isVParam :: Inp -> Name -> Bool
isVParam (_,ps) x = x `elem` map fst (pFuns ps)
isTParam :: Inp -> TVar -> Maybe TParam
isTParam (_,ps) x =
case x of
TVBound tp -> find thisName (pTypes ps)
where thisName y = tpName tp == tpName y
_ -> Nothing
instance Inst a => Inst [a] where
inst ps = map (inst ps)
instance Inst Expr where
inst ps expr =
case expr of
EVar x
| needsInst ps x -> nameInst ps x [] 0
| isVParam ps x ->
let sh = map (nameIdent . fst) (pFuns (snd ps))
in ESel (EVar paramModRecParam) (RecordSel (nameIdent x) (Just sh))
| otherwise -> EVar x
EList es t -> EList (inst ps es) (inst ps t)
ETuple es -> ETuple (inst ps es)
ERec fs -> ERec (fmap (inst ps) fs)
ESel e s -> ESel (inst ps e) s
ESet e s v -> ESet (inst ps e) s (inst ps v)
EIf e1 e2 e3 -> EIf (inst ps e1) (inst ps e2) (inst ps e3)
EComp t1 t2 e ms -> EComp (inst ps t1) (inst ps t2)
(inst ps e) (inst ps ms)
ETAbs x e -> ETAbs x (inst ps e)
ETApp e1 t ->
case splitExprInst expr of
(EVar x, ts, prfs) | needsInst ps x -> nameInst ps x ts prfs
_ -> ETApp (inst ps e1) t
EApp e1 e2 -> EApp (inst ps e1) (inst ps e2)
EAbs x t e -> EAbs x (inst ps t) (inst ps e)
EProofAbs p e -> EProofAbs (inst ps p) (inst ps e)
EProofApp e1 ->
case splitExprInst expr of
(EVar x, ts, prfs) | needsInst ps x -> nameInst ps x ts prfs
_ -> EProofApp (inst ps e1)
EWhere e dgs -> EWhere (inst ps e) (inst ps dgs)
instance Inst Match where
inst ps m =
case m of
From x t1 t2 e -> From x (inst ps t1) (inst ps t2) (inst ps e)
Let d -> Let (inst ps d)
instance Inst DeclGroup where
inst ps dg =
case dg of
Recursive ds -> Recursive (inst ps ds)
NonRecursive d -> NonRecursive (inst ps d)
instance Inst Decl where
inst ps d = d { dDefinition = inst ps (dDefinition d) }
instance Inst DeclDef where
inst ps d =
case d of
DPrim -> DPrim
DExpr e -> DExpr (inst ps e)
instance Inst Type where
inst ps ty =
case ty of
TUser x ts t
| needsInst ps x -> TUser x (instTyParams ps ++ ts1) t1
| otherwise -> TUser x ts1 t1
where ts1 = inst ps ts
t1 = inst ps t
TCon tc ts ->
case tc of
TC (TCNewtype (UserTC x k))
| needsInst ps x -> TCon (TC (TCNewtype (UserTC x (k1 k))))
(newTs ++ ts1)
_ -> TCon tc ts1
where
ts1 = inst ps ts
newTs = instTyParams ps
k1 k = foldr (:->) k (map kindOf newTs)
TVar x | Just x' <- isTParam ps x -> TVar (TVBound x')
| otherwise -> ty
TRec xs -> TRec (fmap (inst ps) xs)
instance Inst TySyn where
inst ps ts = ts { tsConstraints = inst ps (tsConstraints ts)
, tsDef = inst ps (tsDef ts)
}
instance Inst Newtype where
inst ps nt = nt { ntConstraints = inst ps (ntConstraints nt)
, ntFields = [ (f, inst ps t) | (f,t) <- ntFields nt ]
}