{-| Copyright : (C) 2015-2016, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij A type checker plugin for GHC that can solve /equalities/ of types of kind 'GHC.TypeLits.Nat', where these types are either: * Type-level naturals * Type variables * Applications of the arithmetic expressions @(+,-,*,^)@. It solves these equalities by normalising them to /sort-of/ 'GHC.TypeLits.Normalise.SOP.SOP' (Sum-of-Products) form, and then perform a simple syntactic equality. For example, this solver can prove the equality between: @ (x + 2)^(y + 2) @ and @ 4*x*(2 + x)^y + 4*(2 + x)^y + (2 + x)^y*x^2 @ Because the latter is actually the 'GHC.TypeLits.Normalise.SOP.SOP' normal form of the former. To use the plugin, add @ {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise \#-\} @ To the header of your file. -} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} {-# OPTIONS_HADDOCK show-extensions #-} module GHC.TypeLits.Normalise ( plugin ) where -- external import Control.Arrow (second) import Data.IORef (IORef, newIORef,readIORef, modifyIORef) import Data.List (intersect) import Data.Maybe (catMaybes, mapMaybe) import GHC.TcPluginM.Extra (tracePlugin) -- GHC API import Outputable (Outputable (..), (<+>), ($$), text) import Plugins (Plugin (..), defaultPlugin) import TcEvidence (EvTerm (..)) import TcPluginM (TcPluginM, tcPluginIO, tcPluginTrace, zonkCt) import TcRnTypes (Ct, TcPlugin (..), TcPluginResult(..), ctEvidence, ctEvPred, ctPred, isWanted, mkNonCanonical) import Type (EqRel (NomEq), Kind, PredTree (EqPred), PredType, TyVar, classifyPredType, eqType, getEqPredTys, mkTyVarTy) import TysWiredIn (typeNatKind) import Coercion (CoercionHole, Role (..), mkForAllCos, mkHoleCo, mkInstCo, mkNomReflCo, mkUnivCo) import TcPluginM (newCoercionHole, newFlexiTyVar) import TcRnTypes (CtEvidence (..), TcEvDest (..), ctLoc) import TyCoRep (UnivCoProvenance (..)) import Type (mkPrimEqPred) import TcType (typeKind) import TyCoRep (Type (..)) import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon, typeNatSubTyCon) import TcTypeNats (typeNatLeqTyCon) import Type (mkNumLitTy,mkTyConApp) import TysWiredIn (promotedFalseDataCon, promotedTrueDataCon) -- internal import GHC.TypeLits.Normalise.Unify -- | To use the plugin, add -- -- @ -- {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise \#-\} -- @ -- -- To the header of your file. plugin :: Plugin plugin = defaultPlugin { tcPlugin = const $ Just normalisePlugin } normalisePlugin :: TcPlugin normalisePlugin = tracePlugin "ghc-typelits-natnormalise" TcPlugin { tcPluginInit = tcPluginIO $ newIORef [] , tcPluginSolve = decideEqualSOP , tcPluginStop = const (return ()) } decideEqualSOP :: IORef [Ct] -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult decideEqualSOP _ _givens _deriveds [] = return (TcPluginOk [] []) decideEqualSOP discharged givens _deriveds wanteds = do -- GHC 7.10.1 puts deriveds with the wanteds, so filter them out let wanteds' = filter (isWanted . ctEvidence) wanteds let unit_wanteds = mapMaybe toNatEquality wanteds' case unit_wanteds of [] -> return (TcPluginOk [] []) _ -> do unit_givens <- mapMaybe toNatEquality <$> mapM zonkCt givens sr <- simplifyNats (unit_givens ++ unit_wanteds) tcPluginTrace "normalised" (ppr sr) case sr of Simplified _subst evs -> do let solved = filter (isWanted . ctEvidence . (\(_,x,_) -> x)) evs discharedWanteds <- tcPluginIO (readIORef discharged) let existingWanteds = wanteds' ++ discharedWanteds -- Create new wanted constraints (solved',newWanteds) <- (second concat . unzip . catMaybes) <$> mapM (evItemToCt existingWanteds) solved -- update set of discharged wanteds tcPluginIO (modifyIORef discharged (++ newWanteds)) -- return return (TcPluginOk solved' newWanteds) Impossible eq -> return (TcPluginContradiction [fromNatEquality eq]) evItemToCt :: [Ct] -- ^ Existing wanteds -> (EvTerm,Ct,CoreUnify CoreNote) -> TcPluginM (Maybe ((EvTerm,Ct),[Ct])) evItemToCt existingWanteds (ev,ct,subst) | null newWanteds = return (Just ((ev,ct),[])) | otherwise = do newWanteds' <- catMaybes <$> mapM (substItemToCt existingWanteds) newWanteds -- only allow new (conditional) evidence if conditional wanted constraints -- can be added as new work if length newWanteds == length newWanteds' then return (Just ((ev,ct),newWanteds')) else return Nothing where newWanteds = filter (isWanted . ctEvidence . snd . siNote) subst substItemToCt :: [Ct] -- ^ Existing wanteds wanted -> UnifyItem TyVar CType CoreNote -> TcPluginM (Maybe Ct) substItemToCt existingWanteds si | CType predicate `notElem` wantedPreds , CType predicateS `notElem` wantedPreds = return (Just (mkNonCanonical (CtWanted predicate (HoleDest ev) (ctLoc ct)))) | otherwise = return Nothing where predicate = unifyItemToPredType si (ty1,ty2) = getEqPredTys predicate predicateS = mkPrimEqPred ty2 ty1 ((ev,_,_),ct) = siNote si wantedPreds = map (CType . ctPred) existingWanteds unifyItemToPredType :: UnifyItem TyVar CType a -> PredType unifyItemToPredType ui = mkPrimEqPred ty1 ty2 where ty1 = case ui of SubstItem {..} -> mkTyVarTy siVar UnifyItem {..} -> reifySOP siLHS ty2 = case ui of SubstItem {..} -> reifySOP siSOP UnifyItem {..} -> reifySOP siRHS type NatEquality = (Ct,CoreSOP,CoreSOP) type NatInEquality = (Ct,CoreSOP) fromNatEquality :: Either NatEquality NatInEquality -> Ct fromNatEquality (Left (ct, _, _)) = ct fromNatEquality (Right (ct, _)) = ct type CoreNote = ((CoercionHole,TyVar,PredType), Ct) data SimplifyResult = Simplified (CoreUnify CoreNote) [(EvTerm,Ct,CoreUnify CoreNote)] | Impossible (Either NatEquality NatInEquality) instance Outputable SimplifyResult where ppr (Simplified subst evs) = text "Simplified" $$ ppr subst $$ ppr evs ppr (Impossible eq) = text "Impossible" <+> ppr eq simplifyNats :: [Either NatEquality NatInEquality] -> TcPluginM SimplifyResult simplifyNats eqs = tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] eqs where simples :: CoreUnify CoreNote -> [Maybe (EvTerm, Ct, CoreUnify CoreNote)] -> [Either NatEquality NatInEquality] -> [Either NatEquality NatInEquality] -> TcPluginM SimplifyResult simples subst evs _xs [] = return (Simplified subst (catMaybes evs)) simples subst evs xs (eq@(Left (ct,u,v)):eqs') = do ur <- unifyNats ct (substsSOP subst u) (substsSOP subst v) tcPluginTrace "unifyNats result" (ppr ur) case ur of Win -> simples subst (((,,) <$> evMagic ct [] <*> pure ct <*> pure []):evs) [] (xs ++ eqs') Lose -> return (Impossible eq) Draw [] -> simples subst evs (eq:xs) eqs' Draw subst' -> do newEvs <- mapM (\si -> (,,) <$> newCoercionHole <*> newFlexiTyVar typeNatKind <*> pure (unifyItemToPredType si)) subst' let subst'' = zipWith (\si ev -> si {siNote = (ev,siNote si)}) subst' newEvs simples (substsSubst subst'' subst ++ subst'') (((,,) <$> evMagic ct newEvs <*> pure ct <*> pure subst''):evs) [] (xs ++ eqs') simples subst evs xs (eq@(Right (ct,u)):eqs') = case isNatural u of Just True -> simples subst (((,,) <$> evMagic ct [] <*> pure ct <*> pure []):evs) xs eqs' Just False -> return (Impossible eq) Nothing -> simples subst evs (eq:xs) eqs' -- Extract the Nat equality constraints toNatEquality :: Ct -> Maybe (Either NatEquality NatInEquality) toNatEquality ct = case classifyPredType $ ctEvPred $ ctEvidence ct of EqPred NomEq t1 t2 -> go t1 t2 _ -> Nothing where go (TyConApp tc xs) (TyConApp tc' ys) | tc == tc' , null ([tc,tc'] `intersect` [typeNatAddTyCon,typeNatSubTyCon ,typeNatMulTyCon,typeNatExpTyCon]) = case filter (not . uncurry eqType) (zip xs ys) of [(x,y)] | isNatKind (typeKind x) && isNatKind (typeKind y) -> Just (Left (ct, normaliseNat x, normaliseNat y)) _ -> Nothing | tc == typeNatLeqTyCon , [x,y] <- xs = if tc' == promotedTrueDataCon then Just (Right (ct,normaliseNat (mkTyConApp typeNatSubTyCon [y,x]))) else if tc' == promotedFalseDataCon then Just (Right (ct,normaliseNat (mkTyConApp typeNatSubTyCon [x,mkTyConApp typeNatAddTyCon [y,mkNumLitTy 1]]))) else Nothing go x y | isNatKind (typeKind x) && isNatKind (typeKind y) = Just (Left (ct,normaliseNat x,normaliseNat y)) | otherwise = Nothing isNatKind :: Kind -> Bool isNatKind = (`eqType` typeNatKind) evMagic :: Ct -> [(CoercionHole, TyVar, PredType)] -> Maybe EvTerm evMagic ct evs = case classifyPredType $ ctEvPred $ ctEvidence ct of EqPred NomEq t1 t2 -> let ctEv = mkUnivCo (PluginProv "ghc-typelits-natnormalise") Nominal t1 t2 (holes,tvs,preds) = unzip3 evs holeEvs = zipWith (\h p -> uncurry (mkHoleCo h Nominal) (getEqPredTys p)) holes preds natReflCo = mkNomReflCo typeNatKind forallEv = mkForAllCos (map (,natReflCo) tvs) ctEv finalEv = foldl mkInstCo forallEv holeEvs in Just (EvCoercion finalEv) _ -> Nothing