{-|
Copyright  :  (C) 2016     , University of Twente,
                  2017-2018, QBayLogic B.V.,
                  2017     , Google Inc.
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

A type checker plugin for GHC that can derive \"complex\" @KnownNat@
constraints from other simple/variable @KnownNat@ constraints. i.e. without
this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@
constraint in the type signature of the following function:

@
f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

Using the plugin you can omit the @KnownNat (n+2)@ constraint:

@
f :: forall n . KnownNat n => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

The plugin can derive @KnownNat@ constraints for types consisting of:

* Type variables, when there is a corresponding @KnownNat@ constraint
* Type-level naturals
* Applications of the arithmetic expression: @{+,-,*,^}@
* Type functions, when there is either:
  * a matching given @KnownNat@ constraint; or
  * a corresponding @KnownNat\<N\>@ instance for the type function

To elaborate the latter points, given the type family @Min@:

@
type family Min (a :: Nat) (b :: Nat) :: Nat where
  Min 0 b = 0
  Min a b = If (a <=? b) a b
@

the plugin can derive a @KnownNat (Min x y + 1)@ constraint given only a
@KnownNat (Min x y)@ constraint:

@
g :: forall x y . (KnownNat (Min x y)) => Proxy x -> Proxy y -> Integer
g _ _ = natVal (Proxy :: Proxy (Min x y + 1))
@

And, given the type family @Max@:

@
type family Max (a :: Nat) (b :: Nat) :: Nat where
  Max 0 b = b
  Max a b = If (a <=? b) b a
@

and corresponding @KnownNat2@ instance:

@
instance (KnownNat a, KnownNat b) => KnownNat2 \"TestFunctions.Max\" a b where
  natSing2 = let x = natVal (Proxy @a)
                 y = natVal (Proxy @b)
                 z = max x y
             in  SNatKn z
  \{\-# INLINE natSing2 \#-\}
@

the plugin can derive a @KnownNat (Max x y + 1)@ constraint given only a
@KnownNat x@ and @KnownNat y@ constraint:

@
h :: forall x y . (KnownNat x, KnownNat y) => Proxy x -> Proxy y -> Integer
h _ _ = natVal (Proxy :: Proxy (Max x y + 1))
@

To use the plugin, add the

@
OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver
@

Pragma to the header of your file.

-}

{-# LANGUAGE CPP           #-}
{-# LANGUAGE LambdaCase    #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns  #-}

{-# LANGUAGE Trustworthy   #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module GHC.TypeLits.KnownNat.Solver
  ( plugin )
where

-- external
import Control.Arrow                ((&&&), first)
import Control.Monad.Trans.Maybe    (MaybeT (..))
import Control.Monad.Trans.Writer.Strict
import Data.Maybe                   (catMaybes,mapMaybe)
import GHC.TcPluginM.Extra          (lookupModule, lookupName, newWanted,
                                     tracePlugin)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra          (flattenGivens, mkSubst', substType)
#endif
import GHC.TypeLits.Normalise.SOP   (SOP (..), Product (..), Symbol (..))
import GHC.TypeLits.Normalise.Unify (CType (..),normaliseNat,reifySOP)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Names (knownNatClassName)
import GHC.Builtin.Types (boolTy)
import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatSubTyCon)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Builtin.Types (promotedFalseDataCon, promotedTrueDataCon)
import GHC.Builtin.Types.Literals (typeNatCmpTyCon)
#endif
import GHC.Core.Class (Class, classMethods, className, classTyCon)
import GHC.Core.Coercion (Role (Representational), mkUnivCo)
import GHC.Core.InstEnv (instanceDFunId, lookupUniqueInstEnv)
import GHC.Core.Make (mkNaturalExpr)
import GHC.Core.Predicate
  (EqRel (NomEq), Pred (ClassPred,EqPred), classifyPredType)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv))
import GHC.Core.TyCon (tyConName)
import GHC.Core.Type
  (PredType, dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
   piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe, typeKind,
   irrelevantMult)
import GHC.Data.FastString (fsLit)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Tc.Instance.Family (tcInstNewTyCon_maybe)
import GHC.Tc.Plugin (TcPluginM, tcLookupClass, getInstEnvs)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult (..))
import GHC.Tc.Types.Constraint
  (Ct, ctEvExpr, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted,
   mkNonCanonical, setCtLoc, setCtLocSpan)
import GHC.Tc.Types.Evidence
  (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
import GHC.Types.Id (idType)
import GHC.Types.Name (nameModule_maybe, nameOccName)
import GHC.Types.Name.Occurrence (mkTcOcc, occNameString)
import GHC.Types.Var (DFunId)
import GHC.Unit.Module (mkModuleName, moduleName, moduleNameString)
#else
import Class      (Class, classMethods, className, classTyCon)
#if MIN_VERSION_ghc(8,6,0)
import Coercion   (Role (Representational), mkUnivCo)
#endif
import FamInst    (tcInstNewTyCon_maybe)
import FastString (fsLit)
import Id         (idType)
import InstEnv    (instanceDFunId,lookupUniqueInstEnv)
#if MIN_VERSION_ghc(8,5,0)
import MkCore     (mkNaturalExpr)
#endif
import Module     (mkModuleName, moduleName, moduleNameString)
import Name       (nameModule_maybe, nameOccName)
import OccName    (mkTcOcc, occNameString)
import Plugins    (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins    (purePlugin)
#endif
import PrelNames  (knownNatClassName)
#if MIN_VERSION_ghc(8,5,0)
import TcEvidence (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
#else
import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo)
#endif
#if MIN_VERSION_ghc(8,5,0)
import TcPluginM  (unsafeTcPluginTcM)
#endif
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM  (zonkCt)
#endif
import TcPluginM  (TcPluginM, tcLookupClass, getInstEnvs)
import TcRnTypes  (TcPlugin(..), TcPluginResult (..))
import TcTypeNats (typeNatAddTyCon, typeNatSubTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatDivTyCon)
#endif
import Type
  (PredType,
   dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
   piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe, typeKind)
import TyCon      (tyConName)
import TyCoRep    (Type (..), TyLit (..))
#if MIN_VERSION_ghc(8,6,0)
import TyCoRep    (UnivCoProvenance (PluginProv))
import TysWiredIn (boolTy)
#endif
import Var        (DFunId)

#if MIN_VERSION_ghc(8,10,0)
import Constraint
  (Ct, ctEvExpr, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted,
   mkNonCanonical, setCtLoc, setCtLocSpan)
import Predicate (EqRel (NomEq), Pred (ClassPred,EqPred), classifyPredType)
#else
import TcRnTypes
  (Ct, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted, mkNonCanonical,
   setCtLoc, setCtLocSpan)
import Type      (EqRel (NomEq), PredTree (ClassPred,EqPred), classifyPredType)
#if MIN_VERSION_ghc(8,5,0)
import TcRnTypes (ctEvExpr)
#else
import TcRnTypes (ctEvTerm)
#endif
#endif
#endif

-- | Classes and instances from "GHC.TypeLits.KnownNat"
data KnownNatDefs
  = KnownNatDefs
  { KnownNatDefs -> Class
knownBool     :: Class
  , KnownNatDefs -> Class
knownBoolNat2 :: Class
  , KnownNatDefs -> Class
knownNat2Bool :: Class
  , KnownNatDefs -> Int -> Maybe Class
knownNatN     :: Int -> Maybe Class -- ^ KnownNat{N}
  }

-- | Simple newtype wrapper to distinguish the original (flattened) argument of
-- knownnat from the un-flattened version that we work with internally.
newtype Orig a = Orig { forall a. Orig a -> a
unOrig :: a }

-- | KnownNat constraints
type KnConstraint = (Ct    -- The constraint
                    ,Class -- KnownNat class
                    ,Type  -- The argument to KnownNat
                    ,Orig Type  -- Original, flattened, argument to KnownNat
                    )

{-|
A type checker plugin for GHC that can derive \"complex\" @KnownNat@
constraints from other simple/variable @KnownNat@ constraints. i.e. without
this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@
constraint in the type signature of the following function:

@
f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

Using the plugin you can omit the @KnownNat (n+2)@ constraint:

@
f :: forall n . KnownNat n => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

The plugin can derive @KnownNat@ constraints for types consisting of:

* Type variables, when there is a corresponding @KnownNat@ constraint
* Type-level naturals
* Applications of the arithmetic expression: @{+,-,*,^}@
* Type functions, when there is either:
  * a matching given @KnownNat@ constraint; or
  * a corresponding @KnownNat\<N\>@ instance for the type function

To elaborate the latter points, given the type family @Min@:

@
type family Min (a :: Nat) (b :: Nat) :: Nat where
  Min 0 b = 0
  Min a b = If (a <=? b) a b
@

the plugin can derive a @KnownNat (Min x y + 1)@ constraint given only a
@KnownNat (Min x y)@ constraint:

@
g :: forall x y . (KnownNat (Min x y)) => Proxy x -> Proxy y -> Integer
g _ _ = natVal (Proxy :: Proxy (Min x y + 1))
@

And, given the type family @Max@:

@
type family Max (a :: Nat) (b :: Nat) :: Nat where
  Max 0 b = b
  Max a b = If (a <=? b) b a

$(genDefunSymbols [''Max]) -- creates the 'MaxSym0' symbol
@

and corresponding @KnownNat2@ instance:

@
instance (KnownNat a, KnownNat b) => KnownNat2 \"TestFunctions.Max\" a b where
  type KnownNatF2 \"TestFunctions.Max\" = MaxSym0
  natSing2 = let x = natVal (Proxy @ a)
                 y = natVal (Proxy @ b)
                 z = max x y
             in  SNatKn z
  \{\-# INLINE natSing2 \#-\}
@

the plugin can derive a @KnownNat (Max x y + 1)@ constraint given only a
@KnownNat x@ and @KnownNat y@ constraint:

@
h :: forall x y . (KnownNat x, KnownNat y) => Proxy x -> Proxy y -> Integer
h _ _ = natVal (Proxy :: Proxy (Max x y + 1))
@

To use the plugin, add the

@
OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver
@

Pragma to the header of your file.

-}
plugin :: Plugin
plugin :: Plugin
plugin
  = Plugin
defaultPlugin
  { tcPlugin :: TcPlugin
tcPlugin = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just TcPlugin
normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
  , pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile = [CommandLineOption] -> IO PluginRecompile
purePlugin
#endif
  }

normalisePlugin :: TcPlugin
normalisePlugin :: TcPlugin
normalisePlugin = CommandLineOption -> TcPlugin -> TcPlugin
tracePlugin CommandLineOption
"ghc-typelits-knownnat"
  TcPlugin { tcPluginInit :: TcPluginM KnownNatDefs
tcPluginInit  = TcPluginM KnownNatDefs
lookupKnownNatDefs
           , tcPluginSolve :: KnownNatDefs -> TcPluginSolver
tcPluginSolve = KnownNatDefs -> TcPluginSolver
solveKnownNat
           , tcPluginStop :: KnownNatDefs -> TcPluginM ()
tcPluginStop  = forall a b. a -> b -> a
const (forall (m :: * -> *) a. Monad m => a -> m a
return ())
           }

solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
              -> TcPluginM TcPluginResult
solveKnownNat :: KnownNatDefs -> TcPluginSolver
solveKnownNat KnownNatDefs
_defs [Ct]
_givens [Ct]
_deriveds []      = forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
solveKnownNat KnownNatDefs
defs  [Ct]
givens  [Ct]
_deriveds [Ct]
wanteds = do
  -- GHC 7.10 puts deriveds with the wanteds, so filter them out
  let wanteds' :: [Ct]
wanteds'   = forall a. (a -> Bool) -> [a] -> [a]
filter (CtEvidence -> Bool
isWanted forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ct -> CtEvidence
ctEvidence) [Ct]
wanteds
#if MIN_VERSION_ghc(8,4,0)
      subst :: [(TcTyVar, TcType)]
subst      = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst
                 forall a b. (a -> b) -> a -> b
$ [Ct] -> [((TcTyVar, TcType), Ct)]
mkSubst' [Ct]
givens
      kn_wanteds :: [(Ct, Class, TcType, Orig TcType)]
kn_wanteds = forall a b. (a -> b) -> [a] -> [b]
map (\(Ct
x,Class
y,TcType
z,Orig TcType
orig) -> (Ct
x,Class
y,[(TcTyVar, TcType)] -> TcType -> TcType
substType [(TcTyVar, TcType)]
subst TcType
z,Orig TcType
orig))
                 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (KnownNatDefs -> Ct -> Maybe (Ct, Class, TcType, Orig TcType)
toKnConstraint KnownNatDefs
defs) [Ct]
wanteds'
#else
      kn_wanteds = mapMaybe (toKnConstraint defs) wanteds'
#endif
  case [(Ct, Class, TcType, Orig TcType)]
kn_wanteds of
    [] -> forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
    [(Ct, Class, TcType, Orig TcType)]
_  -> do
      -- Make a lookup table for all the [G]iven constraints
#if MIN_VERSION_ghc(8,4,0)
      let given_map :: [(CType, EvExpr)]
given_map = forall a b. (a -> b) -> [a] -> [b]
map Ct -> (CType, EvExpr)
toGivenEntry ([Ct] -> [Ct]
flattenGivens [Ct]
givens)
#else
      given_map <- mapM (fmap toGivenEntry . zonkCt) givens
#endif
      -- Try to solve the wanted KnownNat constraints given the [G]iven
      -- KnownNat constraints
      ([(EvTerm, Ct)]
solved,[[Ct]]
new) <- (forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (KnownNatDefs
-> [(CType, EvExpr)]
-> (Ct, Class, TcType, Orig TcType)
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
constraintToEvTerm KnownNatDefs
defs [(CType, EvExpr)]
given_map) [(Ct, Class, TcType, Orig TcType)]
kn_wanteds)
      forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [(EvTerm, Ct)]
solved (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new))

-- | Get the KnownNat constraints
toKnConstraint :: KnownNatDefs -> Ct -> Maybe KnConstraint
toKnConstraint :: KnownNatDefs -> Ct -> Maybe (Ct, Class, TcType, Orig TcType)
toKnConstraint KnownNatDefs
defs Ct
ct = case TcType -> Pred
classifyPredType forall a b. (a -> b) -> a -> b
$ CtEvidence -> TcType
ctEvPred forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
  ClassPred Class
cls [TcType
ty]
    |  Class -> Name
className Class
cls forall a. Eq a => a -> a -> Bool
== Name
knownNatClassName Bool -> Bool -> Bool
||
       Class -> Name
className Class
cls forall a. Eq a => a -> a -> Bool
== Class -> Name
className (KnownNatDefs -> Class
knownBool KnownNatDefs
defs)
    -> forall a. a -> Maybe a
Just (Ct
ct,Class
cls,TcType
ty,forall a. a -> Orig a
Orig TcType
ty)
  Pred
_ -> forall a. Maybe a
Nothing

-- | Create a look-up entry for a [G]iven constraint.
#if MIN_VERSION_ghc(8,5,0)
toGivenEntry :: Ct -> (CType,EvExpr)
#else
toGivenEntry :: Ct -> (CType,EvTerm)
#endif
toGivenEntry :: Ct -> (CType, EvExpr)
toGivenEntry Ct
ct = let ct_ev :: CtEvidence
ct_ev = Ct -> CtEvidence
ctEvidence Ct
ct
                      c_ty :: TcType
c_ty  = CtEvidence -> TcType
ctEvPred   CtEvidence
ct_ev
#if MIN_VERSION_ghc(8,5,0)
                      ev :: EvExpr
ev    = CtEvidence -> EvExpr
ctEvExpr   CtEvidence
ct_ev
#else
                      ev    = ctEvTerm   ct_ev
#endif
                  in  (TcType -> CType
CType TcType
c_ty,EvExpr
ev)

-- | Find the \"magic\" classes and instances in "GHC.TypeLits.KnownNat"
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs = do
    Module
md     <- ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
myModule FastString
myPackage
    Class
kbC    <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownBool"
    Class
kbn2C  <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownBoolNat2"
    Class
kn2bC  <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat2Bool"
    Class
kn1C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat1"
    Class
kn2C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat2"
    Class
kn3C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat3"
    forall (m :: * -> *) a. Monad m => a -> m a
return KnownNatDefs
           { knownBool :: Class
knownBool     = Class
kbC
           , knownBoolNat2 :: Class
knownBoolNat2 = Class
kbn2C
           , knownNat2Bool :: Class
knownNat2Bool = Class
kn2bC
           , knownNatN :: Int -> Maybe Class
knownNatN     = \case { Int
1 -> forall a. a -> Maybe a
Just Class
kn1C
                                   ; Int
2 -> forall a. a -> Maybe a
Just Class
kn2C
                                   ; Int
3 -> forall a. a -> Maybe a
Just Class
kn3C
                                   ; Int
_ -> forall a. Maybe a
Nothing
                                   }
           }
  where
    look :: Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
s = do
      Name
nm   <- Module -> OccName -> TcPluginM Name
lookupName Module
md (CommandLineOption -> OccName
mkTcOcc CommandLineOption
s)
      Name -> TcPluginM Class
tcLookupClass Name
nm

    myModule :: ModuleName
myModule  = CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"GHC.TypeLits.KnownNat"
    myPackage :: FastString
myPackage = CommandLineOption -> FastString
fsLit CommandLineOption
"ghc-typelits-knownnat"

-- | Try to create evidence for a wanted constraint
constraintToEvTerm
  :: KnownNatDefs     -- ^ The "magic" KnownNatN classes
#if MIN_VERSION_ghc(8,5,0)
  -> [(CType,EvExpr)]
#else
  -> [(CType,EvTerm)]
#endif
  -- All the [G]iven constraints

  -> KnConstraint
  -> TcPluginM (Maybe ((EvTerm,Ct),[Ct]))
constraintToEvTerm :: KnownNatDefs
-> [(CType, EvExpr)]
-> (Ct, Class, TcType, Orig TcType)
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
constraintToEvTerm KnownNatDefs
defs [(CType, EvExpr)]
givens (Ct
ct,Class
cls,TcType
op,Orig TcType
orig) = do
    -- 1. Determine if we are an offset apart from a [G]iven constraint
    Maybe (EvTerm, [Ct])
offsetM <- TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
offset TcType
op
    Maybe (EvTerm, [Ct])
evM     <- case Maybe (EvTerm, [Ct])
offsetM of
                 -- 3.a If so, we are done
                 found :: Maybe (EvTerm, [Ct])
found@Just {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, [Ct])
found
                 -- 3.b If not, we check if the outer type-level operation
                 -- has a corresponding KnownNat<N> instance.
                 Maybe (EvTerm, [Ct])
_ -> TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go TcType
op
    forall (m :: * -> *) a. Monad m => a -> m a
return ((forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (,Ct
ct)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (EvTerm, [Ct])
evM)
  where
    -- Determine whether the outer type-level operation has a corresponding
    -- KnownNat<N> instance, where /N/ corresponds to the arity of the
    -- type-level operation
    go :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
    go :: TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go (TcType -> Maybe EvTerm
go_other -> Just EvTerm
ev) = forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (EvTerm
ev,[]))
    go ty :: TcType
ty@(TyConApp TyCon
tc [TcType]
args0)
      | let tcNm :: Name
tcNm = TyCon -> Name
tyConName TyCon
tc
      , Just Module
m <- Name -> Maybe Module
nameModule_maybe Name
tcNm
      = do
        InstEnvs
ienv <- TcPluginM InstEnvs
getInstEnvs
        let mS :: CommandLineOption
mS  = ModuleName -> CommandLineOption
moduleNameString (forall unit. GenModule unit -> ModuleName
moduleName Module
m)
            tcS :: CommandLineOption
tcS = OccName -> CommandLineOption
occNameString (Name -> OccName
nameOccName Name
tcNm)
            fn0 :: CommandLineOption
fn0 = CommandLineOption
mS forall a. [a] -> [a] -> [a]
++ CommandLineOption
"." forall a. [a] -> [a] -> [a]
++ CommandLineOption
tcS
            fn1 :: TcType
fn1 = FastString -> TcType
mkStrLitTy (CommandLineOption -> FastString
fsLit CommandLineOption
fn0)
            args1 :: [TcType]
args1 = TcType
fn1forall a. a -> [a] -> [a]
:[TcType]
args0
            instM :: Maybe (ClsInst, Class, [TcType], [TcType])
instM = case () of
              () | Just Class
knN_cls    <- KnownNatDefs -> Int -> Maybe Class
knownNatN KnownNatDefs
defs (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0)
                 , Right (ClsInst
inst, [TcType]
_) <- InstEnvs -> Class -> [TcType] -> Either SDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1
                 -> forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0,[TcType]
args1)
#if MIN_VERSION_base(4,16,0)
                 | CommandLineOption
fn0 forall a. Eq a => a -> a -> Bool
== CommandLineOption
"Data.Type.Ord.OrdCond"
                 , [TcType
_,TcType
cmpNat,TyConApp TyCon
t1 [],TyConApp TyCon
t2 [],TyConApp TyCon
f1 []] <- [TcType]
args0
                 , TyConApp TyCon
cmpNatTc [TcType]
args2 <- TcType
cmpNat
                 , TyCon
cmpNatTc forall a. Eq a => a -> a -> Bool
== TyCon
typeNatCmpTyCon
                 , TyCon
t1 forall a. Eq a => a -> a -> Bool
== TyCon
promotedTrueDataCon
                 , TyCon
t2 forall a. Eq a => a -> a -> Bool
== TyCon
promotedTrueDataCon
                 , TyCon
f1 forall a. Eq a => a -> a -> Bool
== TyCon
promotedFalseDataCon
                 , let knN_cls :: Class
knN_cls = KnownNatDefs -> Class
knownBoolNat2 KnownNatDefs
defs
                       ki :: TcType
ki      = HasDebugCallStack => TcType -> TcType
typeKind (forall a. [a] -> a
head [TcType]
args2)
                       args1N :: [TcType]
args1N  = TcType
kiforall a. a -> [a] -> [a]
:TcType
fn1forall a. a -> [a] -> [a]
:[TcType]
args2
                 , Right (ClsInst
inst,[TcType]
_) <- InstEnvs -> Class -> [TcType] -> Either SDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1N
                 -> forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args2,[TcType]
args1N)
#endif
                 | forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0 forall a. Eq a => a -> a -> Bool
== Int
2
                 , let knN_cls :: Class
knN_cls = KnownNatDefs -> Class
knownBoolNat2 KnownNatDefs
defs
                       ki :: TcType
ki      = HasDebugCallStack => TcType -> TcType
typeKind (forall a. [a] -> a
head [TcType]
args0)
                       args1N :: [TcType]
args1N  = TcType
kiforall a. a -> [a] -> [a]
:[TcType]
args1
                 , Right (ClsInst
inst, [TcType]
_) <- InstEnvs -> Class -> [TcType] -> Either SDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1N
                 -> forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0,[TcType]
args1N)
                 | forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0 forall a. Eq a => a -> a -> Bool
