-- |
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Utilities related to type unification.
module Swarm.Language.Typecheck.Unify (
  UnifyStatus (..),
  unifyCheck,
) where

import Control.Unification
import Data.Foldable qualified as F
import Data.Function (on)
import Data.Map qualified as M
import Data.Map.Merge.Lazy qualified as M
import Swarm.Language.Types

-- | The result of doing a unification check on two types.
data UnifyStatus
  = -- | The two types are definitely not equal; they will never unify
    --   no matter how any unification variables get filled in.  For
    --   example, (int * u0) and (u1 -> u2) are apart: the first is a
    --   product type and the second is a function type.
    Apart
  | -- | The two types might unify, depending on how unification
    --   variables get filled in, but we're not sure.  For example,
    --   (int * u0) and (u1 * bool).
    MightUnify
  | -- | The two types are most definitely equal, and we don't need to
    --   bother generating a constraint to make them so.  For example,
    --   (int * text) and (int * text).
    Equal
  deriving (UnifyStatus -> UnifyStatus -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UnifyStatus -> UnifyStatus -> Bool
$c/= :: UnifyStatus -> UnifyStatus -> Bool
== :: UnifyStatus -> UnifyStatus -> Bool
$c== :: UnifyStatus -> UnifyStatus -> Bool
Eq, Eq UnifyStatus
UnifyStatus -> UnifyStatus -> Bool
UnifyStatus -> UnifyStatus -> Ordering
UnifyStatus -> UnifyStatus -> UnifyStatus
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: UnifyStatus -> UnifyStatus -> UnifyStatus
$cmin :: UnifyStatus -> UnifyStatus -> UnifyStatus
max :: UnifyStatus -> UnifyStatus -> UnifyStatus
$cmax :: UnifyStatus -> UnifyStatus -> UnifyStatus
>= :: UnifyStatus -> UnifyStatus -> Bool
$c>= :: UnifyStatus -> UnifyStatus -> Bool
> :: UnifyStatus -> UnifyStatus -> Bool
$c> :: UnifyStatus -> UnifyStatus -> Bool
<= :: UnifyStatus -> UnifyStatus -> Bool
$c<= :: UnifyStatus -> UnifyStatus -> Bool
< :: UnifyStatus -> UnifyStatus -> Bool
$c< :: UnifyStatus -> UnifyStatus -> Bool
compare :: UnifyStatus -> UnifyStatus -> Ordering
$ccompare :: UnifyStatus -> UnifyStatus -> Ordering
Ord, ReadPrec [UnifyStatus]
ReadPrec UnifyStatus
Int -> ReadS UnifyStatus
ReadS [UnifyStatus]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [UnifyStatus]
$creadListPrec :: ReadPrec [UnifyStatus]
readPrec :: ReadPrec UnifyStatus
$creadPrec :: ReadPrec UnifyStatus
readList :: ReadS [UnifyStatus]
$creadList :: ReadS [UnifyStatus]
readsPrec :: Int -> ReadS UnifyStatus
$creadsPrec :: Int -> ReadS UnifyStatus
Read, Int -> UnifyStatus -> ShowS
[UnifyStatus] -> ShowS
UnifyStatus -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnifyStatus] -> ShowS
$cshowList :: [UnifyStatus] -> ShowS
show :: UnifyStatus -> String
$cshow :: UnifyStatus -> String
showsPrec :: Int -> UnifyStatus -> ShowS
$cshowsPrec :: Int -> UnifyStatus -> ShowS
Show)

-- | The @Semigroup@ instance for @UnifyStatus@ is used to combine
--   results for compound types.
instance Semigroup UnifyStatus where
  -- If either part of a compound type is apart, then the whole thing is.
  UnifyStatus
Apart <> :: UnifyStatus -> UnifyStatus -> UnifyStatus
<> UnifyStatus
_ = UnifyStatus
Apart
  UnifyStatus
_ <> UnifyStatus
Apart = UnifyStatus
Apart
  -- Otherwise, if we're unsure about either part of a compound type,
  -- then we're unsure about the whole thing.
  UnifyStatus
MightUnify <> UnifyStatus
_ = UnifyStatus
MightUnify
  UnifyStatus
_ <> UnifyStatus
MightUnify = UnifyStatus
MightUnify
  -- Finally, if both parts are definitely equal then the whole thing is.
  UnifyStatus
Equal <> UnifyStatus
Equal = UnifyStatus
Equal

instance Monoid UnifyStatus where
  mempty :: UnifyStatus
mempty = UnifyStatus
Equal

