-- | Functions to separate a set of wanted constrains into groups of -- constraints that require being solved together. module Control.Supermonad.Plugin.Separation ( separateContraints , componentTopTyCons , componentTopTcVars , componentMonoTyCon ) where import Data.Maybe ( fromJust ) import qualified Data.Set as S import Control.Monad ( filterM, liftM2 ) import Data.Graph.Inductive.Graph ( LNode, Edge , mkGraph, toLEdge ) import Data.Graph.Inductive.PatriciaTree ( Gr ) import Data.Graph.Inductive.Query.DFS ( components ) import TcRnTypes ( Ct ) import TyCon ( TyCon ) import Type ( Type, TyVar ) import TcType ( isAmbiguousTyVar ) import Control.Supermonad.Plugin.Constraint ( WantedCt, constraintClassTyArgs ) import Control.Supermonad.Plugin.Utils ( collectTopTyCons, collectTopTcVars, anyM ) import Control.Supermonad.Plugin.Environment ( SupermonadPluginM ) import Control.Supermonad.Plugin.Environment.Lift ( isReturnConstraint, isBindConstraint ) type SCNode = LNode WantedCt -- | Checks if the given component only involved exactly one top-level type constructor -- in its supermonad constraints. componentMonoTyCon :: [WantedCt] -> SupermonadPluginM (Maybe TyCon) componentMonoTyCon cts = do -- Find all of the return and bind constraints smCts <- filterM (\ct -> liftM2 (||) (isReturnConstraint ct) (isBindConstraint ct)) cts -- Get the polymorphic type constructors let tyVars = S.filter (not . isAmbiguousTyVar) $ componentTopTcVars smCts -- Get the concrete type constructors let tyCons = componentTopTyCons smCts return $ case (S.toList tyCons, S.size tyVars) of ([tc], 0) -> Just tc _ -> Nothing -- | Collect all top-level type constructors for the given list of -- wanted constraints. See 'collectTopTyCons'. componentTopTyCons :: [WantedCt] -> S.Set TyCon componentTopTyCons = collect collectTopTyCons -- | Collect all top-level type constructors variables for the given list of -- wanted constraints. See 'collectTopTyCons'. componentTopTcVars :: [WantedCt] -> S.Set TyVar componentTopTcVars = collect collectTopTcVars -- | Utility function that applies the given collection functions to all -- type arguments of the given constraints and returns a list of the -- collected results. Duplicates are removed from the result list. collect :: (Ord a) => ([Type] -> S.Set a) -> [Ct] -> S.Set a collect f cts = S.unions $ fmap collectLocal cts where -- collectLocal :: WantedCt -> S.Set a collectLocal ct = maybe S.empty id $ fmap f $ constraintClassTyArgs ct -- | Creates a graph of the constraints and how they are -- conntected by their top-level ambiguous type constructor variables. -- Returns the connected components of that graph. -- These components represent the groups of constraints that are in need of -- solving and have to be handeled together. separateContraints :: [WantedCt] -> SupermonadPluginM [[WantedCt]] separateContraints wantedCts = filterM containsBindOrReturn comps where -- | Checks if the given constraint group contains any 'Bind' or 'Return' -- constraints. containsBindOrReturn :: [WantedCt] -> SupermonadPluginM Bool containsBindOrReturn = anyM $ \ct -> liftM2 (||) (isBindConstraint ct) (isReturnConstraint ct) comps :: [[WantedCt]] comps = fmap (\n -> fromJust $ lookup n nodes) <$> components g g :: Gr WantedCt () g = mkGraph nodes (fmap (\e -> toLEdge e ()) edges) -- | Each constraint is a node. nodes :: [SCNode] nodes = zip [0..] wantedCts -- | An edge between two constraints exists when they have a common -- top-level type constructor or type constructor variables in their -- type arguments. edges :: [Edge] edges = [ e | e <- allEdgesFor nodes, isEdge e ] -- | Returns 'True' if the given edge is an edge of the graph. isEdge :: Edge -> Bool isEdge (na, nb) = maybe False id $ do -- Lookup the constraints associated with the nodes in the given -- edge and keep their type arguments. caArgs <- lookup na nodes >>= constraintClassTyArgs cbArgs <- lookup nb nodes >>= constraintClassTyArgs -- Collect all top level type constructors and type constructor variables -- in the type arguments. let ta = S.filter isAmbiguousTyVar $ collectTopTcVars caArgs let tb = S.filter isAmbiguousTyVar $ collectTopTcVars cbArgs -- If there is an element in the intersection of these sets return $ not $ S.null $ S.intersection ta tb -- | Returns the edges for a complete undirected graph of the given nodes. allEdgesFor :: [SCNode] -> [Edge] allEdgesFor [] = [] allEdgesFor (n : ns) = concat [ fmap (\m -> (m, fst n)) (fmap fst ns) , fmap (\m -> (fst n, m)) (fmap fst ns) , allEdgesFor ns ]