-- | -- Module : Cryptol.TypeCheck.Unify -- Copyright : (c) 2013-2016 Galois, Inc. -- License : BSD3 -- Maintainer : cryptol@galois.com -- Stability : provisional -- Portability : portable {-# LANGUAGE Safe #-} {-# LANGUAGE PatternGuards, ViewPatterns #-} {-# LANGUAGE DeriveFunctor #-} module Cryptol.TypeCheck.Unify where import Cryptol.TypeCheck.AST import Cryptol.TypeCheck.Subst import Control.Monad.Writer (Writer, writer, runWriter) import Data.Ord(comparing) import Data.List(sortBy) import qualified Data.Set as Set import Prelude () import Prelude.Compat -- | The most general unifier is a substitution and a set of constraints -- on bound variables. type MGU = (Subst,[Prop]) type Result a = Writer [UnificationError] a runResult :: Result a -> (a, [UnificationError]) runResult = runWriter data UnificationError = UniTypeMismatch Type Type | UniKindMismatch Kind Kind | UniTypeLenMismatch Int Int | UniRecursive TVar Type | UniNonPolyDepends TVar [TParam] | UniNonPoly TVar Type uniError :: UnificationError -> Result MGU uniError e = writer (emptyMGU, [e]) emptyMGU :: MGU emptyMGU = (emptySubst, []) mgu :: Type -> Type -> Result MGU mgu (TUser c1 ts1 _) (TUser c2 ts2 _) | c1 == c2 && ts1 == ts2 = return emptyMGU mgu (TVar x) t = bindVar x t mgu t (TVar x) = bindVar x t mgu (TUser _ _ t1) t2 = mgu t1 t2 mgu t1 (TUser _ _ t2) = mgu t1 t2 mgu (TCon (TC tc1) ts1) (TCon (TC tc2) ts2) | tc1 == tc2 = mguMany ts1 ts2 mgu (TCon (TF f1) ts1) (TCon (TF f2) ts2) | f1 == f2 && ts1 == ts2 = return emptyMGU mgu t1 t2 | TCon (TF _) _ <- t1, isNum, k1 == k2 = return (emptySubst, [t1 =#= t2]) | TCon (TF _) _ <- t2, isNum, k1 == k2 = return (emptySubst, [t1 =#= t2]) where k1 = kindOf t1 k2 = kindOf t2 isNum = k1 == KNum mgu (TRec fs1) (TRec fs2) | ns1 == ns2 = mguMany ts1 ts2 where (ns1,ts1) = sortFields fs1 (ns2,ts2) = sortFields fs2 sortFields = unzip . sortBy (comparing fst) mgu t1 t2 | not (k1 == k2) = uniError $ UniKindMismatch k1 k2 | otherwise = uniError $ UniTypeMismatch t1 t2 where k1 = kindOf t1 k2 = kindOf t2 mguMany :: [Type] -> [Type] -> Result MGU mguMany [] [] = return emptyMGU mguMany (t1 : ts1) (t2 : ts2) = do (su1,ps1) <- mgu t1 t2 (su2,ps2) <- mguMany (apSubst su1 ts1) (apSubst su1 ts2) return (su2 @@ su1, ps1 ++ ps2) mguMany t1 t2 = uniError $ UniTypeLenMismatch (length t1) (length t2) bindVar :: TVar -> Type -> Result MGU bindVar x (tNoUser -> TVar y) | x == y = return emptyMGU bindVar v@(TVBound {}) (tNoUser -> TVar v1@(TVFree {})) = bindVar v1 (TVar v) bindVar v@(TVBound {}) t | k == kindOf t = if k == KNum then return (emptySubst, [TVar v =#= t]) else uniError $ UniNonPoly v t | otherwise = uniError $ UniKindMismatch k (kindOf t) where k = kindOf v bindVar x@(TVFree _ _ xscope _) (TVar y@(TVFree _ _ yscope _)) | xscope `Set.isProperSubsetOf` yscope = return (singleSubst y (TVar x), []) bindVar x@(TVFree _ k inScope _d) t | not (k == kindOf t) = uniError $ UniKindMismatch k (kindOf t) | recTy && k == KType = uniError $ UniRecursive x t | not (Set.null escaped) = uniError $ UniNonPolyDepends x $ Set.toList escaped | recTy = return (emptySubst, [TVar x =#= t]) | otherwise = return (singleSubst x t, []) where escaped = freeParams t `Set.difference` inScope recTy = x `Set.member` fvs t freeParams :: FVS t => t -> Set.Set TParam freeParams x = Set.unions (map params (Set.toList (fvs x))) where params (TVFree _ _ tps _) = tps params (TVBound tp) = Set.singleton tp