== Int
4
                 , CommandLineOption
fn0 forall a. Eq a => a -> a -> Bool
== CommandLineOption
"Data.Type.Bool.If"
                 , let args0N :: [TcType]
args0N = forall a. [a] -> [a]
tail [TcType]
args0
                       args1N :: [TcType]
args1N = forall a. [a] -> a
head [TcType]
args0forall a. a -> [a] -> [a]
:TcType
fn1forall a. a -> [a] -> [a]
:forall a. [a] -> [a]
tail [TcType]
args0
                       knN_cls :: Class
knN_cls = KnownNatDefs -> Class
knownNat2Bool KnownNatDefs
defs
                 , Right (ClsInst
inst, [TcType]
_) <- InstEnvs -> Class -> [TcType] -> Either SDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1N
                 -> forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0N,[TcType]
args1N)
                 | Bool
otherwise
                 -> forall a. Maybe a
Nothing
        case Maybe (ClsInst, Class, [TcType], [TcType])
instM of
          Just (ClsInst
inst,Class
knN_cls,[TcType]
args0N,[TcType]
args1N) -> do
            let df_id :: TcTyVar
df_id   = ClsInst -> TcTyVar
instanceDFunId ClsInst
inst
                df :: (Class, TcTyVar)
df      = (Class
knN_cls,TcTyVar
df_id)
                df_args :: [Scaled TcType]
