-- |
-- Module      :  Cryptol.TypeCheck.Unify
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable

{-# LANGUAGE Safe #-}
{-# LANGUAGE PatternGuards, ViewPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
module Cryptol.TypeCheck.Unify where

import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst
import Cryptol.Utils.RecordMap

import Control.Monad.Writer (Writer, writer, runWriter)
import qualified Data.Set as Set

import Prelude ()
import Prelude.Compat

-- | The most general unifier is a substitution and a set of constraints
-- on bound variables.
type MGU = (Subst,[Prop])

type Result a = Writer [UnificationError] a

runResult :: Result a -> (a, [UnificationError])
runResult :: Result a -> (a, [UnificationError])
runResult = Result a -> (a, [UnificationError])
forall w a. Writer w a -> (a, w)
runWriter

data UnificationError
  = UniTypeMismatch Type Type
  | UniKindMismatch Kind Kind
  | UniTypeLenMismatch Int Int
  | UniRecursive TVar Type
  | UniNonPolyDepends TVar [TParam]
  | UniNonPoly TVar Type

uniError :: UnificationError -> Result MGU
uniError :: UnificationError -> Result MGU
uniError UnificationError
e = (MGU, [UnificationError]) -> Result MGU
forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (MGU
emptyMGU, [UnificationError
e])


emptyMGU :: MGU
emptyMGU :: MGU
emptyMGU = (Subst
emptySubst, [])

mgu :: Type -> Type -> Result MGU

mgu :: Prop -> Prop -> Result MGU
mgu (TUser Name
c1 [Prop]
ts1 Prop
_) (TUser Name
c2 [Prop]
ts2 Prop
_)
  | Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2 Bool -> Bool -> Bool
&& [Prop]
ts1 [Prop] -> [Prop] -> Bool
forall a. Eq a => a -> a -> Bool
== [Prop]
ts2  = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

mgu (TVar TVar
x) Prop
t        = TVar -> Prop -> Result MGU
bindVar TVar
x Prop
t
mgu Prop
t (TVar TVar
x)        = TVar -> Prop -> Result MGU
bindVar TVar
x Prop
t

mgu (TUser Name
_ [Prop]
_ Prop
t1) Prop
t2 = Prop -> Prop -> Result MGU
mgu Prop
t1 Prop
t2
mgu Prop
t1 (TUser Name
_ [Prop]
_ Prop
t2) = Prop -> Prop -> Result MGU
mgu Prop
t1 Prop
t2

mgu (TCon (TC TC
tc1) [Prop]
ts1) (TCon (TC TC
tc2) [Prop]
ts2)
  | TC
tc1 TC -> TC -> Bool
forall a. Eq a => a -> a -> Bool
== TC
tc2 = [Prop] -> [Prop] -> Result MGU
mguMany [Prop]
ts1 [Prop]
ts2

mgu (TCon (TF TFun
f1) [Prop]
ts1) (TCon (TF TFun
f2) [Prop]
ts2)
  | TFun
f1 TFun -> TFun -> Bool
forall a. Eq a => a -> a -> Bool
== TFun
f2 Bool -> Bool -> Bool
&& [Prop]
ts1 [Prop] -> [Prop] -> Bool
forall a. Eq a => a -> a -> Bool
== [Prop]
ts2  = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

mgu Prop
t1 Prop
t2
  | TCon (TF TFun
_) [Prop]
_ <- Prop
t1, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Prop
t1 Prop -> Prop -> Prop
=#= Prop
t2])
  | TCon (TF TFun
_) [Prop]
_ <- Prop
t2, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Prop
t1 Prop -> Prop -> Prop
=#= Prop
t2])
  where
  k1 :: Kind
k1 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t1
  k2 :: Kind
k2 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t2

  isNum :: Bool
isNum = Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum

mgu (TRec RecordMap Ident Prop
fs1) (TRec RecordMap Ident Prop
fs2)
  | RecordMap Ident Prop -> Set Ident
forall a b. Ord a => RecordMap a b -> Set a
fieldSet RecordMap Ident Prop
fs1 Set Ident -> Set Ident -> Bool
forall a. Eq a => a -> a -> Bool
== RecordMap Ident Prop -> Set Ident
forall a b. Ord a => RecordMap a b -> Set a
fieldSet RecordMap Ident Prop
fs2 = [Prop] -> [Prop] -> Result MGU
mguMany (RecordMap Ident Prop -> [Prop]
forall a b. RecordMap a b -> [b]
recordElements RecordMap Ident Prop
fs1) (RecordMap Ident Prop -> [Prop]
forall a b. RecordMap a b -> [b]
recordElements RecordMap Ident Prop
fs2)

