{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.KnownNat.Solver
( plugin )
import Control.Arrow ((&&&), first)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Control.Monad.Trans.Writer.Strict
import Data.Maybe (catMaybes,mapMaybe)
import GHC.TcPluginM.Extra (lookupModule, lookupName, newWanted,
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens, mkSubst', substType)
import GHC.TypeLits.Normalise.SOP (SOP (..), Product (..), Symbol (..))
import GHC.TypeLits.Normalise.Unify (CType (..),normaliseNat,reifySOP)
import Class (Class, classMethods, className, classTyCon)
#if MIN_VERSION_ghc(8,6,0)
import Coercion (Role (Representational), mkUnivCo)
import FamInst (tcInstNewTyCon_maybe)
import FastString (fsLit)
import Id (idType)
import InstEnv (instanceDFunId,lookupUniqueInstEnv)
#if MIN_VERSION_ghc(8,5,0)
import MkCore (mkNaturalExpr)
import Module (mkModuleName, moduleName, moduleNameString)
import Name (nameModule_maybe, nameOccName)
import OccName (mkTcOcc, occNameString)
import Plugins (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins (purePlugin)
import PrelNames (knownNatClassName)
#if MIN_VERSION_ghc(8,5,0)
import TcEvidence (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo)
#if MIN_VERSION_ghc(8,5,0)
import TcPluginM (unsafeTcPluginTcM)
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM (zonkCt)
import TcPluginM (TcPluginM, tcLookupClass, getInstEnvs)
import TcRnTypes (TcPlugin(..), TcPluginResult (..))
import TcTypeNats (typeNatAddTyCon, typeNatSubTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatDivTyCon)
import Type
dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe, typeKind)
import TyCon (tyConName)
import TyCoRep (Type (..), TyLit (..))
#if MIN_VERSION_ghc(8,6,0)
import TyCoRep (UnivCoProvenance (PluginProv))
import TysWiredIn (boolTy)
import Var (DFunId)
#if MIN_VERSION_ghc(8,10,0)
import Constraint
(Ct, ctEvExpr, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted,
mkNonCanonical, setCtLoc, setCtLocSpan)
import Predicate (EqRel (NomEq), Pred (ClassPred,EqPred), classifyPredType)
import TcRnTypes
(Ct, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted, mkNonCanonical,
setCtLoc, setCtLocSpan)
import Type (EqRel (NomEq), PredTree (ClassPred,EqPred), classifyPredType)
#if MIN_VERSION_ghc(8,5,0)
import TcRnTypes (ctEvExpr)
import TcRnTypes (ctEvTerm)
data KnownNatDefs
= KnownNatDefs
{ knownBool :: Class
, knownBoolNat2 :: Class
, knownNat2Bool :: Class
, knownNatN :: Int -> Maybe Class
type KnConstraint = (Ct
plugin :: Plugin
= defaultPlugin
{ tcPlugin = const $ Just normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
, pluginRecompile = purePlugin
normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-knownnat"
TcPlugin { tcPluginInit = lookupKnownNatDefs
, tcPluginSolve = solveKnownNat
, tcPluginStop = const (return ())
solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
-> TcPluginM TcPluginResult
solveKnownNat _defs _givens _deriveds [] = return (TcPluginOk [] [])
solveKnownNat defs givens _deriveds wanteds = do
-- GHC 7.10 puts deriveds with the wanteds, so filter them out
let wanteds' = filter (isWanted . ctEvidence) wanteds
#if MIN_VERSION_ghc(8,4,0)
subst = map fst
$ mkSubst' givens
kn_wanteds = map (\(x,y,z) -> (x,y,substType subst z))
$ mapMaybe (toKnConstraint defs) wanteds'
kn_wanteds = mapMaybe (toKnConstraint defs) wanteds'
case kn_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
#if MIN_VERSION_ghc(8,4,0)
let given_map = map toGivenEntry (flattenGivens givens)
given_map <- mapM (fmap toGivenEntry . zonkCt) givens
(solved,new) <- (unzip . catMaybes) <$> (mapM (constraintToEvTerm defs given_map) kn_wanteds)
return (TcPluginOk solved (concat new))
toKnConstraint :: KnownNatDefs -> Ct -> Maybe KnConstraint
toKnConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
ClassPred cls [ty]
| className cls == knownNatClassName ||
className cls == className (knownBool defs)
-> Just (ct,cls,ty)
_ -> Nothing
#if MIN_VERSION_ghc(8,5,0)
toGivenEntry :: Ct -> (CType,EvExpr)
toGivenEntry :: Ct -> (CType,EvTerm)
toGivenEntry ct = let ct_ev = ctEvidence ct
c_ty = ctEvPred ct_ev
#if MIN_VERSION_ghc(8,5,0)
ev = ctEvExpr ct_ev
ev = ctEvTerm ct_ev
in (CType c_ty,ev)
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs = do
md <- lookupModule myModule myPackage
kbC <- look md "KnownBool"
kbn2C <- look md "KnownBoolNat2"
kn2bC <- look md "KnownNat2Bool"
kn1C <- look md "KnownNat1"
kn2C <- look md "KnownNat2"
kn3C <- look md "KnownNat3"
return KnownNatDefs
{ knownBool = kbC
, knownBoolNat2 = kbn2C
, knownNat2Bool = kn2bC
, knownNatN = \case { 1 -> Just kn1C
; 2 -> Just kn2C
; 3 -> Just kn3C
; _ -> Nothing
look md s = do
nm <- lookupName md (mkTcOcc s)
tcLookupClass nm
myModule = mkModuleName "GHC.TypeLits.KnownNat"
myPackage = fsLit "ghc-typelits-knownnat"
:: KnownNatDefs
#if MIN_VERSION_ghc(8,5,0)
-> [(CType,EvExpr)]
-> [(CType,EvTerm)]
-> KnConstraint
-> TcPluginM (Maybe ((EvTerm,Ct),[Ct]))
constraintToEvTerm defs givens (ct,cls,op) = do
offsetM <- offset op
evM <- case offsetM of
found@Just {} -> return found
_ -> go op
return ((first (,ct)) <$> evM)
go :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
go (go_other -> Just ev) = return (Just (ev,[]))
go ty@(TyConApp tc args0)
| let tcNm = tyConName tc
, Just m <- nameModule_maybe tcNm
= do
ienv <- getInstEnvs
let mS = moduleNameString (moduleName m)
tcS = occNameString (nameOccName tcNm)
fn0 = mS ++ "." ++ tcS
fn1 = mkStrLitTy (fsLit fn0)
args1 = fn1:args0
instM = case () of
() | Just knN_cls <- knownNatN defs (length args0)
, Right (inst, _) <- lookupUniqueInstEnv ienv knN_cls args1
-> Just (inst,knN_cls,args0,args1)
| length args0 == 2
, let knN_cls = knownBoolNat2 defs
ki = typeKind (head args0)
args1N = ki:args1
, Right (inst, _) <- lookupUniqueInstEnv ienv knN_cls args1N
-> Just (inst,knN_cls,args0,args1N)
| length args0 == 4
, fn0 == "Data.Type.Bool.If"
, let args0N = tail args0
args1N = head args0:fn1:tail args0
knN_cls = knownNat2Bool defs
, Right (inst, _) <- lookupUniqueInstEnv ienv knN_cls args1N
-> Just (inst,knN_cls,args0N,args1N)
| otherwise
-> Nothing
case instM of
Just (inst,knN_cls,args0N,args1N) -> do
let df_id = instanceDFunId inst
df = (knN_cls,df_id)
df_args = fst
. splitFunTys
. (`piResultTys` args0N)
$ idType df_id
(evs,new) <- unzip <$> mapM go_arg df_args
if className cls == className (knownBool defs)
then return ((,concat new) <$> makeOpDictByFiat df cls args1N args0N op evs)
else return ((,concat new) <$> makeOpDict df cls args1N args0N op evs)
_ -> return ((,[]) <$> go_other ty)
go (LitTy (NumTyLit i))
| LitTy _ <- op
= return Nothing
| otherwise
#if MIN_VERSION_ghc(8,5,0)
= (fmap (,[])) <$> makeLitDict cls op i
= return ((,[]) <$> makeLitDict cls op i)
go _ = return Nothing
#if MIN_VERSION_ghc(8,5,0)
go_arg :: PredType -> TcPluginM (EvExpr,[Ct])
go_arg :: PredType -> TcPluginM (EvTerm,[Ct])
go_arg ty = case lookup (CType ty) givens of
Just ev -> return (ev,[])
_ -> do
(ev,wanted) <- makeWantedEv ct ty
return (ev,[wanted])
go_other :: Type -> Maybe EvTerm
go_other ty =
let knClsTc = classTyCon cls
kn = mkTyConApp knClsTc [ty]
cast = if CType ty == CType op
#if MIN_VERSION_ghc(8,6,0)
then Just . EvExpr
then Just
else makeKnCoercion cls ty op
in cast =<< lookup (CType kn) givens
offset :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
offset want = runMaybeT $ do
unKn ty' = case classifyPredType ty' of
ClassPred cls' [ty'']
| className cls' == knownNatClassName
-> Just ty''
_ -> Nothing
unEq ty' = case classifyPredType ty' of
EqPred NomEq ty1 ty2 -> Just (ty1,ty2)
_ -> Nothing
rewrites = mapMaybe (unEq . unCType . fst) givens
rewriteTy tyK (ty1,ty2) | ty1 `eqType` tyK = Just ty2
| ty2 `eqType` tyK = Just ty1
| otherwise = Nothing
knowns = mapMaybe (unKn . unCType . fst) givens
knownsR = catMaybes $ concatMap (\t -> map (rewriteTy t) rewrites) knowns
knownsX = knowns ++ knownsR
subWant = mkTyConApp typeNatSubTyCon . (:[want])
exploded = map (fst . runWriter . normaliseNat . subWant &&& id)
examineDiff (S [P [I n]]) entire = Just (entire,I n)
examineDiff (S [P [V v]]) entire = Just (entire,V v)
examineDiff _ _ = Nothing
interesting = mapMaybe (uncurry examineDiff) exploded
((h,corr):_) <- pure interesting
x <- case corr of
I 0 -> pure h
I i | i < 0
-> pure (mkTyConApp typeNatAddTyCon [h,mkNumLitTy (negate i)])
| otherwise
-> pure (mkTyConApp typeNatSubTyCon [h,mkNumLitTy i])
c | CType (reifySOP (S [P [c]])) == CType want ->
#if MIN_VERSION_ghc(8,4,0)
pure (mkTyConApp typeNatDivTyCon [h,reifySOP (S [P [I 2]])])
MaybeT (pure Nothing)
V v | all (not . eqType (TyVarTy v)) knownsX
-> MaybeT (pure Nothing)
_ -> pure (mkTyConApp typeNatSubTyCon [h,reifySOP (S [P [corr]])])
MaybeT (go x)
:: Ct
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> TcPluginM (EvExpr,Ct)
-> TcPluginM (EvTerm,Ct)
makeWantedEv ct ty = do
wantedCtEv <- newWanted (ctLoc ct) ty
#if MIN_VERSION_ghc(8,5,0)
let ev = ctEvExpr wantedCtEv
let ev = ctEvTerm wantedCtEv
wanted = mkNonCanonical wantedCtEv
ct_ls = ctLocSpan (ctLoc ct)
ctl = ctEvLoc wantedCtEv
wanted' = setCtLoc wanted (setCtLocSpan ctl ct_ls)
return (ev,wanted')
:: (Class,DFunId)
-> Class
-> [Type]
-> [Type]
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> [EvExpr]
-> [EvTerm]
-> Maybe EvTerm
makeOpDict (opCls,dfid) knCls tyArgsC tyArgsI z evArgs
| Just (_, kn_co_dict) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, Just (_, kn_co_rep) <- tcInstNewTyCon_maybe kn_tcRep [z]
, Just (_, op_co_dict) <- tcInstNewTyCon_maybe (classTyCon opCls) tyArgsC
, [ op_meth ] <- classMethods opCls
, Just (op_tcRep,op_args) <- splitTyConApp_maybe
$ funResultTy
$ (`piResultTys` tyArgsC)
$ idType op_meth
, Just (_, op_co_rep) <- tcInstNewTyCon_maybe op_tcRep op_args
#if MIN_VERSION_ghc(8,5,0)
, let EvExpr dfun_inst = evDFunApp dfid tyArgsI evArgs
, let dfun_inst = EvDFunApp dfid tyArgsI evArgs
op_to_kn = mkTcTransCo (mkTcTransCo op_co_dict op_co_rep)
(mkTcSymCo (mkTcTransCo kn_co_dict kn_co_rep))
ev_tm = mkEvCast dfun_inst op_to_kn
= Just ev_tm
| otherwise
= Nothing
makeKnCoercion :: Class
-> Type
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> EvExpr
-> EvTerm
-> Maybe EvTerm
makeKnCoercion knCls x z xEv
| Just (_, kn_co_dict_z) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, Just (_, kn_co_rep_z) <- tcInstNewTyCon_maybe kn_tcRep [z]
, Just (_, kn_co_rep_x) <- tcInstNewTyCon_maybe kn_tcRep [x]
, Just (_, kn_co_dict_x) <- tcInstNewTyCon_maybe (classTyCon knCls) [x]
= Just . mkEvCast xEv $ (kn_co_dict_x `mkTcTransCo` kn_co_rep_x) `mkTcTransCo` mkTcSymCo (kn_co_dict_z `mkTcTransCo` kn_co_rep_z)
| otherwise = Nothing
#if MIN_VERSION_ghc(8,5,0)
makeLitDict :: Class -> Type -> Integer -> TcPluginM (Maybe EvTerm)
makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm
makeLitDict clas ty i
| Just (_, co_dict) <- tcInstNewTyCon_maybe (classTyCon clas) [ty]
, [ meth ] <- classMethods clas
, Just tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType meth
, Just (_, co_rep) <- tcInstNewTyCon_maybe tcRep [ty]
#if MIN_VERSION_ghc(8,5,0)
= do
et <- unsafeTcPluginTcM (mkNaturalExpr i)
let ev_tm = mkEvCast et (mkTcSymCo (mkTcTransCo co_dict co_rep))
return (Just ev_tm)
| otherwise
= return Nothing
, let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep))
= Just ev_tm
| otherwise
= Nothing
:: (Class,DFunId)
-> Class
-> [Type]
-> [Type]
-> Type
#if MIN_VERSION_ghc(8,6,0)
-> [EvExpr]
-> [EvTerm]
-> Maybe EvTerm
#if MIN_VERSION_ghc(8,6,0)
makeOpDictByFiat (opCls,dfid) knCls tyArgsC tyArgsI z evArgs
| Just (_, kn_co_dict) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, let kn_co_rep = mkUnivCo (PluginProv "ghc-typelits-knownnat")
(mkTyConApp kn_tcRep [z]) boolTy
, Just (_, op_co_dict) <- tcInstNewTyCon_maybe (classTyCon opCls) tyArgsC
, [ op_meth ] <- classMethods opCls
, Just (op_tcRep,op_args) <- splitTyConApp_maybe
$ funResultTy
$ (`piResultTys` tyArgsC)
$ idType op_meth
, Just (_, op_co_rep) <- tcInstNewTyCon_maybe op_tcRep op_args
, EvExpr dfun_inst <- evDFunApp dfid tyArgsI evArgs
, let op_to_kn = mkTcTransCo (mkTcTransCo op_co_dict op_co_rep)
(mkTcSymCo (mkTcTransCo kn_co_dict kn_co_rep))
ev_tm = mkEvCast dfun_inst op_to_kn
= Just ev_tm
| otherwise
= Nothing
makeOpDictByFiat _ _ _ _ _ _ = Nothing