df_args = forall a b. (a, b) -> a
fst                  -- [KnownNat x, KnownNat y]
                        forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> ([Scaled TcType], TcType)
splitFunTys          -- ([KnownNat x, KnowNat y], DKnownNat2 "+" x y)
                        forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HasDebugCallStack => TcType -> [TcType] -> TcType
`piResultTys` [TcType]
args0N) -- (KnowNat x, KnownNat y) => DKnownNat2 "+" x y
                        forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
df_id         -- forall a b . (KnownNat a, KnownNat b) => DKnownNat2 "+" a b
#if MIN_VERSION_ghc(9,0,0)
            ([EvExpr]
evs,[[Ct]]
new) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TcType -> TcPluginM (EvExpr, [Ct])
go_arg forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Scaled a -> a
irrelevantMult) [Scaled TcType]
df_args
#else
            (evs,new) <- unzip <$> mapM go_arg df_args
#endif
            if Class -> Name
className Class
cls forall a. Eq a => a -> a -> Bool
== Class -> Name
className (KnownNatDefs -> Class
knownBool KnownNatDefs
defs)
               -- Create evidence using the original, flattened, argument of
               -- the KnownNat we're trying to solve. Not doing this results in
               -- GHC panics for:
               -- https://gist.github.com/christiaanb/0d204fe19f89b28f1f8d24feb63f1e63
               --
               -- That's because the flattened KnownNat we're asked to solve is
               -- [W] KnownNat fsk
               -- given:
               -- [G] fsk ~ CLog 2 n + 1
               -- [G] fsk2 ~ n
               -- [G] fsk2 ~ n + m
               --
               -- Our flattening picks one of the solution, so we try to solve
               -- [W] KnownNat (CLog 2 n + 1)
               --
               -- Turns out, GHC wanted us to solve:
               -- [W] KnownNat (CLog 2 (n + m) + 1)
               --
               -- But we have no way of knowing this! Solving the "wrong" expansion
               -- of 'fsk' results in:
               --
               -- ghc: panic! (the 'impossible' happened)
               -- (GHC version 8.6.5 for x86_64-unknown-linux):
               --       buildKindCoercion
               -- CLog 2 (n_a681K + m_a681L)
               -- CLog 2 n_a681K
               -- n_a681K + m_a681L
               -- n_a681K
               --
               -- down the line.
               --
               -- So while the "shape" of the KnownNat evidence that we return
               -- follows 'CLog 2 n + 1', the type of the evidence will be
               -- 'KnownNat fsk'; the one GHC originally asked us to solve.
               then forall (m :: * -> *) a. Monad m => a -> m a
