{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}

module GHC.Magic.Dict.Plugin.Old (plugin) where

import Control.Applicative (liftA2)
import Data.Bitraversable (bitraverse)
import qualified Data.DList as DL
import Data.Maybe (mapMaybe)
import GHC.Builtin.Types.Prim (openAlphaTy, openAlphaTyVar, runtimeRep1TyVar)
import qualified GHC.Core as Core
import GHC.Core.Class
import GHC.Core.Coercion (mkSubCo, mkSymCo, mkTransCo)
import GHC.Core.DataCon
import GHC.Core.Make (mkCoreLams)
import GHC.Core.Predicate
import GHC.Core.TyCon
import GHC.Core.Type
import GHC.Data.FastString
import GHC.Plugins (Plugin (..), defaultPlugin, mkModuleName, purePlugin)
import GHC.Tc.Instance.Family (tcInstNewTyCon_maybe)
import GHC.Tc.Plugin hiding (newWanted)
import GHC.Tc.Types
import GHC.Tc.Types.Constraint
import GHC.Tc.Types.Evidence
import GHC.TcPluginM.Extra
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable

plugin :: Plugin
plugin :: Plugin
plugin =
  Plugin
defaultPlugin
    { tcPlugin :: TcPlugin
tcPlugin = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just TcPlugin
withDictPlugin
    , pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile = [CommandLineOption] -> IO PluginRecompile
purePlugin
    }

withDictPlugin :: TcPlugin
withDictPlugin :: TcPlugin
withDictPlugin =
  CommandLineOption -> TcPlugin -> TcPlugin
tracePlugin
    CommandLineOption
"WithDictPlugin"
    TcPlugin
      { tcPluginStop :: () -> TcPluginM ()
tcPluginStop = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      , tcPluginSolve :: () -> TcPluginSolver
tcPluginSolve = forall a b. a -> b -> a
const TcPluginSolver
solveWithDict
      , tcPluginInit :: TcPluginM ()
tcPluginInit = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      }

data Info = Info
  { Info -> Class
_WithDict :: !Class
  , Info -> DataCon
_WithDictDataCon :: !DataCon
  }

solveWithDict :: TcPluginSolver
solveWithDict :: TcPluginSolver
solveWithDict [Ct]
_ [Ct]
_ [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] []
solveWithDict [Ct]
gs [Ct]
_ [Ct]
wanteds = do
  let subs :: [(TcTyVar, PredType)]
subs = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' [Ct]
gs
  Info
info <- TcPluginM Info
lookupInfo
  let withDicts :: [((CtLoc, DecodedPred), Ct)]
withDicts =
        forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
          ( forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,)
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ct -> CtLoc
ctLoc forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Info -> PredType -> Maybe DecodedPred
decodeWithDictPred Info
info forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ct -> PredType
ctPred forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subs)
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure
          )
          [Ct]
wanteds
  (DList Ct
contrs, DList (EvTerm, Ct)
solved, DList Ct
wants) <-
    forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
      ( \case
          (Maybe (EvTerm, [Ct])
Nothing, Ct
ct) -> (forall a. a -> DList a
DL.singleton Ct
ct, forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)
          (Just (EvTerm
pf, [Ct]
newWants), Ct
ct) -> (forall a. Monoid a => a
mempty, forall a. a -> DList a
DL.singleton (EvTerm
pf, Ct
ct), forall a. [a] -> DList a
DL.fromList [Ct]
newWants)
      )
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ Info -> CtLoc -> DecodedPred -> TcPluginM (Maybe (EvTerm, [Ct]))
solveWithDictPred Info
info) forall (f :: * -> *) a. Applicative f => a -> f a
pure) [((CtLoc, DecodedPred), Ct)]
withDicts
  CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace
    CommandLineOption
"solveWithDict/contradictions"
    (forall a. Outputable a => a -> SDoc
ppr forall a b. (a -> b) -> a -> b
$ forall a. DList a -> [a]
DL.toList DList Ct
contrs)
  CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"solveWithDict/solveds" forall a b. (a -> b) -> a -> b
