{-# LANGUAGE Safe #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE OverloadedStrings #-}
module Cryptol.TypeCheck.Monad
( module Cryptol.TypeCheck.Monad
, module Cryptol.TypeCheck.InferTypes
) where
import Cryptol.ModuleSystem.Name
(FreshM(..),Supply,mkParameter
, nameInfo, NameInfo(..),NameSource(..))
import Cryptol.Parser.Position
import qualified Cryptol.Parser.AST as P
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst
import Cryptol.TypeCheck.Unify(mgu, runResult, UnificationError(..))
import Cryptol.TypeCheck.InferTypes
import Cryptol.TypeCheck.Error(Warning(..),Error(..),cleanupErrors)
import Cryptol.TypeCheck.PP (brackets, commaSep)
import qualified Cryptol.TypeCheck.SimpleSolver as Simple
import qualified Cryptol.TypeCheck.Solver.SMT as SMT
import Cryptol.Utils.PP(pp, (<+>), text, quotes)
import Cryptol.Utils.Ident(Ident)
import Cryptol.Utils.Panic(panic)
import qualified Control.Applicative as A
import Control.Monad.Fix(MonadFix(..))
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Map (Map)
import Data.Set (Set)
import Data.List(find, foldl')
import Data.Maybe(mapMaybe,fromMaybe)
import MonadLib hiding (mapM)
import Data.IORef
import GHC.Generics (Generic)
import Control.DeepSeq
import Prelude ()
import Prelude.Compat
data InferInput = InferInput
{ inpRange :: Range
, inpVars :: Map Name Schema
, inpTSyns :: Map Name TySyn
, inpNewtypes :: Map Name Newtype
, inpParamTypes :: !(Map Name ModTParam)
, inpParamConstraints :: !([Located Prop])
, inpParamFuns :: !(Map Name ModVParam)
, inpNameSeeds :: NameSeeds
, inpMonoBinds :: Bool
, inpSolverConfig :: SolverConfig
, inpSearchPath :: [FilePath]
, inpPrimNames :: !PrimMap
, inpSupply :: !Supply
} deriving Show
data NameSeeds = NameSeeds
{ seedTVar :: !Int
, seedGoal :: !Int
} deriving (Show, Generic, NFData)
nameSeeds :: NameSeeds
nameSeeds = NameSeeds { seedTVar = 10, seedGoal = 0 }
data InferOutput a
= InferFailed [(Range,Warning)] [(Range,Error)]
| InferOK [(Range,Warning)] NameSeeds Supply a
deriving Show
bumpCounter :: InferM ()
bumpCounter = do RO { .. } <- IM ask
io $ modifyIORef' iSolveCounter (+1)
runInferM :: TVars a => InferInput -> InferM a -> IO (InferOutput a)
runInferM info (IM m) = SMT.withSolver (inpSolverConfig info) $ \solver ->
do coutner <- newIORef 0
rec ro <- return RO { iRange = inpRange info
, iVars = Map.map ExtVar (inpVars info)
, iTVars = []
, iTSyns = fmap mkExternal (inpTSyns info)
, iNewtypes = fmap mkExternal (inpNewtypes info)
, iParamTypes = inpParamTypes info
, iParamFuns = inpParamFuns info
, iParamConstraints = inpParamConstraints info
, iSolvedHasLazy = iSolvedHas finalRW
, iMonoBinds = inpMonoBinds info
, iSolver = solver
, iPrimNames = inpPrimNames info
, iSolveCounter = coutner
}
(result, finalRW) <- runStateT rw
$ runReaderT ro m
let theSu = iSubst finalRW
defSu = defaultingSubst theSu
warns = [(r,apSubst theSu w) | (r,w) <- iWarnings finalRW ]
case iErrors finalRW of
[] ->
case (iCts finalRW, iHasCts finalRW) of
(cts,[])
| nullGoals cts
-> return $ InferOK warns
(iNameSeeds finalRW)
(iSupply finalRW)
(apSubst defSu result)
(cts,has) -> return $ InferFailed warns
$ cleanupErrors
[ ( goalRange g
, UnsolvedGoals False [apSubst theSu g]
) | g <- fromGoals cts ++ map hasGoal has
]
errs -> return $ InferFailed warns
$ cleanupErrors [(r,apSubst theSu e) | (r,e) <- errs]
where
mkExternal x = (IsExternal, x)
rw = RW { iErrors = []
, iWarnings = []
, iSubst = emptySubst
, iExistTVars = []
, iNameSeeds = inpNameSeeds info
, iCts = emptyGoals
, iHasCts = []
, iSolvedHas = Map.empty
, iSupply = inpSupply info
}
newtype InferM a = IM { unIM :: ReaderT RO (StateT RW IO) a }
data DefLoc = IsLocal | IsExternal
data RO = RO
{ iRange :: Range
, iVars :: Map Name VarType
, iTVars :: [TParam]
, iTSyns :: Map Name (DefLoc, TySyn)
, iNewtypes :: Map Name (DefLoc, Newtype)
, iParamTypes :: Map Name ModTParam
, iParamConstraints :: [Located Prop]
, iParamFuns :: Map Name ModVParam
, iSolvedHasLazy :: Map Int (Expr -> Expr)
, iMonoBinds :: Bool
, iSolver :: SMT.Solver
, iPrimNames :: !PrimMap
, iSolveCounter :: !(IORef Int)
}
data RW = RW
{ iErrors :: ![(Range,Error)]
, iWarnings :: ![(Range,Warning)]
, iSubst :: !Subst
, iExistTVars :: [Map Name Type]
, iSolvedHas :: Map Int (Expr -> Expr)
, iNameSeeds :: !NameSeeds
, iCts :: !Goals
, iHasCts :: ![HasGoal]
, iSupply :: !Supply
}
instance Functor InferM where
fmap f (IM m) = IM (fmap f m)
instance A.Applicative InferM where
pure = return
(<*>) = ap
instance Monad InferM where
return x = IM (return x)
fail x = IM (fail x)
IM m >>= f = IM (m >>= unIM . f)
instance MonadFix InferM where
mfix f = IM (mfix (unIM . f))
instance FreshM InferM where
liftSupply f = IM $
do rw <- get
let (a,s') = f (iSupply rw)
set rw { iSupply = s' }
return a
io :: IO a -> InferM a
io m = IM $ inBase m
inRange :: Range -> InferM a -> InferM a
inRange r (IM m) = IM $ mapReader (\ro -> ro { iRange = r }) m
inRangeMb :: Maybe Range -> InferM a -> InferM a
inRangeMb Nothing m = m
inRangeMb (Just r) m = inRange r m
curRange :: InferM Range
curRange = IM $ asks iRange
recordError :: Error -> InferM ()
recordError e =
do r <- curRange
IM $ sets_ $ \s -> s { iErrors = (r,e) : iErrors s }
recordWarning :: Warning -> InferM ()
recordWarning w =
unless ignore $
do r <- case w of
DefaultingTo d _ -> return (tvarSource d)
_ -> curRange
IM $ sets_ $ \s -> s { iWarnings = (r,w) : iWarnings s }
where
ignore
| DefaultingTo d _ <- w
, Just n <- tvSourceName (tvarDesc d)
, Declared _ SystemName <- nameInfo n
= True
| otherwise = False
getSolver :: InferM SMT.Solver
getSolver =
do RO { .. } <- IM ask
return iSolver
getPrimMap :: InferM PrimMap
getPrimMap =
do RO { .. } <- IM ask
return iPrimNames
newGoal :: ConstraintSource -> Prop -> InferM Goal
newGoal goalSource goal =
do goalRange <- curRange
return Goal { .. }
newGoals :: ConstraintSource -> [Prop] -> InferM ()
newGoals src ps = addGoals =<< mapM (newGoal src) ps
getGoals :: InferM [Goal]
getGoals =
do goals <- applySubst =<<
IM (sets $ \s -> (iCts s, s { iCts = emptyGoals }))
return (fromGoals goals)
addGoals :: [Goal] -> InferM ()
addGoals gs0 = doAdd =<< simpGoals gs0
where
doAdd [] = return ()
doAdd gs = IM $ sets_ $ \s -> s { iCts = foldl' (flip insertGoal) (iCts s) gs }
collectGoals :: InferM a -> InferM (a, [Goal])
collectGoals m =
do origGs <- applySubst =<< getGoals'
a <- m
newGs <- getGoals
setGoals' origGs
return (a, newGs)
where
getGoals' = IM $ sets $ \ RW { .. } -> (iCts, RW { iCts = emptyGoals, .. })
setGoals' gs = IM $ sets $ \ RW { .. } -> ((), RW { iCts = gs, .. })
simpGoal :: Goal -> InferM [Goal]
simpGoal g =
case Simple.simplify Map.empty (goal g) of
p | Just e <- tIsError p ->
do recordError $ ErrorMsg $ text $ tcErrorMessage e
return []
| ps <- pSplitAnd p -> return [ g { goal = pr } | pr <- ps ]
simpGoals :: [Goal] -> InferM [Goal]
simpGoals gs = concat <$> mapM simpGoal gs
newHasGoal :: P.Selector -> Type -> Type -> InferM (Expr -> Expr)
newHasGoal l ty f =
do goalName <- newGoalName
g <- newGoal CtSelector (pHas l ty f)
IM $ sets_ $ \s -> s { iHasCts = HasGoal goalName g : iHasCts s }
solns <- IM $ fmap iSolvedHasLazy ask
return $ case Map.lookup goalName solns of
Just e1 -> e1
Nothing -> panic "newHasGoal" ["Unsolved has goal in result"]
addHasGoal :: HasGoal -> InferM ()
addHasGoal g = IM $ sets_ $ \s -> s { iHasCts = g : iHasCts s }
getHasGoals :: InferM [HasGoal]
getHasGoals = do gs <- IM $ sets $ \s -> (iHasCts s, s { iHasCts = [] })
applySubst gs
solveHasGoal :: Int -> (Expr -> Expr) -> InferM ()
solveHasGoal n e =
IM $ sets_ $ \s -> s { iSolvedHas = Map.insert n e (iSolvedHas s) }
newParamName :: Ident -> InferM Name
newParamName x =
do r <- curRange
liftSupply (mkParameter x r)
newName :: (NameSeeds -> (a , NameSeeds)) -> InferM a
newName upd = IM $ sets $ \s -> let (x,seeds) = upd (iNameSeeds s)
in (x, s { iNameSeeds = seeds })
newGoalName :: InferM Int
newGoalName = newName $ \s -> let x = seedGoal s
in (x, s { seedGoal = x + 1})
newTVar :: TVarSource -> Kind -> InferM TVar
newTVar src k = newTVar' src Set.empty k
newTVar' :: TVarSource -> Set TParam -> Kind -> InferM TVar
newTVar' src extraBound k =
do r <- curRange
bound <- getBoundInScope
let vs = Set.union extraBound bound
msg = TVarInfo { tvarDesc = src, tvarSource = r }
newName $ \s -> let x = seedTVar s
in (TVFree x k vs msg, s { seedTVar = x + 1 })
newTParam :: P.TParam Name -> TPFlavor -> Kind -> InferM TParam
newTParam nm flav k = newName $ \s ->
let x = seedTVar s
in (TParam { tpUnique = x
, tpKind = k
, tpFlav = flav
, tpInfo = desc
}
, s { seedTVar = x + 1 })
where desc = TVarInfo { tvarDesc = TVFromSignature (P.tpName nm)
, tvarSource = fromMaybe emptyRange (P.tpRange nm)
}
newType :: TVarSource -> Kind -> InferM Type
newType src k = TVar `fmap` newTVar src k
unify :: Type -> Type -> InferM [Prop]
unify t1 t2 =
do t1' <- applySubst t1
t2' <- applySubst t2
let ((su1, ps), errs) = runResult (mgu t1' t2')
extendSubst su1
let toError :: UnificationError -> Error
toError err =
case err of
UniTypeLenMismatch _ _ -> TypeMismatch t1' t2'
UniTypeMismatch s1 s2 -> TypeMismatch s1 s2
UniKindMismatch k1 k2 -> KindMismatch k1 k2
UniRecursive x t -> RecursiveType (TVar x) t
UniNonPolyDepends x vs -> TypeVariableEscaped (TVar x) vs
UniNonPoly x t -> NotForAll x t
case errs of
[] -> return ps
_ -> do mapM_ (recordError . toError) errs
return []
applySubst :: TVars t => t -> InferM t
applySubst t =
do su <- getSubst
return (apSubst su t)
applySubstPreds :: [Prop] -> InferM [Prop]
applySubstPreds ps =
do ps1 <- applySubst ps
return (concatMap pSplitAnd ps1)
applySubstGoals :: [Goal] -> InferM [Goal]
applySubstGoals gs =
do gs1 <- applySubst gs
return [ g { goal = p } | g <- gs1, p <- pSplitAnd (goal g) ]
getSubst :: InferM Subst
getSubst = IM $ fmap iSubst get
extendSubst :: Subst -> InferM ()
extendSubst su =
do mapM_ check (substToList su)
IM $ sets_ $ \s -> s { iSubst = su @@ iSubst s }
where
check :: (TVar, Type) -> InferM ()
check (v, ty) =
case v of
TVBound _ ->
panic "Cryptol.TypeCheck.Monad.extendSubst"
[ "Substitution instantiates bound variable:"
, "Variable: " ++ show (pp v)
, "Type: " ++ show (pp ty)
]
TVFree _ _ tvs _ ->
do let bounds tv =
case tv of
TVBound tp -> Set.singleton tp
TVFree _ _ tps _ -> tps
let vars = Set.unions (map bounds (Set.elems (fvs ty)))
let escaped = Set.difference vars tvs
if Set.null escaped then return () else
panic "Cryptol.TypeCheck.Monad.extendSubst"
[ "Escaped quantified variables:"
, "Substitution: " ++ show (pp v <+> text ":=" <+> pp ty)
, "Vars in scope: " ++ show (brackets (commaSep (map pp (Set.toList tvs))))
, "Escaped: " ++ show (brackets (commaSep (map pp (Set.toList escaped))))
]
varsWithAsmps :: InferM (Set TVar)
varsWithAsmps =
do env <- IM $ fmap (Map.elems . iVars) ask
fromEnv <- forM env $ \v ->
case v of
ExtVar sch -> getVars sch
CurSCC _ t -> getVars t
sels <- IM $ fmap (map (goal . hasGoal) . iHasCts) get
fromSels <- mapM getVars sels
fromEx <- (getVars . concatMap Map.elems) =<< IM (fmap iExistTVars get)
return (Set.unions fromEnv `Set.union` Set.unions fromSels
`Set.union` fromEx)
where
getVars x = fvs `fmap` applySubst x
lookupVar :: Name -> InferM VarType
lookupVar x =
do mb <- IM $ asks $ Map.lookup x . iVars
case mb of
Just t -> return t
Nothing ->
do mbNT <- lookupNewtype x
case mbNT of
Just nt -> return (ExtVar (newtypeConType nt))
Nothing ->
do mbParamFun <- lookupParamFun x
case mbParamFun of
Just pf -> return (ExtVar (mvpType pf))
Nothing -> panic "lookupVar" [ "Undefined type variable"
, show x]
lookupTParam :: Name -> InferM (Maybe TParam)
lookupTParam x = IM $ asks $ find this . iTVars
where this tp = tpName tp == Just x
lookupTSyn :: Name -> InferM (Maybe TySyn)
lookupTSyn x = fmap (fmap snd . Map.lookup x) getTSyns
lookupNewtype :: Name -> InferM (Maybe Newtype)
lookupNewtype x = fmap (fmap snd . Map.lookup x) getNewtypes
lookupParamType :: Name -> InferM (Maybe ModTParam)
lookupParamType x = Map.lookup x <$> getParamTypes
lookupParamFun :: Name -> InferM (Maybe ModVParam)
lookupParamFun x = Map.lookup x <$> getParamFuns
existVar :: Name -> Kind -> InferM Type
existVar x k =
do scopes <- iExistTVars <$> IM get
case msum (map (Map.lookup x) scopes) of
Just ty -> return ty
Nothing ->
case scopes of
[] ->
do recordError $ ErrorMsg
$ text "Undefined type" <+> quotes (pp x)
newType TypeErrorPlaceHolder k
sc : more ->
do ty <- newType TypeErrorPlaceHolder k
IM $ sets_ $ \s -> s{ iExistTVars = Map.insert x ty sc : more }
return ty
getTSyns :: InferM (Map Name (DefLoc,TySyn))
getTSyns = IM $ asks iTSyns
getNewtypes :: InferM (Map Name (DefLoc,Newtype))
getNewtypes = IM $ asks iNewtypes
getParamFuns :: InferM (Map Name ModVParam)
getParamFuns = IM $ asks iParamFuns
getParamTypes :: InferM (Map Name ModTParam)
getParamTypes = IM $ asks iParamTypes
getParamConstraints :: InferM [Located Prop]
getParamConstraints = IM $ asks iParamConstraints
getTVars :: InferM (Set Name)
getTVars = IM $ asks $ Set.fromList . mapMaybe tpName . iTVars
getBoundInScope :: InferM (Set TParam)
getBoundInScope =
do ro <- IM ask
let params = Set.fromList (map mtpParam (Map.elems (iParamTypes ro)))
bound = Set.fromList (iTVars ro)
return $! Set.union params bound
getMonoBinds :: InferM Bool
getMonoBinds = IM (asks iMonoBinds)
checkTShadowing :: String -> Name -> InferM ()
checkTShadowing this new =
do ro <- IM ask
rw <- IM get
let shadowed =
do _ <- Map.lookup new (iTSyns ro)
return "type synonym"
`mplus`
do guard (new `elem` mapMaybe tpName (iTVars ro))
return "type variable"
`mplus`
do _ <- msum (map (Map.lookup new) (iExistTVars rw))
return "type"
case shadowed of
Nothing -> return ()
Just that ->
recordError $ ErrorMsg $
text "Type" <+> text this <+> quotes (pp new) <+>
text "shadows an existing" <+>
text that <+> text "with the same name."
withTParam :: TParam -> InferM a -> InferM a
withTParam p (IM m) =
do case tpName p of
Just x -> checkTShadowing "variable" x
Nothing -> return ()
IM $ mapReader (\r -> r { iTVars = p : iTVars r }) m
withTParams :: [TParam] -> InferM a -> InferM a
withTParams ps m = foldr withTParam m ps
withTySyn :: TySyn -> InferM a -> InferM a
withTySyn t (IM m) =
do let x = tsName t
checkTShadowing "synonym" x
IM $ mapReader (\r -> r { iTSyns = Map.insert x (IsLocal,t) (iTSyns r) }) m
withNewtype :: Newtype -> InferM a -> InferM a
withNewtype t (IM m) =
IM $ mapReader
(\r -> r { iNewtypes = Map.insert (ntName t) (IsLocal,t)
(iNewtypes r) }) m
withParamType :: ModTParam -> InferM a -> InferM a
withParamType a (IM m) =
IM $ mapReader
(\r -> r { iParamTypes = Map.insert (mtpName a) a (iParamTypes r) })
m
withVarType :: Name -> VarType -> InferM a -> InferM a
withVarType x s (IM m) =
IM $ mapReader (\r -> r { iVars = Map.insert x s (iVars r) }) m
withVarTypes :: [(Name,VarType)] -> InferM a -> InferM a
withVarTypes xs m = foldr (uncurry withVarType) m xs
withVar :: Name -> Schema -> InferM a -> InferM a
withVar x s = withVarType x (ExtVar s)
withParamFuns :: [ModVParam] -> InferM a -> InferM a
withParamFuns xs (IM m) =
IM $ mapReader (\r -> r { iParamFuns = foldr add (iParamFuns r) xs }) m
where
add x = Map.insert (mvpName x) x
withParameterConstraints :: [Located Prop] -> InferM a -> InferM a
withParameterConstraints ps (IM m) =
IM $ mapReader (\r -> r { iParamConstraints = ps ++ iParamConstraints r }) m
withMonoType :: (Name,Located Type) -> InferM a -> InferM a
withMonoType (x,lt) = withVar x (Forall [] [] (thing lt))
withMonoTypes :: Map Name (Located Type) -> InferM a -> InferM a
withMonoTypes xs m = foldr withMonoType m (Map.toList xs)
withDecls :: ([TySyn], Map Name Schema) -> InferM a -> InferM a
withDecls (ts,vs) m = foldr withTySyn (foldr add m (Map.toList vs)) ts
where
add (x,t) = withVar x t
inNewScope :: InferM a -> InferM a
inNewScope m =
do curScopes <- iExistTVars <$> IM get
IM $ sets_ $ \s -> s { iExistTVars = Map.empty : curScopes }
a <- m
IM $ sets_ $ \s -> s { iExistTVars = curScopes }
return a
newtype KindM a = KM { unKM :: ReaderT KRO (StateT KRW InferM) a }
data KRO = KRO { lazyTParams :: Map Name TParam
, allowWild :: AllowWildCards
}
data AllowWildCards = AllowWildCards | NoWildCards
data KRW = KRW { typeParams :: Map Name Kind
, kCtrs :: [(ConstraintSource,[Prop])]
}
instance Functor KindM where
fmap f (KM m) = KM (fmap f m)
instance A.Applicative KindM where
pure = return
(<*>) = ap
instance Monad KindM where
return x = KM (return x)
fail x = KM (fail x)
KM m >>= k = KM (m >>= unKM . k)
runKindM :: AllowWildCards
-> [(Name, Maybe Kind, TParam)]
-> KindM a -> InferM (a, Map Name Kind, [(ConstraintSource,[Prop])])
runKindM wildOK vs (KM m) =
do (a,kw) <- runStateT krw (runReaderT kro m)
return (a, typeParams kw, kCtrs kw)
where
tps = Map.fromList [ (x,t) | (x,_,t) <- vs ]
kro = KRO { allowWild = wildOK, lazyTParams = tps }
krw = KRW { typeParams = Map.fromList [ (x,k) | (x,Just k,_) <- vs ]
, kCtrs = []
}
data LkpTyVar = TLocalVar TParam (Maybe Kind)
| TOuterVar TParam
kLookupTyVar :: Name -> KindM (Maybe LkpTyVar)
kLookupTyVar x = KM $
do vs <- lazyTParams `fmap` ask
ss <- get
case Map.lookup x vs of
Just t -> return $ Just $ TLocalVar t $ Map.lookup x $ typeParams ss
Nothing -> lift $ lift $ do t <- lookupTParam x
return (fmap TOuterVar t)
kWildOK :: KindM AllowWildCards
kWildOK = KM $ fmap allowWild ask
kRecordError :: Error -> KindM ()
kRecordError e = kInInferM $ recordError e
kRecordWarning :: Warning -> KindM ()
kRecordWarning w = kInInferM $ recordWarning w
kNewType :: TVarSource -> Kind -> KindM Type
kNewType src k =
do tps <- KM $ do vs <- asks lazyTParams
return $ Set.fromList (Map.elems vs)
kInInferM $ TVar `fmap` newTVar' src tps k
kLookupTSyn :: Name -> KindM (Maybe TySyn)
kLookupTSyn x = kInInferM $ lookupTSyn x
kLookupNewtype :: Name -> KindM (Maybe Newtype)
kLookupNewtype x = kInInferM $ lookupNewtype x
kLookupParamType :: Name -> KindM (Maybe ModTParam)
kLookupParamType x = kInInferM (lookupParamType x)
kExistTVar :: Name -> Kind -> KindM Type
kExistTVar x k = kInInferM $ existVar x k
kInstantiateT :: Type -> [(TParam,Type)] -> KindM Type
kInstantiateT t as = return (apSubst su t)
where su = listSubst [ (tpVar x, t1) | (x,t1) <- as ]
kSetKind :: Name -> Kind -> KindM ()
kSetKind v k = KM $ sets_ $ \s -> s{ typeParams = Map.insert v k (typeParams s)}
kInRange :: Range -> KindM a -> KindM a
kInRange r (KM m) = KM $
do e <- ask
s <- get
(a,s1) <- lift $ lift $ inRange r $ runStateT s $ runReaderT e m
set s1
return a
kNewGoals :: ConstraintSource -> [Prop] -> KindM ()
kNewGoals _ [] = return ()
kNewGoals c ps = KM $ sets_ $ \s -> s { kCtrs = (c,ps) : kCtrs s }
kInInferM :: InferM a -> KindM a
kInInferM m = KM $ lift $ lift m