return ((,forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDictByFiat (Class, TcTyVar)
df Class
cls [TcType]
args1N [TcType]
args0N (forall a. Orig a -> a
unOrig Orig TcType
orig) [EvExpr]
evs)
               else forall (m :: * -> *) a. Monad m => a -> m a
return ((,forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDict (Class, TcTyVar)
df Class
cls [TcType]
args1N [TcType]
args0N (forall a. Orig a -> a
unOrig Orig TcType
orig) [EvExpr]
evs)
          Maybe (ClsInst, Class, [TcType], [TcType])
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ((,[]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TcType -> Maybe EvTerm
go_other TcType
ty)

    go (LitTy (NumTyLit Integer
i))
      -- Let GHC solve simple Literal constraints
      | LitTy TyLit
_ <- TcType
op
      = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
      -- This plugin only solves Literal KnownNat's that needed to be normalised
      -- first
      | Bool
otherwise
#if MIN_VERSION_ghc(8,5,0)
      = (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,[])) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Class -> TcType -> Integer -> TcPluginM (Maybe EvTerm)
makeLitDict Class
cls TcType
op Integer
i
#else
      = return ((,[]) <$> makeLitDict cls op i)
#endif
    go TcType
_ = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing

    -- Get EvTerm arguments for type-level operations. If they do not exist
    -- as [G]iven constraints, then generate new [W]anted constraints
#if MIN_VERSION_ghc(8,5,0)
    go_arg :: PredType -> TcPluginM (EvExpr,[Ct])
#else
    go_arg :: PredType -> TcPluginM (EvTerm,[Ct])
#endif
    go_arg :: TcType -> TcPluginM (EvExpr, [Ct])
go_arg TcType
ty = case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (TcType -> CType
CType TcType
ty) [(CType, EvExpr)]
givens of
      Just EvExpr
ev -> forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,[])
      Maybe EvExpr
_ -> do
        (EvExpr
ev,Ct
wanted) <- Ct -> TcType -> TcPluginM (EvExpr, Ct)
makeWantedEv Ct
ct TcType
ty
        forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,[Ct
wanted])

    -- Fall through case: look up the normalised [W]anted constraint in the list
    -- of [G]iven constraints.
    go_other :: Type -> Maybe EvTerm
    go_other :: TcType -> Maybe EvTerm