mgu (TNewtype Newtype
ntx [Prop]
xs) (TNewtype Newtype
nty [Prop]
ys)
  | Newtype
ntx Newtype -> Newtype -> Bool
forall a. Eq a => a -> a -> Bool
== Newtype
nty = [Prop] -> [Prop] -> Result MGU
mguMany [Prop]
xs [Prop]
ys

mgu Prop
t1 Prop
t2
  | Bool -> Bool
not (Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2)  = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
  | Bool
otherwise       = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Prop -> Prop -> UnificationError
UniTypeMismatch Prop
t1 Prop
t2
  where
  k1 :: Kind
k1 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t1
  k2 :: Kind
k2 = Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t2


mguMany :: [Type] -> [Type] -> Result MGU
mguMany :: [Prop] -> [Prop] -> Result MGU
mguMany [] [] = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mguMany (Prop
t1 : [Prop]
ts1) (Prop
t2 : [Prop]
ts2) =
  do (Subst
su1,[Prop]
ps1) <- Prop -> Prop -> Result MGU
mgu Prop
t1 Prop
t2
     (Subst
su2,[Prop]
ps2) <- [Prop] -> [Prop] -> Result MGU
mguMany (Subst -> [Prop] -> [Prop]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Prop]
ts1) (Subst -> [Prop] -> [Prop]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Prop]
ts2)
     MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su2 Subst -> Subst -> Subst
@@ Subst
su1, [Prop]
ps1 [Prop] -> [Prop] -> [Prop]
forall a. [a] -> [a] -> [a]
++ [Prop]
ps2)
mguMany [Prop]
t1 [Prop]
t2 = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Int -> Int -> UnificationError
UniTypeLenMismatch ([Prop] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Prop]
t1) ([Prop] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Prop]
t2)


bindVar :: TVar -> Type -> Result MGU

bindVar :: TVar -> Prop -> Result MGU
bindVar TVar
x (Prop -> Prop
tNoUser -> TVar TVar
y)
  | TVar
x TVar -> TVar -> Bool
forall a. Eq a => a -> a -> Bool
== TVar
y                      = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

bindVar v :: TVar
v@(TVBound {}) (Prop -> Prop
tNoUser -> TVar v1 :: TVar
v1@(TVFree {})) = TVar -> Prop -> Result MGU
bindVar TVar
v1 (TVar -> Prop
TVar TVar
v)

bindVar v :: TVar
v@(TVBound {}) Prop
t
  | Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t = if Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum
                       then MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Prop
TVar TVar
v Prop -> Prop -> Prop
=#= Prop
t])
                       else UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Prop -> UnificationError
UniNonPoly TVar
v Prop
t
  | Bool
otherwise     = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k (Prop -> Kind
forall t. HasKind t => t -> Kind
kindOf Prop
t)
  where k :: Kind
k = TVar -> Kind
forall t. HasKind t => t -> Kind
kindOf TVar
v

bindVar x :: TVar
x@(TVFree Int
_ Kind
xk Set TParam
xscope TVarInfo
_) (Prop -> Prop
tNoUser -> TVar y :: TVar
y@(TVFree Int
_ Kind
yk Set TParam
yscope TVarInfo
_))
  | Set TParam
xscope Set TParam -> Set TParam -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`Set.isProperSubsetOf` Set TParam
yscope, Kind
xk Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
yk =
    MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar -> Prop -> Subst
uncheckedSingleSubst TVar
y (TVar -> Prop
TVar TVar
x), [])
    -- In this case, we can add the reverse binding y ~> x to the
    -- substitution, but the instantiation x ~> y would be forbidden
    -- because it would allow y to escape from its scope.

bindVar TVar
x Prop
t =
  case TVar -> Prop -> Either SubstError Subst
singleSubst TVar
x Prop
t of
    Left SubstError
SubstRecursive
      | TVar -> Kind
forall t. HasKind t => t -> Kind
kindOf TVar
x Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KType -> UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Prop -> UnificationError
UniRecursive TVar
x Prop
t
      | Bool
otherwise -> MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Prop
TVar TVar
x Prop -> Prop -> Prop
=#= Prop
t])
    Left (SubstEscaped [TParam]
tps) ->
      UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> [TParam] -> UnificationError
UniNonPolyDepends TVar
x [TParam]
tps
    Left (SubstKindMismatch Kind
k1 Kind
k2) ->
      UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
    Right Subst
su ->
      MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su, [])