{-# Language OverloadedStrings #-}
module Cryptol.TypeCheck.Instantiate
( instantiateWith
, TypeArg(..)
, uncheckedTypeArg
, MaybeCheckedType(..)
) where
import Cryptol.ModuleSystem.Name (nameIdent)
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Monad
import Cryptol.TypeCheck.Subst (listParamSubst, apSubst)
import Cryptol.TypeCheck.Kind(checkType)
import Cryptol.TypeCheck.Error
import Cryptol.Parser.Position (Located(..))
import Cryptol.Utils.Ident (Ident)
import Cryptol.Utils.Panic(panic)
import qualified Cryptol.Parser.AST as P
import Control.Monad(zipWithM)
import Data.Function (on)
import Data.List(sortBy, groupBy, find)
import Data.Maybe(mapMaybe,isJust)
import Data.Either(partitionEithers)
import qualified Data.Set as Set
data TypeArg = TypeArg
{ tyArgName :: Maybe (Located Ident)
, tyArgType :: MaybeCheckedType
}
uncheckedTypeArg :: P.TypeInst Name -> TypeArg
uncheckedTypeArg a =
case a of
P.NamedInst x ->
TypeArg { tyArgName = Just (P.name x), tyArgType = Unchecked (P.value x) }
P.PosInst t ->
TypeArg { tyArgName = Nothing, tyArgType = Unchecked t }
data MaybeCheckedType = Checked Type | Unchecked (P.Type Name)
checkTyParam :: TVarSource -> Kind -> MaybeCheckedType -> InferM Type
checkTyParam src k mb =
case mb of
Checked t
| k == k' -> pure t
| otherwise -> do recordError (KindMismatch k k')
newType src k
where k' = kindOf t
Unchecked t -> checkType t (Just k)
instantiateWith :: Name -> Expr -> Schema -> [TypeArg] -> InferM (Expr,Type)
instantiateWith nm e s ts
| null named = instantiateWithPos nm e s positional
| null positional = instantiateWithNames nm e s named
| otherwise = do recordError CannotMixPositionalAndNamedTypeParams
instantiateWithNames nm e s named
where
(named,positional) = partitionEithers (map classify ts)
classify t = case tyArgName t of
Just n -> Left n { thing = (thing n, tyArgType t) }
Nothing -> Right (tyArgType t)
instantiateWithPos ::
Name -> Expr -> Schema -> [MaybeCheckedType] -> InferM (Expr,Type)
instantiateWithPos nm e (Forall as ps t) ts =
do su <- makeSu (1::Int) [] as ts
doInst su e ps t
where
isNamed q = isJust (tpName q)
makeSu n su (q : qs) (mbty : tys)
| not (isNamed q) = do r <- unnamed n q
makeSu (n+1) (r : su) qs (mbty : tys)
| otherwise = do ty <- checkTyParam (TypeParamInstPos nm n) (kindOf q) mbty
makeSu (n+1) ((q,ty) : su) qs tys
makeSu _ su [] [] = return (reverse su)
makeSu n su (q : qs) [] = do r <- unnamed n q
makeSu (n+1) (r : su) qs []
makeSu _ su [] _ = do recordError TooManyPositionalTypeParams
return (reverse su)
unnamed n q = do ty <- newType src (kindOf q)
return (q, ty)
where
src = case drop (n-1) as of
p:_ ->
case tpFlav p of
TPOther (Just a) -> TypeParamInstNamed nm (nameIdent a)
_ -> TypeParamInstPos nm n
_ -> panic "instantiateWithPos"
[ "Invalid parameter index", show n, show as ]
instantiateWithNames ::
Name -> Expr -> Schema -> [Located (Ident,MaybeCheckedType)]
-> InferM (Expr,Type)
instantiateWithNames nm e (Forall as ps t) xs =
do sequence_ repeatedParams
mapM_ (recordError . UndefinedTypeParameter . fmap fst) undefParams
su' <- zipWithM paramInst [ 1.. ] as
doInst su' e ps t
where
paramInst n x =
do let k = tpKind x
lkp name = find (\th -> fst (thing th) == nameIdent name) xs
src = case tpName x of
Just na -> TypeParamInstNamed nm (nameIdent na)
Nothing -> TypeParamInstPos nm n
ty <- case lkp =<< tpName x of
Just lty -> checkTyParam src k (snd (thing lty))
Nothing -> newType src k
return (x, ty)
repeatedParams = mapMaybe isRepeated
$ groupBy ((==) `on` pName)
$ sortBy (compare `on` pName) xs
isRepeated ys@(a : _ : _) =
Just $ recordError (RepeatedTypeParameter (fst (thing a)) (map srcRange ys))
isRepeated _ = Nothing
paramIdents = [ nameIdent n | Just n <- map tpName as ]
undefParams = [ x | x <- xs, pName x `notElem` paramIdents ]
pName = fst . thing
doInst :: [(TParam, Type)] -> Expr -> [Prop] -> Type -> InferM (Expr,Type)
doInst su' e ps t =
do let su = listParamSubst su'
newGoals (CtInst e) (map (apSubst su) ps)
let t1 = apSubst su t
ps' <- concat <$> mapM checkInst su'
newGoals (CtInst e) ps'
return ( addProofParams (addTyParams (map snd su') e), t1 )
where
addTyParams ts e1 = foldl ETApp e1 ts
addProofParams e1 = foldl (\e2 _ -> EProofApp e2) e1 ps
frees = Set.unions (map fvs (t : ps))
bounds = Set.unions (map scope (Set.toList frees))
where
scope (TVFree _ _ vs _) = vs
scope (TVBound _) = Set.empty
checkInst :: (TParam, Type) -> InferM [Prop]
checkInst (tp, ty)
| Set.notMember tp bounds = return []
| otherwise = unify (TVar (tpVar tp)) ty