{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.KnownNat.Solver
( plugin )
where
import Control.Arrow ((&&&), first)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Control.Monad.Trans.Writer.Strict
import Data.Maybe (catMaybes,mapMaybe)
import GHC.TcPluginM.Extra (lookupModule, lookupName, newWanted,
tracePlugin)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens, mkSubst', substType)
#endif
import GHC.TypeLits.Normalise.SOP (SOP (..), Product (..), Symbol (..))
import GHC.TypeLits.Normalise.Unify (CType (..),normaliseNat,reifySOP)
import Class (Class, classMethods, className, classTyCon)
#if MIN_VERSION_ghc(8,6,0)
import Coercion (Role (Representational), mkUnivCo)
#endif
import FamInst (tcInstNewTyCon_maybe)
import FastString (fsLit)
import Id (idType)
import InstEnv (instanceDFunId,lookupUniqueInstEnv)
#if MIN_VERSION_ghc(8,5,0)
import MkCore (mkNaturalExpr)
#endif
import Module (mkModuleName, moduleName, moduleNameString)
import Name (nameModule_maybe, nameOccName)
import OccName (mkTcOcc, occNameString)
import Plugins (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins (purePlugin)
#endif
import PrelNames (knownNatClassName)
#if MIN_VERSION_ghc(8,5,0)
import TcEvidence (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
#else
import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo)
#endif
#if MIN_VERSION_ghc(8,5,0)
import TcPluginM (unsafeTcPluginTcM)
#endif
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM (zonkCt)
#endif
import TcPluginM (TcPluginM, tcLookupClass, getInstEnvs)
import TcRnTypes (TcPlugin(..), TcPluginResult (..))
import TcTypeNats (typeNatAddTyCon, typeNatSubTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatDivTyCon)
#endif
import Type
(PredType,
dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe, typeKind)
import TyCon (tyConName)
import TyCoRep (Type (..), TyLit (..))
#if MIN_VERSION_ghc(8,6,0)
import TyCoRep (UnivCoProvenance (PluginProv))
import TysWiredIn (boolTy)
#endif
import Var (DFunId)
#if MIN_VERSION_ghc(8,10,0)
import Constraint
(Ct, ctEvExpr, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted,
mkNonCanonical, setCtLoc, setCtLocSpan)
import Predicate (EqRel (NomEq), Pred (ClassPred,EqPred), classifyPredType)
#else
import TcRnTypes
(Ct, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted, mkNonCanonical,
setCtLoc, setCtLocSpan)
import Type (EqRel (NomEq), PredTree (ClassPred,EqPred), classifyPredType)
#if MIN_VERSION_ghc(8,5,0)
import TcRnTypes (ctEvExpr)
#else
import TcRnTypes (ctEvTerm)
#endif
#endif
data KnownNatDefs
= KnownNatDefs
{ knownBool :: Class
, knownBoolNat2 :: Class
, knownNat2Bool :: Class
, knownNatN :: Int -> Maybe Class
}
type KnConstraint = (Ct
,Class
,Type
)
plugin :: Plugin
plugin
= defaultPlugin
{ tcPlugin = const $ Just normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
, pluginRecompile = purePlugin
#endif
}
normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-knownnat"
TcPlugin { tcPluginInit = lookupKnownNatDefs
, tcPluginSolve = solveKnownNat
, tcPluginStop = const (return ())
}
solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
-> TcPluginM TcPluginResult
solveKnownNat _defs _givens _deriveds [] = return (TcPluginOk [] [])
solveKnownNat defs givens _deriveds wanteds = do
-- GHC 7.10 puts deriveds with the wanteds, so filter them out
let wanteds' = filter (isWanted . ctEvidence) wanteds
#if MIN_VERSION_ghc(8,4,0)
subst = map fst
$ mkSubst' givens
kn_wanteds = map (\(x,y,z) -> (x,y,substType subst z))
$ mapMaybe (toKnConstraint defs) wanteds'
#else
kn_wanteds = mapMaybe (toKnConstraint defs) wanteds'
#endif
case kn_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
#if MIN_VERSION_ghc(8,4,0)
let given_map = map toGivenEntry (flattenGivens givens)
#else
given_map <- mapM (fmap toGivenEntry . zonkCt) givens
#endif
(solved,new) <- (unzip . catMaybes) <$> (mapM (constraintToEvTerm defs given_map) kn_wanteds)
return (TcPluginOk solved (concat new))
toKnConstraint :: KnownNatDefs -> Ct -> Maybe KnConstraint
toKnConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
ClassPred cls [ty]
| className cls == knownNatClassName ||
className cls == className (knownBool defs)
-> Just (ct,cls,ty)
_ -> Nothing
#if MIN_VERSION_ghc(8,5,0)
toGivenEntry :: Ct -> (CType,EvExpr)
#else
toGivenEntry :: Ct -> (CType,EvTerm)
#endif
toGivenEntry ct = let ct_ev = ctEvidence ct
c_ty = ctEvPred ct_ev
#if MIN_VERSION_ghc(8,5,0)
ev = ctEvExpr ct_ev
#else
ev = ctEvTerm ct_ev
#endif
in (CType c_ty,ev)
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs = do
md <- lookupModule myModule myPackage
kbC <- look md "KnownBool"
kbn2C <- look md "KnownBoolNat2"
kn2bC <- look md "KnownNat2Bool"
kn1C <- look md "KnownNat1"
kn2C <- look md "KnownNat2"
kn3C <- look md "KnownNat3"
return KnownNatDefs
{ knownBool = kbC
, knownBoolNat2 = kbn2C
, knownNat2Bool = kn2bC
, knownNatN = \case { 1 -> Just kn1C
; 2 -> Just kn2C
; 3 -> Just kn3C
; _ -> Nothing
}
}
where
look md s = do
nm <- lookupName md (mkTcOcc s)
tcLookupClass nm
myModule = mkModuleName "GHC.TypeLits.KnownNat"
myPackage = fsLit "ghc-typelits-knownnat"
constraintToEvTerm
:: KnownNatDefs
#if MIN_VERSION_ghc(8,5,0)
-> [(CType,EvExpr)]
#else
-> [(CType,EvTerm)]
#endif
-> KnConstraint
-> TcPluginM (Maybe ((EvTerm,Ct),[Ct]))
constraintToEvTerm defs givens (ct,cls,op) = do
offsetM <- offset op
evM <- case offsetM of
found@Just {} -> return found
_ -> go op
return ((first (,ct)) <$> evM)
where
go :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
go (go_other -> Just ev) = return (Just (ev,[]))
go ty@(TyConApp tc args0)
| let tcNm = tyConName tc
, Just m <- nameModule_maybe tcNm
= do
ienv <- getInstEnvs
let mS = moduleNameString (moduleName m)
tcS = occNameString (nameOccName tcNm)
fn0 = mS ++ "." ++ tcS
fn1 = mkStrLitTy (fsLit fn0)
args1 = fn1:args0
instM = case () of
() | Just knN_cls <- knownNatN defs (length args0)
, Right (inst, _) <- lookupUniqueInstEnv ienv knN_cls args1
-> Just (inst,knN_cls,args0,args1)
| length args0 == 2
, let knN_cls = knownBoolNat2 defs
ki = typeKind (head args0)
args1N = ki:args1
, Right (inst, _) <- lookupUniqueInstEnv ienv knN_cls args1N
-> Just (inst,knN_cls,args0,args1N)
| length args0 == 4
, fn0 == "Data.Type.Bool.If"
, let args0N = tail args0
args1N = head args0:fn1:tail args0
knN_cls = knownNat2Bool defs
, Right (inst, _) <- lookupUniqueInstEnv ienv knN_cls args1N
-> Just (inst,knN_cls,args0N,args1N)
| otherwise
-> Nothing
case instM of
Just (inst,knN_cls,args0N,args1N) -> do
let df_id = instanceDFunId inst
df = (knN_cls,df_id)
df_args = fst
. splitFunTys
. (`piResultTys` args0N)
$ idType df_id
(evs,new) <- unzip <$> mapM go_arg df_args
if className cls == className (knownBool defs)
then return ((,concat new) <$> makeOpDictByFiat df cls args1N args0N op evs)
else return ((,concat new) <$> makeOpDict df cls args1N args0N op evs)
_ -> return ((,[]) <$> go_other ty)
go (LitTy (NumTyLit i))
| LitTy _ <- op
= return Nothing
| otherwise
#if MIN_VERSION_ghc(8,5,0)
= (fmap (,[])) <$> makeLitDict cls op i
#else
= return ((,[]) <$> makeLitDict cls op i)
#endif
go _ = return Nothing
#if MIN_VERSION_ghc(8,5,0)
go_arg :: PredType -> TcPluginM (EvExpr,[Ct])
#else
go_arg :: PredType -> TcPluginM (EvTerm,[Ct])
#endif
go_arg ty = case lookup (CType ty) givens of
Just ev -> return (ev,[])
_ -> do
(ev,wanted) <- makeWantedEv ct ty
return (ev,[wanted])
go_other :: Type -> Maybe EvTerm
go_other ty =
let knClsTc = classTyCon cls
kn = mkTyConApp knClsTc [ty]
cast = if CType ty == CType op
#if MIN_VERSION_ghc(8,6,0)
then Just . EvExpr
#else
then Just
#endif
else makeKnCoercion cls ty op
in cast =<< lookup (CType kn) givens
offset :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
offset want = runMaybeT $ do
let
unKn ty' = case classifyPredType ty' of
ClassPred cls' [ty'']
| className cls' == knownNatClassName
-> Just ty''
_ -> Nothing
unEq ty' = case classifyPredType ty' of
EqPred NomEq ty1 ty2 -> Just (ty1,ty2)
_ -> Nothing
rewrites = mapMaybe (unEq . unCType . fst) givens
rewriteTy tyK (ty1,ty2) | ty1 `eqType` tyK = Just ty2
| ty2 `eqType` tyK = Just ty1
| otherwise = Nothing
knowns = mapMaybe (unKn . unCType . fst) givens
knownsR = catMaybes $ concatMap (\t -> map (rewriteTy t) rewrites) knowns
knownsX = knowns ++ knownsR
subWant = mkTyConApp typeNatSubTyCon . (:[want])
exploded = map (fst . runWriter . normaliseNat . subWant &&& id)
knownsX
examineDiff (S [P [I n]]) entire = Just (entire,I n)
examineDiff (S [P [V v]]) entire = Just (entire,V v)
examineDiff _ _ = Nothing
interesting = mapMaybe (uncurry examineDiff) exploded
((h,corr):_) <- pure interesting
x <- case corr of
I 0 -> pure h
I i | i < 0
-> pure (mkTyConApp typeNatAddTyCon [h,mkNumLitTy (negate i)])
| otherwise
-> pure (mkTyConApp typeNatSubTyCon [h,mkNumLitTy i])
c | CType (reifySOP (S [P [c]])) == CType want ->
#if MIN_VERSION_ghc(8,4,0)
pure (mkTyConApp typeNatDivTyCon [h,reifySOP (S [P [I 2]])])
#else
MaybeT (pure Nothing)
#endif
V v | all (not . eqType (TyVarTy v)) knownsX
-> MaybeT (pure Nothing)
_ -> pure (mkTyConApp typeNatSubTyCon [h,reifySOP (S [P [corr]])])
MaybeT (go x)
makeWantedEv
:: Ct
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> TcPluginM (EvExpr,Ct)
#else
-> TcPluginM (EvTerm,Ct)
#endif
makeWantedEv ct ty = do
wantedCtEv <- newWanted (ctLoc ct) ty
#if MIN_VERSION_ghc(8,5,0)
let ev = ctEvExpr wantedCtEv
#else
let ev = ctEvTerm wantedCtEv
#endif
wanted = mkNonCanonical wantedCtEv
ct_ls = ctLocSpan (ctLoc ct)
ctl = ctEvLoc wantedCtEv
wanted' = setCtLoc wanted (setCtLocSpan ctl ct_ls)
return (ev,wanted')
makeOpDict
:: (Class,DFunId)
-> Class
-> [Type]
-> [Type]
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> [EvExpr]
#else
-> [EvTerm]
#endif
-> Maybe EvTerm
makeOpDict (opCls,dfid) knCls tyArgsC tyArgsI z evArgs
| Just (_, kn_co_dict) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, Just (_, kn_co_rep) <- tcInstNewTyCon_maybe kn_tcRep [z]
, Just (_, op_co_dict) <- tcInstNewTyCon_maybe (classTyCon opCls) tyArgsC
, [ op_meth ] <- classMethods opCls
, Just (op_tcRep,op_args) <- splitTyConApp_maybe
$ funResultTy
$ (`piResultTys` tyArgsC)
$ idType op_meth
, Just (_, op_co_rep) <- tcInstNewTyCon_maybe op_tcRep op_args
#if MIN_VERSION_ghc(8,5,0)
, let EvExpr dfun_inst = evDFunApp dfid tyArgsI evArgs
#else
, let dfun_inst = EvDFunApp dfid tyArgsI evArgs
#endif
op_to_kn = mkTcTransCo (mkTcTransCo op_co_dict op_co_rep)
(mkTcSymCo (mkTcTransCo kn_co_dict kn_co_rep))
ev_tm = mkEvCast dfun_inst op_to_kn
= Just ev_tm
| otherwise
= Nothing
makeKnCoercion :: Class
-> Type
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> EvExpr
#else
-> EvTerm
#endif
-> Maybe EvTerm
makeKnCoercion knCls x z xEv
| Just (_, kn_co_dict_z) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, Just (_, kn_co_rep_z) <- tcInstNewTyCon_maybe kn_tcRep [z]
, Just (_, kn_co_rep_x) <- tcInstNewTyCon_maybe kn_tcRep [x]
, Just (_, kn_co_dict_x) <- tcInstNewTyCon_maybe (classTyCon knCls) [x]
= Just . mkEvCast xEv $ (kn_co_dict_x `mkTcTransCo` kn_co_rep_x) `mkTcTransCo` mkTcSymCo (kn_co_dict_z `mkTcTransCo` kn_co_rep_z)
| otherwise = Nothing
#if MIN_VERSION_ghc(8,5,0)
makeLitDict :: Class -> Type -> Integer -> TcPluginM (Maybe EvTerm)
#else
makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm
#endif
makeLitDict clas ty i
| Just (_, co_dict) <- tcInstNewTyCon_maybe (classTyCon clas) [ty]
, [ meth ] <- classMethods clas
, Just tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType meth
, Just (_, co_rep) <- tcInstNewTyCon_maybe tcRep [ty]
#if MIN_VERSION_ghc(8,5,0)
= do
et <- unsafeTcPluginTcM (mkNaturalExpr i)
let ev_tm = mkEvCast et (mkTcSymCo (mkTcTransCo co_dict co_rep))
return (Just ev_tm)
| otherwise
= return Nothing
#else
, let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep))
= Just ev_tm
| otherwise
= Nothing
#endif
makeOpDictByFiat
:: (Class,DFunId)
-> Class
-> [Type]
-> [Type]
-> Type
#if MIN_VERSION_ghc(8,6,0)
-> [EvExpr]
#else
-> [EvTerm]
#endif
-> Maybe EvTerm
#if MIN_VERSION_ghc(8,6,0)
makeOpDictByFiat (opCls,dfid) knCls tyArgsC tyArgsI z evArgs
| Just (_, kn_co_dict) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, let kn_co_rep = mkUnivCo (PluginProv "ghc-typelits-knownnat")
Representational
(mkTyConApp kn_tcRep [z]) boolTy
, Just (_, op_co_dict) <- tcInstNewTyCon_maybe (classTyCon opCls) tyArgsC
, [ op_meth ] <- classMethods opCls
, Just (op_tcRep,op_args) <- splitTyConApp_maybe
$ funResultTy
$ (`piResultTys` tyArgsC)
$ idType op_meth
, Just (_, op_co_rep) <- tcInstNewTyCon_maybe op_tcRep op_args
, EvExpr dfun_inst <- evDFunApp dfid tyArgsI evArgs
, let op_to_kn = mkTcTransCo (mkTcTransCo op_co_dict op_co_rep)
(mkTcSymCo (mkTcTransCo kn_co_dict kn_co_rep))
ev_tm = mkEvCast dfun_inst op_to_kn
= Just ev_tm
| otherwise
= Nothing
#else
makeOpDictByFiat _ _ _ _ _ _ = Nothing
#endif