{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.Normalise
( plugin )
where
import Control.Arrow (second)
import Control.Monad ((<=<), forM)
#if !MIN_VERSION_ghc(8,4,1)
import Control.Monad (replicateM)
#endif
import Control.Monad.Trans.Writer.Strict
import Data.Either (partitionEithers, rights)
import Data.List (intersect, partition, stripPrefix, find)
import Data.Maybe (mapMaybe, catMaybes)
import Data.Set (Set, empty, toList, notMember, fromList, union)
import GHC.TcPluginM.Extra (tracePlugin, newGiven, newWanted)
import qualified GHC.TcPluginM.Extra as TcPluginM
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens)
#endif
import Text.Read (readMaybe)
#if MIN_VERSION_ghc(8,5,0)
import CoreSyn (Expr (..))
#endif
import Outputable (Outputable (..), (<+>), ($$), text)
import Plugins (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins (purePlugin)
#endif
import PrelNames (knownNatClassName)
import TcEvidence (EvTerm (..))
#if MIN_VERSION_ghc(8,6,0)
import TcEvidence (evCast)
#endif
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM (zonkCt)
#endif
import TcPluginM (TcPluginM, tcPluginTrace, tcPluginIO)
import Type (Kind, PredType, eqType, mkTyVarTy)
import TysWiredIn (typeNatKind)
import Coercion (CoercionHole, Role (..), mkUnivCo)
import TcPluginM (newCoercionHole, tcLookupClass)
import TcRnTypes (TcPlugin (..), TcPluginResult(..))
import TyCoRep (UnivCoProvenance (..))
import TcType (isEqPred)
import TyCoRep (Type (..))
import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon,
typeNatSubTyCon)
import TcTypeNats (typeNatLeqTyCon)
import TysWiredIn (promotedFalseDataCon, promotedTrueDataCon)
import Data.IORef
#if MIN_VERSION_ghc(8,10,0)
import Constraint
(Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred,
ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan,
isWantedCt)
import Predicate
(EqRel (NomEq), Pred (EqPred), classifyPredType, getEqPredTys, mkClassPred,
mkPrimEqPred)
import Type (typeKind)
#else
import TcRnTypes
(Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred,
ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan,
isWantedCt)
import TcType (typeKind)
import Type
(EqRel (NomEq), PredTree (EqPred), classifyPredType, getEqPredTys, mkClassPred,
mkPrimEqPred)
#endif
#if MIN_VERSION_ghc(8,10,0)
import Constraint (ctEvExpr)
#elif MIN_VERSION_ghc(8,6,0)
import TcRnTypes (ctEvExpr)
#else
import TcRnTypes (ctEvTerm)
#endif
#if MIN_VERSION_ghc(8,2,0)
#if MIN_VERSION_ghc(8,10,0)
import Constraint (ShadowInfo (WDeriv))
#else
import TcRnTypes (ShadowInfo (WDeriv))
#endif
#endif
#if MIN_VERSION_ghc(8,10,0)
import TcType (isEqPrimPred)
#endif
import GHC.TypeLits.Normalise.SOP
import GHC.TypeLits.Normalise.Unify
#if !MIN_VERSION_ghc(8,10,0)
isEqPrimPred :: PredType -> Bool
isEqPrimPred = isEqPred
#endif
plugin :: Plugin
plugin
= defaultPlugin
{ tcPlugin = fmap (normalisePlugin . foldr id defaultOpts) . traverse parseArgument
#if MIN_VERSION_ghc(8,6,0)
, pluginRecompile = purePlugin
#endif
}
where
parseArgument "allow-negated-numbers" = Just (\ opts -> opts { negNumbers = True })
parseArgument (readMaybe <=< stripPrefix "depth=" -> Just depth) = Just (\ opts -> opts { depth })
parseArgument _ = Nothing
defaultOpts = Opts { negNumbers = False, depth = 5 }
data Opts = Opts { negNumbers :: Bool, depth :: Word }
normalisePlugin :: Opts -> TcPlugin
normalisePlugin opts = tracePlugin "ghc-typelits-natnormalise"
TcPlugin { tcPluginInit = tcPluginIO $ newIORef empty
, tcPluginSolve = decideEqualSOP opts
, tcPluginStop = const (return ())
}
newtype OrigCt = OrigCt { runOrigCt :: Ct }
decideEqualSOP
:: Opts
-> IORef (Set CType)
-- ^ Givens that is already generated.
-- We have to generate new givens at most once;
-- otherwise GHC will loop indefinitely.
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
-- Simplification phase: Derives /simplified/ givens;
-- we can reduce given constraints like @Show (Foo (n + 2))@
-- to its normal form @Show (Foo (2 + n))@, which is eventually
-- useful in solving phase.
--
-- This helps us to solve /indirect/ constraints;
-- without this phase, we cannot derive, e.g.,
-- @IsVector UVector (Fin (n + 1))@ from
-- @Unbox (1 + n)@!
decideEqualSOP opts gen'd givens _deriveds [] = do
done <- tcPluginIO $ readIORef gen'd
#if MIN_VERSION_ghc(8,4,0)
let simplGivens = flattenGivens givens
#else
simplGivens <- mapM zonkCt givens
#endif
let reds =
filter (\(_,(_,_,v)) -> null v || negNumbers opts) $
reduceGivens opts done simplGivens
newlyDone = map (\(_,(prd, _,_)) -> CType prd) reds
tcPluginIO $
modifyIORef' gen'd $ union (fromList newlyDone)
newGivens <- forM reds $ \(origCt, (pred', evTerm, _)) ->
mkNonCanonical' (ctLoc origCt) <$> newGiven (ctLoc origCt) pred' evTerm
return (TcPluginOk [] newGivens)
decideEqualSOP opts gen'd givens _deriveds wanteds = do
#if MIN_VERSION_ghc(8,4,0)
let simplGivens = givens ++ flattenGivens givens
subst = fst $ unzip $ TcPluginM.mkSubst' givens
wanteds0 = map (\ct -> (OrigCt ct,
TcPluginM.substCt subst ct
)
) wanteds
#else
let wanteds0 = map (\ct -> (OrigCt ct, ct)) wanteds
simplGivens <- mapM zonkCt givens
#endif
let wanteds' = filter (isWanted . ctEvidence) wanteds
unit_wanteds = mapMaybe toNatEquality wanteds'
nonEqs = filter (not . (\p -> isEqPred p || isEqPrimPred p) . ctEvPred . ctEvidence.snd)
$ filter (isWanted. ctEvidence.snd) wanteds0
done <- tcPluginIO $ readIORef gen'd
let redGs = reduceGivens opts done simplGivens
newlyDone = map (\(_,(prd, _,_)) -> CType prd) redGs
redGivens <- forM redGs $ \(origCt, (pred', evTerm, _)) ->
mkNonCanonical' (ctLoc origCt) <$> newGiven (ctLoc origCt) pred' evTerm
reducible_wanteds
<- catMaybes <$>
mapM
(\(origCt, ct) -> fmap (runOrigCt origCt,) <$>
reduceNatConstr (simplGivens ++ redGivens) ct
)
nonEqs
if null unit_wanteds && null reducible_wanteds
then return $ TcPluginOk [] []
else do
ineqForRedWants <- fmap concat $ forM redGs $ \(ct, (_,_, ws)) -> forM ws $
fmap (mkNonCanonical' (ctLoc ct)) . newWanted (ctLoc ct)
tcPluginIO $
modifyIORef' gen'd $ union (fromList newlyDone)
let unit_givens = mapMaybe toNatEquality simplGivens
sr <- simplifyNats opts unit_givens unit_wanteds
tcPluginTrace "normalised" (ppr sr)
reds <- forM reducible_wanteds $ \(origCt,(term, ws)) -> do
wants <- evSubtPreds origCt $ subToPred opts ws
return ((term, origCt), wants)
case sr of
Simplified evs -> do
let simpld = filter (isWanted . ctEvidence . (\((_,x),_) -> x)) evs
(solved',newWanteds) = second concat (unzip $ simpld ++ reds)
return (TcPluginOk solved' $ newWanteds ++ ineqForRedWants)
Impossible eq -> return (TcPluginContradiction [fromNatEquality eq])
type NatEquality = (Ct,CoreSOP,CoreSOP)
type NatInEquality = (Ct,(CoreSOP,CoreSOP,Bool))
reduceGivens :: Opts -> Set CType -> [Ct] -> [(Ct, (Type, EvTerm, [PredType]))]
reduceGivens opts done givens =
let nonEqs =
[ ct
| ct <- givens
, let ev = ctEvidence ct
prd = ctEvPred ev
, isGiven ev
, not $ (\p -> isEqPred p || isEqPrimPred p) prd
]
in filter
(\(_, (prd, _, _)) ->
notMember (CType prd) done
)
$ mapMaybe
(\ct -> (ct,) <$> tryReduceGiven opts givens ct)
nonEqs
tryReduceGiven
:: Opts -> [Ct] -> Ct
-> Maybe (PredType, EvTerm, [PredType])
tryReduceGiven opts simplGivens ct = do
let (mans, ws) =
runWriter $ normaliseNatEverywhere $
ctEvPred $ ctEvidence ct
ws' = [ p
| (p, _) <- subToPred opts ws
, all (not . (`eqType` p). ctEvPred . ctEvidence) simplGivens
]
pred' <- mans
return (pred', toReducedDict (ctEvidence ct) pred', ws')
fromNatEquality :: Either NatEquality NatInEquality -> Ct
fromNatEquality (Left (ct, _, _)) = ct
fromNatEquality (Right (ct, _)) = ct
reduceNatConstr :: [Ct] -> Ct -> TcPluginM (Maybe (EvTerm, [(Type, Type)]))
reduceNatConstr givens ct = do
let pred0 = ctEvPred $ ctEvidence ct
(mans, tests) = runWriter $ normaliseNatEverywhere pred0
case mans of
Nothing -> return Nothing
Just pred' -> do
case find ((`eqType` pred') .ctEvPred . ctEvidence) givens of
Nothing -> return Nothing
Just c -> return (Just (toReducedDict (ctEvidence c) pred0, tests))
toReducedDict :: CtEvidence -> PredType -> EvTerm
toReducedDict ct pred' =
let pred0 = ctEvPred ct
evCo = mkUnivCo (PluginProv "ghc-typelits-natnormalise")
Representational
pred0 pred'
#if MIN_VERSION_ghc(8,6,0)
ev = ctEvExpr ct
`evCast` evCo
#else
ev = ctEvTerm ct `EvCast` evCo
#endif
in ev
data SimplifyResult
= Simplified [((EvTerm,Ct),[Ct])]
| Impossible (Either NatEquality NatInEquality)
instance Outputable SimplifyResult where
ppr (Simplified evs) = text "Simplified" $$ ppr evs
ppr (Impossible eq) = text "Impossible" <+> ppr eq
simplifyNats
:: Opts
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> TcPluginM SimplifyResult
simplifyNats opts@Opts {..} eqsG eqsW = do
let eqsG1 = map (second (const ([] :: [(Type,Type)]))) eqsG
(varEqs,otherEqs) = partition isVarEqs eqsG1
fancyGivens = concatMap (makeGivensSet otherEqs) varEqs
case varEqs of
[] -> do
let eqs = otherEqs ++ eqsW
tcPluginTrace "simplifyNats" (ppr eqs)
simples [] [] [] [] eqs
_ -> do
tcPluginTrace ("simplifyNats(backtrack: " ++ show (length fancyGivens) ++ ")")
(ppr varEqs)
foldr findFirstSimpliedWanted (Simplified []) <$>
mapM (\v -> do let eqs = v ++ eqsW
tcPluginTrace "simplifyNats" (ppr eqs)
simples [] [] [] [] eqs)
fancyGivens
where
simples :: [CoreUnify]
-> [((EvTerm, Ct), [Ct])]
-> [(CoreSOP,CoreSOP,Bool)]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> TcPluginM SimplifyResult
simples _subst evs _leqsG _xs [] = return (Simplified evs)
simples subst evs leqsG xs (eq@(Left (ct,u,v),k):eqs') = do
let u' = substsSOP subst u
v' = substsSOP subst v
ur <- unifyNats ct u' v'
tcPluginTrace "unifyNats result" (ppr ur)
case ur of
Win -> do
evs' <- maybe evs (:evs) <$> evMagic ct empty (subToPred opts k)
simples subst evs' leqsG [] (xs ++ eqs')
Lose -> if null evs && null eqs'
then return (Impossible (fst eq))
else simples subst evs leqsG xs eqs'
Draw [] -> simples subst evs [] (eq:xs) eqs'
Draw subst' -> do
evM <- evMagic ct empty (map unifyItemToPredType subst' ++
subToPred opts k)
let leqsG' | isGiven (ctEvidence ct) = eqToLeq u' v' ++ leqsG
| otherwise = leqsG
case evM of
Nothing -> simples subst evs leqsG' xs eqs'
Just ev ->
simples (substsSubst subst' subst ++ subst')
(ev:evs) leqsG' [] (xs ++ eqs')
simples subst evs leqsG xs (eq@(Right (ct,u@(x,y,b)),k):eqs') = do
let u' = substsSOP subst (subtractIneq u)
x' = substsSOP subst x
y' = substsSOP subst y
uS = (x',y',b)
leqsG' | isGiven (ctEvidence ct) = (x',y',b):leqsG
| otherwise = leqsG
ineqs = concat [ leqsG
, map (substLeq subst) leqsG
, map snd (rights (map fst eqsG))
]
tcPluginTrace "unifyNats(ineq) results" (ppr (ct,u,u',ineqs))
case runWriterT (isNatural u') of
Just (True,knW) -> do
evs' <- maybe evs (:evs) <$> evMagic ct knW (subToPred opts k)
simples subst evs' leqsG' xs eqs'
Just (False,_) | null k -> return (Impossible (fst eq))
_ -> do
let solvedIneq = mapMaybe runWriterT
(instantSolveIneq depth u:
map (solveIneq depth u) ineqs ++
map (solveIneq depth uS) ineqs)
smallest = solvedInEqSmallestConstraint solvedIneq
case smallest of
(True,kW) -> do
evs' <- maybe evs (:evs) <$> evMagic ct kW (subToPred opts k)
simples subst evs' leqsG' xs eqs'
_ -> simples subst evs leqsG (eq:xs) eqs'
eqToLeq x y = [(x,y,True),(y,x,True)]
substLeq s (x,y,b) = (substsSOP s x, substsSOP s y, b)
isVarEqs (Left (_,S [P [V _]], S [P [V _]]), _) = True
isVarEqs _ = False
makeGivensSet otherEqs varEq
= let (noMentionsV,mentionsV) = partitionEithers
(map (matchesVarEq varEq) otherEqs)
(mentionsLHS,mentionsRHS) = partitionEithers mentionsV
vS = swapVar varEq
givensLHS = case mentionsLHS of
[] -> []
_ -> [mentionsLHS ++ ((varEq:mentionsRHS) ++ noMentionsV)]
givensRHS = case mentionsRHS of
[] -> []
_ -> [mentionsRHS ++ (vS:mentionsLHS ++ noMentionsV)]
in case mentionsV of
[] -> [noMentionsV]
_ -> givensLHS ++ givensRHS
matchesVarEq (Left (_, S [P [V v1]], S [P [V v2]]),_) r = case r of
(Left (_,S [P [V v3]],_),_)
| v1 == v3 -> Right (Left r)
| v2 == v3 -> Right (Right r)
(Left (_,_,S [P [V v3]]),_)
| v1 == v3 -> Right (Left r)
| v2 == v3 -> Right (Right r)
(Right (_,(S [P [V v3]],_,_)),_)
| v1 == v3 -> Right (Left r)
| v2 == v3 -> Right (Right r)
(Right (_,(_,S [P [V v3]],_)),_)
| v1 == v3 -> Right (Left r)
| v2 == v3 -> Right (Right r)
_ -> Left r
matchesVarEq _ _ = error "internal error"
swapVar (Left (ct,S [P [V v1]], S [P [V v2]]),ps) =
(Left (ct,S [P [V v2]], S [P [V v1]]),ps)
swapVar _ = error "internal error"
findFirstSimpliedWanted (Impossible e) _ = Impossible e
findFirstSimpliedWanted (Simplified evs) s2
| any (isWantedCt . snd . fst) evs
= Simplified evs
| otherwise
= s2
subToPred :: Opts -> [(Type, Type)] -> [(PredType, Kind)]
subToPred Opts{..}
| negNumbers = const []
| otherwise = map subtractionToPred
toNatEquality :: Ct -> Maybe (Either NatEquality NatInEquality,[(Type,Type)])
toNatEquality ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
-> go t1 t2
_ -> Nothing
where
go (TyConApp tc xs) (TyConApp tc' ys)
| tc == tc'
, null ([tc,tc'] `intersect` [typeNatAddTyCon,typeNatSubTyCon
,typeNatMulTyCon,typeNatExpTyCon])
= case filter (not . uncurry eqType) (zip xs ys) of
[(x,y)]
| isNatKind (typeKind x)
, isNatKind (typeKind y)
, let (x',k1) = runWriter (normaliseNat x)
, let (y',k2) = runWriter (normaliseNat y)
-> Just (Left (ct, x', y'),k1 ++ k2)
_ -> Nothing
| tc == typeNatLeqTyCon
, [x,y] <- xs
, let (x',k1) = runWriter (normaliseNat x)
, let (y',k2) = runWriter (normaliseNat y)
, let ks = k1 ++ k2
= case tc' of
_ | tc' == promotedTrueDataCon
-> Just (Right (ct, (x', y', True)), ks)
_ | tc' == promotedFalseDataCon
-> Just (Right (ct, (x', y', False)), ks)
_ -> Nothing
go x y
| isNatKind (typeKind x)
, isNatKind (typeKind y)
, let (x',k1) = runWriter (normaliseNat x)
, let (y',k2) = runWriter (normaliseNat y)
= Just (Left (ct,x',y'),k1 ++ k2)
| otherwise
= Nothing
isNatKind :: Kind -> Bool
isNatKind = (`eqType` typeNatKind)
unifyItemToPredType :: CoreUnify -> (PredType,Kind)
unifyItemToPredType ui =
(mkPrimEqPred ty1 ty2,typeNatKind)
where
ty1 = case ui of
SubstItem {..} -> mkTyVarTy siVar
UnifyItem {..} -> reifySOP siLHS
ty2 = case ui of
SubstItem {..} -> reifySOP siSOP
UnifyItem {..} -> reifySOP siRHS
evSubtPreds :: Ct -> [(PredType,Kind)] -> TcPluginM [Ct]
evSubtPreds ct preds = do
let predTypes = map fst preds
#if MIN_VERSION_ghc(8,4,1)
holes <- mapM (newCoercionHole . uncurry mkPrimEqPred . getEqPredTys) predTypes
#else
holes <- replicateM (length preds) newCoercionHole
#endif
return (zipWith (unifyItemToCt (ctLoc ct)) predTypes holes)
evMagic :: Ct -> Set CType -> [(PredType,Kind)] -> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
evMagic ct knW preds = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2 -> do
holeWanteds <- evSubtPreds ct preds
knWanted <- mapM (mkKnWanted ct) (toList knW)
let newWant = knWanted ++ holeWanteds
ctEv = mkUnivCo (PluginProv "ghc-typelits-natnormalise") Nominal t1 t2
#if MIN_VERSION_ghc(8,5,0)
return (Just ((EvExpr (Coercion ctEv), ct),newWant))
#else
return (Just ((EvCoercion ctEv, ct),newWant))
#endif
_ -> return Nothing
mkNonCanonical' :: CtLoc -> CtEvidence -> Ct
mkNonCanonical' origCtl ev =
let ct_ls = ctLocSpan origCtl
ctl = ctEvLoc ev
in setCtLoc (mkNonCanonical ev) (setCtLocSpan ctl ct_ls)
mkKnWanted
:: Ct
-> CType
-> TcPluginM Ct
mkKnWanted ct (CType ty) = do
kc_clas <- tcLookupClass knownNatClassName
let kn_pred = mkClassPred kc_clas [ty]
wantedCtEv <- TcPluginM.newWanted (ctLoc ct) kn_pred
let wanted' = mkNonCanonical' (ctLoc ct) wantedCtEv
return wanted'
unifyItemToCt :: CtLoc
-> PredType
-> CoercionHole
-> Ct
unifyItemToCt loc pred_type hole =
mkNonCanonical
(CtWanted
pred_type
(HoleDest hole)
#if MIN_VERSION_ghc(8,2,0)
WDeriv
#endif
loc)