module Language.PureScript.TypeChecker.Entailment
( InstanceContext
, SolverOptions(..)
, replaceTypeClassDictionaries
, newDictionaries
, entails
) where
import Prelude.Compat
import Control.Arrow (second)
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.State
import Control.Monad.Supply.Class (MonadSupply(..))
import Control.Monad.Writer
import Data.Foldable (for_, fold, toList)
import Data.Function (on)
import Data.List (minimumBy, nub)
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Text as T
import Data.Text (Text)
import Language.PureScript.AST
import Language.PureScript.Crash
import Language.PureScript.Environment
import Language.PureScript.Errors
import Language.PureScript.Names
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C
data Evidence
= NamedInstance (Qualified Ident)
| IsSymbolInstance Text
| CompareSymbolInstance
| AppendSymbolInstance
deriving (Eq)
namedInstanceIdentifier :: Evidence -> Maybe (Qualified Ident)
namedInstanceIdentifier (NamedInstance i) = Just i
namedInstanceIdentifier _ = Nothing
type TypeClassDict = TypeClassDictionaryInScope Evidence
type InstanceContext = M.Map (Maybe ModuleName)
(M.Map (Qualified (ProperName 'ClassName))
(M.Map (Qualified Ident) NamedDict))
type Matching a = M.Map Text a
combineContexts :: InstanceContext -> InstanceContext -> InstanceContext
combineContexts = M.unionWith (M.unionWith M.union)
replaceTypeClassDictionaries
:: forall m
. (MonadState CheckState m, MonadError MultipleErrors m, MonadWriter MultipleErrors m, MonadSupply m)
=> Bool
-> Expr
-> m (Expr, [(Ident, InstanceContext, Constraint)])
replaceTypeClassDictionaries shouldGeneralize expr = flip evalStateT M.empty $ do
let loop e = do
(e', solved) <- deferPass e
if getAny solved
then loop e'
else return e'
loop expr >>= generalizePass
where
deferPass :: Expr -> StateT InstanceContext m (Expr, Any)
deferPass = fmap (second fst) . runWriterT . f where
f :: Expr -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
(_, f, _) = everywhereOnValuesTopDownM return (go True) return
generalizePass :: Expr -> StateT InstanceContext m (Expr, [(Ident, InstanceContext, Constraint)])
generalizePass = fmap (second snd) . runWriterT . f where
f :: Expr -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
(_, f, _) = everywhereOnValuesTopDownM return (go False) return
go :: Bool -> Expr -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
go deferErrors (TypeClassDictionary constraint context hints) =
rethrow (addHints hints) $ entails (SolverOptions shouldGeneralize deferErrors) constraint context hints
go _ other = return other
data EntailsResult a
= Solved a TypeClassDict
| Unsolved Constraint
| Deferred
data SolverOptions = SolverOptions
{ solverShouldGeneralize :: Bool
, solverDeferErrors :: Bool
}
entails
:: forall m
. (MonadState CheckState m, MonadError MultipleErrors m, MonadWriter MultipleErrors m, MonadSupply m)
=> SolverOptions
-> Constraint
-> InstanceContext
-> [ErrorMessageHint]
-> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
entails SolverOptions{..} constraint context hints =
solve constraint
where
forClassName :: InstanceContext -> Qualified (ProperName 'ClassName) -> [Type] -> [TypeClassDict]
forClassName _ C.IsSymbol [TypeLevelString sym] =
[TypeClassDictionaryInScope (IsSymbolInstance sym) [] C.IsSymbol [TypeLevelString sym] Nothing]
forClassName _ C.CompareSymbol [arg0@(TypeLevelString lhs), arg1@(TypeLevelString rhs), _] =
let ordering = case compare lhs rhs of
LT -> C.orderingLT
EQ -> C.orderingEQ
GT -> C.orderingGT
args = [arg0, arg1, TypeConstructor ordering]
in [TypeClassDictionaryInScope CompareSymbolInstance [] C.CompareSymbol args Nothing]
forClassName _ C.AppendSymbol [arg0@(TypeLevelString lhs), arg1@(TypeLevelString rhs), _] =
let args = [arg0, arg1, TypeLevelString (lhs <> rhs)]
in [TypeClassDictionaryInScope AppendSymbolInstance [] C.AppendSymbol args Nothing]
forClassName ctx cn@(Qualified (Just mn) _) tys = concatMap (findDicts ctx cn) (nub (Nothing : Just mn : map Just (mapMaybe ctorModules tys)))
forClassName _ _ _ = internalError "forClassName: expected qualified class name"
ctorModules :: Type -> Maybe ModuleName
ctorModules (TypeConstructor (Qualified (Just mn) _)) = Just mn
ctorModules (TypeConstructor (Qualified Nothing _)) = internalError "ctorModules: unqualified type name"
ctorModules (TypeApp ty _) = ctorModules ty
ctorModules (KindedType ty _) = ctorModules ty
ctorModules _ = Nothing
findDicts :: InstanceContext -> Qualified (ProperName 'ClassName) -> Maybe ModuleName -> [TypeClassDict]
findDicts ctx cn = fmap (fmap NamedInstance) . maybe [] M.elems . (>>= M.lookup cn) . flip M.lookup ctx
valUndefined :: Expr
valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))
solve :: Constraint -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
solve con = go 0 con
where
go :: Int -> Constraint -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) Expr
go work (Constraint className' tys' _) | work > 1000 = throwError . errorMessage $ PossiblyInfiniteInstance className' tys'
go work con'@(Constraint className' tys' conInfo) = WriterT . StateT . (withErrorMessageHint (ErrorSolvingConstraint con') .) . runStateT . runWriterT $ do
latestSubst <- lift . lift $ gets checkSubstitution
let tys'' = map (substituteType latestSubst) tys'
inferred <- lift get
classesInScope <- lift . lift $ gets (typeClasses . checkEnv)
TypeClassData{ typeClassDependencies } <- case M.lookup className' classesInScope of
Nothing -> throwError . errorMessage $ UnknownClass className'
Just tcd -> pure tcd
let instances =
[ (substs, tcd)
| tcd <- forClassName (combineContexts context inferred) className' tys''
, substs <- maybeToList (matches typeClassDependencies tcd tys'')
]
solution <- lift . lift $ unique tys'' instances
case solution of
Solved substs tcd -> do
tell (Any True, mempty)
lift . lift . for_ substs $ pairwiseM unifyTypes
let subst = fmap head substs
currentSubst <- lift . lift $ gets checkSubstitution
subst' <- lift . lift $ withFreshTypes tcd (fmap (substituteType currentSubst) subst)
lift . lift $ zipWithM_ (\t1 t2 -> do
let inferredType = replaceAllTypeVars (M.toList subst') t1
unifyTypes inferredType t2) (tcdInstanceTypes tcd) tys''
currentSubst' <- lift . lift $ gets checkSubstitution
let subst'' = fmap (substituteType currentSubst') subst'
args <- solveSubgoals subst'' (tcdDependencies tcd)
let match = foldr (\(superclassName, index) dict -> subclassDictionaryValue dict superclassName index)
(mkDictionary (tcdValue tcd) args)
(tcdPath tcd)
return match
Unsolved unsolved -> do
ident <- freshIdent ("dict" <> runProperName (disqualify (constraintClass unsolved)))
let qident = Qualified Nothing ident
newDicts <- lift . lift $ newDictionaries [] qident unsolved
let newContext = mkContext newDicts
modify (combineContexts newContext)
tell (mempty, [(ident, context, unsolved)])
return (Var qident)
Deferred ->
return (TypeClassDictionary (Constraint className' tys'' conInfo) context hints)
where
withFreshTypes
:: TypeClassDict
-> Matching Type
-> m (Matching Type)
withFreshTypes TypeClassDictionaryInScope{..} subst = do
let onType = everythingOnTypes S.union fromTypeVar
typeVarsInHead = foldMap onType tcdInstanceTypes
<> foldMap (foldMap (foldMap onType . constraintArgs)) tcdDependencies
typeVarsInSubst = S.fromList (M.keys subst)
uninstantiatedTypeVars = typeVarsInHead S.\\ typeVarsInSubst
newSubst <- traverse withFreshType (S.toList uninstantiatedTypeVars)
return (subst <> M.fromList newSubst)
where
fromTypeVar (TypeVar v) = S.singleton v
fromTypeVar _ = S.empty
withFreshType s = do
t <- freshType
return (s, t)
unique :: [Type] -> [(a, TypeClassDict)] -> m (EntailsResult a)
unique tyArgs []
| solverDeferErrors = return Deferred
| solverShouldGeneralize && (null tyArgs || any canBeGeneralized tyArgs) = return (Unsolved (Constraint className' tyArgs conInfo))
| otherwise = throwError . errorMessage $ NoInstanceFound (Constraint className' tyArgs conInfo)
unique _ [(a, dict)] = return $ Solved a dict
unique tyArgs tcds
| pairwiseAny overlapping (map snd tcds) = do
tell . errorMessage $ OverlappingInstances className' tyArgs (tcds >>= (toList . namedInstanceIdentifier . tcdValue . snd))
return $ uncurry Solved (head tcds)
| otherwise = return $ uncurry Solved (minimumBy (compare `on` length . tcdPath . snd) tcds)
canBeGeneralized :: Type -> Bool
canBeGeneralized TUnknown{} = True
canBeGeneralized (KindedType t _) = canBeGeneralized t
canBeGeneralized _ = False
overlapping :: TypeClassDict -> TypeClassDict -> Bool
overlapping TypeClassDictionaryInScope{ tcdPath = _ : _ } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdPath = _ : _ } = False
overlapping TypeClassDictionaryInScope{ tcdDependencies = Nothing } _ = False
overlapping _ TypeClassDictionaryInScope{ tcdDependencies = Nothing } = False
overlapping tcd1 tcd2 = tcdValue tcd1 /= tcdValue tcd2
solveSubgoals :: Matching Type -> Maybe [Constraint] -> WriterT (Any, [(Ident, InstanceContext, Constraint)]) (StateT InstanceContext m) (Maybe [Expr])
solveSubgoals _ Nothing = return Nothing
solveSubgoals subst (Just subgoals) =
Just <$> traverse (go (work + 1) . mapConstraintArgs (map (replaceAllTypeVars (M.toList subst)))) subgoals
mkDictionary :: Evidence -> Maybe [Expr] -> Expr
mkDictionary (NamedInstance n) args = foldl App (Var n) (fold args)
mkDictionary (IsSymbolInstance sym) _ =
let fields = [ ("reflectSymbol", Abs (Left (Ident C.__unused)) (Literal (StringLiteral sym))) ] in
TypeClassDictionaryConstructorApp C.IsSymbol (Literal (ObjectLiteral fields))
mkDictionary CompareSymbolInstance _ =
TypeClassDictionaryConstructorApp C.CompareSymbol (Literal (ObjectLiteral []))
mkDictionary AppendSymbolInstance _ =
TypeClassDictionaryConstructorApp C.AppendSymbol (Literal (ObjectLiteral []))
subclassDictionaryValue :: Expr -> Qualified (ProperName a) -> Integer -> Expr
subclassDictionaryValue dict superclassName index =
App (Accessor (C.__superclass_ <> showQualified runProperName superclassName <> "_" <> T.pack (show index))
dict)
valUndefined
matches :: [FunctionalDependency] -> TypeClassDict -> [Type] -> Maybe (Matching [Type])
matches deps TypeClassDictionaryInScope{..} tys = do
let matched = zipWith typeHeadsAreEqual tys tcdInstanceTypes
guard $ covers matched
let determinedSet = foldMap (S.fromList . fdDetermined) deps
solved = map snd . filter ((`S.notMember` determinedSet) . fst) $ zipWith (\(_, ts) i -> (i, ts)) matched [0..]
verifySubstitution (M.unionsWith (++) solved)
where
covers :: [(Bool, subst)] -> Bool
covers ms = finalSet == S.fromList [0..length ms 1]
where
initialSet :: S.Set Int
initialSet = S.fromList . map snd . filter (fst . fst) $ zip ms [0..]
finalSet :: S.Set Int
finalSet = untilFixedPoint applyAll initialSet
untilFixedPoint :: Eq a => (a -> a) -> a -> a
untilFixedPoint f = go
where
go a | a' == a = a'
| otherwise = go a'
where a' = f a
applyAll :: S.Set Int -> S.Set Int
applyAll s = foldr applyDependency s deps
applyDependency :: FunctionalDependency -> S.Set Int -> S.Set Int
applyDependency FunctionalDependency{..} xs
| S.fromList fdDeterminers `S.isSubsetOf` xs = xs <> S.fromList fdDetermined
| otherwise = xs
typeHeadsAreEqual :: Type -> Type -> (Bool, Matching [Type])
typeHeadsAreEqual (KindedType t1 _) t2 = typeHeadsAreEqual t1 t2
typeHeadsAreEqual t1 (KindedType t2 _) = typeHeadsAreEqual t1 t2
typeHeadsAreEqual (TUnknown u1) (TUnknown u2) | u1 == u2 = (True, M.empty)
typeHeadsAreEqual (Skolem _ s1 _ _) (Skolem _ s2 _ _) | s1 == s2 = (True, M.empty)
typeHeadsAreEqual t (TypeVar v) = (True, M.singleton v [t])
typeHeadsAreEqual (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = (True, M.empty)
typeHeadsAreEqual (TypeLevelString s1) (TypeLevelString s2) | s1 == s2 = (True, M.empty)
typeHeadsAreEqual (TypeApp h1 t1) (TypeApp h2 t2) =
both (typeHeadsAreEqual h1 h2) (typeHeadsAreEqual t1 t2)
typeHeadsAreEqual REmpty REmpty = (True, M.empty)
typeHeadsAreEqual r1@RCons{} r2@RCons{} =
foldr both (go sd1 r1' sd2 r2') (map (uncurry typeHeadsAreEqual) int)
where
(s1, r1') = rowToList r1
(s2, r2') = rowToList r2
int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
go :: [(Text, Type)] -> Type -> [(Text, Type)] -> Type -> (Bool, Matching [Type])
go l (KindedType t1 _) r t2 = go l t1 r t2
go l t1 r (KindedType t2 _) = go l t1 r t2
go [] REmpty [] REmpty = (True, M.empty)
go [] (TUnknown u1) [] (TUnknown u2) | u1 == u2 = (True, M.empty)
go [] (TypeVar v1) [] (TypeVar v2) | v1 == v2 = (True, M.empty)
go [] (Skolem _ sk1 _ _) [] (Skolem _ sk2 _ _) | sk1 == sk2 = (True, M.empty)
go sd r [] (TypeVar v) = (True, M.singleton v [rowFromList (sd, r)])
go _ _ _ _ = (False, M.empty)
typeHeadsAreEqual _ _ = (False, M.empty)
both :: (Bool, Matching [Type]) -> (Bool, Matching [Type]) -> (Bool, Matching [Type])
both (b1, m1) (b2, m2) = (b1 && b2, M.unionWith (++) m1 m2)
verifySubstitution :: Matching [Type] -> Maybe (Matching [Type])
verifySubstitution = traverse meet where
meet ts | pairwiseAll typesAreEqual ts = Just ts
| otherwise = Nothing
typesAreEqual :: Type -> Type -> Bool
typesAreEqual (KindedType t1 _) t2 = typesAreEqual t1 t2
typesAreEqual t1 (KindedType t2 _) = typesAreEqual t1 t2
typesAreEqual (TUnknown u1) (TUnknown u2) | u1 == u2 = True
typesAreEqual (Skolem _ s1 _ _) (Skolem _ s2 _ _) = s1 == s2
typesAreEqual (TypeVar v1) (TypeVar v2) = v1 == v2
typesAreEqual (TypeLevelString s1) (TypeLevelString s2) = s1 == s2
typesAreEqual (TypeConstructor c1) (TypeConstructor c2) = c1 == c2
typesAreEqual (TypeApp h1 t1) (TypeApp h2 t2) = typesAreEqual h1 h2 && typesAreEqual t1 t2
typesAreEqual REmpty REmpty = True
typesAreEqual r1 r2 | isRCons r1 || isRCons r2 =
let (s1, r1') = rowToList r1
(s2, r2') = rowToList r2
int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
in all (uncurry typesAreEqual) int && go sd1 r1' sd2 r2'
where
go :: [(Text, Type)] -> Type -> [(Text, Type)] -> Type -> Bool
go l (KindedType t1 _) r t2 = go l t1 r t2
go l t1 r (KindedType t2 _) = go l t1 r t2
go [] (TUnknown u1) [] (TUnknown u2) | u1 == u2 = True
go [] (Skolem _ s1 _ _) [] (Skolem _ s2 _ _) = s1 == s2
go [] REmpty [] REmpty = True
go [] (TypeVar v1) [] (TypeVar v2) = v1 == v2
go _ _ _ _ = False
typesAreEqual _ _ = False
isRCons :: Type -> Bool
isRCons RCons{} = True
isRCons _ = False
newDictionaries
:: MonadState CheckState m
=> [(Qualified (ProperName 'ClassName), Integer)]
-> Qualified Ident
-> Constraint
-> m [NamedDict]
newDictionaries path name (Constraint className instanceTy _) = do
tcs <- gets (typeClasses . checkEnv)
let TypeClassData{..} = fromMaybe (internalError "newDictionaries: type class lookup failed") $ M.lookup className tcs
supDicts <- join <$> zipWithM (\(Constraint supName supArgs _) index ->
newDictionaries ((supName, index) : path)
name
(Constraint supName (instantiateSuperclass (map fst typeClassArguments) supArgs instanceTy) Nothing)
) typeClassSuperclasses [0..]
return (TypeClassDictionaryInScope name path className instanceTy Nothing : supDicts)
where
instantiateSuperclass :: [Text] -> [Type] -> [Type] -> [Type]
instantiateSuperclass args supArgs tys = map (replaceAllTypeVars (zip args tys)) supArgs
mkContext :: [NamedDict] -> InstanceContext
mkContext = foldr combineContexts M.empty . map fromDict where
fromDict d = M.singleton Nothing (M.singleton (tcdClassName d) (M.singleton (tcdValue d) d))
pairwiseAll :: (a -> a -> Bool) -> [a] -> Bool
pairwiseAll _ [] = True
pairwiseAll _ [_] = True
pairwiseAll p (x : xs) = all (p x) xs && pairwiseAll p xs
pairwiseAny :: (a -> a -> Bool) -> [a] -> Bool
pairwiseAny _ [] = False
pairwiseAny _ [_] = False
pairwiseAny p (x : xs) = any (p x) xs || pairwiseAny p xs
pairwiseM :: Applicative m => (a -> a -> m ()) -> [a] -> m ()
pairwiseM _ [] = pure ()
pairwiseM _ [_] = pure ()
pairwiseM p (x : xs) = traverse (p x) xs *> pairwiseM p xs