$ forall a. Outputable a => a -> SDoc
ppr forall a b. (a -> b) -> a -> b
$ forall a. DList a -> [a]
DL.toList DList (EvTerm, Ct)
solved
  CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"solveWithDict/newWanteds" forall a b. (a -> b) -> a -> b
$ forall a. Outputable a => a -> SDoc
ppr forall a b. (a -> b) -> a -> b
$ forall a. DList a -> [a]
DL.toList DList Ct
wants
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    if forall (t :: * -> *) a. Foldable t => t a -> Bool
null DList Ct
contrs
      then [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk (forall a. DList a -> [a]
DL.toList DList (EvTerm, Ct)
solved) (forall a. DList a -> [a]
DL.toList DList Ct
wants)
      else [Ct] -> TcPluginResult
TcPluginContradiction forall a b. (a -> b) -> a -> b
$ forall a. DList a -> [a]
DL.toList DList Ct
contrs

mkNonCanonical' ::
  CtLoc -> CtEvidence -> Ct
mkNonCanonical' :: CtLoc -> CtEvidence -> Ct
mkNonCanonical' CtLoc
origCtl CtEvidence
ev =
  let ct_ls :: RealSrcSpan
ct_ls = CtLoc -> RealSrcSpan
ctLocSpan CtLoc
origCtl
      ctl :: CtLoc
ctl = CtEvidence -> CtLoc
ctEvLoc CtEvidence
ev
      wanted :: Ct
wanted = CtEvidence -> Ct
mkNonCanonical CtEvidence
ev
   in Ct -> CtLoc -> Ct
setCtLoc Ct
wanted (CtLoc -> RealSrcSpan -> CtLoc
setCtLocSpan CtLoc
ctl RealSrcSpan
ct_ls)

solveWithDictPred :: Info -> CtLoc -> DecodedPred -> TcPluginM (Maybe (EvTerm, [Ct]))
solveWithDictPred :: Info -> CtLoc -> DecodedPred -> TcPluginM (Maybe (EvTerm, [Ct]))
solveWithDictPred Info {Class
DataCon
_WithDictDataCon :: DataCon
_WithDict :: Class
_WithDictDataCon :: Info -> DataCon
_WithDict :: Info -> Class
..} CtLoc
loc DecodedPred {[PredType]
PredType
TyCon
argType :: DecodedPred -> PredType
constrArgs :: DecodedPred -> [PredType]
constrTyCon :: DecodedPred -> TyCon
constraint :: DecodedPred -> PredType
argType :: PredType
constrArgs :: [PredType]
constrTyCon :: TyCon
constraint :: PredType
..} = do
  CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"solveWithDictPred" (forall a. Outputable a => a -> SDoc
ppr (PredType
constraint, PredType
argType))
  case TyCon -> [PredType] -> Maybe (PredType, TcCoercion)
tcInstNewTyCon_maybe TyCon
constrTyCon [PredType]
constrArgs of
    Maybe (PredType, TcCoercion)
Nothing -> do
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"solveWithDictPred: Failed!" (forall a. Outputable a => a -> SDoc
ppr (PredType
constraint, PredType
argType))
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    Just (PredType
onlyMethodType, TcCoercion
co) -> do
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"solveWithDictPred: singleton class found" (forall a. Outputable a => a -> SDoc
ppr (PredType
constraint, PredType
argType, PredType
onlyMethodType, TcCoercion
co))
      let nomEq :: PredType
nomEq = PredType -> PredType -> PredType
mkPrimEqPred PredType
argType PredType
onlyMethodType
      CoercionHole
hole <- PredType -> TcPluginM CoercionHole
newCoercionHole PredType
nomEq
      let want :: CtEvidence
want = PredType -> TcEvDest -> ShadowInfo -> CtLoc -> CtEvidence
CtWanted PredType
nomEq (CoercionHole -> TcEvDest
HoleDest CoercionHole
hole) ShadowInfo
WDeriv CtLoc
loc
      TcTyVar
sv <- forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadUnique m =>
FastString -> PredType -> PredType -> m TcTyVar
mkSysLocalM (CommandLineOption -> FastString
fsLit CommandLineOption
"withDict_s") PredType
Many PredType
argType
      TcTyVar
k <- forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadUnique m =>
FastString -> PredType -> PredType -> m TcTyVar
mkSysLocalM (CommandLineOption -> FastString
fsLit CommandLineOption
"withDict_k") PredType
Many (PredType -> PredType -> PredType -> PredType
mkInvisFunTy PredType
Many PredType
constraint PredType
openAlphaTy)
      -- Given co2 : mty ~N# inst_meth_ty, construct the method of
      -- the WithDict dictionary:
      --
      --   \@(r :: RuntimeRep) @(a :: TYPE r) (sv :: mty) (k :: cls => a) ->
      --     k (sv |> (sub co ; sym co2))
      let proof :: EvTerm
proof =
            DataCon -> [PredType] -> [EvExpr] -> EvTerm
evDataConApp
              DataCon
_WithDictDataCon
              [PredType
constraint, PredType
argType]
              [ [TcTyVar] -> EvExpr -> EvExpr
mkCoreLams [TcTyVar
runtimeRep1TyVar, TcTyVar
openAlphaTyVar, TcTyVar
sv, TcTyVar
k] forall a b. (a -> b) -> a -> b
$
                  forall b. TcTyVar -> Expr b
Core.Var TcTyVar
k
                    forall b. Expr b -> Expr b -> Expr b
`Core.App` (forall b. TcTyVar -> Expr b
Core.Var TcTyVar
sv forall b. Expr b -> TcCoercion -> Expr b
`Core.Cast` TcCoercion -> TcCoercion -> TcCoercion
mkTransCo (HasDebugCallStack => TcCoercion -> TcCoercion
mkSubCo (HasDebugCallStack => CtEvidence -> TcCoercion
ctEvCoercion CtEvidence
want)) (TcCoercion -> TcCoercion
mkSymCo TcCoercion
co))
              ]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (EvTerm
proof, [CtLoc -> CtEvidence -> Ct
mkNonCanonical' CtLoc
loc CtEvidence
want])

data DecodedPred = DecodedPred
  { DecodedPred -> PredType
constraint :: !PredType
  , DecodedPred -> TyCon
constrTyCon :: !TyCon
  , DecodedPred -> [PredType]
constrArgs :: ![Type]
  , DecodedPred -> PredType
argType :: !Type
  }

decodeWithDictPred :: Info -> PredType -> Maybe DecodedPred
decodeWithDictPred :: Info -> PredType -> Maybe DecodedPred
decodeWithDictPred Info {Class
DataCon
_WithDictDataCon :: DataCon
_WithDict :: Class
_WithDictDataCon :: Info -> DataCon
_WithDict :: Info -> Class
..} PredType
pt
  | ClassPred Class
withDic [PredType
cls, PredType
argType] <- PredType -> Pred
classifyPredType PredType
pt
  , Class
withDic forall a. Eq a => a -> a -> Bool
== Class
_WithDict
  , Just (TyCon
dict_tc, [PredType]
dict_args) <- HasCallStack => PredType -> Maybe (TyCon, [PredType])
tcSplitTyConApp_maybe PredType
cls =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        DecodedPred
          { constraint :: PredType
constraint = PredType
cls
          , constrTyCon :: TyCon
constrTyCon = TyCon
dict_tc
          , constrArgs :: [PredType]
constrArgs = [PredType]
dict_args
          , PredType
argType :: PredType
argType :: PredType
..
          }
  | Bool
otherwise = forall a. Maybe a
Nothing

lookupInfo :: TcPluginM Info
lookupInfo :: TcPluginM Info
lookupInfo = do
  Module
theMod <-
    ModuleName -> FastString -> TcPluginM Module
lookupModule
      (CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"GHC.Magic.Dict.Compat")
      (CommandLineOption -> FastString
fsLit CommandLineOption
"ghc-magic-dict-compat")
  Class
_WithDict <- Name -> TcPluginM Class
tcLookupClass forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
lookupOrig Module
theMod (CommandLineOption -> OccName
mkTcOcc CommandLineOption
"WithDict")
  let _WithDictDataCon :: DataCon
_WithDictDataCon = Class -> DataCon
classDataCon Class
_WithDict
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Info {Class
DataCon
_WithDictDataCon :: DataCon
_WithDict :: Class
_WithDictDataCon :: DataCon
_WithDict :: Class
..}