{-# LANGUAGE Safe #-}
{-# LANGUAGE PatternGuards, ViewPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
module Cryptol.TypeCheck.Unify where
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst
import Control.Monad.Writer (Writer, writer, runWriter)
import Data.Ord(comparing)
import Data.List(sortBy)
import qualified Data.Set as Set
import Prelude ()
import Prelude.Compat
type MGU = (Subst,[Prop])
type Result a = Writer [UnificationError] a
runResult :: Result a -> (a, [UnificationError])
runResult = runWriter
data UnificationError
= UniTypeMismatch Type Type
| UniKindMismatch Kind Kind
| UniTypeLenMismatch Int Int
| UniRecursive TVar Type
| UniNonPolyDepends TVar [TParam]
| UniNonPoly TVar Type
uniError :: UnificationError -> Result MGU
uniError e = writer (emptyMGU, [e])
emptyMGU :: MGU
emptyMGU = (emptySubst, [])
mgu :: Type -> Type -> Result MGU
mgu (TUser c1 ts1 _) (TUser c2 ts2 _)
| c1 == c2 && ts1 == ts2 = return emptyMGU
mgu (TVar x) t = bindVar x t
mgu t (TVar x) = bindVar x t
mgu (TUser _ _ t1) t2 = mgu t1 t2
mgu t1 (TUser _ _ t2) = mgu t1 t2
mgu (TCon (TC tc1) ts1) (TCon (TC tc2) ts2)
| tc1 == tc2 = mguMany ts1 ts2
mgu (TCon (TF f1) ts1) (TCon (TF f2) ts2)
| f1 == f2 && ts1 == ts2 = return emptyMGU
mgu t1 t2
| TCon (TF _) _ <- t1, isNum, k1 == k2 = return (emptySubst, [t1 =#= t2])
| TCon (TF _) _ <- t2, isNum, k1 == k2 = return (emptySubst, [t1 =#= t2])
where
k1 = kindOf t1
k2 = kindOf t2
isNum = k1 == KNum
mgu (TRec fs1) (TRec fs2)
| ns1 == ns2 = mguMany ts1 ts2
where
(ns1,ts1) = sortFields fs1
(ns2,ts2) = sortFields fs2
sortFields = unzip . sortBy (comparing fst)
mgu t1 t2
| not (k1 == k2) = uniError $ UniKindMismatch k1 k2
| otherwise = uniError $ UniTypeMismatch t1 t2
where
k1 = kindOf t1
k2 = kindOf t2
mguMany :: [Type] -> [Type] -> Result MGU
mguMany [] [] = return emptyMGU
mguMany (t1 : ts1) (t2 : ts2) =
do (su1,ps1) <- mgu t1 t2
(su2,ps2) <- mguMany (apSubst su1 ts1) (apSubst su1 ts2)
return (su2 @@ su1, ps1 ++ ps2)
mguMany t1 t2 = uniError $ UniTypeLenMismatch (length t1) (length t2)
bindVar :: TVar -> Type -> Result MGU
bindVar x (tNoUser -> TVar y)
| x == y = return emptyMGU
bindVar v@(TVBound {}) (tNoUser -> TVar v1@(TVFree {})) = bindVar v1 (TVar v)
bindVar v@(TVBound {}) t
| k == kindOf t = if k == KNum
then return (emptySubst, [TVar v =#= t])
else uniError $ UniNonPoly v t
| otherwise = uniError $ UniKindMismatch k (kindOf t)
where k = kindOf v
bindVar x@(TVFree _ _ xscope _) (TVar y@(TVFree _ _ yscope _))
| xscope `Set.isProperSubsetOf` yscope = return (singleSubst y (TVar x), [])
bindVar x@(TVFree _ k inScope _d) t
| not (k == kindOf t) = uniError $ UniKindMismatch k (kindOf t)
| recTy && k == KType = uniError $ UniRecursive x t
| not (Set.null escaped) = uniError $ UniNonPolyDepends x $ Set.toList escaped
| recTy = return (emptySubst, [TVar x =#= t])
| otherwise = return (singleSubst x t, [])
where
escaped = freeParams t `Set.difference` inScope
recTy = x `Set.member` fvs t
freeParams :: FVS t => t -> Set.Set TParam
freeParams x = Set.unions (map params (Set.toList (fvs x)))
where
params (TVFree _ _ tps _) = tps
params (TVBound tp) = Set.singleton tp