-- | -- Functions and instances relating to unification -- module Language.PureScript.TypeChecker.Unify ( freshType , freshTypeWithKind , solveType , substituteType , unknownsInType , unifyTypes , unifyRows , alignRowsWith , replaceTypeWildcards , varIfUnknown ) where import Prelude import Control.Monad (forM_, void) import Control.Monad.Error.Class (MonadError(..)) import Control.Monad.State.Class (MonadState(..), gets, modify, state) import Control.Monad.Writer.Class (MonadWriter(..)) import Data.Foldable (traverse_) import Data.Maybe (fromMaybe) import Data.Map qualified as M import Data.Text qualified as T import Language.PureScript.Crash (internalError) import Language.PureScript.Environment qualified as E import Language.PureScript.Errors (ErrorMessageHint(..), MultipleErrors, SimpleErrorMessage(..), SourceAnn, errorMessage, internalCompilerError, onErrorMessages, rethrow, warnWithPosition, withoutPosition) import Language.PureScript.TypeChecker.Kinds (elaborateKind, instantiateKind, unifyKinds') import Language.PureScript.TypeChecker.Monad (CheckState(..), Substitution(..), UnkLevel(..), Unknown, getLocalContext, guardWith, lookupUnkName, withErrorMessageHint) import Language.PureScript.TypeChecker.Skolems (newSkolemConstant, skolemize) import Language.PureScript.Types (Constraint(..), pattern REmptyKinded, RowListItem(..), SourceType, Type(..), WildcardData(..), alignRowsWith, everythingOnTypes, everywhereOnTypes, everywhereOnTypesM, getAnnForType, mkForAll, rowFromList, srcTUnknown) -- | Generate a fresh type variable with an unknown kind. Avoid this if at all possible. freshType :: (MonadState CheckState m) => m SourceType freshType = state $ \st -> do let t = checkNextType st st' = st { checkNextType = t + 2 , checkSubstitution = (checkSubstitution st) { substUnsolved = M.insert t (UnkLevel (pure t), E.kindType) . M.insert (t + 1) (UnkLevel (pure (t + 1)), srcTUnknown t) . substUnsolved $ checkSubstitution st } } (srcTUnknown (t + 1), st') -- | Generate a fresh type variable with a known kind. freshTypeWithKind :: (MonadState CheckState m) => SourceType -> m SourceType freshTypeWithKind kind = state $ \st -> do let t = checkNextType st st' = st { checkNextType = t + 1 , checkSubstitution = (checkSubstitution st) { substUnsolved = M.insert t (UnkLevel (pure t), kind) (substUnsolved (checkSubstitution st)) } } (srcTUnknown t, st') -- | Update the substitution to solve a type constraint solveType :: (MonadError MultipleErrors m, MonadState CheckState m) => Int -> SourceType -> m () solveType u t = rethrow (onErrorMessages withoutPosition) $ do -- We strip the position so that any errors get rethrown with the position of -- the original unification constraint. Otherwise errors may arise from arbitrary -- locations. We don't otherwise have the "correct" position on hand, since it -- is maintained as part of the type-checker stack. occursCheck u t k1 <- elaborateKind t subst <- gets checkSubstitution k2 <- maybe (internalCompilerError ("No kind for unification variable ?" <> T.pack (show u))) (pure . substituteType subst . snd) . M.lookup u . substUnsolved $ subst t' <- instantiateKind (t, k1) k2 modify $ \cs -> cs { checkSubstitution = (checkSubstitution cs) { substType = M.insert u t' $ substType $ checkSubstitution cs } } -- | Apply a substitution to a type substituteType :: Substitution -> SourceType -> SourceType substituteType sub = everywhereOnTypes go where go (TUnknown ann u) = case M.lookup u (substType sub) of Nothing -> TUnknown ann u Just (TUnknown ann' u1) | u1 == u -> TUnknown ann' u1 Just t -> substituteType sub t go other = other -- | Make sure that an unknown does not occur in a type occursCheck :: (MonadError MultipleErrors m) => Int -> SourceType -> m () occursCheck _ TUnknown{} = return () occursCheck u t = void $ everywhereOnTypesM go t where go (TUnknown _ u') | u == u' = throwError . errorMessage . InfiniteType $ t go other = return other -- | Compute a list of all unknowns appearing in a type unknownsInType :: Type a -> [(a, Int)] unknownsInType t = everythingOnTypes (.) go t [] where go :: Type a -> [(a, Int)] -> [(a, Int)] go (TUnknown ann u) = ((ann, u) :) go _ = id -- | Unify two types, updating the current substitution unifyTypes :: (MonadError MultipleErrors m, MonadState CheckState m) => SourceType -> SourceType -> m () unifyTypes t1 t2 = do sub <- gets checkSubstitution withErrorMessageHint (ErrorUnifyingTypes t1 t2) $ unifyTypes' (substituteType sub t1) (substituteType sub t2) where unifyTypes' (TUnknown _ u1) (TUnknown _ u2) | u1 == u2 = return () unifyTypes' (TUnknown _ u) t = solveType u t unifyTypes' t (TUnknown _ u) = solveType u t unifyTypes' (ForAll ann1 _ ident1 mbK1 ty1 sc1) (ForAll ann2 _ ident2 mbK2 ty2 sc2) = case (sc1, sc2) of (Just sc1', Just sc2') -> do sko <- newSkolemConstant let sk1 = skolemize ann1 ident1 mbK1 sko sc1' ty1 let sk2 = skolemize ann2 ident2 mbK2 sko sc2' ty2 sk1 `unifyTypes` sk2 _ -> internalError "unifyTypes: unspecified skolem scope" unifyTypes' (ForAll ann _ ident mbK ty1 (Just sc)) ty2 = do sko <- newSkolemConstant let sk = skolemize ann ident mbK sko sc ty1 sk `unifyTypes` ty2 unifyTypes' ForAll{} _ = internalError "unifyTypes: unspecified skolem scope" unifyTypes' ty f@ForAll{} = f `unifyTypes` ty unifyTypes' (TypeVar _ v1) (TypeVar _ v2) | v1 == v2 = return () unifyTypes' ty1@(TypeConstructor _ c1) ty2@(TypeConstructor _ c2) = guardWith (errorMessage (TypesDoNotUnify ty1 ty2)) (c1 == c2) unifyTypes' (TypeLevelString _ s1) (TypeLevelString _ s2) | s1 == s2 = return () unifyTypes' (TypeLevelInt _ n1) (TypeLevelInt _ n2) | n1 == n2 = return () unifyTypes' (TypeApp _ t3 t4) (TypeApp _ t5 t6) = do t3 `unifyTypes` t5 t4 `unifyTypes` t6 unifyTypes' (KindApp _ t3 t4) (KindApp _ t5 t6) = do t3 `unifyKinds'` t5 t4 `unifyTypes` t6 unifyTypes' (Skolem _ _ _ s1 _) (Skolem _ _ _ s2 _) | s1 == s2 = return () unifyTypes' (KindedType _ ty1 _) ty2 = ty1 `unifyTypes` ty2 unifyTypes' ty1 (KindedType _ ty2 _) = ty1 `unifyTypes` ty2 unifyTypes' r1@RCons{} r2 = unifyRows r1 r2 unifyTypes' r1 r2@RCons{} = unifyRows r1 r2 unifyTypes' r1@REmptyKinded{} r2 = unifyRows r1 r2 unifyTypes' r1 r2@REmptyKinded{} = unifyRows r1 r2 unifyTypes' (ConstrainedType _ c1 ty1) (ConstrainedType _ c2 ty2) | constraintClass c1 == constraintClass c2 && constraintData c1 == constraintData c2 = do traverse_ (uncurry unifyTypes) (constraintArgs c1 `zip` constraintArgs c2) ty1 `unifyTypes` ty2 unifyTypes' ty1@ConstrainedType{} ty2 = throwError . errorMessage $ ConstrainedTypeUnified ty1 ty2 unifyTypes' t3 t4@ConstrainedType{} = unifyTypes' t4 t3 unifyTypes' t3 t4 = throwError . errorMessage $ TypesDoNotUnify t3 t4 -- | Unify two rows, updating the current substitution -- -- Common labels are identified and unified. Remaining labels and types are unified with a -- trailing row unification variable, if appropriate. unifyRows :: forall m. (MonadError MultipleErrors m, MonadState CheckState m) => SourceType -> SourceType -> m () unifyRows r1 r2 = sequence_ matches *> uncurry unifyTails rest where unifyTypesWithLabel l t1 t2 = withErrorMessageHint (ErrorInRowLabel l) $ unifyTypes t1 t2 (matches, rest) = alignRowsWith unifyTypesWithLabel r1 r2 unifyTails :: ([RowListItem SourceAnn], SourceType) -> ([RowListItem SourceAnn], SourceType) -> m () unifyTails ([], TUnknown _ u) (sd, r) = solveType u (rowFromList (sd, r)) unifyTails (sd, r) ([], TUnknown _ u) = solveType u (rowFromList (sd, r)) unifyTails ([], REmptyKinded _ _) ([], REmptyKinded _ _) = return () unifyTails ([], TypeVar _ v1) ([], TypeVar _ v2) | v1 == v2 = return () unifyTails ([], Skolem _ _ _ s1 _) ([], Skolem _ _ _ s2 _) | s1 == s2 = return () unifyTails (sd1, TUnknown a u1) (sd2, TUnknown _ u2) | u1 /= u2 = do forM_ sd1 $ occursCheck u2 . rowListType forM_ sd2 $ occursCheck u1 . rowListType rest' <- freshTypeWithKind =<< elaborateKind (TUnknown a u1) solveType u1 (rowFromList (sd2, rest')) solveType u2 (rowFromList (sd1, rest')) unifyTails _ _ = throwError . errorMessage $ TypesDoNotUnify r1 r2 -- | -- Replace type wildcards with unknowns -- replaceTypeWildcards :: (MonadWriter MultipleErrors m, MonadState CheckState m) => SourceType -> m SourceType replaceTypeWildcards = everywhereOnTypesM replace where replace (TypeWildcard ann wdata) = do t <- freshType ctx <- getLocalContext let err = case wdata of HoleWildcard n -> Just $ HoleInferredType n t ctx Nothing UnnamedWildcard -> Just $ WildcardInferredType t ctx IgnoredWildcard -> Nothing forM_ err $ warnWithPosition (fst ann) . tell . errorMessage return t replace other = return other -- | -- Replace outermost unsolved unification variables with named type variables -- varIfUnknown :: forall m. (MonadState CheckState m) => [(Unknown, SourceType)] -> SourceType -> m SourceType varIfUnknown unks ty = do bn' <- traverse toBinding unks ty' <- go ty pure $ mkForAll bn' ty' where toName :: Unknown -> m T.Text toName u = (<> T.pack (show u)) . fromMaybe "t" <$> lookupUnkName u toBinding :: (Unknown, SourceType) -> m (SourceAnn, (T.Text, Maybe SourceType)) toBinding (u, k) = do u' <- toName u k' <- go k pure (getAnnForType ty, (u', Just k')) go :: SourceType -> m SourceType go = everywhereOnTypesM $ \case (TUnknown ann u) -> TypeVar ann <$> toName u t -> pure t