{-# OPTIONS_GHC -fno-warn-name-shadowing #-} module Statics.Constraint ( -- * The constraint solver interface MonadConstraint(..), generalize, generalizeList, generalizeEx, -- * An implementation of the interface ConstraintT, runConstraintT, mapConstraintT, ConstraintState, constraintState0, pprConstraintState, runConstraintIO, ) where import Util import Util.Trace import Util.MonadRef import qualified Syntax.Ppr as Ppr import qualified Alt.Graph as Gr import qualified Data.UnionFind as UF import Type import Statics.Error import Prelude () import qualified Data.List as List import qualified Data.Set as S import qualified Data.Map as M import qualified Data.Boolean.SatSolver as SAT import Data.IORef (IORef) --- --- A CONSTRAINT-SOLVING MONAD --- class MonadSubst tv r m ⇒ MonadConstraint tv r m | m → tv r where -- | Subtype and equality constraints (<:), (=:) ∷ Type tv → Type tv → m () -- | Subqualifier constraint (⊏:), (~:) ∷ (Qualifier q1 tv, Qualifier q2 tv) ⇒ q1 → q2 → m () -- | Constrain by the given variance relate ∷ Variance → Type tv → Type tv → m () -- τ1 =: τ2 = τ1 <: τ2 >> τ2 <: τ1 τ1 ~: τ2 = τ1 ⊏: τ2 >> τ2 ⊏: τ1 relate variance τ1 τ2 = case variance of Covariant → τ1 <: τ2 Contravariant → τ2 <: τ1 Invariant → τ1 =: τ2 QCovariant → τ1 ⊏: τ2 QContravariant → τ2 ⊏: τ1 QInvariant → τ1 ~: τ2 Omnivariant → return () -- -- | Get the set of pinned type variables. getPinnedTVs ∷ m (S.Set tv) -- | Run a computation in the context of some "pinned down" type -- variables, which means that they won't be considered for -- elimination or generalization. withPinnedTVs ∷ Ftv a tv ⇒ a → m b → m b -- | Update the list of pinned type variables to reflect a substitution. -- PRECONDITION: τ is substituted. updatePinnedTVs ∷ tv → Type tv → m () -- -- | Figure out which variables to generalize in a piece of syntax. -- The 'Bool' indicates whether the syntax whose type is being -- generalized is a syntactic value. Returns a list of -- generalizable variables and their qualifier bounds. generalize' ∷ Bool → Rank → Type tv → m [(tv, QLit)] -- | Find 'QLit' bounds for a set of type variables. This assumes -- that these variables may safely be removed from the constraint -- if bounded as specified. In particular, all the variables must -- appear only on the left-hand side of the qualifier inequalities. getTVBounds ∷ [tv] → m [QLit] -- | Ensure that the current constraint is satisfiable. This is -- necessary after each REPL entry, because that's the commit point -- for the constraint, and the REPL becomes unusable if a particular -- type error hangs around in the constraint forever. ensureSatisfiability ∷ m () infix 5 <:, =:, ⊏:, ~: -- -- Pass-through instances -- instance (MonadConstraint tv s m, Monoid w) ⇒ MonadConstraint tv s (WriterT w m) where (<:) = lift <$$> (<:) (=:) = lift <$$> (=:) (⊏:) = lift <$$> (⊏:) (~:) = lift <$$> (~:) getPinnedTVs = lift getPinnedTVs withPinnedTVs = mapWriterT <$> withPinnedTVs updatePinnedTVs= lift <$$> updatePinnedTVs generalize' = lift <$$$> generalize' getTVBounds = lift <$> getTVBounds ensureSatisfiability = lift ensureSatisfiability instance MonadConstraint tv r m ⇒ MonadConstraint tv r (StateT s m) where (<:) = lift <$$> (<:) (=:) = lift <$$> (=:) (⊏:) = lift <$$> (⊏:) (~:) = lift <$$> (~:) getPinnedTVs = lift getPinnedTVs withPinnedTVs = mapStateT <$> withPinnedTVs updatePinnedTVs= lift <$$> updatePinnedTVs generalize' = lift <$$$> generalize' getTVBounds = lift <$> getTVBounds ensureSatisfiability = lift ensureSatisfiability instance MonadConstraint tv p m ⇒ MonadConstraint tv p (ReaderT r m) where (<:) = lift <$$> (<:) (=:) = lift <$$> (=:) (⊏:) = lift <$$> (⊏:) (~:) = lift <$$> (~:) getPinnedTVs = lift getPinnedTVs withPinnedTVs = mapReaderT <$> withPinnedTVs updatePinnedTVs= lift <$$> updatePinnedTVs generalize' = lift <$$$> generalize' getTVBounds = lift <$> getTVBounds ensureSatisfiability = lift ensureSatisfiability instance (MonadConstraint tv p m, Monoid w) ⇒ MonadConstraint tv p (RWST r w s m) where (<:) = lift <$$> (<:) (=:) = lift <$$> (=:) (⊏:) = lift <$$> (⊏:) (~:) = lift <$$> (~:) getPinnedTVs = lift getPinnedTVs withPinnedTVs = mapRWST <$> withPinnedTVs updatePinnedTVs= lift <$$> updatePinnedTVs generalize' = lift <$$$> generalize' getTVBounds = lift <$> getTVBounds ensureSatisfiability = lift ensureSatisfiability -- -- Some generic operations -- -- | Generalize a type under a constraint and environment, -- given whether the the value restriction is satisfied or not generalize ∷ MonadConstraint tv r m ⇒ Bool → Rank → Type tv → m (Type tv) generalize value γrank ρ = do αqs ← generalize' value γrank ρ standardizeMus <$> closeQuant Forall αqs <$> subst ρ -- | Generalize a list of types together. generalizeList ∷ MonadConstraint tv r m ⇒ Bool → Rank → [Type tv] → m [Type tv] generalizeList value γrank ρs = do αqs ← generalize' value γrank (foldl tyTuple tyUnit ρs) mapM (standardizeMus <$> closeQuant Forall αqs <$$> subst) ρs -- | Generalize the existential type variables in a type generalizeEx ∷ MonadConstraint tv r m ⇒ Rank → Type tv → m (Type tv) generalizeEx γrank ρ0 = do ρ ← subst ρ0 αs ← removeByRank γrank (filter (tvFlavorIs Existential) (ftvList ρ)) αqs ← mapM addQual αs return (closeQuant Exists αqs ρ) where addQual α = case tvQual α of Just ql → return (α, ql) Nothing → typeBug "generalizeEx" "existential type variable with no rank" -- | Remove type variables from a list if their rank indicates that -- they're in the environment or if they're pinned removeByRank ∷ MonadConstraint tv r m ⇒ Rank → [tv] → m [tv] removeByRank γrank αs = do pinned ← getPinnedTVs let keep α = do rank ← getTVRank α return (rank > γrank && α `S.notMember` pinned) filterM keep αs --- --- SUBTYPING CONSTRAINT SOLVER --- -- -- The internal state -- -- | The state of the constraint solver. data CTState tv r = CTState { -- | Graph for subtype constraints on type variables and atomic -- type constructors csGraph ∷ !(Gr.Gr tv ()), -- | Reverse lookup for turning atoms into node numbers for the -- 'csGraph' graph csNodeMap ∷ !(Gr.NodeMap tv), -- | Maps type variables to same-size equivalence classes csEquivs ∷ !(ProxyMap tv r), -- | Types to relate by the subqualifier relation csQuals ∷ ![(Type tv, Type tv)], -- | Stack of pinned type variables csPinned ∷ ![S.Set tv] } -- | Representation of type variable equivalence class type TVProxy tv r = UF.Proxy r (S.Set tv) -- | The map from type variables to equivalence classes type ProxyMap tv r = M.Map tv (TVProxy tv r) -- | Updater for 'csQuals' field csQualsUpdate ∷ ([(Type tv, Type tv)] → [(Type tv, Type tv)]) → CTState tv r → CTState tv r csQualsUpdate f cs = cs { csQuals = f (csQuals cs) } -- | Updater for 'csEquivs' field csEquivsUpdate ∷ (ProxyMap tv r → ProxyMap tv r) → CTState tv r → CTState tv r csEquivsUpdate f cs = cs { csEquivs = f (csEquivs cs) } -- | Updater for 'csPinned' field csPinnedUpdate ∷ ([S.Set tv] → [S.Set tv]) → CTState tv r → CTState tv r csPinnedUpdate f cs = cs { csPinned = f (csPinned cs) } instance Tv tv ⇒ Show (CTState tv r) where showsPrec _ cs | null (Gr.edges (csGraph cs)) , null (csQuals cs) = id | otherwise = showString "CTState { csGraph = " . shows (Gr.ShowGraph (csGraph cs)) . showString ", csQuals = " . shows (csQuals cs) . showString " }" instance Tv tv ⇒ Ppr.Ppr (CTState tv r) where ppr cs = Ppr.ppr . M.fromList $ [ ("graph", Ppr.fsep . Ppr.punctuate Ppr.comma $ [ Ppr.pprPrec 10 α1 Ppr.<> Ppr.text "<:" Ppr.<> Ppr.pprPrec 10 α2 | (α1, α2) ← Gr.labNodeEdges (csGraph cs) ]) , ("quals", Ppr.fsep . Ppr.punctuate Ppr.comma $ [ Ppr.pprPrec 9 τ1 Ppr.<> Ppr.char '⊑' Ppr.<> Ppr.pprPrec 9 τ2 | (τ1, τ2) ← csQuals cs ]) ] -- -- The monad transformer -- -- | Underlying 'ConstraintT' is a monad transformer that carries merely -- the constraint-solving state. newtype ConstraintT_ tv r m a = ConstraintT_ { unConstraintT_ ∷ StateT (CTState tv r) m a } deriving (Functor, Applicative, Monad, MonadAlmsError, MonadTrace) -- | Map some higher-order operation through 'ConstraintT_'. mapConstraintT_ ∷ (∀ s. m (a, s) → n (b, s)) → ConstraintT_ tv r m a → ConstraintT_ tv r n b mapConstraintT_ f = ConstraintT_ . mapStateT f . unConstraintT_ -- | Constraint monad transformer carries constraint solver state. -- 'SubstT' substitution state. type ConstraintT tv r m = ConstraintT_ tv r (SubstT r m) -- | Map some higher-order operation through 'ConstraintT'. mapConstraintT ∷ (Functor m, Functor n) ⇒ (∀ s. m (a, s) → n (b, s)) → ConstraintT tv r m a → ConstraintT tv r n b mapConstraintT f = mapConstraintT_ (mapSubstT f') where f' action = unshift <$> f (shift <$> action) shift ((a, s), s') = (a, (s, s')) unshift (a, (s, s')) = ((a, s), s') -- | Run the constraint solver. runConstraintT ∷ (MonadAlmsError m, MonadRef r m) ⇒ ConstraintState (TV r) r → ConstraintT (TV r) r m a → m (a, ConstraintState (TV r) r) runConstraintT ecs m = do ((result, cs), ss) ← runSubstT (ecsSubst ecs) (runStateT (unConstraintT_ (resetEquivTVs >> m)) (ecsInternal ecs)) return (result, ExternalConstraintState cs ss) -- | Run a constraint computation in the IO Monad runConstraintIO ∷ ConstraintState (TV IORef) IORef → ConstraintT (TV IORef) IORef (AlmsErrorT IO) a → IO (Either [AlmsError] (a, ConstraintState (TV IORef) IORef)) runConstraintIO ecs m = runAlmsErrorT (runConstraintT ecs m) -- | The external representation of the constraint solver's state data ConstraintState tv r = ExternalConstraintState { ecsInternal ∷ !(CTState tv r), ecsSubst ∷ !SubstState } -- | The initial constraint solver state constraintState0 ∷ Tv tv ⇒ ConstraintState tv r constraintState0 = ExternalConstraintState { ecsInternal = CTState { csGraph = Gr.empty, csNodeMap = Gr.new, csEquivs = M.empty, csQuals = [], csPinned = [] }, ecsSubst = substState0 } instance Tv tv ⇒ Ppr.Ppr (ConstraintState tv r) where ppr = Ppr.ppr . ecsInternal instance Tv tv ⇒ Show (ConstraintState tv r) where showsPrec = Ppr.showFromPpr -- | Get a printable representations of the internal constraint-solving -- state. pprConstraintState ∷ Tv tv ⇒ ConstraintState tv r → Ppr.Doc pprConstraintState = Ppr.ppr . ecsInternal -- -- Instances -- -- | Transformer instance instance MonadTrans (ConstraintT_ tv r) where lift = ConstraintT_ . lift -- | Pass through for reference operations instance MonadSubst tv r m ⇒ MonadRef r (ConstraintT_ tv r m) where newRef = lift <$> newRef readRef = lift <$> readRef writeRef = lift <$$> writeRef -- | Pass through for unification operations instance MonadSubst tv r m ⇒ MonadSubst tv r (ConstraintT_ tv r m) where newTV_ (Universal, kind, bound, descr) = do α ← lift (newTV' (kind, descr)) fvTy α ⊏: bound return α newTV_ attrs = lift (newTV' attrs) writeTV_ = lift <$$> writeTV_ readTV_ = lift <$> readTV_ getTVRank_ = lift <$> getTVRank_ setTVRank_ = lift <$$> setTVRank_ collectTVs = mapConstraintT_ (mapListen2 collectTVs) reportTVs = lift . reportTVs monitorChange = mapConstraintT_ (mapListen2 monitorChange) setChanged = lift setChanged -- | 'ConstraintT' implements 'Graph'/'NodeMap' transformer operations -- for accessing its graph and node map. instance (Ord tv, Monad m) ⇒ Gr.MonadNM tv () Gr.Gr (ConstraintT_ tv r m) where getNMState = ConstraintT_ (gets (csNodeMap &&& csGraph)) getNodeMap = ConstraintT_ (gets csNodeMap) getGraph = ConstraintT_ (gets csGraph) putNMState (nm, g) = ConstraintT_ . modify $ \cs → cs { csNodeMap = nm, csGraph = g } putNodeMap nm = ConstraintT_ . modify $ \cs → cs { csNodeMap = nm } putGraph g = ConstraintT_ . modify $ \cs → cs { csGraph = g } -- | Constraint solver implementation. instance MonadSubst tv r m ⇒ MonadConstraint tv r (ConstraintT_ tv r m) where τ <: τ' = do traceN 3 ("<:", τ, τ') runSeenT (subtypeTypes False τ τ') τ =: τ' = do traceN 3 ("=:", τ, τ') runSeenT (subtypeTypes True τ τ') τ ⊏: τ' = do traceN 3 ("⊏:", qualToType τ, qualToType τ') addQualConstraint τ τ' -- getPinnedTVs = S.unions <$> ConstraintT_ (gets csPinned) -- withPinnedTVs a m = do let αs = ftvSet a ConstraintT_ (modify (csPinnedUpdate (αs :))) result ← m ConstraintT_ (modify (csPinnedUpdate tail)) return result -- updatePinnedTVs α τ = do let βs = ftvSet τ update = snd . mapAccumR eachSet False eachSet False set | α `S.member` set = (True, βs `S.union` S.delete α set) eachSet done set = (done, set) ConstraintT_ (modify (csPinnedUpdate update)) -- generalize' = solveConstraint getTVBounds = solveBounds ensureSatisfiability = checkQualifiers {-# INLINE gtraceN #-} gtraceN ∷ (TraceMessage a, Tv tv, MonadTrace m) ⇒ Int → a → ConstraintT_ tv r m () gtraceN = if debug then \n info → if n <= debugLevel then do trace info cs ← ConstraintT_ get let doc = Ppr.ppr cs unless (Ppr.isEmpty doc) $ trace (Ppr.nest 4 doc) else return () else \_ _ → return () -- | Monad transformer for tracking which type comparisons we've seen, -- in order to implement recursive subtyping. A pair of types mapped -- to @True@ means that they're known to be equal, whereas @False@ -- means that they're only known to be subtyped. type SeenT tv r m = StateT (M.Map (Type tv, Type tv) Bool) (ConstraintT_ tv r m) -- | Run a recursive subtyping computation. runSeenT ∷ (Tv tv, MonadTrace m) ⇒ SeenT tv r m a → ConstraintT_ tv r m a runSeenT m = do gtraceN 4 "runSeenT {" result ← evalStateT m M.empty gtraceN 4 "} runSeenT" return result -- | Relate two types at either subtyping or equality, depending on -- the value of the first parameter (@True@ means equality). -- This eagerly solves as much as possible, adding to the constraint -- only as necessary. subtypeTypes ∷ MonadSubst tv r m ⇒ Bool → Type tv → Type tv → SeenT tv r m () subtypeTypes unify = check where check τ1 τ2 = do lift $ gtraceN 4 ("subtypeTypes", unify, τ1, τ2) τ1' ← subst τ1 τ2' ← subst τ2 seen ← get unless (M.lookup (τ1', τ2') seen >= Just unify) $ do put (M.insert (τ1', τ2') unify seen) decomp τ1' τ2' -- decomp τ1 τ2 = case (τ1, τ2) of (TyVar v1, TyVar v2) | v1 == v2 → return () (TyVar (Free r1), TyVar (Free r2)) | tvFlavorIs Universal r1, tvFlavorIs Universal r2 → if unify then unifyVar r1 (fvTy r2) else do lift (makeEquivTVs r1 r2) addEdge r1 r2 (TyVar (Free r1), _) | tvFlavorIs Universal r1 → occursCheck r1 τ2 decomp $ \τ2'' → do τ2' ← if unify then return τ2'' else copyType τ2'' unifyVar r1 τ2' unless unify (check τ2' τ2) (_, TyVar (Free r2)) | tvFlavorIs Universal r2 → do occursCheck r2 τ1 (flip decomp) $ \τ1'' → do τ1' ← if unify then return τ1'' else copyType τ1'' unifyVar r2 τ1' unless unify (check τ1 τ1') (TyQu Forall αs1 τ1', TyQu Forall αs2 τ2') | if unify then αs1 == αs2 else length αs1 == length αs2 && and (zipWith ((⊒)`on`snd) αs1 αs2) → check τ1' τ2' (TyQu Exists αs1 τ1', TyQu Exists αs2 τ2') | αs1 == αs2 → check τ1' τ2' (TyApp tc1 τs1, TyApp tc2 τs2) | tc1 == tc2 && tc1 /= tcRowMap && length τs1 == length τs2 → sequence_ [ if unify then if isQVariance var then τ1' ~: τ2' else check τ1' τ2' else relateTypes var τ1' τ2' | var ← tcArity tc1 | τ1' ← τs1 | τ2' ← τs2 ] (TyRow n1 τ11 τ12, TyRow n2 τ21 τ22) | n1 == n2 → do check τ11 τ21 check τ12 τ22 | otherwise → do α ← newTVTy check (TyRow n1 τ11 α) τ22 β ← newTVTy check τ12 (TyRow n2 τ21 β) check α β (TyMu _ τ1', _) → decomp (openTy 0 [τ1] τ1') τ2 (_, TyMu _ τ2') → decomp τ1 (openTy 0 [τ2] τ2') _ | Just (τ1', τ2') ← matchReduce τ1 τ2 → check τ1' τ2' (TyApp tc1 [τ11, τ12], TyApp tc2 [τ21, τ22]) | tc1 == tcRowMap && tc2 == tcRowMap → do check τ11 τ21 check τ12 τ22 _ | otherwise → tErrExp (if unify then [msg| Cannot unify: |] else [msg| Cannot subtype: |]) (pprMsg τ1) (pprMsg τ2) -- addEdge a1 a2 = do Gr.insNewMapNodeM a1 Gr.insNewMapNodeM a2 Gr.insMapEdgeM (a1, a2, ()) lift (fvTy a1 ⊏: fvTy a2) -- | Relate two types at the given variance. relateTypes ∷ MonadSubst tv r m ⇒ Variance → Type tv → Type tv → SeenT tv r m () relateTypes var = case var of Invariant → subtypeTypes True Covariant → subtypeTypes False Contravariant → flip (subtypeTypes False) QInvariant → (~:) QCovariant → (⊏:) QContravariant→ flip (⊏:) Omnivariant → \_ _ → return () -- | Copy a type while replacing all the type variables with fresh ones -- of the same kind. copyType ∷ MonadSubst tv r m ⇒ Type tv → m (Type tv) copyType = foldTypeM (mkQuF (return <$$$> TyQu)) (mkBvF (return <$$$> bvTy)) fvar fcon (return <$$$> TyRow) (mkMuF (return <$$> TyMu)) where fvar α | tvFlavorIs Universal α = newTVTy' (tvKind α) | otherwise = return (fvTy α) -- Nullary type constructors that are involved in the atomic subtype -- relation are converted to type variables: fcon tc τs = TyApp tc <$> sequence [ -- A Q-variant type constructor parameter becomes a single -- type variable: if isQVariance var then newTVTy' KdQual else return τ | τ ← τs | var ← tcArity tc ] -- | Unify a type variable with a type, where the type must be locally -- closed. -- ASSUMPTIONS: @α@ has not been substituted and the occurs check has -- already passed. unifyVar ∷ MonadSubst tv r m ⇒ tv → Type tv → SeenT tv r m () unifyVar α τ0 = do lift $ gtraceN 4 ("unifyVar", α, τ0) τ ← subst τ0 tassert (lcTy 0 τ) [msg| Cannot unify because a $τ is insufficiently polymorphic |] writeTV α τ lift (updatePinnedTVs α τ) (n, _) ← Gr.mkNodeM α gr ← Gr.getGraph case Gr.match n gr of (Nothing, _) → return () (Just (pres, _, _, sucs), gr') → do Gr.putGraph gr' sequence_ $ [ case Gr.lab gr' n' of Nothing → return () Just a → subtypeTypes False (fvTy a) τ | (_, n') ← pres ] ++ [ case Gr.lab gr' n' of Nothing → return () Just a → subtypeTypes False τ (fvTy a) | (_, n') ← sucs ] --- OCCURS CHECK -- | Performs the occurs check and returns a type suitable for unifying -- with the given type variable, if possible. This does the subtyping -- occurs check, which checks not in terms of type variables but in -- terms of same-size equivalence classes of type variables. -- Unification is possible if all occurrences of type variables -- size-equivalent to @α@ appear guarded by a type constructor that -- permits recursion, in which case we roll up @τ@ as a recursive type -- and return that. occursCheck ∷ MonadSubst tv r m ⇒ tv → Type tv → (Type tv → Type tv → SeenT tv r m ()) → (Type tv → SeenT tv r m ()) → SeenT tv r m () occursCheck α τ0 kv kt = do lift (gtraceN 3 ("occursCheck", α, τ0)) loop S.empty τ0 where loop seen τ = do let (guarded, unguarded) = (M.keys***M.keys) . M.partition id $ ftvG τ apparentCycle ← lift $ anyA (checkEquivTVs α) unguarded if apparentCycle then case headReduceType τ of Next τ'@(TyVar (Free _)) → kv (fvTy α) τ' Next τ' | τ' ∉ seen → loop (S.insert τ' seen) τ' _ → -- | This type error has to throw because continuing will -- likely cause the type checker to diverge. typeError' [msg| Occurs check failed. Cannot construct an infinite type when unifying: