module Type.Unify (unify) where import Control.Applicative ((<|>)) import Control.Monad.State import qualified Data.List as List import qualified Data.Map as Map import qualified Data.Maybe as Maybe import qualified Data.UnionFind.IO as UF import qualified AST.Annotation as A import qualified AST.Variable as Var import qualified Type.State as TS import Type.Type import Type.PrettyPrint import Text.PrettyPrint (render) import Elm.Utils ((|>)) unify :: A.Region -> Variable -> Variable -> StateT TS.SolverState IO () unify region variable1 variable2 = do equivalent <- liftIO $ UF.equivalent variable1 variable2 if equivalent then return () else actuallyUnify region variable1 variable2 actuallyUnify :: A.Region -> Variable -> Variable -> StateT TS.SolverState IO () actuallyUnify region variable1 variable2 = do desc1 <- liftIO $ UF.descriptor variable1 desc2 <- liftIO $ UF.descriptor variable2 let unify' = unify region (name', flex', rank', alias') = combinedDescriptors desc1 desc2 merge1 :: StateT TS.SolverState IO () merge1 = liftIO $ do if rank desc1 < rank desc2 then UF.union variable2 variable1 else UF.union variable1 variable2 UF.modifyDescriptor variable1 $ \desc -> desc { structure = structure desc1 , flex = flex' , name = name' , alias = alias' } merge2 :: StateT TS.SolverState IO () merge2 = liftIO $ do if rank desc1 < rank desc2 then UF.union variable2 variable1 else UF.union variable1 variable2 UF.modifyDescriptor variable2 $ \desc -> desc { structure = structure desc2 , flex = flex' , name = name' , alias = alias' } merge = if rank desc1 < rank desc2 then merge1 else merge2 fresh :: Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable fresh structure = do v <- liftIO . UF.fresh $ Descriptor { structure = structure , rank = rank' , flex = flex' , name = name' , copy = Nothing , mark = noMark , alias = alias' } TS.register v flexAndUnify v = do liftIO $ UF.modifyDescriptor v $ \desc -> desc { flex = Flexible } unify' variable1 variable2 unifyNumber svar (Var.Canonical home name) = case home of Var.BuiltIn | name `elem` ["Int","Float"] -> flexAndUnify svar Var.Local | List.isPrefixOf "number" name -> flexAndUnify svar _ -> let hint = "Looks like something besides an Int or Float is being used as a number." in TS.addError region (Just hint) variable1 variable2 comparableError maybe = TS.addError region (Just $ Maybe.fromMaybe msg maybe) variable1 variable2 where msg = "Looks like you want something comparable, but the only valid comparable\n\ \types are Int, Float, Char, String, lists, or tuples." appendableError maybe = TS.addError region (Just $ Maybe.fromMaybe msg maybe) variable1 variable2 where msg = "Looks like you want something appendable, but the only Strings, Lists,\n\ \and Text can be appended with the (++) operator." unifyComparable v (Var.Canonical home name) = case home of Var.BuiltIn | name `elem` ["Int","Float","Char","String"] -> flexAndUnify v Var.Local | List.isPrefixOf "comparable" name -> flexAndUnify v _ -> comparableError Nothing unifyComparableStructure varSuper varFlex = do struct <- liftIO $ collectApps varFlex case struct of Other -> comparableError Nothing List v -> do flexAndUnify varSuper unify' v =<< liftIO (variable $ Is Comparable) Tuple vs | length vs > 6 -> comparableError $ Just "Cannot compare a tuple with more than 6 elements." | otherwise -> do flexAndUnify varSuper cmpVars <- liftIO $ forM [1..length vs] $ \_ -> variable (Is Comparable) zipWithM_ unify' vs cmpVars unifyAppendable varSuper varFlex = do struct <- liftIO $ collectApps varFlex case struct of List _ -> flexAndUnify varSuper _ -> appendableError Nothing rigidError var = TS.addError region (Just hint) variable1 variable2 where hint = "Could not unify rigid type variable '" ++ render (pretty Never var) ++ "'.\n" ++ "The problem probably relates to the type variable being shared between a\n\ \top-level type annotation and a related let-bound type annotation." superUnify = case (flex desc1, flex desc2, name desc1, name desc2) of (Is super1, Is super2, _, _) | super1 == super2 -> merge (Is Number, Is Comparable, _, _) -> merge1 (Is Comparable, Is Number, _, _) -> merge2 (Is Number, _, _, Just name) -> unifyNumber variable1 name (_, Is Number, Just name, _) -> unifyNumber variable2 name (Is Comparable, _, _, Just name) -> unifyComparable variable1 name (_, Is Comparable, Just name, _) -> unifyComparable variable2 name (Is Comparable, _, _, _) -> unifyComparableStructure variable1 variable2 (_, Is Comparable, _, _) -> unifyComparableStructure variable2 variable1 (Is Appendable, _, _, Just name) | Var.isText name || Var.isPrim "String" name -> flexAndUnify variable1 (_, Is Appendable, Just name, _) | Var.isText name || Var.isPrim "String" name -> flexAndUnify variable2 (Is Appendable, _, _, _) -> unifyAppendable variable1 variable2 (_, Is Appendable, _, _) -> unifyAppendable variable2 variable1 (Rigid, _, _, _) -> rigidError variable1 (_, Rigid, _, _) -> rigidError variable2 _ -> TS.addError region Nothing variable1 variable2 case (structure desc1, structure desc2) of (Nothing, Nothing) | flex desc1 == Flexible && flex desc1 == Flexible -> merge (Nothing, _) | flex desc1 == Flexible -> merge2 (_, Nothing) | flex desc2 == Flexible -> merge1 (Just (Var1 v), _) -> unify' v variable2 (_, Just (Var1 v)) -> unify' v variable1 (Nothing, _) -> superUnify (_, Nothing) -> superUnify (Just type1, Just type2) -> case (type1,type2) of (App1 term1 term2, App1 term1' term2') -> do merge unify' term1 term1' unify' term2 term2' (Fun1 term1 term2, Fun1 term1' term2') -> do merge unify' term1 term1' unify' term2 term2' (EmptyRecord1, EmptyRecord1) -> return () (Record1 fields ext, EmptyRecord1) | Map.null fields -> unify' ext variable2 (EmptyRecord1, Record1 fields ext) | Map.null fields -> unify' ext variable1 (Record1 _ _, Record1 _ _) -> recordUnify region fresh variable1 variable2 _ -> TS.addError region Nothing variable1 variable2 -- RECORD UNIFICATION recordUnify :: A.Region -> (Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable) -> Variable -> Variable -> StateT TS.SolverState IO () recordUnify region fresh variable1 variable2 = do (ExpandedRecord fields1 ext1) <- liftIO (gatherFields variable1) (ExpandedRecord fields2 ext2) <- liftIO (gatherFields variable2) unifyOverlappingFields region fields1 fields2 let freshRecord fields ext = fresh (Just (Record1 fields ext)) let uniqueFields1 = diffFields fields1 fields2 let uniqueFields2 = diffFields fields2 fields1 let addFieldMismatchError missingFields = let msg = fieldMismatchError missingFields in TS.addError region (Just msg) variable1 variable2 case (ext1, ext2) of (Empty _, Empty _) -> case Map.null uniqueFields1 && Map.null uniqueFields2 of True -> return () False -> TS.addError region Nothing variable1 variable2 (Empty var1, Extension var2) -> case (Map.null uniqueFields1, Map.null uniqueFields2) of (_, False) -> addFieldMismatchError uniqueFields2 (True, True) -> unify region var1 var2 (False, True) -> do subRecord <- freshRecord uniqueFields1 var1 unify region subRecord var2 (Extension var1, Empty var2) -> case (Map.null uniqueFields1, Map.null uniqueFields2) of (False, _) -> addFieldMismatchError uniqueFields1 (True, True) -> unify region var1 var2 (True, False) -> do subRecord <- freshRecord uniqueFields2 var2 unify region var1 subRecord (Extension var1, Extension var2) -> case (Map.null uniqueFields1, Map.null uniqueFields2) of (True, True) -> unify region var1 var2 (True, False) -> do subRecord <- freshRecord uniqueFields2 var2 unify region var1 subRecord (False, True) -> do subRecord <- freshRecord uniqueFields1 var1 unify region subRecord var2 (False, False) -> do record1' <- freshRecord uniqueFields1 =<< fresh Nothing record2' <- freshRecord uniqueFields2 =<< fresh Nothing unify region record1' var2 unify region var1 record2' unifyOverlappingFields :: A.Region -> Map.Map String [Variable] -> Map.Map String [Variable] -> StateT TS.SolverState IO () unifyOverlappingFields region fields1 fields2 = Map.intersectionWith (zipWith (unify region)) fields1 fields2 |> Map.elems |> concat |> sequence_ diffFields :: Map.Map String [a] -> Map.Map String [a] -> Map.Map String [a] diffFields fields1 fields2 = let eat (_:xs) (_:ys) = eat xs ys eat xs _ = xs in Map.union (Map.intersectionWith eat fields1 fields2) fields1 |> Map.filter (not . null) data ExpandedRecord = ExpandedRecord { _fields :: Map.Map String [Variable] , _extension :: Extension } data Extension = Empty Variable | Extension Variable gatherFields :: Variable -> IO ExpandedRecord gatherFields var = do desc <- UF.descriptor var case structure desc of (Just (Record1 fields ext)) -> do (ExpandedRecord deeperFields rootExt) <- gatherFields ext return (ExpandedRecord (Map.unionWith (++) fields deeperFields) rootExt) (Just EmptyRecord1) -> return (ExpandedRecord Map.empty (Empty var)) _ -> return (ExpandedRecord Map.empty (Extension var)) -- assumes that one of the dicts has stuff in it fieldMismatchError :: Map.Map String a -> String fieldMismatchError missingFields = case Map.keys missingFields of [] -> "" [key] -> "Looks like a record is missing the field '" ++ key ++ "'.\n " ++ "Maybe there is a misspelling in a record access or record update?" keys -> "Looks like one record is missing fields " ++ List.intercalate ", " (init keys) ++ ", and " ++ last keys combinedDescriptors :: Descriptor -> Descriptor -> (Maybe Var.Canonical, Flex, Int, Maybe Var.Canonical) combinedDescriptors desc1 desc2 = (name', flex', rank', alias') where rank' :: Int rank' = min (rank desc1) (rank desc2) alias' :: Maybe Var.Canonical alias' = alias desc1 <|> alias desc2 name' :: Maybe Var.Canonical name' = case (name desc1, name desc2) of (Just name1, Just name2) -> case (flex desc1, flex desc2) of (_, Flexible) -> Just name1 (Flexible, _) -> Just name2 (Is Number, Is _) -> Just name1 (Is _, Is Number) -> Just name2 (Is _, Is _) -> Just name1 (_, _) -> Nothing (Just name1, _) -> Just name1 (_, Just name2) -> Just name2 _ -> Nothing flex' :: Flex flex' = case (flex desc1, flex desc2) of (f, Flexible) -> f (Flexible, f) -> f (Is Number, Is _) -> Is Number (Is _, Is Number) -> Is Number (Is super, Is _) -> Is super (_, _) -> Flexible