go_other TcType
ty =
      let knClsTc :: TyCon
knClsTc = Class -> TyCon
classTyCon Class
cls
          kn :: TcType
kn      = TyCon -> [TcType] -> TcType
mkTyConApp TyCon
knClsTc [TcType
ty]
          cast :: EvExpr -> Maybe EvTerm
cast    = if TcType -> CType
CType TcType
ty forall a. Eq a => a -> a -> Bool
== TcType -> CType
CType TcType
op
#if MIN_VERSION_ghc(8,6,0)
                       then forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
EvExpr
#else
                       then Just
#endif
                       else Class -> TcType -> TcType -> EvExpr -> Maybe EvTerm
makeKnCoercion Class
cls TcType
ty TcType
op
      in  EvExpr -> Maybe EvTerm
cast forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (TcType -> CType
CType TcType
kn) [(CType, EvExpr)]
givens

    -- Find a known constraint for a wanted, so that (modulo normalization)
    -- the two are a constant offset apart.
    offset :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
    offset :: TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
offset TcType
want = forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall a b. (a -> b) -> a -> b
$ do
      let -- Get the knownnat contraints
          unKn :: TcType -> Maybe TcType
unKn TcType
ty' = case TcType -> Pred
classifyPredType TcType
ty' of
                       ClassPred Class
cls' [TcType
ty'']
                         | Class -> Name
className Class
cls' forall a. Eq a => a -> a -> Bool
== Name
knownNatClassName
                         -> forall a. a -> Maybe a
