-----------------------------------------------------------------------------

-----------------------------------------------------------------------------

-- |
-- Module      :  Disco.Typecheck.Unify
-- Copyright   :  disco team and contributors
-- Maintainer  :  byorgey@gmail.com
--
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Unification.
module Disco.Typecheck.Unify where

import Unbound.Generics.LocallyNameless (Name, fv)

import Control.Lens (anyOf)
import Control.Monad.State
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S

import Disco.Subst
import Disco.Types

-- XXX todo: might be better if unification took sorts into account
-- directly.  As it is, however, I think it works properly;
-- e.g. suppose we have a with sort {sub} and we unify it with Bool.
-- unify will just return a substitution [a |-> Bool].  But then when
-- we call extendSubst, and in particular substSortMap, the sort {sub}
-- will be applied to Bool and decomposed which will throw an error.

-- | Given a list of equations between types, return a substitution
--   which makes all the equations satisfied (or fail if it is not
--   possible).
--
--   This is not the most efficient way to implement unification but
--   it is simple.
unify :: TyDefCtx -> [(Type, Type)] -> Maybe S
unify :: TyDefCtx -> [(Type, Type)] -> Maybe S
unify = (BaseTy -> BaseTy -> Bool) -> TyDefCtx -> [(Type, Type)] -> Maybe S
unify' forall a. Eq a => a -> a -> Bool
(==)

-- | Given a list of equations between types, return a substitution
--   which makes all the equations equal *up to* identifying all base
--   types.  So, for example, Int = Nat weakly unifies but Int = (Int
--   -> Int) does not.  This is used to check whether subtyping
--   constraints are structurally sound before doing constraint
--   simplification/solving, to ensure termination.
weakUnify :: TyDefCtx -> [(Type, Type)] -> Maybe S
weakUnify :: TyDefCtx -> [(Type, Type)] -> Maybe S
weakUnify = (BaseTy -> BaseTy -> Bool) -> TyDefCtx -> [(Type, Type)] -> Maybe S
unify' (\BaseTy
_ BaseTy
_ -> Bool
True)

-- | Given a list of equations between types, return a substitution
--   which makes all the equations satisfied (or fail if it is not
--   possible), up to the given comparison on base types.
unify' ::
  (BaseTy -> BaseTy -> Bool) ->
  TyDefCtx ->
  [(Type, Type)] ->
  Maybe S
