{-# LANGUAGE Safe #-}
{-# LANGUAGE PatternGuards, ViewPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
module Cryptol.TypeCheck.Unify where
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst
import Cryptol.Utils.RecordMap
import Control.Monad.Writer (Writer, writer, runWriter)
import qualified Data.Set as Set
import Prelude ()
import Prelude.Compat
type MGU = (Subst,[Prop])
type Result a = Writer [UnificationError] a
runResult :: Result a -> (a, [UnificationError])
runResult :: Result a -> (a, [UnificationError])
runResult = Result a -> (a, [UnificationError])
forall w a. Writer w a -> (a, w)
runWriter
data UnificationError
= UniTypeMismatch Type Type
| UniKindMismatch Kind Kind
| UniTypeLenMismatch Int Int
| UniRecursive TVar Type
| UniNonPolyDepends TVar [TParam]
| UniNonPoly TVar Type
uniError :: UnificationError -> Result MGU
uniError :: UnificationError -> Result MGU
uniError UnificationError
e = (MGU, [UnificationError]) -> Result MGU
forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (MGU
emptyMGU, [UnificationError
e])
emptyMGU :: MGU
emptyMGU :: MGU
emptyMGU = (Subst
emptySubst, [])
mgu :: Type -> Type -> Result MGU
mgu :: Prop -> Prop -> Result MGU
mgu (TUser Name
c1 [Prop]
ts1 Prop
_) (TUser Name
c2 [Prop]
ts2 Prop
_)
| Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2 Bool -> Bool -> Bool
&& [Prop]
ts1 [Prop] -> [Prop] -> Bool
forall a. Eq a => a -> a -> Bool
== [Prop]
ts2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mgu (TVar TVar
x) Prop
t = TVar -> Prop -> Result MGU
bindVar TVar
x Prop
t
mgu Prop
t (TVar TVar
x) = TVar -> Prop -> Result MGU
bindVar TVar
x Prop
t
mgu (TUser Name
_ [Prop]
_ Prop
t1) Prop
t2 = Prop -> Prop -> Result MGU
mgu Prop
t1 Prop
t2
mgu Prop
t1 (TUser Name
_ [Prop]
_ Prop
t2) = Prop -> Prop -> Result MGU
mgu Prop
t1 Prop
t2
mgu (TCon (TC TC
tc1) [Prop]
ts1) (TCon (TC TC
tc2) [Prop]
ts2)
| TC
tc1 TC -> TC -> Bool
forall a. Eq a => a -> a -> Bool
== TC
tc2 = [Prop] -> [Prop] -> Result MGU
mguMany [Prop]
ts1 [Prop]
ts2
mgu (TCon (TF TFun
f1) [Prop]
ts1) (TCon (TF TFun
f2) [Prop]
ts2)
| TFun
f1 TFun -> TFun -> Bool
forall a. Eq a => a -> a -> Bool
== TFun
f2 Bool -> Bool -> Bool
&& [Prop]
ts1 [Prop] -> [Prop] -> Bool
forall a. Eq a => a -> a -> Bool
== [Prop]
ts2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mgu Prop
t1 Prop
t2
| TCon (TF TFun
_) [Prop]
_ <- Prop
t1, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Prop
t1 Prop -> Prop -> Prop
=#= Prop
t2])
| TCon (TF TFun
_) [Prop]
_ <- Prop
t2, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Prop
t1 Prop -> Prop -> Prop
=#= Prop
t2])
where
k1 :: Kind
k1 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t1
k2 :: Kind
k2 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t2
isNum :: Bool
isNum = Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum
mgu (TRec RecordMap Ident Prop
fs1) (TRec RecordMap Ident Prop
fs2)
| RecordMap Ident Prop -> Set Ident
forall a b. Ord a => RecordMap a b -> Set a
fieldSet RecordMap Ident Prop
fs1 Set Ident -> Set Ident -> Bool
forall a. Eq a => a -> a -> Bool
== RecordMap Ident Prop -> Set Ident
forall a b. Ord a => RecordMap a b -> Set a
fieldSet RecordMap Ident Prop
fs2 = [Prop] -> [Prop] -> Result MGU
mguMany (RecordMap Ident Prop -> [Prop]
forall a b. RecordMap a b -> [b]
recordElements RecordMap Ident Prop
fs1) (RecordMap Ident Prop -> [Prop]
forall a b. RecordMap a b -> [b]
recordElements RecordMap Ident Prop
fs2)
mgu (TNewtype Newtype
ntx [Prop]
xs) (TNewtype Newtype
nty [Prop]
ys)
| Newtype
ntx Newtype -> Newtype -> Bool
forall a. Eq a => a -> a -> Bool
== Newtype
nty = [Prop] -> [Prop] -> Result MGU
mguMany [Prop]
xs [Prop]
ys
mgu Prop
t1 Prop
t2
| Bool -> Bool
not (Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2) = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
| Bool
otherwise = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Prop -> Prop -> UnificationError
UniTypeMismatch Prop
t1 Prop
t2
where
k1 :: Kind
k1 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t1
k2 :: Kind
k2 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t2
mguMany :: [Type] -> [Type] -> Result MGU
mguMany :: [Prop] -> [Prop] -> Result MGU
mguMany [] [] = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mguMany (Prop
t1 : [Prop]
ts1) (Prop
t2 : [Prop]
ts2) =
do (Subst
su1,[Prop]
ps1) <- Prop -> Prop -> Result MGU
mgu Prop
t1 Prop
t2
(Subst
su2,[Prop]
ps2) <- [Prop] -> [Prop] -> Result MGU
mguMany (Subst -> [Prop] -> [Prop]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Prop]
ts1) (Subst -> [Prop] -> [Prop]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Prop]
ts2)
MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su2 Subst -> Subst -> Subst
@@ Subst
su1, [Prop]
ps1 [Prop] -> [Prop] -> [Prop]
forall a. [a] -> [a] -> [a]
++ [Prop]
ps2)
mguMany [Prop]
t1 [Prop]
t2 = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Int -> Int -> UnificationError
UniTypeLenMismatch ([Prop] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Prop]
t1) ([Prop] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Prop]
t2)
bindVar :: TVar -> Type -> Result MGU
bindVar :: TVar -> Prop -> Result MGU
bindVar TVar
x (Prop -> Prop
tNoUser -> TVar TVar
y)
| TVar
x TVar -> TVar -> Bool
forall a. Eq a => a -> a -> Bool
== TVar
y = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
bindVar v :: TVar
v@(TVBound {}) (Prop -> Prop
tNoUser -> TVar v1 :: TVar
v1@(TVFree {})) = TVar -> Prop -> Result MGU
bindVar TVar
v1 (TVar -> Prop
TVar TVar
v)
bindVar v :: TVar
v@(TVBound {}) Prop
t
| Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t = if Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum
then MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Prop
TVar TVar
v Prop -> Prop -> Prop
=#= Prop
t])
else UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Prop -> UnificationError
UniNonPoly TVar
v Prop
t
| Bool
otherwise = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k (Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t)
where k :: Kind
k = TVar -> Kind
forall t. HasKind t => t -> Kind
kindOf TVar
v
bindVar x :: TVar
x@(TVFree Int
_ Kind
xk Set TParam
xscope TVarInfo
_) (Prop -> Prop
tNoUser -> TVar y :: TVar
y@(TVFree Int
_ Kind
yk Set TParam
yscope TVarInfo
_))
| Set TParam
xscope Set TParam -> Set TParam -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`Set.isProperSubsetOf` Set TParam
yscope, Kind
xk Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
yk =
MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar -> Prop -> Subst
uncheckedSingleSubst TVar
y (TVar -> Prop
TVar TVar
x), [])
bindVar TVar
x Prop
t =
case TVar -> Prop -> Either SubstError Subst
singleSubst TVar
x Prop
t of
Left SubstError
SubstRecursive
| TVar -> Kind
forall t. HasKind t => t -> Kind
kindOf TVar
x Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KType -> UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Prop -> UnificationError
UniRecursive TVar
x Prop
t
| Bool
otherwise -> MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Prop
TVar TVar
x Prop -> Prop -> Prop
=#= Prop
t])
Left (SubstEscaped [TParam]
tps) ->
UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> [TParam] -> UnificationError
UniNonPolyDepends TVar
x [TParam]
tps
Left (SubstKindMismatch Kind
k1 Kind
k2) ->
UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
Right Subst
su ->
MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su, [])