Just TcType
ty''
                       Pred
_ -> forall a. Maybe a
Nothing
          -- Get the rewrites
          unEq :: TcType -> Maybe (TcType, TcType)
unEq TcType
ty' = case TcType -> Pred
classifyPredType TcType
ty' of
                       EqPred EqRel
NomEq TcType
ty1 TcType
ty2 -> forall a. a -> Maybe a
Just (TcType
ty1,TcType
ty2)
                       Pred
_ -> forall a. Maybe a
Nothing
          rewrites :: [(TcType, TcType)]
rewrites = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TcType -> Maybe (TcType, TcType)
unEq forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> TcType
unCType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(CType, EvExpr)]
givens
          -- Rewrite
          rewriteTy :: TcType -> (TcType, TcType) -> Maybe TcType
rewriteTy TcType
tyK (TcType
ty1,TcType
ty2) | TcType
ty1 TcType -> TcType -> Bool
`eqType` TcType
tyK = forall a. a -> Maybe a
Just TcType
ty2
                                  | TcType
ty2 TcType -> TcType -> Bool
`eqType` TcType
tyK = forall a. a -> Maybe a
Just TcType
ty1
                                  | Bool
otherwise        = forall a. Maybe a
Nothing
          -- Get only the [G]iven KnownNat constraints
          knowns :: [TcType]
knowns   = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TcType -> Maybe TcType
unKn forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> TcType
unCType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(CType, EvExpr)]
givens
          -- Get all the rewritten KNs
          knownsR :: [TcType]
knownsR  = forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\TcType
t -> forall a b. (a -> b) -> [a] -> [b]
map (TcType -> (TcType, TcType) -> Maybe TcType
rewriteTy TcType
t) [(TcType, TcType)]
rewrites) [TcType]
knowns
          knownsX :: [TcType]
knownsX  = [TcType]
knowns forall a. [a] -> [a] -> [a]
++ [TcType]
knownsR
          -- pair up the sum-of-products KnownNat constraints
          -- with the original Nat operation
          subWant :: TcType -> TcType
subWant  = TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[TcType
want])
          exploded :: [(CoreSOP, TcType)]
exploded = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w a. Writer w a -> (a, w)
runWriter forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> Writer [(TcType, TcType)] CoreSOP
normaliseNat forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> TcType
subWant forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& forall a. a -> a
id)
                         [TcType]
knownsX
          -- interesting cases for us are those where
          -- wanted and given only differ by a constant
          examineDiff :: SOP v c -> a -> Maybe (a, Symbol v c)
examineDiff (S [P [I Integer
n]]) a
entire = forall a. a -> Maybe a
Just (a
entire,forall v c. Integer -> Symbol v c
I Integer
n)
          examineDiff (S [P [V v
v]]) a
entire = forall a. a -> Maybe a
Just (a
entire,forall v c. v -> Symbol v c
V v
v)
          examineDiff SOP v c
_ a
_ = forall a. Maybe a
Nothing
          interesting :: [(TcType, Symbol TcTyVar c)]
interesting = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {v} {c} {a} {c}. SOP v c -> a -> Maybe (a, Symbol v c)
examineDiff) [(CoreSOP, TcType)]
exploded
      -- convert the first suitable evidence
      ((TcType
h,Symbol TcTyVar CType
corr):[(TcType, Symbol TcTyVar CType)]
_) <- forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {c}. [(TcType, Symbol TcTyVar c)]
interesting
      TcType
x <- case Symbol TcTyVar CType
corr of
                I Integer
0 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure TcType
h
                I Integer
i | Integer
i forall a. Ord a => a -> a -> Bool
< Integer
0
                    -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatAddTyCon [TcType
h,Integer -> TcType
mkNumLitTy (forall a. Num a => a -> a
negate Integer
i)])
                    | Bool
otherwise
                    -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon [TcType
h,Integer -> TcType
mkNumLitTy Integer
i])
                -- If the offset between a given and a wanted is again the wanted
                -- then the given is twice the wanted; so we can just divide
                -- the given by two. Only possible in GHC 8.4+; for 8.2 we simply
                -- fail because we don't know how to divide.
                Symbol TcTyVar CType
c   | TcType -> CType
CType (CoreSOP -> TcType
reifySOP (forall v c. [Product v c] -> SOP v c
S [forall v c. [Symbol v c] -> Product v c
P [Symbol TcTyVar CType
c]])) forall a. Eq a => a -> a -> Bool
== TcType -> CType
CType TcType
want ->
#if MIN_VERSION_ghc(8,4,0)
                     forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatDivTyCon [TcType
h,CoreSOP -> TcType
reifySOP (forall v c. [Product v c] -> SOP v c
S [forall v c. [Symbol v c] -> Product v c
P [forall v c. Integer -> Symbol v c
I Integer
2]])])
#else
                     MaybeT (pure Nothing)
#endif
                -- Only solve with a variable offset if we have [G]iven knownnat for it
                -- Failing to do this check results in #30
                V TcTyVar
v | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> TcType -> Bool
eqType (TcTyVar -> TcType
TyVarTy TcTyVar
v)) [TcType]
knownsX
                    -> forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing)
                Symbol TcTyVar CType
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon [TcType
h,CoreSOP -> TcType
reifySOP (forall v c. [Product v c] -> SOP v c
S [forall v c. [Symbol v c] -> Product v c
P [Symbol TcTyVar CType
corr]])])
      forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go TcType
x)

makeWantedEv
  :: Ct
  -> Type
#if MIN_VERSION_ghc(8,5,0)
  -> TcPluginM (EvExpr,Ct)
#else
  -> TcPluginM (EvTerm,Ct)
#endif
makeWantedEv :: Ct -> TcType -> TcPluginM (EvExpr, Ct)
makeWantedEv Ct
ct TcType
ty = do
  -- Create a new wanted constraint
  CtEvidence
wantedCtEv <- CtLoc -> TcType -> TcPluginM CtEvidence
newWanted (Ct -> CtLoc
ctLoc Ct
ct) TcType
ty
#if MIN_VERSION_ghc(8,5,0)
  let ev :: EvExpr
ev      = CtEvidence -> EvExpr
ctEvExpr CtEvidence
wantedCtEv
#else
  let ev      = ctEvTerm wantedCtEv