unify' :: (BaseTy -> BaseTy -> Bool) -> TyDefCtx -> [(Type, Type)] -> Maybe S
unify' BaseTy -> BaseTy -> Bool
baseEq TyDefCtx
tyDefns [(Type, Type)]
eqs = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go [(Type, Type)]
eqs) forall a. Set a
S.empty
 where
  go :: [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
  go :: [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go [] = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Substitution a
idS
  go ((Type, Type)
e : [(Type, Type)]
es) = do
    Either S [(Type, Type)]
u <- (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne (Type, Type)
e
    case Either S [(Type, Type)]
u of
      Left S
sub -> (forall a.
Subst a a =>
Substitution a -> Substitution a -> Substitution a
@@ S
sub) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go (forall b a. Subst b a => Substitution b -> a -> a
applySubst S
sub [(Type, Type)]
es)
      Right [(Type, Type)]
newEs -> [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go ([(Type, Type)]
newEs forall a. [a] -> [a] -> [a]
++ [(Type, Type)]
es)

  unifyOne :: (Type, Type) -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
  unifyOne :: (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne (Type, Type)
pair = do
    Set (Type, Type)
seen <- forall s (m :: * -> *). MonadState s m => m s
get
    case (Type, Type)
pair forall a. Ord a => a -> Set a -> Bool
`S.member` Set (Type, Type)
seen of
      Bool
True -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a. Substitution a
idS
      Bool
False -> (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne' (Type, Type)
pair

  unifyOne' :: (Type, Type) -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])

  unifyOne' :: (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne' (Type
ty1, Type
ty2)
    | Type
ty1 forall a. Eq a => a -> a -> Bool
== Type
ty2 = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a. Substitution a
idS
  unifyOne' (TyVar Name Type
x, Type
ty2)
    | Name Type -> Type -> Bool
occurs Name Type
x Type
ty2 = forall (m :: * -> *) a. MonadPlus m => m a
mzero
    | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left (Name Type
x forall a. Name a -> a -> Substitution a
|-> Type
ty2)
  unifyOne' (Type
ty1, x :: Type
x@(TyVar Name Type
_)) =
    (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne (Type
x, Type
ty1)
  -- At this point we know ty2 isn't the same skolem nor a unification variable.
  -- Skolems don't unify with anything.
  unifyOne' (TySkolem Name Type
_, Type
_) = forall (m :: * -> *) a. MonadPlus m => m a
mzero
  unifyOne' (Type
_, TySkolem Name Type
_) = forall (m :: * -> *) a. MonadPlus m => m a
mzero
  -- Unify two container types: unify the container descriptors as
  -- well as the type arguments
  unifyOne' p :: (Type, Type)
p@(TyCon (CContainer Atom
ctr1) [Type]
tys1, TyCon (CContainer Atom
ctr2) [Type]
tys2) = do
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ((Atom -> Type
TyAtom Atom
ctr1, Atom -> Type
TyAtom Atom
ctr2) forall a. a -> [a] -> [a]
: forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys1 [Type]
tys2)

  -- If one of the types to be unified is a user-defined type,
  -- unfold its definition before continuing the matching
  unifyOne' p :: (Type, Type)
p@(TyCon (CUser String
t) [Type]
tys1, Type
ty2) = do
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
    case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
t TyDefCtx
tyDefns of
      Maybe TyDefBody
Nothing -> forall (m :: * -> *) a. MonadPlus m => m a
mzero
      Just (TyDefBody [String]
_ [Type] -> Type
body) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right [([Type] -> Type
body [Type]
tys1, Type
ty2)]
  unifyOne' p :: (Type, Type)
p@(Type
ty1, TyCon (CUser String
t) [Type]
tys2) = do
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
    case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
t TyDefCtx
tyDefns of
      Maybe TyDefBody
Nothing -> forall (m :: * -> *) a. MonadPlus m => m a
mzero
      Just (TyDefBody [String]
_ [Type] -> Type
body) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right [(Type
ty1, [Type] -> Type
body [Type]
tys2)]

  -- Unify any other pair of type constructor applications: the type
  -- constructors must match exactly
  unifyOne' p :: (Type, Type)
p@(TyCon Con
c1 [Type]
tys1, TyCon Con
c2 [Type]
tys2)
    | Con
c1 forall a. Eq a => a -> a -> Bool
== Con
c2 = do
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right (forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys1 [Type]
tys2)
    | Bool
otherwise = forall (m :: * -> *) a. MonadPlus m => m a
mzero
  unifyOne' (TyAtom (ABase BaseTy
b1), TyAtom (ABase BaseTy
b2))
    | BaseTy -> BaseTy -> Bool
baseEq BaseTy
b1 BaseTy
b2 = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a. Substitution a
idS
    | Bool
otherwise = forall (m :: * -> *) a. MonadPlus m => m a
mzero
  unifyOne' (Type, Type)
_ = forall (m :: * -> *) a. MonadPlus m => m a
mzero -- Atom = Cons

equate :: TyDefCtx -> [Type] -> Maybe S
equate :: TyDefCtx -> [Type] -> Maybe S
equate TyDefCtx
tyDefns [Type]
tys = TyDefCtx -> [(Type, Type)] -> Maybe S
unify TyDefCtx
tyDefns [(Type, Type)]
eqns
 where
  eqns :: [(Type, Type)]
eqns = forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys (forall a. [a] -> [a]
tail [Type]
tys)

occurs :: Name Type -> Type -> Bool
occurs :: Name Type -> Type -> Bool
occurs Name Type
x = forall s a. Getting Any s a -> (a -> Bool) -> s -> Bool
anyOf forall a (f :: * -> *) b.
(Alpha a, Typeable b, Contravariant f, Applicative f) =>
(Name b -> f (Name b)) -> a -> f a
fv (forall a. Eq a => a -> a -> Bool
== Name Type
x)

unifyAtoms :: TyDefCtx -> [Atom] -> Maybe (Substitution Atom)
unifyAtoms :: TyDefCtx -> [Atom] -> Maybe (Substitution Atom)
unifyAtoms TyDefCtx
tyDefns = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Atom
fromTyAtom) forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyDefCtx -> [Type] -> Maybe S
equate TyDefCtx
tyDefns forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Atom -> Type
TyAtom
 where
  fromTyAtom :: Type -> Atom
fromTyAtom (TyAtom Atom
a) = Atom
a
  fromTyAtom Type
_ = forall a. HasCallStack => String -> a
error String
"fromTyAtom on non-TyAtom!"

unifyUAtoms :: TyDefCtx -> [UAtom] -> Maybe (Substitution UAtom)
unifyUAtoms :: TyDefCtx -> [UAtom] -> Maybe (Substitution UAtom)
unifyUAtoms TyDefCtx
tyDefns = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> UAtom
fromTyAtom) forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyDefCtx -> [Type] -> Maybe S
equate TyDefCtx
tyDefns forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Atom -> Type
TyAtom forall b c a. (b -> c) -> (a -> b) -> a -> c
. UAtom -> Atom
uatomToAtom)
 where
  fromTyAtom :: Type -> UAtom
fromTyAtom (TyAtom (ABase BaseTy
b)) = BaseTy -> UAtom
UB BaseTy
b
  fromTyAtom (TyAtom (AVar (U Name Type
v))) = Name Type -> UAtom
UV Name Type
v
  fromTyAtom Type
_ = forall a. HasCallStack => String -> a
error String
"fromTyAtom on wrong thing!"