-- | Given two types, try hard to prove either that (1) they are
--   'Apart', i.e. cannot possibly unify, or (2) they are definitely
--   'Equal'.  In case (1), we can generate a much better error
--   message at the instant the two types come together than we could
--   if we threw a constraint into the unifier.  In case (2), we don't
--   have to bother with generating a trivial constraint. If we don't
--   know for sure whether they will unify, return 'MightUnify'.
unifyCheck :: UType -> UType -> UnifyStatus
unifyCheck :: UType -> UType -> UnifyStatus
unifyCheck UType
ty1 UType
ty2 = case (UType
ty1, UType
ty2) of
  (UVar IntVar
x, UVar IntVar
y)
    | IntVar
x forall a. Eq a => a -> a -> Bool
== IntVar
y -> UnifyStatus
Equal
    | Bool
otherwise -> UnifyStatus
MightUnify
  (UVar IntVar
_, UType
_) -> UnifyStatus
MightUnify
  (UType
_, UVar IntVar
_) -> UnifyStatus
MightUnify
  (UTerm TypeF UType
t1, UTerm TypeF UType
t2) -> TypeF UType -> TypeF UType -> UnifyStatus
unifyCheckF TypeF UType
t1 TypeF UType
t2

unifyCheckF :: TypeF UType -> TypeF UType -> UnifyStatus
unifyCheckF :: TypeF UType -> TypeF UType -> UnifyStatus
unifyCheckF TypeF UType
t1 TypeF UType
t2 = case (TypeF UType
t1, TypeF UType
t2) of
  (TyBaseF BaseTy
b1, TyBaseF BaseTy
b2) -> case BaseTy
b1 forall a. Eq a => a -> a -> Bool
== BaseTy
b2 of
    Bool
True -> UnifyStatus
Equal
    Bool
False -> UnifyStatus
Apart
  (TyBaseF {}, TypeF UType
_) -> UnifyStatus
Apart
  (TyVarF Var
v1, TyVarF Var
v2) -> case Var
v1 forall a. Eq a => a -> a -> Bool
== Var
v2 of
    Bool
True -> UnifyStatus
Equal
    Bool
False -> UnifyStatus
Apart
  (TyVarF {}, TypeF UType
_) -> UnifyStatus
Apart
  (TySumF UType
t11 UType
t12, TySumF UType
t21 UType
t22) -> UType -> UType -> UnifyStatus
unifyCheck UType
t11 UType
t21 forall a. Semigroup a => a -> a -> a
<> UType -> UType -> UnifyStatus
unifyCheck UType
t12 UType
t22
  (TySumF {}, TypeF UType
_) -> UnifyStatus
Apart
  (TyProdF UType
t11 UType
t12, TyProdF UType
t21 UType
t22) -> UType -> UType -> UnifyStatus
unifyCheck UType
t11 UType
t21 forall a. Semigroup a => a -> a -> a
<> UType -> UType -> UnifyStatus
unifyCheck UType
t12 UType
t22
  (TyProdF {}, TypeF UType
_) -> UnifyStatus
Apart
  (TyRcdF Map Var UType
m1, TyRcdF Map Var UType
m2) ->
    case (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall k a. Map k a -> Set k
M.keysSet) Map Var UType
m1 Map Var UType
m2 of
      Bool
False -> UnifyStatus
Apart
      Bool
_ -> forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
F.fold (forall k a c b.
Ord k =>
SimpleWhenMissing k a c
-> SimpleWhenMissing k b c
-> SimpleWhenMatched k a b c
-> Map k a
-> Map k b
-> Map k c
M.merge forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
M.dropMissing forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
M.dropMissing (forall (f :: * -> *) k x y z.
Applicative f =>
(k -> x -> y -> z) -> WhenMatched f k x y z
M.zipWithMatched (forall a b. a -> b -> a
const UType -> UType -> UnifyStatus
unifyCheck)) Map Var UType
m1 Map Var UType
m2)
  (TyRcdF {}, TypeF UType
_) -> UnifyStatus
Apart
  (TyCmdF UType
c1, TyCmdF UType
c2) -> UType -> UType -> UnifyStatus
unifyCheck UType
c1 UType
c2
  (TyCmdF {}, TypeF UType
_) -> UnifyStatus
Apart
  (TyDelayF UType
c1, TyDelayF UType
c2) -> UType -> UType -> UnifyStatus
unifyCheck UType
c1 UType
c2
  (TyDelayF {}, TypeF UType
_) -> UnifyStatus
Apart
  (TyFunF UType
t11 UType
t12, TyFunF UType
t21 UType
t22) -> UType -> UType -> UnifyStatus
unifyCheck UType
t11 UType
t21 forall a. Semigroup a => a -> a -> a
<> UType -> UType -> UnifyStatus
unifyCheck UType
t12 UType
t22
  (TyFunF {}, TypeF UType
_) -> UnifyStatus
Apart