-- | Unification {-# LANGUAGE BangPatterns #-} module AST.Unify ( unify , module AST.Class.Unify , module AST.Unify.Constraints , module AST.Unify.Error , -- | Exported only for SPECIALIZE pragmas updateConstraints, updateTermConstraints, updateTermConstraintsH , unifyUTerms, unifyUnbound ) where import Algebra.PartialOrd (PartialOrd(..)) import AST import AST.Class.Unify (Unify(..), UVarOf, BindingDict(..)) import AST.Class.ZipMatch (zipMatchA) import AST.Unify.Constraints import AST.Unify.Error (UnifyError(..)) import AST.Unify.Lookup (semiPruneLookup) import AST.Unify.Occurs (occursError) import AST.Unify.Term (UTerm(..), UTermBody(..), uConstraints, uBody) import Control.Lens.Operators import Data.Constraint (withDict) import Data.Maybe (fromMaybe) import Data.Proxy (Proxy(..)) import Prelude.Compat -- TODO: implement when need / better understand motivations for - -- occursIn, seenAs, getFreeVars, freshen, equals, equiv -- (from unification-fd package) {-# INLINE updateConstraints #-} updateConstraints :: Unify m t => TypeConstraintsOf t -> Tree (UVarOf m) t -> Tree (UTerm (UVarOf m)) t -> m () updateConstraints !newConstraints v x = case x of UUnbound l | newConstraints `leq` l -> pure () | otherwise -> bindVar binding v (UUnbound newConstraints) USkolem l | newConstraints `leq` l -> pure () | otherwise -> SkolemEscape v & unifyError UTerm t -> updateTermConstraints v t newConstraints UResolving t -> () <$ occursError v t _ -> error "This shouldn't happen in unification stage" {-# INLINE updateTermConstraints #-} updateTermConstraints :: forall m t. Unify m t => Tree (UVarOf m) t -> Tree (UTermBody (UVarOf m)) t -> TypeConstraintsOf t -> m () updateTermConstraints v t newConstraints | newConstraints `leq` (t ^. uConstraints) = pure () | otherwise = withDict (unifyRecursive (Proxy @m) (Proxy @t)) $ do bindVar binding v (UResolving t) case verifyConstraints newConstraints (t ^. uBody) of Nothing -> ConstraintsViolation (t ^. uBody) newConstraints & unifyError Just prop -> do traverseK_ (Proxy @(Unify m) #> updateTermConstraintsH) prop UTermBody newConstraints (t ^. uBody) & UTerm & bindVar binding v {-# INLINE updateTermConstraintsH #-} updateTermConstraintsH :: Unify m t => Tree (WithConstraint (UVarOf m)) t -> m () updateTermConstraintsH (WithConstraint c v0) = do (v1, x) <- semiPruneLookup v0 updateConstraints c v1 x -- | Unify unification variables {-# INLINE unify #-} unify :: forall m t. Unify m t => Tree (UVarOf m) t -> Tree (UVarOf m) t -> m (Tree (UVarOf m) t) unify x0 y0 | x0 == y0 = pure x0 | otherwise = do (x1, xu) <- semiPruneLookup x0 if x1 == y0 then pure x1 else do (y1, yu) <- semiPruneLookup y0 if x1 == y1 then pure x1 else unifyUTerms x1 xu y1 yu {-# INLINE unifyUnbound #-} unifyUnbound :: Unify m t => Tree (UVarOf m) t -> TypeConstraintsOf t -> Tree (UVarOf m) t -> Tree (UTerm (UVarOf m)) t -> m (Tree (UVarOf m) t) unifyUnbound xv level yv yt = do updateConstraints level yv yt yv <$ bindVar binding xv (UToVar yv) {-# INLINE unifyUTerms #-} unifyUTerms :: forall m t. Unify m t => Tree (UVarOf m) t -> Tree (UTerm (UVarOf m)) t -> Tree (UVarOf m) t -> Tree (UTerm (UVarOf m)) t -> m (Tree (UVarOf m) t) unifyUTerms xv (UUnbound level) yv yt = unifyUnbound xv level yv yt unifyUTerms xv xt yv (UUnbound level) = unifyUnbound yv level xv xt unifyUTerms xv USkolem{} yv _ = xv <$ unifyError (SkolemUnified xv yv) unifyUTerms xv _ yv USkolem{} = yv <$ unifyError (SkolemUnified yv xv) unifyUTerms xv (UTerm xt) yv (UTerm yt) = withDict (unifyRecursive (Proxy @m) (Proxy @t)) $ do bindVar binding yv (UToVar xv) zipMatchA (Proxy @(Unify m) #> unify) (xt ^. uBody) (yt ^. uBody) & fromMaybe (xt ^. uBody <$ structureMismatch unify xt yt) >>= bindVar binding xv . UTerm . UTermBody (xt ^. uConstraints <> yt ^. uConstraints) pure xv unifyUTerms _ _ _ _ = error "This shouldn't happen in unification stage"