{-|
Copyright  :  (C) 2016     , University of Twente,
                  2017-2018, QBayLogic B.V.,
                  2017     , Google Inc.
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

A type checker plugin for GHC that can derive \"complex\" @KnownNat@
constraints from other simple/variable @KnownNat@ constraints. i.e. without
this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@
constraint in the type signature of the following function:

@
f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

Using the plugin you can omit the @KnownNat (n+2)@ constraint:

@
f :: forall n . KnownNat n => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

The plugin can derive @KnownNat@ constraints for types consisting of:

* Type variables, when there is a corresponding @KnownNat@ constraint
* Type-level naturals
* Applications of the arithmetic expression: @{+,-,*,^}@
* Type functions, when there is either:
  * a matching given @KnownNat@ constraint; or
  * a corresponding @KnownNat\<N\>@ instance for the type function

To elaborate the latter points, given the type family @Min@:

@
type family Min (a :: Nat) (b :: Nat) :: Nat where
  Min 0 b = 0
  Min a b = If (a <=? b) a b
@

the plugin can derive a @KnownNat (Min x y + 1)@ constraint given only a
@KnownNat (Min x y)@ constraint:

@
g :: forall x y . (KnownNat (Min x y)) => Proxy x -> Proxy y -> Integer
g _ _ = natVal (Proxy :: Proxy (Min x y + 1))
@

And, given the type family @Max@:

@
type family Max (a :: Nat) (b :: Nat) :: Nat where
  Max 0 b = b
  Max a b = If (a <=? b) b a
@

and corresponding @KnownNat2@ instance:

@
instance (KnownNat a, KnownNat b) => KnownNat2 \"TestFunctions.Max\" a b where
  natSing2 = let x = natVal (Proxy @a)
                 y = natVal (Proxy @b)
                 z = max x y
             in  SNatKn z
  \{\-# INLINE natSing2 \#-\}
@

the plugin can derive a @KnownNat (Max x y + 1)@ constraint given only a
@KnownNat x@ and @KnownNat y@ constraint:

@
h :: forall x y . (KnownNat x, KnownNat y) => Proxy x -> Proxy y -> Integer
h _ _ = natVal (Proxy :: Proxy (Max x y + 1))
@

To use the plugin, add the

@
OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver
@

Pragma to the header of your file.

-}

{-# LANGUAGE CPP           #-}
{-# LANGUAGE LambdaCase    #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns  #-}

{-# LANGUAGE Trustworthy   #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module GHC.TypeLits.KnownNat.Solver
  ( plugin )
where

-- external
import Control.Arrow                ((&&&), first)
import Control.Monad.Trans.Maybe    (MaybeT (..))
import Control.Monad.Trans.Writer.Strict
import Data.Maybe                   (catMaybes,mapMaybe)
import GHC.TcPluginM.Extra          (lookupModule, lookupName, newWanted,
                                     tracePlugin)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra          (flattenGivens, mkSubst', substType)
#endif
import GHC.TypeLits.Normalise.SOP   (SOP (..), Product (..), Symbol (..))
import GHC.TypeLits.Normalise.Unify (CType (..),normaliseNat,reifySOP)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Names (knownNatClassName)
import GHC.Builtin.Types (boolTy)
import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatSubTyCon)
import GHC.Core.Class (Class, classMethods, className, classTyCon)
import GHC.Core.Coercion (Role (Representational), mkUnivCo)
import GHC.Core.InstEnv (instanceDFunId, lookupUniqueInstEnv)
import GHC.Core.Make (mkNaturalExpr)
import GHC.Core.Predicate
  (EqRel (NomEq), Pred (ClassPred,EqPred), classifyPredType)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv))
import GHC.Core.TyCon (tyConName)
import GHC.Core.Type
  (PredType, dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
   piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe, typeKind,
   irrelevantMult)
import GHC.Data.FastString (fsLit)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Tc.Instance.Family (tcInstNewTyCon_maybe)
import GHC.Tc.Plugin (TcPluginM, tcLookupClass, getInstEnvs)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult (..))
import GHC.Tc.Types.Constraint
  (Ct, ctEvExpr, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted,
   mkNonCanonical, setCtLoc, setCtLocSpan)
import GHC.Tc.Types.Evidence
  (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
import GHC.Types.Id (idType)
import GHC.Types.Name (nameModule_maybe, nameOccName)
import GHC.Types.Name.Occurrence (mkTcOcc, occNameString)
import GHC.Types.Var (DFunId)
import GHC.Unit.Module (mkModuleName, moduleName, moduleNameString)
#else
import Class      (Class, classMethods, className, classTyCon)
#if MIN_VERSION_ghc(8,6,0)
import Coercion   (Role (Representational), mkUnivCo)
#endif
import FamInst    (tcInstNewTyCon_maybe)
import FastString (fsLit)
import Id         (idType)
import InstEnv    (instanceDFunId,lookupUniqueInstEnv)
#if MIN_VERSION_ghc(8,5,0)
import MkCore     (mkNaturalExpr)
#endif
import Module     (mkModuleName, moduleName, moduleNameString)
import Name       (nameModule_maybe, nameOccName)
import OccName    (mkTcOcc, occNameString)
import Plugins    (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins    (purePlugin)
#endif
import PrelNames  (knownNatClassName)
#if MIN_VERSION_ghc(8,5,0)
import TcEvidence (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
#else
import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo)
#endif
#if MIN_VERSION_ghc(8,5,0)
import TcPluginM  (unsafeTcPluginTcM)
#endif
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM  (zonkCt)
#endif
import TcPluginM  (TcPluginM, tcLookupClass, getInstEnvs)
import TcRnTypes  (TcPlugin(..), TcPluginResult (..))
import TcTypeNats (typeNatAddTyCon, typeNatSubTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatDivTyCon)
#endif
import Type
  (PredType,
   dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
   piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe, typeKind)
import TyCon      (tyConName)
import TyCoRep    (Type (..), TyLit (..))
#if MIN_VERSION_ghc(8,6,0)
import TyCoRep    (UnivCoProvenance (PluginProv))
import TysWiredIn (boolTy)
#endif
import Var        (DFunId)

#if MIN_VERSION_ghc(8,10,0)
import Constraint
  (Ct, ctEvExpr, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted,
   mkNonCanonical, setCtLoc, setCtLocSpan)
import Predicate (EqRel (NomEq), Pred (ClassPred,EqPred), classifyPredType)
#else
import TcRnTypes
  (Ct, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted, mkNonCanonical,
   setCtLoc, setCtLocSpan)
import Type      (EqRel (NomEq), PredTree (ClassPred,EqPred), classifyPredType)
#if MIN_VERSION_ghc(8,5,0)
import TcRnTypes (ctEvExpr)
#else
import TcRnTypes (ctEvTerm)
#endif
#endif
#endif

-- | Classes and instances from "GHC.TypeLits.KnownNat"
data KnownNatDefs
  = KnownNatDefs
  { KnownNatDefs -> Class
knownBool     :: Class
  , KnownNatDefs -> Class
knownBoolNat2 :: Class
  , KnownNatDefs -> Class
knownNat2Bool :: Class
  , KnownNatDefs -> Int -> Maybe Class
knownNatN     :: Int -> Maybe Class -- ^ KnownNat{N}
  }

-- | Simple newtype wrapper to distinguish the original (flattened) argument of
-- knownnat from the un-flattened version that we work with internally.
newtype Orig a = Orig { Orig a -> a
unOrig :: a }

-- | KnownNat constraints
type KnConstraint = (Ct    -- The constraint
                    ,Class -- KnownNat class
                    ,Type  -- The argument to KnownNat
                    ,Orig Type  -- Original, flattened, argument to KnownNat
                    )

{-|
A type checker plugin for GHC that can derive \"complex\" @KnownNat@
constraints from other simple/variable @KnownNat@ constraints. i.e. without
this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@
constraint in the type signature of the following function:

@
f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

Using the plugin you can omit the @KnownNat (n+2)@ constraint:

@
f :: forall n . KnownNat n => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

The plugin can derive @KnownNat@ constraints for types consisting of:

* Type variables, when there is a corresponding @KnownNat@ constraint
* Type-level naturals
* Applications of the arithmetic expression: @{+,-,*,^}@
* Type functions, when there is either:
  * a matching given @KnownNat@ constraint; or
  * a corresponding @KnownNat\<N\>@ instance for the type function

To elaborate the latter points, given the type family @Min@:

@
type family Min (a :: Nat) (b :: Nat) :: Nat where
  Min 0 b = 0
  Min a b = If (a <=? b) a b
@

the plugin can derive a @KnownNat (Min x y + 1)@ constraint given only a
@KnownNat (Min x y)@ constraint:

@
g :: forall x y . (KnownNat (Min x y)) => Proxy x -> Proxy y -> Integer
g _ _ = natVal (Proxy :: Proxy (Min x y + 1))
@

And, given the type family @Max@:

@
type family Max (a :: Nat) (b :: Nat) :: Nat where
  Max 0 b = b
  Max a b = If (a <=? b) b a

$(genDefunSymbols [''Max]) -- creates the 'MaxSym0' symbol
@

and corresponding @KnownNat2@ instance:

@
instance (KnownNat a, KnownNat b) => KnownNat2 \"TestFunctions.Max\" a b where
  type KnownNatF2 \"TestFunctions.Max\" = MaxSym0
  natSing2 = let x = natVal (Proxy @ a)
                 y = natVal (Proxy @ b)
                 z = max x y
             in  SNatKn z
  \{\-# INLINE natSing2 \#-\}
@

the plugin can derive a @KnownNat (Max x y + 1)@ constraint given only a
@KnownNat x@ and @KnownNat y@ constraint:

@
h :: forall x y . (KnownNat x, KnownNat y) => Proxy x -> Proxy y -> Integer
h _ _ = natVal (Proxy :: Proxy (Max x y + 1))
@

To use the plugin, add the

@
OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver
@

Pragma to the header of your file.

-}
plugin :: Plugin
plugin :: Plugin
plugin
  = Plugin
defaultPlugin
  { tcPlugin :: TcPlugin
tcPlugin = Maybe TcPlugin -> TcPlugin
forall a b. a -> b -> a
const (Maybe TcPlugin -> TcPlugin) -> Maybe TcPlugin -> TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just TcPlugin
normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
  , pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile = [CommandLineOption] -> IO PluginRecompile
purePlugin
#endif
  }

normalisePlugin :: TcPlugin
normalisePlugin :: TcPlugin
normalisePlugin = CommandLineOption -> TcPlugin -> TcPlugin
tracePlugin CommandLineOption
"ghc-typelits-knownnat"
  TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin { tcPluginInit :: TcPluginM KnownNatDefs
tcPluginInit  = TcPluginM KnownNatDefs
lookupKnownNatDefs
           , tcPluginSolve :: KnownNatDefs -> TcPluginSolver
tcPluginSolve = KnownNatDefs -> TcPluginSolver
solveKnownNat
           , tcPluginStop :: KnownNatDefs -> TcPluginM ()
tcPluginStop  = TcPluginM () -> KnownNatDefs -> TcPluginM ()
forall a b. a -> b -> a
const (() -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
           }

solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
              -> TcPluginM TcPluginResult
solveKnownNat :: KnownNatDefs -> TcPluginSolver
solveKnownNat KnownNatDefs
_defs [Ct]
_givens [Ct]
_deriveds []      = TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
solveKnownNat KnownNatDefs
defs  [Ct]
givens  [Ct]
_deriveds [Ct]
wanteds = do
  -- GHC 7.10 puts deriveds with the wanteds, so filter them out
  let wanteds' :: [Ct]
wanteds'   = (Ct -> Bool) -> [Ct] -> [Ct]
forall a. (a -> Bool) -> [a] -> [a]
filter (CtEvidence -> Bool
isWanted (CtEvidence -> Bool) -> (Ct -> CtEvidence) -> Ct -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ct -> CtEvidence
ctEvidence) [Ct]
wanteds
#if MIN_VERSION_ghc(8,4,0)
      subst :: [(TcTyVar, TcType)]
subst      = (((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst
                 ([((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)])
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> a -> b
$ [Ct] -> [((TcTyVar, TcType), Ct)]
mkSubst' [Ct]
givens
      kn_wanteds :: [(Ct, Class, TcType, Orig TcType)]
kn_wanteds = ((Ct, Class, TcType, Orig TcType)
 -> (Ct, Class, TcType, Orig TcType))
-> [(Ct, Class, TcType, Orig TcType)]
-> [(Ct, Class, TcType, Orig TcType)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Ct
x,Class
y,TcType
z,Orig TcType
orig) -> (Ct
x,Class
y,[(TcTyVar, TcType)] -> TcType -> TcType
substType [(TcTyVar, TcType)]
subst TcType
z,Orig TcType
orig))
                 ([(Ct, Class, TcType, Orig TcType)]
 -> [(Ct, Class, TcType, Orig TcType)])
-> [(Ct, Class, TcType, Orig TcType)]
-> [(Ct, Class, TcType, Orig TcType)]
forall a b. (a -> b) -> a -> b
$ (Ct -> Maybe (Ct, Class, TcType, Orig TcType))
-> [Ct] -> [(Ct, Class, TcType, Orig TcType)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (KnownNatDefs -> Ct -> Maybe (Ct, Class, TcType, Orig TcType)
toKnConstraint KnownNatDefs
defs) [Ct]
wanteds'
#else
      kn_wanteds = mapMaybe (toKnConstraint defs) wanteds'
#endif
  case [(Ct, Class, TcType, Orig TcType)]
kn_wanteds of
    [] -> TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
    [(Ct, Class, TcType, Orig TcType)]
_  -> do
      -- Make a lookup table for all the [G]iven constraints
#if MIN_VERSION_ghc(8,4,0)
      let given_map :: [(CType, EvExpr)]
given_map = (Ct -> (CType, EvExpr)) -> [Ct] -> [(CType, EvExpr)]
forall a b. (a -> b) -> [a] -> [b]
map Ct -> (CType, EvExpr)
toGivenEntry ([Ct] -> [Ct]
flattenGivens [Ct]
givens)
#else
      given_map <- mapM (fmap toGivenEntry . zonkCt) givens
#endif
      -- Try to solve the wanted KnownNat constraints given the [G]iven
      -- KnownNat constraints
      ([(EvTerm, Ct)]
solved,[[Ct]]
new) <- ([((EvTerm, Ct), [Ct])] -> ([(EvTerm, Ct)], [[Ct]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((EvTerm, Ct), [Ct])] -> ([(EvTerm, Ct)], [[Ct]]))
-> ([Maybe ((EvTerm, Ct), [Ct])] -> [((EvTerm, Ct), [Ct])])
-> [Maybe ((EvTerm, Ct), [Ct])]
-> ([(EvTerm, Ct)], [[Ct]])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe ((EvTerm, Ct), [Ct])] -> [((EvTerm, Ct), [Ct])]
forall a. [Maybe a] -> [a]
catMaybes) ([Maybe ((EvTerm, Ct), [Ct])] -> ([(EvTerm, Ct)], [[Ct]]))
-> TcPluginM [Maybe ((EvTerm, Ct), [Ct])]
-> TcPluginM ([(EvTerm, Ct)], [[Ct]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (((Ct, Class, TcType, Orig TcType)
 -> TcPluginM (Maybe ((EvTerm, Ct), [Ct])))
-> [(Ct, Class, TcType, Orig TcType)]
-> TcPluginM [Maybe ((EvTerm, Ct), [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (KnownNatDefs
-> [(CType, EvExpr)]
-> (Ct, Class, TcType, Orig TcType)
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
constraintToEvTerm KnownNatDefs
defs [(CType, EvExpr)]
given_map) [(Ct, Class, TcType, Orig TcType)]
kn_wanteds)
      TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [(EvTerm, Ct)]
solved ([[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new))

-- | Get the KnownNat constraints
toKnConstraint :: KnownNatDefs -> Ct -> Maybe KnConstraint
toKnConstraint :: KnownNatDefs -> Ct -> Maybe (Ct, Class, TcType, Orig TcType)
toKnConstraint KnownNatDefs
defs Ct
ct = case TcType -> Pred
classifyPredType (TcType -> Pred) -> TcType -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> TcType
ctEvPred (CtEvidence -> TcType) -> CtEvidence -> TcType
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
  ClassPred Class
cls [TcType
ty]
    |  Class -> Name
className Class
cls Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
knownNatClassName Bool -> Bool -> Bool
||
       Class -> Name
className Class
cls Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Class -> Name
className (KnownNatDefs -> Class
knownBool KnownNatDefs
defs)
    -> (Ct, Class, TcType, Orig TcType)
-> Maybe (Ct, Class, TcType, Orig TcType)
forall a. a -> Maybe a
Just (Ct
ct,Class
cls,TcType
ty,TcType -> Orig TcType
forall a. a -> Orig a
Orig TcType
ty)
  Pred
_ -> Maybe (Ct, Class, TcType, Orig TcType)
forall a. Maybe a
Nothing

-- | Create a look-up entry for a [G]iven constraint.
#if MIN_VERSION_ghc(8,5,0)
toGivenEntry :: Ct -> (CType,EvExpr)
#else
toGivenEntry :: Ct -> (CType,EvTerm)
#endif
toGivenEntry :: Ct -> (CType, EvExpr)
toGivenEntry Ct
ct = let ct_ev :: CtEvidence
ct_ev = Ct -> CtEvidence
ctEvidence Ct
ct
                      c_ty :: TcType
c_ty  = CtEvidence -> TcType
ctEvPred   CtEvidence
ct_ev
#if MIN_VERSION_ghc(8,5,0)
                      ev :: EvExpr
ev    = CtEvidence -> EvExpr
ctEvExpr   CtEvidence
ct_ev
#else
                      ev    = ctEvTerm   ct_ev
#endif
                  in  (TcType -> CType
CType TcType
c_ty,EvExpr
ev)

-- | Find the \"magic\" classes and instances in "GHC.TypeLits.KnownNat"
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs = do
    Module
md     <- ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
myModule FastString
myPackage
    Class
kbC    <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownBool"
    Class
kbn2C  <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownBoolNat2"
    Class
kn2bC  <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat2Bool"
    Class
kn1C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat1"
    Class
kn2C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat2"
    Class
kn3C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
"KnownNat3"
    KnownNatDefs -> TcPluginM KnownNatDefs
forall (m :: * -> *) a. Monad m => a -> m a
return KnownNatDefs :: Class -> Class -> Class -> (Int -> Maybe Class) -> KnownNatDefs
KnownNatDefs
           { knownBool :: Class
knownBool     = Class
kbC
           , knownBoolNat2 :: Class
knownBoolNat2 = Class
kbn2C
           , knownNat2Bool :: Class
knownNat2Bool = Class
kn2bC
           , knownNatN :: Int -> Maybe Class
knownNatN     = \case { Int
1 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn1C
                                   ; Int
2 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn2C
                                   ; Int
3 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn3C
                                   ; Int
_ -> Maybe Class
forall a. Maybe a
Nothing
                                   }
           }
  where
    look :: Module -> CommandLineOption -> TcPluginM Class
look Module
md CommandLineOption
s = do
      Name
nm   <- Module -> OccName -> TcPluginM Name
lookupName Module
md (CommandLineOption -> OccName
mkTcOcc CommandLineOption
s)
      Name -> TcPluginM Class
tcLookupClass Name
nm

    myModule :: ModuleName
myModule  = CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"GHC.TypeLits.KnownNat"
    myPackage :: FastString
myPackage = CommandLineOption -> FastString
fsLit CommandLineOption
"ghc-typelits-knownnat"

-- | Try to create evidence for a wanted constraint
constraintToEvTerm
  :: KnownNatDefs     -- ^ The "magic" KnownNatN classes
#if MIN_VERSION_ghc(8,5,0)
  -> [(CType,EvExpr)]
#else
  -> [(CType,EvTerm)]
#endif
  -- All the [G]iven constraints

  -> KnConstraint
  -> TcPluginM (Maybe ((EvTerm,Ct),[Ct]))
constraintToEvTerm :: KnownNatDefs
-> [(CType, EvExpr)]
-> (Ct, Class, TcType, Orig TcType)
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
constraintToEvTerm KnownNatDefs
defs [(CType, EvExpr)]
givens (Ct
ct,Class
cls,TcType
op,Orig TcType
orig) = do
    -- 1. Determine if we are an offset apart from a [G]iven constraint
    Maybe (EvTerm, [Ct])
offsetM <- TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
offset TcType
op
    Maybe (EvTerm, [Ct])
evM     <- case Maybe (EvTerm, [Ct])
offsetM of
                 -- 3.a If so, we are done
                 found :: Maybe (EvTerm, [Ct])
found@Just {} -> Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, [Ct])
found
                 -- 3.b If not, we check if the outer type-level operation
                 -- has a corresponding KnownNat<N> instance.
                 Maybe (EvTerm, [Ct])
_ -> TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go TcType
op
    Maybe ((EvTerm, Ct), [Ct])
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((EvTerm -> (EvTerm, Ct)) -> (EvTerm, [Ct]) -> ((EvTerm, Ct), [Ct])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (,Ct
ct)) ((EvTerm, [Ct]) -> ((EvTerm, Ct), [Ct]))
-> Maybe (EvTerm, [Ct]) -> Maybe ((EvTerm, Ct), [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (EvTerm, [Ct])
evM)
  where
    -- Determine whether the outer type-level operation has a corresponding
    -- KnownNat<N> instance, where /N/ corresponds to the arity of the
    -- type-level operation
    go :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
    go :: TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go (TcType -> Maybe EvTerm
go_other -> Just EvTerm
ev) = Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((EvTerm, [Ct]) -> Maybe (EvTerm, [Ct])
forall a. a -> Maybe a
Just (EvTerm
ev,[]))
    go ty :: TcType
ty@(TyConApp TyCon
tc [TcType]
args0)
      | let tcNm :: Name
tcNm = TyCon -> Name
tyConName TyCon
tc
      , Just Module
m <- Name -> Maybe Module
nameModule_maybe Name
tcNm
      = do
        InstEnvs
ienv <- TcPluginM InstEnvs
getInstEnvs
        let mS :: CommandLineOption
mS  = ModuleName -> CommandLineOption
moduleNameString (Module -> ModuleName
moduleName Module
m)
            tcS :: CommandLineOption
tcS = OccName -> CommandLineOption
occNameString (Name -> OccName
nameOccName Name
tcNm)
            fn0 :: CommandLineOption
fn0 = CommandLineOption
mS CommandLineOption -> CommandLineOption -> CommandLineOption
forall a. [a] -> [a] -> [a]
++ CommandLineOption
"." CommandLineOption -> CommandLineOption -> CommandLineOption
forall a. [a] -> [a] -> [a]
++ CommandLineOption
tcS
            fn1 :: TcType
fn1 = FastString -> TcType
mkStrLitTy (CommandLineOption -> FastString
fsLit CommandLineOption
fn0)
            args1 :: [TcType]
args1 = TcType
fn1TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType]
args0
            instM :: Maybe (ClsInst, Class, [TcType], [TcType])
instM = case () of
              () | Just Class
knN_cls    <- KnownNatDefs -> Int -> Maybe Class
knownNatN KnownNatDefs
defs ([TcType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0)
                 , Right (ClsInst
inst, [TcType]
_) <- InstEnvs -> Class -> [TcType] -> Either MsgDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1
                 -> (ClsInst, Class, [TcType], [TcType])
-> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0,[TcType]
args1)
                 | [TcType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2
                 , let knN_cls :: Class
knN_cls = KnownNatDefs -> Class
knownBoolNat2 KnownNatDefs
defs
                       ki :: TcType
ki      = HasDebugCallStack => TcType -> TcType
TcType -> TcType
typeKind ([TcType] -> TcType
forall a. [a] -> a
head [TcType]
args0)
                       args1N :: [TcType]
args1N  = TcType
kiTcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType]
args1
                 , Right (ClsInst
inst, [TcType]
_) <- InstEnvs -> Class -> [TcType] -> Either MsgDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1N
                 -> (ClsInst, Class, [TcType], [TcType])
-> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0,[TcType]
args1N)
                 | [TcType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4
                 , CommandLineOption
fn0 CommandLineOption -> CommandLineOption -> Bool
forall a. Eq a => a -> a -> Bool
== CommandLineOption
"Data.Type.Bool.If"
                 , let args0N :: [TcType]
args0N = [TcType] -> [TcType]
forall a. [a] -> [a]
tail [TcType]
args0
                       args1N :: [TcType]
args1N = [TcType] -> TcType
forall a. [a] -> a
head [TcType]
args0TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:TcType
fn1TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType] -> [TcType]
forall a. [a] -> [a]
tail [TcType]
args0
                       knN_cls :: Class
knN_cls = KnownNatDefs -> Class
knownNat2Bool KnownNatDefs
defs
                 , Right (ClsInst
inst, [TcType]
_) <- InstEnvs -> Class -> [TcType] -> Either MsgDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1N
                 -> (ClsInst, Class, [TcType], [TcType])
-> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0N,[TcType]
args1N)
                 | Bool
otherwise
                 -> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. Maybe a
Nothing
        case Maybe (ClsInst, Class, [TcType], [TcType])
instM of
          Just (ClsInst
inst,Class
knN_cls,[TcType]
args0N,[TcType]
args1N) -> do
            let df_id :: TcTyVar
df_id   = ClsInst -> TcTyVar
instanceDFunId ClsInst
inst
                df :: (Class, TcTyVar)
df      = (Class
knN_cls,TcTyVar
df_id)
                df_args :: [TcType]
df_args = ([TcType], TcType) -> [TcType]
forall a b. (a, b) -> a
fst                  -- [KnownNat x, KnownNat y]
                        (([TcType], TcType) -> [TcType])
-> (TcType -> ([TcType], TcType)) -> TcType -> [TcType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> ([TcType], TcType)
splitFunTys          -- ([KnownNat x, KnowNat y], DKnownNat2 "+" x y)
                        (TcType -> ([TcType], TcType))
-> (TcType -> TcType) -> TcType -> ([TcType], TcType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HasDebugCallStack => TcType -> [TcType] -> TcType
TcType -> [TcType] -> TcType
`piResultTys` [TcType]
args0N) -- (KnowNat x, KnownNat y) => DKnownNat2 "+" x y
                        (TcType -> [TcType]) -> TcType -> [TcType]
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
df_id         -- forall a b . (KnownNat a, KnownNat b) => DKnownNat2 "+" a b
#if MIN_VERSION_ghc(9,0,0)
            (evs,new) <- unzip <$> mapM (go_arg . irrelevantMult) df_args
#else
            ([EvExpr]
evs,[[Ct]]
new) <- [(EvExpr, [Ct])] -> ([EvExpr], [[Ct]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(EvExpr, [Ct])] -> ([EvExpr], [[Ct]]))
-> TcPluginM [(EvExpr, [Ct])] -> TcPluginM ([EvExpr], [[Ct]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TcType -> TcPluginM (EvExpr, [Ct]))
-> [TcType] -> TcPluginM [(EvExpr, [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TcType -> TcPluginM (EvExpr, [Ct])
go_arg [TcType]
df_args
#endif
            if Class -> Name
className Class
cls Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Class -> Name
className (KnownNatDefs -> Class
knownBool KnownNatDefs
defs)
               -- Create evidence using the original, flattened, argument of
               -- the KnownNat we're trying to solve. Not doing this results in
               -- GHC panics for:
               -- https://gist.github.com/christiaanb/0d204fe19f89b28f1f8d24feb63f1e63
               --
               -- That's because the flattened KnownNat we're asked to solve is
               -- [W] KnownNat fsk
               -- given:
               -- [G] fsk ~ CLog 2 n + 1
               -- [G] fsk2 ~ n
               -- [G] fsk2 ~ n + m
               --
               -- Our flattening picks one of the solution, so we try to solve
               -- [W] KnownNat (CLog 2 n + 1)
               --
               -- Turns out, GHC wanted us to solve:
               -- [W] KnownNat (CLog 2 (n + m) + 1)
               --
               -- But we have no way of knowing this! Solving the "wrong" expansion
               -- of 'fsk' results in:
               --
               -- ghc: panic! (the 'impossible' happened)
               -- (GHC version 8.6.5 for x86_64-unknown-linux):
               --       buildKindCoercion
               -- CLog 2 (n_a681K + m_a681L)
               -- CLog 2 n_a681K
               -- n_a681K + m_a681L
               -- n_a681K
               --
               -- down the line.
               --
               -- So while the "shape" of the KnownNat evidence that we return
               -- follows 'CLog 2 n + 1', the type of the evidence will be
               -- 'KnownNat fsk'; the one GHC originally asked us to solve.
               then Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((,[[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new) (EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDictByFiat (Class, TcTyVar)
df Class
cls [TcType]
args1N [TcType]
args0N (Orig TcType -> TcType
forall a. Orig a -> a
unOrig Orig TcType
orig) [EvExpr]
evs)
               else Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((,[[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new) (EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDict (Class, TcTyVar)
df Class
cls [TcType]
args1N [TcType]
args0N (Orig TcType -> TcType
forall a. Orig a -> a
unOrig Orig TcType
orig) [EvExpr]
evs)
          Maybe (ClsInst, Class, [TcType], [TcType])
_ -> Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((,[]) (EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TcType -> Maybe EvTerm
go_other TcType
ty)

    go (LitTy (NumTyLit Integer
i))
      -- Let GHC solve simple Literal constraints
      | LitTy TyLit
_ <- TcType
op
      = Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, [Ct])
forall a. Maybe a
Nothing
      -- This plugin only solves Literal KnownNat's that needed to be normalised
      -- first
      | Bool
otherwise
#if MIN_VERSION_ghc(8,5,0)
      = ((EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,[])) (Maybe EvTerm -> Maybe (EvTerm, [Ct]))
-> TcPluginM (Maybe EvTerm) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Class -> TcType -> Integer -> TcPluginM (Maybe EvTerm)
makeLitDict Class
cls TcType
op Integer
i
#else
      = return ((,[]) <$> makeLitDict cls op i)
#endif
    go TcType
_ = Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, [Ct])
forall a. Maybe a
Nothing

    -- Get EvTerm arguments for type-level operations. If they do not exist
    -- as [G]iven constraints, then generate new [W]anted constraints
#if MIN_VERSION_ghc(8,5,0)
    go_arg :: PredType -> TcPluginM (EvExpr,[Ct])
#else
    go_arg :: PredType -> TcPluginM (EvTerm,[Ct])
#endif
    go_arg :: TcType -> TcPluginM (EvExpr, [Ct])
go_arg TcType
ty = case CType -> [(CType, EvExpr)] -> Maybe EvExpr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (TcType -> CType
CType TcType
ty) [(CType, EvExpr)]
givens of
      Just EvExpr
ev -> (EvExpr, [Ct]) -> TcPluginM (EvExpr, [Ct])
forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,[])
      Maybe EvExpr
_ -> do
        (EvExpr
ev,Ct
wanted) <- Ct -> TcType -> TcPluginM (EvExpr, Ct)
makeWantedEv Ct
ct TcType
ty
        (EvExpr, [Ct]) -> TcPluginM (EvExpr, [Ct])
forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,[Ct
wanted])

    -- Fall through case: look up the normalised [W]anted constraint in the list
    -- of [G]iven constraints.
    go_other :: Type -> Maybe EvTerm
    go_other :: TcType -> Maybe EvTerm
go_other TcType
ty =
      let knClsTc :: TyCon
knClsTc = Class -> TyCon
classTyCon Class
cls
          kn :: TcType
kn      = TyCon -> [TcType] -> TcType
mkTyConApp TyCon
knClsTc [TcType
ty]
          cast :: EvExpr -> Maybe EvTerm
cast    = if TcType -> CType
CType TcType
ty CType -> CType -> Bool
forall a. Eq a => a -> a -> Bool
== TcType -> CType
CType TcType
op
#if MIN_VERSION_ghc(8,6,0)
                       then EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> Maybe EvTerm)
-> (EvExpr -> EvTerm) -> EvExpr -> Maybe EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
EvExpr
#else
                       then Just
#endif
                       else Class -> TcType -> TcType -> EvExpr -> Maybe EvTerm
makeKnCoercion Class
cls TcType
ty TcType
op
      in  EvExpr -> Maybe EvTerm
cast (EvExpr -> Maybe EvTerm) -> Maybe EvExpr -> Maybe EvTerm
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CType -> [(CType, EvExpr)] -> Maybe EvExpr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (TcType -> CType
CType TcType
kn) [(CType, EvExpr)]
givens

    -- Find a known constraint for a wanted, so that (modulo normalization)
    -- the two are a constant offset apart.
    offset :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
    offset :: TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
offset TcType
want = MaybeT TcPluginM (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT TcPluginM (EvTerm, [Ct])
 -> TcPluginM (Maybe (EvTerm, [Ct])))
-> MaybeT TcPluginM (EvTerm, [Ct])
-> TcPluginM (Maybe (EvTerm, [Ct]))
forall a b. (a -> b) -> a -> b
$ do
      let -- Get the knownnat contraints
          unKn :: TcType -> Maybe TcType
unKn TcType
ty' = case TcType -> Pred
classifyPredType TcType
ty' of
                       ClassPred Class
cls' [TcType
ty'']
                         | Class -> Name
className Class
cls' Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
knownNatClassName
                         -> TcType -> Maybe TcType
forall a. a -> Maybe a
Just TcType
ty''
                       Pred
_ -> Maybe TcType
forall a. Maybe a
Nothing
          -- Get the rewrites
          unEq :: TcType -> Maybe (TcType, TcType)
unEq TcType
ty' = case TcType -> Pred
classifyPredType TcType
ty' of
                       EqPred EqRel
NomEq TcType
ty1 TcType
ty2 -> (TcType, TcType) -> Maybe (TcType, TcType)
forall a. a -> Maybe a
Just (TcType
ty1,TcType
ty2)
                       Pred
_ -> Maybe (TcType, TcType)
forall a. Maybe a
Nothing
          rewrites :: [(TcType, TcType)]
rewrites = ((CType, EvExpr) -> Maybe (TcType, TcType))
-> [(CType, EvExpr)] -> [(TcType, TcType)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TcType -> Maybe (TcType, TcType)
unEq (TcType -> Maybe (TcType, TcType))
-> ((CType, EvExpr) -> TcType)
-> (CType, EvExpr)
-> Maybe (TcType, TcType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> TcType
unCType (CType -> TcType)
-> ((CType, EvExpr) -> CType) -> (CType, EvExpr) -> TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CType, EvExpr) -> CType
forall a b. (a, b) -> a
fst) [(CType, EvExpr)]
givens
          -- Rewrite
          rewriteTy :: TcType -> (TcType, TcType) -> Maybe TcType
rewriteTy TcType
tyK (TcType
ty1,TcType
ty2) | TcType
ty1 TcType -> TcType -> Bool
`eqType` TcType
tyK = TcType -> Maybe TcType
forall a. a -> Maybe a
Just TcType
ty2
                                  | TcType
ty2 TcType -> TcType -> Bool
`eqType` TcType
tyK = TcType -> Maybe TcType
forall a. a -> Maybe a
Just TcType
ty1
                                  | Bool
otherwise        = Maybe TcType
forall a. Maybe a
Nothing
          -- Get only the [G]iven KnownNat constraints
          knowns :: [TcType]
knowns   = ((CType, EvExpr) -> Maybe TcType) -> [(CType, EvExpr)] -> [TcType]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TcType -> Maybe TcType
unKn (TcType -> Maybe TcType)
-> ((CType, EvExpr) -> TcType) -> (CType, EvExpr) -> Maybe TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> TcType
unCType (CType -> TcType)
-> ((CType, EvExpr) -> CType) -> (CType, EvExpr) -> TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CType, EvExpr) -> CType
forall a b. (a, b) -> a
fst) [(CType, EvExpr)]
givens
          -- Get all the rewritten KNs
          knownsR :: [TcType]
knownsR  = [Maybe TcType] -> [TcType]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe TcType] -> [TcType]) -> [Maybe TcType] -> [TcType]
forall a b. (a -> b) -> a -> b
$ (TcType -> [Maybe TcType]) -> [TcType] -> [Maybe TcType]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\TcType
t -> ((TcType, TcType) -> Maybe TcType)
-> [(TcType, TcType)] -> [Maybe TcType]
forall a b. (a -> b) -> [a] -> [b]
map (TcType -> (TcType, TcType) -> Maybe TcType
rewriteTy TcType
t) [(TcType, TcType)]
rewrites) [TcType]
knowns
          knownsX :: [TcType]
knownsX  = [TcType]
knowns [TcType] -> [TcType] -> [TcType]
forall a. [a] -> [a] -> [a]
++ [TcType]
knownsR
          -- pair up the sum-of-products KnownNat constraints
          -- with the original Nat operation
          subWant :: TcType -> TcType
subWant  = TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon ([TcType] -> TcType) -> (TcType -> [TcType]) -> TcType -> TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType
want])
          exploded :: [(CoreSOP, TcType)]
exploded = (TcType -> (CoreSOP, TcType)) -> [TcType] -> [(CoreSOP, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((CoreSOP, [(TcType, TcType)]) -> CoreSOP
forall a b. (a, b) -> a
fst ((CoreSOP, [(TcType, TcType)]) -> CoreSOP)
-> (TcType -> (CoreSOP, [(TcType, TcType)])) -> TcType -> CoreSOP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Writer [(TcType, TcType)] CoreSOP -> (CoreSOP, [(TcType, TcType)])
forall w a. Writer w a -> (a, w)
runWriter (Writer [(TcType, TcType)] CoreSOP
 -> (CoreSOP, [(TcType, TcType)]))
-> (TcType -> Writer [(TcType, TcType)] CoreSOP)
-> TcType
-> (CoreSOP, [(TcType, TcType)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> Writer [(TcType, TcType)] CoreSOP
normaliseNat (TcType -> Writer [(TcType, TcType)] CoreSOP)
-> (TcType -> TcType)
-> TcType
-> Writer [(TcType, TcType)] CoreSOP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> TcType
subWant (TcType -> CoreSOP)
-> (TcType -> TcType) -> TcType -> (CoreSOP, TcType)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& TcType -> TcType
forall a. a -> a
id)
                         [TcType]
knownsX
          -- interesting cases for us are those where
          -- wanted and given only differ by a constant
          examineDiff :: SOP v c -> a -> Maybe (a, Symbol v c)
examineDiff (S [P [I Integer
n]]) a
entire = (a, Symbol v c) -> Maybe (a, Symbol v c)
forall a. a -> Maybe a
Just (a
entire,Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
n)
          examineDiff (S [P [V v
v]]) a
entire = (a, Symbol v c) -> Maybe (a, Symbol v c)
forall a. a -> Maybe a
Just (a
entire,v -> Symbol v c
forall v c. v -> Symbol v c
V v
v)
          examineDiff SOP v c
_ a
_ = Maybe (a, Symbol v c)
forall a. Maybe a
Nothing
          interesting :: [(TcType, Symbol TcTyVar c)]
interesting = ((CoreSOP, TcType) -> Maybe (TcType, Symbol TcTyVar c))
-> [(CoreSOP, TcType)] -> [(TcType, Symbol TcTyVar c)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((CoreSOP -> TcType -> Maybe (TcType, Symbol TcTyVar c))
-> (CoreSOP, TcType) -> Maybe (TcType, Symbol TcTyVar c)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry CoreSOP -> TcType -> Maybe (TcType, Symbol TcTyVar c)
forall v c a c. SOP v c -> a -> Maybe (a, Symbol v c)
examineDiff) [(CoreSOP, TcType)]
exploded
      -- convert the first suitable evidence
      ((TcType
h,Symbol TcTyVar CType
corr):[(TcType, Symbol TcTyVar CType)]
_) <- [(TcType, Symbol TcTyVar CType)]
-> MaybeT TcPluginM [(TcType, Symbol TcTyVar CType)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(TcType, Symbol TcTyVar CType)]
forall c. [(TcType, Symbol TcTyVar c)]
interesting
      TcType
x <- case Symbol TcTyVar CType
corr of
                I Integer
0 -> TcType -> MaybeT TcPluginM TcType
forall (f :: * -> *) a. Applicative f => a -> f a
pure TcType
h
                I Integer
i | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0
                    -> TcType -> MaybeT TcPluginM TcType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatAddTyCon [TcType
h,Integer -> TcType
mkNumLitTy (Integer -> Integer
forall a. Num a => a -> a
negate Integer
i)])
                    | Bool
otherwise
                    -> TcType -> MaybeT TcPluginM TcType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon [TcType
h,Integer -> TcType
mkNumLitTy Integer
i])
                -- If the offset between a given and a wanted is again the wanted
                -- then the given is twice the wanted; so we can just divide
                -- the given by two. Only possible in GHC 8.4+; for 8.2 we simply
                -- fail because we don't know how to divide.
                Symbol TcTyVar CType
c   | TcType -> CType
CType (CoreSOP -> TcType
reifySOP ([Product TcTyVar CType] -> CoreSOP
forall v c. [Product v c] -> SOP v c
S [[Symbol TcTyVar CType] -> Product TcTyVar CType
forall v c. [Symbol v c] -> Product v c
P [Symbol TcTyVar CType
c]])) CType -> CType -> Bool
forall a. Eq a => a -> a -> Bool
== TcType -> CType
CType TcType
want ->
#if MIN_VERSION_ghc(8,4,0)
                     TcType -> MaybeT TcPluginM TcType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatDivTyCon [TcType
h,CoreSOP -> TcType
reifySOP ([Product TcTyVar CType] -> CoreSOP
forall v c. [Product v c] -> SOP v c
S [[Symbol TcTyVar CType] -> Product TcTyVar CType
forall v c. [Symbol v c] -> Product v c
P [Integer -> Symbol TcTyVar CType
forall v c. Integer -> Symbol v c
I Integer
2]])])
#else
                     MaybeT (pure Nothing)
#endif
                -- Only solve with a variable offset if we have [G]iven knownnat for it
                -- Failing to do this check results in #30
                V TcTyVar
v | (TcType -> Bool) -> [TcType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not (Bool -> Bool) -> (TcType -> Bool) -> TcType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> TcType -> Bool
eqType (TcTyVar -> TcType
TyVarTy TcTyVar
v)) [TcType]
knownsX
                    -> TcPluginM (Maybe TcType) -> MaybeT TcPluginM TcType
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (Maybe TcType -> TcPluginM (Maybe TcType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe TcType
forall a. Maybe a
Nothing)
                Symbol TcTyVar CType
_ -> TcType -> MaybeT TcPluginM TcType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon [TcType
h,CoreSOP -> TcType
reifySOP ([Product TcTyVar CType] -> CoreSOP
forall v c. [Product v c] -> SOP v c
S [[Symbol TcTyVar CType] -> Product TcTyVar CType
forall v c. [Symbol v c] -> Product v c
P [Symbol TcTyVar CType
corr]])])
      TcPluginM (Maybe (EvTerm, [Ct])) -> MaybeT TcPluginM (EvTerm, [Ct])
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go TcType
x)

makeWantedEv
  :: Ct
  -> Type
#if MIN_VERSION_ghc(8,5,0)
  -> TcPluginM (EvExpr,Ct)
#else
  -> TcPluginM (EvTerm,Ct)
#endif
makeWantedEv :: Ct -> TcType -> TcPluginM (EvExpr, Ct)
makeWantedEv Ct
ct TcType
ty = do
  -- Create a new wanted constraint
  CtEvidence
wantedCtEv <- CtLoc -> TcType -> TcPluginM CtEvidence
newWanted (Ct -> CtLoc
ctLoc Ct
ct) TcType
ty
#if MIN_VERSION_ghc(8,5,0)
  let ev :: EvExpr
ev      = CtEvidence -> EvExpr
ctEvExpr CtEvidence
wantedCtEv
#else
  let ev      = ctEvTerm wantedCtEv
#endif
      wanted :: Ct
wanted  = CtEvidence -> Ct
mkNonCanonical CtEvidence
wantedCtEv
      -- Set the source-location of the new wanted constraint to the source
      -- location of the [W]anted constraint we are currently trying to solve
      ct_ls :: RealSrcSpan
ct_ls   = CtLoc -> RealSrcSpan
ctLocSpan (Ct -> CtLoc
ctLoc Ct
ct)
      ctl :: CtLoc
ctl     = CtEvidence -> CtLoc
ctEvLoc  CtEvidence
wantedCtEv
      wanted' :: Ct
wanted' = Ct -> CtLoc -> Ct
setCtLoc Ct
wanted (CtLoc -> RealSrcSpan -> CtLoc
setCtLocSpan CtLoc
ctl RealSrcSpan
ct_ls)
  (EvExpr, Ct) -> TcPluginM (EvExpr, Ct)
forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,Ct
wanted')

{- |
Given:

* A "magic" class, and corresponding instance dictionary function, for a
  type-level arithmetic operation
* Two KnownNat dictionaries

makeOpDict instantiates the dictionary function with the KnownNat dictionaries,
and coerces it to a KnownNat dictionary. i.e. for KnownNat2, the "magic"
dictionary for binary functions, the coercion happens in the following steps:

1. KnownNat2 "+" a b           -> SNatKn (KnownNatF2 "+" a b)
2. SNatKn (KnownNatF2 "+" a b) -> Integer
3. Integer                     -> SNat (a + b)
4. SNat (a + b)                -> KnownNat (a + b)

this process is mirrored for the dictionary functions of a higher arity
-}
makeOpDict
  :: (Class,DFunId)
  -- ^ "magic" class function and dictionary function id
  -> Class
  -- ^ KnownNat class
  -> [Type]
  -- ^ Argument types for the Class
  -> [Type]
  -- ^ Argument types for the Instance
  -> Type           -- ^ Type of the result
#if MIN_VERSION_ghc(8,5,0)
  -> [EvExpr]
#else
  -> [EvTerm]
#endif
  -- ^ Evidence arguments
  -> Maybe EvTerm
makeOpDict :: (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDict (Class
opCls,TcTyVar
dfid) Class
knCls [TcType]
tyArgsC [TcType]
tyArgsI TcType
z [EvExpr]
evArgs
  | Just (TcType
_, TcCoercion
kn_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
    -- KnownNat n ~ SNat n
  , [ TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                      (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy      -- SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls      -- KnownNat n => SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth   -- forall n. KnownNat n => SNat n
  , Just (TcType
_, TcCoercion
kn_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
z]
    -- SNat n ~ Integer
  , Just (TcType
_, TcCoercion
op_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
opCls) [TcType]
tyArgsC
    -- KnownNatAdd a b ~ SNatKn (a+b)
  , [ TcTyVar
op_meth ] <- Class -> [TcTyVar]
classMethods Class
opCls
  , Just (TyCon
op_tcRep,[TcType]
op_args) <- HasDebugCallStack => TcType -> Maybe (TyCon, [TcType])
TcType -> Maybe (TyCon, [TcType])
splitTyConApp_maybe        -- (SNatKn, [KnownNatF2 f x y])
                                 (TcType -> Maybe (TyCon, [TcType]))
-> TcType -> Maybe (TyCon, [TcType])
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy            -- SNatKn (KnownNatF2 f x y)
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ (HasDebugCallStack => TcType -> [TcType] -> TcType
TcType -> [TcType] -> TcType
`piResultTys` [TcType]
tyArgsC) -- KnownNatAdd f x y => SNatKn (KnownNatF2 f x y)
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
op_meth         -- forall f a b . KnownNat2 f a b => SNatKn (KnownNatF2 f a b)
  , Just (TcType
_, TcCoercion
op_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
op_tcRep [TcType]
op_args
    -- SNatKn (a+b) ~ Integer
#if MIN_VERSION_ghc(8,5,0)
  , let EvExpr EvExpr
dfun_inst = TcTyVar -> [TcType] -> [EvExpr] -> EvTerm
evDFunApp TcTyVar
dfid [TcType]
tyArgsI [EvExpr]
evArgs
#else
  , let dfun_inst = EvDFunApp dfid tyArgsI evArgs
#endif
        -- KnownNatAdd a b
        op_to_kn :: TcCoercion
op_to_kn  = TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
op_co_dict TcCoercion
op_co_rep)
                                (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
kn_co_dict TcCoercion
kn_co_rep))
        -- KnownNatAdd a b ~ KnownNat (a+b)
        ev_tm :: EvTerm
ev_tm     = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
dfun_inst TcCoercion
op_to_kn
  = EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just EvTerm
ev_tm
  | Bool
otherwise
  = Maybe EvTerm
forall a. Maybe a
Nothing

{-
Given:
* A KnownNat dictionary evidence over a type x
* a desired type z
makeKnCoercion assembles a coercion from a KnownNat x
dictionary to a KnownNat z dictionary and applies it
to the passed-in evidence.
The coercion happens in the following steps:
1. KnownNat x -> SNat x
2. SNat x     -> Integer
3. Integer    -> SNat z
4. SNat z     -> KnownNat z
-}
makeKnCoercion :: Class          -- ^ KnownNat class
               -> Type           -- ^ Type of the argument
               -> Type           -- ^ Type of the result
#if MIN_VERSION_ghc(8,5,0)
               -> EvExpr
#else
               -> EvTerm
#endif
               -- ^ KnownNat dictionary for the argument
               -> Maybe EvTerm
makeKnCoercion :: Class -> TcType -> TcType -> EvExpr -> Maybe EvTerm
makeKnCoercion Class
knCls TcType
x TcType
z EvExpr
xEv
  | Just (TcType
_, TcCoercion
kn_co_dict_z) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
    -- KnownNat z ~ SNat z
  , [ TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                      (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy      -- SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls      -- KnownNat n => SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth   -- forall n. KnownNat n => SNat n
  , Just (TcType
_, TcCoercion
kn_co_rep_z) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
z]
    -- SNat z ~ Integer
  , Just (TcType
_, TcCoercion
kn_co_rep_x) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
x]
    -- Integer ~ SNat x
  , Just (TcType
_, TcCoercion
kn_co_dict_x) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
x]
    -- SNat x ~ KnownNat x
  = EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> Maybe EvTerm)
-> (TcCoercion -> EvTerm) -> TcCoercion -> Maybe EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
xEv (TcCoercion -> Maybe EvTerm) -> TcCoercion -> Maybe EvTerm
forall a b. (a -> b) -> a -> b
$ (TcCoercion
kn_co_dict_x TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion
kn_co_rep_x) TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion
kn_co_dict_z TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion
kn_co_rep_z)
  | Bool
otherwise = Maybe EvTerm
forall a. Maybe a
Nothing

-- | THIS CODE IS COPIED FROM:
-- https://github.com/ghc/ghc/blob/8035d1a5dc7290e8d3d61446ee4861e0b460214e/compiler/typecheck/TcInteract.hs#L1973
--
-- makeLitDict adds a coercion that will convert the literal into a dictionary
-- of the appropriate type.  See Note [KnownNat & KnownSymbol and EvLit]
-- in TcEvidence.  The coercion happens in 2 steps:
--
--     Integer -> SNat n     -- representation of literal to singleton
--     SNat n  -> KnownNat n -- singleton to dictionary
#if MIN_VERSION_ghc(8,5,0)
makeLitDict :: Class -> Type -> Integer -> TcPluginM (Maybe EvTerm)
#else
makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm
#endif
makeLitDict :: Class -> TcType -> Integer -> TcPluginM (Maybe EvTerm)
makeLitDict Class
clas TcType
ty Integer
i
  | Just (TcType
_, TcCoercion
co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
clas) [TcType
ty]
    -- co_dict :: KnownNat n ~ SNat n
  , [ TcTyVar
meth ]   <- Class -> [TcTyVar]
classMethods Class
clas
  , Just TyCon
tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                    (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy     -- SNat n
                    (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls     -- KnownNat n => SNat n
                    (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
meth     -- forall n. KnownNat n => SNat n
  , Just (TcType
_, TcCoercion
co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
tcRep [TcType
ty]
        -- SNat n ~ Integer
#if MIN_VERSION_ghc(8,5,0)
  = do
#if MIN_VERSION_ghc(9,0,0)
    let et = mkNaturalExpr i
#else
    EvExpr
et <- TcM EvExpr -> TcPluginM EvExpr
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM (Integer -> TcM EvExpr
forall (m :: * -> *). MonadThings m => Integer -> m EvExpr
mkNaturalExpr Integer
i)
#endif
    let ev_tm :: EvTerm
ev_tm = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
et (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
co_dict TcCoercion
co_rep))
    Maybe EvTerm -> TcPluginM (Maybe EvTerm)
forall (m :: * -> *) a. Monad m => a -> m a
return (EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just EvTerm
ev_tm)
  | Bool
otherwise
  = Maybe EvTerm -> TcPluginM (Maybe EvTerm)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe EvTerm
forall a. Maybe a
Nothing
#else
  , let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep))
  = Just ev_tm
  | otherwise
  = Nothing
#endif

{- |
Given:

* A "magic" class, and corresponding instance dictionary function, for a
  type-level boolean operation
* Two KnownBool dictionaries

makeOpDictByFiat instantiates the dictionary function with the KnownBool
dictionaries, and coerces it to a KnownBool dictionary. i.e. for KnownBoolNat2,
the "magic" dictionary for binary functions, the coercion happens in the
following steps:

1. KnownBoolNat2 "<=?" x y     -> SBoolF "<=?"
2. SBoolF "<=?"                -> Bool
3. Bool                        -> SNat (x <=? y)  THE BY FIAT PART!
4. SBool (x <=? y)             -> KnownBool (x <=? y)

this process is mirrored for the dictionary functions of a higher arity
-}
makeOpDictByFiat
  :: (Class,DFunId)
  -- ^ "magic" class function and dictionary function id
  -> Class
   -- ^ KnownNat class
  -> [Type]
  -- ^ Argument types for the Class
  -> [Type]
  -- ^ Argument types for the Instance
  -> Type
  -- ^ Type of the result
#if MIN_VERSION_ghc(8,6,0)
  -> [EvExpr]
#else
  -> [EvTerm]
#endif
  -- ^ Evidence arguments
  -> Maybe EvTerm
#if MIN_VERSION_ghc(8,6,0)
makeOpDictByFiat :: (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDictByFiat (Class
opCls,TcTyVar
dfid) Class
knCls [TcType]
tyArgsC [TcType]
tyArgsI TcType
z [EvExpr]
evArgs
    -- KnownBool b ~ SBool b
  | Just (TcType
_, TcCoercion
kn_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
  , [ TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SBool
                       (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy     -- SBool b
                       (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls     -- KnownBool b => SBool b
                       (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth  -- forall b. KnownBool b => SBool b
    -- SBool b R~ Bool (The "Lie")
  , let kn_co_rep :: TcCoercion
kn_co_rep = UnivCoProvenance -> Role -> TcType -> TcType -> TcCoercion
mkUnivCo (CommandLineOption -> UnivCoProvenance
PluginProv CommandLineOption
"ghc-typelits-knownnat")
                             Role
Representational
                             (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
kn_tcRep [TcType
z]) TcType
boolTy
    -- KnownBoolNat2 f a b ~ SBool f
  , Just (TcType
_, TcCoercion
op_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
opCls) [TcType]
tyArgsC
  , [ TcTyVar
op_meth ] <- Class -> [TcTyVar]
classMethods Class
opCls
  , Just (TyCon
op_tcRep,[TcType]
op_args) <- HasDebugCallStack => TcType -> Maybe (TyCon, [TcType])
TcType -> Maybe (TyCon, [TcType])
splitTyConApp_maybe        -- (SBool, [f])
                                 (TcType -> Maybe (TyCon, [TcType]))
-> TcType -> Maybe (TyCon, [TcType])
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy            -- SBool f
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ (HasDebugCallStack => TcType -> [TcType] -> TcType
TcType -> [TcType] -> TcType
`piResultTys` [TcType]
tyArgsC) -- KnownBoolNat2 f x y => SBool f
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
op_meth         -- forall f x y . KnownBoolNat2 f a b => SBoolf f
    -- SBoolF f ~ Bool
  , Just (TcType
_, TcCoercion
op_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
op_tcRep [TcType]
op_args
  , EvExpr EvExpr
dfun_inst <- TcTyVar -> [TcType] -> [EvExpr] -> EvTerm
evDFunApp TcTyVar
dfid [TcType]
tyArgsI [EvExpr]
evArgs
    -- KnownBoolNat2 f x y ~ KnownBool b
  , let op_to_kn :: TcCoercion
op_to_kn  = TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
op_co_dict TcCoercion
op_co_rep)
                                (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
kn_co_dict TcCoercion
kn_co_rep))
        ev_tm :: EvTerm
ev_tm     = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
dfun_inst TcCoercion
op_to_kn
  = EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just EvTerm
ev_tm
  | Bool
otherwise
  = Maybe EvTerm
forall a. Maybe a
Nothing
#else
makeOpDictByFiat _ _ _ _ _ _ = Nothing
#endif