#endif
      wanted :: Ct
wanted  = CtEvidence -> Ct
mkNonCanonical CtEvidence
wantedCtEv
      -- Set the source-location of the new wanted constraint to the source
      -- location of the [W]anted constraint we are currently trying to solve
      ct_ls :: RealSrcSpan
ct_ls   = CtLoc -> RealSrcSpan
ctLocSpan (Ct -> CtLoc
ctLoc Ct
ct)
      ctl :: CtLoc
ctl     = CtEvidence -> CtLoc
ctEvLoc  CtEvidence
wantedCtEv
      wanted' :: Ct
wanted' = Ct -> CtLoc -> Ct
setCtLoc Ct
wanted (CtLoc -> RealSrcSpan -> CtLoc
setCtLocSpan CtLoc
ctl RealSrcSpan
ct_ls)
  forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,Ct
wanted')

{- |
Given:

* A "magic" class, and corresponding instance dictionary function, for a
  type-level arithmetic operation
* Two KnownNat dictionaries

makeOpDict instantiates the dictionary function with the KnownNat dictionaries,
and coerces it to a KnownNat dictionary. i.e. for KnownNat2, the "magic"
dictionary for binary functions, the coercion happens in the following steps:

1. KnownNat2 "+" a b           -> SNatKn (KnownNatF2 "+" a b)
2. SNatKn (KnownNatF2 "+" a b) -> Integer
3. Integer                     -> SNat (a + b)
4. SNat (a + b)                -> KnownNat (a + b)

this process is mirrored for the dictionary functions of a higher arity
-}
makeOpDict
  :: (Class,DFunId)
  -- ^ "magic" class function and dictionary function id
  -> Class
  -- ^ KnownNat class
  -> [Type]
  -- ^ Argument types for the Class
  -> [Type]
  -- ^ Argument types for the Instance
  -> Type           -- ^ Type of the result
#if MIN_VERSION_ghc(8,5,0)
  -> [EvExpr]
#else
  -> [EvTerm]
#endif
  -- ^ Evidence arguments
  -> Maybe EvTerm
makeOpDict :: (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDict (Class
opCls,TcTyVar
dfid) Class
knCls [TcType]
tyArgsC [TcType]
tyArgsI TcType
z [EvExpr]
evArgs
  | Just (TcType
_, TcCoercion
kn_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
    -- KnownNat n ~ SNat n
  , [ TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                      forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy      -- SNat n
                      forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls      -- KnownNat n => SNat n
                      forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth   -- forall n. KnownNat n => SNat n
  , Just (TcType
_, TcCoercion
kn_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
z]
    -- SNat n ~ Integer
  , Just (TcType
_, TcCoercion
op_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
opCls) [TcType]
tyArgsC
    -- KnownNatAdd a b ~ SNatKn (a+b)
  , [ TcTyVar
op_meth ] <- Class -> [TcTyVar]
classMethods Class
opCls
  , Just (TyCon
op_tcRep,[TcType]
op_args) <- HasDebugCallStack => TcType -> Maybe (TyCon, [TcType])
splitTyConApp_maybe        -- (SNatKn, [KnownNatF2 f x y])
                                 forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy            -- SNatKn (KnownNatF2 f x y)
                                 forall a b. (a -> b) -> a -> b
$ (HasDebugCallStack => TcType -> [TcType] -> TcType
`piResultTys` [TcType]
tyArgsC) -- KnownNatAdd f x y => SNatKn (KnownNatF2 f x y)
                                 forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
op_meth         -- forall f a b . KnownNat2 f a b => SNatKn (KnownNatF2 f a b)
  , Just (TcType
_, TcCoercion
op_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
op_tcRep [TcType]
op_args
    -- SNatKn (a+b) ~ Integer
#if MIN_VERSION_ghc(8,5,0)
  , EvExpr EvExpr
dfun_inst <- TcTyVar -> [TcType] -> [EvExpr] -> EvTerm
evDFunApp TcTyVar
dfid [TcType]
tyArgsI [EvExpr]
evArgs
#else
  , let dfun_inst = EvDFunApp dfid tyArgsI evArgs
#endif
        -- KnownNatAdd a b
  , let op_to_kn :: TcCoercion
op_to_kn  = TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
op_co_dict TcCoercion
op_co_rep)
                                (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
kn_co_dict TcCoercion
kn_co_rep))
        -- KnownNatAdd a b ~ KnownNat (a+b)
        ev_tm :: EvTerm
ev_tm     = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
dfun_inst TcCoercion
op_to_kn
  = forall a. a -> Maybe a
Just EvTerm
ev_tm
  | Bool
otherwise
  = forall a. Maybe a
Nothing

{-
Given:
* A KnownNat dictionary evidence over a type x
* a desired type z
makeKnCoercion assembles a coercion from a KnownNat x
dictionary to a KnownNat z dictionary and applies it
to the passed-in evidence.
The coercion happens in the following steps:
1. KnownNat x -> SNat x
2. SNat x     -> Integer
3. Integer    -> SNat z
4. SNat z     -> KnownNat z
-}
makeKnCoercion :: Class          -- ^ KnownNat class
               -> Type           -- ^ Type of the argument
               -> Type           -- ^ Type of the result
#if MIN_VERSION_ghc(8,5,0)
               -> EvExpr
#else
               -> EvTerm
#endif
               -- ^ KnownNat dictionary for the argument
               -> Maybe EvTerm
makeKnCoercion :: Class -> TcType -> TcType -> EvExpr -> Maybe EvTerm
makeKnCoercion Class
knCls TcType
x TcType
z EvExpr
xEv
  | Just (TcType
_, TcCoercion
kn_co_dict_z) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
    -- KnownNat z ~ SNat z
  , [ TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                      forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy      -- SNat n
                      forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls      -- KnownNat n => SNat n
                      forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth   -- forall n. KnownNat n => SNat n
  , Just (TcType
_, TcCoercion
kn_co_rep_z) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
z]
    -- SNat z ~ Integer
  , Just (TcType
_, TcCoercion
kn_co_rep_x) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
x]
    -- Integer ~ SNat x
  , Just (TcType
_, TcCoercion
kn_co_dict_x) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
x]
    -- SNat x ~ KnownNat x
  = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
