{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.Normalise
( plugin )
where
import Control.Arrow (second)
#if !MIN_VERSION_ghc(8,4,1)
import Control.Monad (replicateM)
#endif
import Control.Monad.Trans.Writer.Strict
import Data.Either (rights)
import Data.List (intersect)
import Data.Maybe (mapMaybe)
import GHC.TcPluginM.Extra (tracePlugin)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens)
#endif
#if MIN_VERSION_ghc(8,5,0)
import CoreSyn (Expr (..))
#endif
import Outputable (Outputable (..), (<+>), ($$), text)
import Plugins (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins (purePlugin)
#endif
import TcEvidence (EvTerm (..))
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM (zonkCt)
#endif
import TcPluginM (TcPluginM, tcPluginTrace)
import TcRnTypes (Ct, TcPlugin (..), TcPluginResult(..), ctEvidence, ctEvPred,
isWanted, mkNonCanonical)
import Type (EqRel (NomEq), Kind, PredTree (EqPred), PredType,
classifyPredType, eqType, getEqPredTys, mkTyVarTy)
import TysWiredIn (typeNatKind)
import Coercion (CoercionHole, Role (..), mkForAllCos, mkHoleCo, mkInstCo,
mkNomReflCo, mkUnivCo)
import TcPluginM (newCoercionHole, newFlexiTyVar)
import TcRnTypes (CtEvidence (..), CtLoc, TcEvDest (..), ctLoc, isGiven)
#if MIN_VERSION_ghc(8,2,0)
import TcRnTypes (ShadowInfo (WDeriv))
#endif
import TyCoRep (UnivCoProvenance (..))
import Type (mkPrimEqPred)
import TcType (typeKind)
import TyCoRep (Type (..))
import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon,
typeNatSubTyCon)
import TcTypeNats (typeNatLeqTyCon)
import TysWiredIn (promotedFalseDataCon, promotedTrueDataCon)
import GHC.TypeLits.Normalise.Unify
plugin :: Plugin
plugin
= defaultPlugin
{ tcPlugin = go
#if MIN_VERSION_ghc(8,6,0)
, pluginRecompile = purePlugin
#endif
}
where
go ["allow-negated-numbers"] = Just (normalisePlugin True)
go _ = Just (normalisePlugin False)
normalisePlugin :: Bool -> TcPlugin
normalisePlugin negNumbers = tracePlugin "ghc-typelits-natnormalise"
TcPlugin { tcPluginInit = return ()
, tcPluginSolve = const (decideEqualSOP negNumbers)
, tcPluginStop = const (return ())
}
decideEqualSOP
:: Bool
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
decideEqualSOP _negNumbers _givens _deriveds [] = return (TcPluginOk [] [])
decideEqualSOP negNumbers givens _deriveds wanteds = do
-- GHC 7.10.1 puts deriveds with the wanteds, so filter them out
let wanteds' = filter (isWanted . ctEvidence) wanteds
let unit_wanteds = mapMaybe toNatEquality wanteds'
case unit_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
#if MIN_VERSION_ghc(8,4,0)
let unit_givens = mapMaybe toNatEquality (givens ++ flattenGivens givens)
#else
unit_givens <- mapMaybe toNatEquality <$> mapM zonkCt givens
#endif
sr <- simplifyNats negNumbers unit_givens unit_wanteds
tcPluginTrace "normalised" (ppr sr)
case sr of
Simplified evs -> do
let solved = filter (isWanted . ctEvidence . (\((_,x),_) -> x)) evs
(solved',newWanteds) = second concat (unzip solved)
return (TcPluginOk solved' newWanteds)
Impossible eq -> return (TcPluginContradiction [fromNatEquality eq])
type NatEquality = (Ct,CoreSOP,CoreSOP)
type NatInEquality = (Ct,(CoreSOP,CoreSOP,Bool))
fromNatEquality :: Either NatEquality NatInEquality -> Ct
fromNatEquality (Left (ct, _, _)) = ct
fromNatEquality (Right (ct, _)) = ct
data SimplifyResult
= Simplified [((EvTerm,Ct),[Ct])]
| Impossible (Either NatEquality NatInEquality)
instance Outputable SimplifyResult where
ppr (Simplified evs) = text "Simplified" $$ ppr evs
ppr (Impossible eq) = text "Impossible" <+> ppr eq
simplifyNats
:: Bool
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> TcPluginM SimplifyResult
simplifyNats negNumbers eqsG eqsW =
let eqs = map (second (const [])) eqsG ++ eqsW
in tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] [] eqs
where
subToPred | negNumbers = const []
| otherwise = map subtractionToPred
simples :: [CoreUnify]
-> [((EvTerm, Ct), [Ct])]
-> [(CoreSOP,CoreSOP,Bool)]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> TcPluginM SimplifyResult
simples _subst evs _leqsG _xs [] = return (Simplified evs)
simples subst evs leqsG xs (eq@(Left (ct,u,v),k):eqs') = do
let u' = substsSOP subst u
v' = substsSOP subst v
ur <- unifyNats ct u' v'
tcPluginTrace "unifyNats result" (ppr ur)
case ur of
Win -> do
evs' <- maybe evs (:evs) <$> evMagic ct (subToPred k)
simples subst evs' leqsG [] (xs ++ eqs')
Lose -> return (Impossible (fst eq))
Draw [] -> simples subst evs [] (eq:xs) eqs'
Draw subst' -> do
evM <- evMagic ct (map unifyItemToPredType subst' ++
subToPred k)
let leqsG' | isGiven (ctEvidence ct) = eqToLeq u' v' ++ leqsG
| otherwise = leqsG
case evM of
Nothing -> simples subst evs leqsG' xs eqs'
Just ev ->
simples (substsSubst subst' subst ++ subst')
(ev:evs) leqsG' [] (xs ++ eqs')
simples subst evs leqsG xs (eq@(Right (ct,u@(x,y,b)),k):eqs') = do
let u' = substsSOP subst (subtractIneq u)
x' = substsSOP subst x
y' = substsSOP subst y
leqsG' | isGiven (ctEvidence ct) = (x',y',b):leqsG
| otherwise = leqsG
ineqs = concat [ leqsG
, map (substLeq subst) leqsG
, map snd (rights (map fst eqsG))
]
tcPluginTrace "unifyNats(ineq) results" (ppr (ct,u,u',ineqs))
case isNatural u' of
Just True -> do
evs' <- maybe evs (:evs) <$> evMagic ct (subToPred k)
simples subst evs' leqsG' xs eqs'
Just False -> return (Impossible (fst eq))
Nothing
| or (mapMaybe (solveIneq 5 u) ineqs)
-> do
evs' <- maybe evs (:evs) <$> evMagic ct (subToPred k)
simples subst evs' leqsG' xs eqs'
| otherwise
-> simples subst evs leqsG (eq:xs) eqs'
eqToLeq x y = [(x,y,True),(y,x,True)]
substLeq s (x,y,b) = (substsSOP s x, substsSOP s y, b)
toNatEquality :: Ct -> Maybe (Either NatEquality NatInEquality,[(Type,Type)])
toNatEquality ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
-> go t1 t2
_ -> Nothing
where
go (TyConApp tc xs) (TyConApp tc' ys)
| tc == tc'
, null ([tc,tc'] `intersect` [typeNatAddTyCon,typeNatSubTyCon
,typeNatMulTyCon,typeNatExpTyCon])
= case filter (not . uncurry eqType) (zip xs ys) of
[(x,y)]
| isNatKind (typeKind x)
, isNatKind (typeKind y)
, let (x',k1) = runWriter (normaliseNat x)
, let (y',k2) = runWriter (normaliseNat y)
-> Just (Left (ct, x', y'),k1 ++ k2)
_ -> Nothing
| tc == typeNatLeqTyCon
, [x,y] <- xs
, let (x',k1) = runWriter (normaliseNat x)
, let (y',k2) = runWriter (normaliseNat y)
, let ks = k1 ++ k2
= case tc' of
_ | tc' == promotedTrueDataCon
-> Just (Right (ct, (x', y', True)), ks)
_ | tc' == promotedFalseDataCon
-> Just (Right (ct, (x', y', False)), ks)
_ -> Nothing
go x y
| isNatKind (typeKind x)
, isNatKind (typeKind y)
, let (x',k1) = runWriter (normaliseNat x)
, let (y',k2) = runWriter (normaliseNat y)
= Just (Left (ct,x',y'),k1 ++ k2)
| otherwise
= Nothing
isNatKind :: Kind -> Bool
isNatKind = (`eqType` typeNatKind)
unifyItemToPredType :: CoreUnify -> (PredType,Kind)
unifyItemToPredType ui =
(mkPrimEqPred ty1 ty2,typeNatKind)
where
ty1 = case ui of
SubstItem {..} -> mkTyVarTy siVar
UnifyItem {..} -> reifySOP siLHS
ty2 = case ui of
SubstItem {..} -> reifySOP siSOP
UnifyItem {..} -> reifySOP siRHS
evMagic :: Ct -> [(PredType,Kind)] -> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
evMagic ct preds = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2 -> do
let predTypes = map fst preds
predKinds = map snd preds
#if MIN_VERSION_ghc(8,4,1)
holes <- mapM (newCoercionHole . uncurry mkPrimEqPred . getEqPredTys) predTypes
#else
holes <- replicateM (length preds) newCoercionHole
#endif
let newWanted = zipWith (unifyItemToCt (ctLoc ct)) predTypes holes
ctEv = mkUnivCo (PluginProv "ghc-typelits-natnormalise") Nominal t1 t2
#if MIN_VERSION_ghc(8,4,1)
holeEvs = map mkHoleCo holes
#else
holeEvs = zipWith (\h p -> uncurry (mkHoleCo h Nominal) (getEqPredTys p)) holes predTypes
#endif
forallEv <- mkForAllCos <$> (mapM mkCoVar predKinds) <*> pure ctEv
let finalEv = foldl mkInstCo forallEv holeEvs
#if MIN_VERSION_ghc(8,5,0)
return (Just ((EvExpr (Coercion finalEv), ct),newWanted))
#else
return (Just ((EvCoercion finalEv, ct),newWanted))
#endif
_ -> return Nothing
where
mkCoVar k = (,natReflCo) <$> (newFlexiTyVar k)
where
natReflCo = mkNomReflCo k
unifyItemToCt :: CtLoc
-> PredType
-> CoercionHole
-> Ct
unifyItemToCt loc pred_type hole =
mkNonCanonical
(CtWanted
pred_type
(HoleDest hole)
#if MIN_VERSION_ghc(8,2,0)
WDeriv
#endif
loc)