xEv forall a b. (a -> b) -> a -> b
$ (TcCoercion
kn_co_dict_x TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion
kn_co_rep_x) TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion
kn_co_dict_z TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion
kn_co_rep_z)
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | THIS CODE IS COPIED FROM:
-- https://github.com/ghc/ghc/blob/8035d1a5dc7290e8d3d61446ee4861e0b460214e/compiler/typecheck/TcInteract.hs#L1973
--
-- makeLitDict adds a coercion that will convert the literal into a dictionary
-- of the appropriate type.  See Note [KnownNat & KnownSymbol and EvLit]
-- in TcEvidence.  The coercion happens in 2 steps:
--
--     Integer -> SNat n     -- representation of literal to singleton
--     SNat n  -> KnownNat n -- singleton to dictionary
#if MIN_VERSION_ghc(8,5,0)
makeLitDict :: Class -> Type -> Integer -> TcPluginM (Maybe EvTerm)
#else
makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm
#endif
makeLitDict :: Class -> TcType -> Integer -> TcPluginM (Maybe EvTerm)
makeLitDict Class
clas TcType
ty Integer
i
  | Just (TcType
_, TcCoercion
co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
clas) [TcType
ty]
    -- co_dict :: KnownNat n ~ SNat n
  , [ TcTyVar
meth ]   <- Class -> [TcTyVar]
classMethods Class
clas
  , Just TyCon
tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                    forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy     -- SNat n
                    forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls     -- KnownNat n => SNat n
                    forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
meth     -- forall n. KnownNat n => SNat n
  , Just (TcType
_, TcCoercion
co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
tcRep [TcType
ty]
        -- SNat n ~ Integer
#if MIN_VERSION_ghc(8,5,0)
  = do
#if MIN_VERSION_ghc(9,0,0)
    let et :: EvExpr
et = Integer -> EvExpr
mkNaturalExpr Integer
i
#else
    et <- unsafeTcPluginTcM (mkNaturalExpr i)
#endif
    let ev_tm :: EvTerm
ev_tm = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
et (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
co_dict TcCoercion
co_rep))
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just EvTerm
ev_tm)
  | Bool
otherwise
  = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
#else
  , let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep))
  = Just ev_tm
  | otherwise
  = Nothing
#endif

{- |
Given:

* A "magic" class, and corresponding instance dictionary function, for a
  type-level boolean operation
* Two KnownBool dictionaries

makeOpDictByFiat instantiates the dictionary function with the KnownBool
dictionaries, and coerces it to a KnownBool dictionary. i.e. for KnownBoolNat2,
the "magic" dictionary for binary functions, the coercion happens in the
following steps:

1. KnownBoolNat2 "<=?" x y     -> SBoolF "<=?"
2. SBoolF "<=?"                -> Bool
3. Bool                        -> SNat (x <=? y)  THE BY FIAT PART!
4. SBool (x <=? y)             -> KnownBool (x <=? y)

this process is mirrored for the dictionary functions of a higher arity
-}
makeOpDictByFiat
  :: (Class,DFunId)
  -- ^ "magic" class function and dictionary function id
  -> Class
   -- ^ KnownNat class
  -> [Type]
  -- ^ Argument types for the Class
  -> [Type]
  -- ^ Argument types for the Instance
  -> Type
  -- ^ Type of the result
#if MIN_VERSION_ghc(8,6,0)
  -> [EvExpr]
#else
  -> [EvTerm]
#endif
  -- ^ Evidence arguments
  -> Maybe EvTerm
#if MIN_VERSION_ghc(8,6,0)
makeOpDictByFiat :: (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDictByFiat (Class
opCls,TcTyVar
dfid) Class
knCls [TcType]
tyArgsC [TcType]
tyArgsI TcType
z [EvExpr]
evArgs
    -- KnownBool b ~ SBool b
  | Just (TcType
_, TcCoercion
kn_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
  , [ TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SBool
                       forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy     -- SBool b
                       forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls     -- KnownBool b => SBool b
                       forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth  -- forall b. KnownBool b => SBool b
    -- SBool b R~ Bool (The "Lie")
  , let kn_co_rep :: TcCoercion
kn_co_rep = UnivCoProvenance -> Role -> TcType -> TcType -> TcCoercion
mkUnivCo (CommandLineOption -> UnivCoProvenance
PluginProv CommandLineOption
"ghc-typelits-knownnat")
                             Role
Representational
                             (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
kn_tcRep [TcType
z]) TcType
boolTy
    -- KnownBoolNat2 f a b ~ SBool f
  , Just (TcType
_, TcCoercion
op_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
opCls) [TcType]
tyArgsC
  , [ TcTyVar
op_meth ] <- Class -> [TcTyVar]
classMethods Class
opCls
  , Just (TyCon
op_tcRep,[TcType]
op_args) <- HasDebugCallStack => TcType -> Maybe (TyCon, [TcType])
splitTyConApp_maybe        -- (SBool, [f])
                                 forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy            -- SBool f
                                 forall a b. (a -> b) -> a -> b
$ (HasDebugCallStack => TcType -> [TcType] -> TcType
`piResultTys` [TcType]
tyArgsC) -- KnownBoolNat2 f x y => SBool f
                                 forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
op_meth         -- forall f x y . KnownBoolNat2 f a b => SBoolf f
    -- SBoolF f ~ Bool
  , Just (TcType
_, TcCoercion
op_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
op_tcRep [TcType]
op_args
  , EvExpr EvExpr
dfun_inst <- TcTyVar -> [TcType] -> [EvExpr] -> EvTerm
evDFunApp TcTyVar
dfid [TcType]
tyArgsI [EvExpr]
evArgs
    -- KnownBoolNat2 f x y ~ KnownBool b
  , let op_to_kn :: TcCoercion
op_to_kn  = TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
op_co_dict TcCoercion
op_co_rep)
                                (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
kn_co_dict TcCoercion
kn_co_rep))
        ev_tm :: EvTerm
ev_tm     = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
dfun_inst TcCoercion
op_to_kn
  = forall a. a -> Maybe a
Just EvTerm
ev_tm
  | Bool
otherwise
  = forall a. Maybe a
Nothing
#else
makeOpDictByFiat _ _ _ _ _ _ = Nothing
#endif