{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE LambdaCase #-}
module Debug.Breakpoint.TypeChecker
  ( tcPlugin
  ) where

import           Data.Either
import           Data.Maybe
import           Data.Traversable (for)
import qualified GHC.Tc.Plugin as Plugin

import qualified Debug.Breakpoint.GhcFacade as Ghc

--------------------------------------------------------------------------------
-- Type Checker Plugin
--------------------------------------------------------------------------------

data TcPluginNames =
  MkTcPluginNames
    { TcPluginNames -> Name
showLevClassName :: !Ghc.Name
    , TcPluginNames -> Name
showLevNameTc :: !Ghc.Name
    , TcPluginNames -> Class
showClass :: !Ghc.Class
    , TcPluginNames -> Class
succeedClass :: !Ghc.Class
    , TcPluginNames -> TyCon
showWrapperTyCon :: !Ghc.TyCon
    }

tcPlugin :: Ghc.TcPlugin
tcPlugin :: TcPlugin
tcPlugin = Ghc.TcPlugin
  { tcPluginInit :: TcPluginM TcPluginNames
Ghc.tcPluginInit  = TcPluginM TcPluginNames
initTcPlugin
  , tcPluginSolve :: TcPluginNames -> TcPluginSolver
Ghc.tcPluginSolve = TcPluginNames -> TcPluginSolver
solver
  , tcPluginStop :: TcPluginNames -> TcPluginM ()
Ghc.tcPluginStop = TcPluginM () -> TcPluginNames -> TcPluginM ()
forall a b. a -> b -> a
const (TcPluginM () -> TcPluginNames -> TcPluginM ())
-> TcPluginM () -> TcPluginNames -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ () -> TcPluginM ()
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
#if MIN_VERSION_ghc(9,4,0)
  , tcPluginRewrite :: TcPluginNames -> UniqFM TyCon TcPluginRewriter
Ghc.tcPluginRewrite = TcPluginNames -> UniqFM TyCon TcPluginRewriter
forall a. Monoid a => a
mempty
#endif
  }

initTcPlugin :: Ghc.TcPluginM TcPluginNames
initTcPlugin :: TcPluginM TcPluginNames
initTcPlugin = do
  Ghc.Found ModLocation
_ Module
breakpointMod <-
    ModuleName -> TcPluginM FindResult
Ghc.findImportedModule' (String -> ModuleName
Ghc.mkModuleName String
"Debug.Breakpoint")
  Ghc.Found ModLocation
_ Module
showMod <-
    ModuleName -> TcPluginM FindResult
Ghc.findImportedModule' (String -> ModuleName
Ghc.mkModuleName String
"GHC.Show")

  Name
showLevClassName <- Module -> OccName -> TcPluginM Name
Plugin.lookupOrig Module
breakpointMod (String -> OccName
Ghc.mkClsOcc String
"ShowLev")
  Name
showLevNameTc <- Module -> OccName -> TcPluginM Name
Plugin.lookupOrig Module
breakpointMod (String -> OccName
Ghc.mkVarOcc String
"showLev")
  Class
showClass <- Name -> TcPluginM Class
Plugin.tcLookupClass (Name -> TcPluginM Class) -> TcPluginM Name -> TcPluginM Class
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
Plugin.lookupOrig Module
showMod (String -> OccName
Ghc.mkClsOcc String
"Show")
  Class
succeedClass <- Name -> TcPluginM Class
Plugin.tcLookupClass (Name -> TcPluginM Class) -> TcPluginM Name -> TcPluginM Class
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
Plugin.lookupOrig Module
breakpointMod (String -> OccName
Ghc.mkClsOcc String
"Succeed")
  TyCon
showWrapperTyCon <- Name -> TcPluginM TyCon
Plugin.tcLookupTyCon (Name -> TcPluginM TyCon) -> TcPluginM Name -> TcPluginM TyCon
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
Plugin.lookupOrig Module
breakpointMod (String -> OccName
Ghc.mkClsOcc String
"ShowWrapper")

  TcPluginNames -> TcPluginM TcPluginNames
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MkTcPluginNames{Name
TyCon
Class
showLevClassName :: Name
showLevNameTc :: Name
showClass :: Class
succeedClass :: Class
showWrapperTyCon :: TyCon
showLevClassName :: Name
showLevNameTc :: Name
showClass :: Class
succeedClass :: Class
showWrapperTyCon :: TyCon
..}

data FindWantedResult
  = FoundLifted Ghc.Type Ghc.Ct
  | FoundUnlifted Ghc.Type Ghc.Ct
  | NotFound

findShowLevWanted
  :: TcPluginNames
  -> Ghc.Ct
  -> FindWantedResult
findShowLevWanted :: TcPluginNames -> Ct -> FindWantedResult
findShowLevWanted TcPluginNames
names Ct
ct
  | Ghc.CDictCan' CtEvidence
_ Class
di_cls [Xi]
di_tys <- Ct
ct
  , TcPluginNames -> Name
showLevClassName TcPluginNames
names Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Class -> Name
forall a. NamedThing a => a -> Name
Ghc.getName Class
di_cls
  , [Ghc.TyConApp TyCon
tyCon [], Xi
arg2] <- [Xi]
di_tys
  = if TyCon -> Name
forall a. NamedThing a => a -> Name
Ghc.getName TyCon
tyCon Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
Ghc.liftedRepName
       then Xi -> Ct -> FindWantedResult
FoundLifted Xi
arg2 Ct
ct
       else Xi -> Ct -> FindWantedResult
FoundUnlifted Xi
arg2 Ct
ct
  | Bool
otherwise = FindWantedResult
NotFound

findShowWithSuperclass
  :: TcPluginNames
  -> Ghc.Ct
  -> Maybe (Ghc.Type, Ghc.Ct)
findShowWithSuperclass :: TcPluginNames -> Ct -> Maybe (Xi, Ct)
findShowWithSuperclass TcPluginNames
names Ct
ct
  | Ghc.CDictCan' CtEvidence
di_ev Class
di_cls [Xi]
di_tys <- Ct
ct
  , Class -> Name
forall a. NamedThing a => a -> Name
Ghc.getName (TcPluginNames -> Class
showClass TcPluginNames
names) Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Class -> Name
forall a. NamedThing a => a -> Name
Ghc.getName Class
di_cls
  , CtOrigin -> Bool
hasShowLevSuperclass (CtOrigin -> Bool) -> (CtLoc -> CtOrigin) -> CtLoc -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CtLoc -> CtOrigin
Ghc.ctLocOrigin (CtLoc -> Bool) -> CtLoc -> Bool
forall a b. (a -> b) -> a -> b
$ CtEvidence -> CtLoc
Ghc.ctev_loc CtEvidence
di_ev
  , [Xi
arg] <- [Xi]
di_tys
  = (Xi, Ct) -> Maybe (Xi, Ct)
forall a. a -> Maybe a
Just (Xi
arg, Ct
ct)
  | Bool
otherwise = Maybe (Xi, Ct)
forall a. Maybe a
Nothing
  where
    hasShowLevSuperclass :: CtOrigin -> Bool
hasShowLevSuperclass (Ghc.OccurrenceOf Name
name)
      = Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== TcPluginNames -> Name
showLevNameTc TcPluginNames
names
    hasShowLevSuperclass CtOrigin
_ = Bool
False

solver :: TcPluginNames -> Ghc.TcPluginSolver
solver :: TcPluginNames -> TcPluginSolver
solver TcPluginNames
names EvBindsVar
_given [Ct]
_derived [Ct]
wanted = do
  InstEnvs
instEnvs <- TcPluginM InstEnvs
Plugin.getInstEnvs

  -- Check if wanted is ShowLev
  --   * Create a new wanted for Show
  --   * Use its EvBindId as the inner dict for ShowLev
  --   * Emit the new wanted
  ([(EvTerm, Ct)]
showLevDicts, [Maybe Ct]
mNewWanteds) <- ([Maybe ((EvTerm, Ct), Maybe Ct)] -> ([(EvTerm, Ct)], [Maybe Ct]))
-> TcPluginM [Maybe ((EvTerm, Ct), Maybe Ct)]
-> TcPluginM ([(EvTerm, Ct)], [Maybe Ct])
forall a b. (a -> b) -> TcPluginM a -> TcPluginM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([((EvTerm, Ct), Maybe Ct)] -> ([(EvTerm, Ct)], [Maybe Ct])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((EvTerm, Ct), Maybe Ct)] -> ([(EvTerm, Ct)], [Maybe Ct]))
-> ([Maybe ((EvTerm, Ct), Maybe Ct)] -> [((EvTerm, Ct), Maybe Ct)])
-> [Maybe ((EvTerm, Ct), Maybe Ct)]
-> ([(EvTerm, Ct)], [Maybe Ct])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe ((EvTerm, Ct), Maybe Ct)] -> [((EvTerm, Ct), Maybe Ct)]
forall a. [Maybe a] -> [a]
catMaybes) (TcPluginM [Maybe ((EvTerm, Ct), Maybe Ct)]
 -> TcPluginM ([(EvTerm, Ct)], [Maybe Ct]))
-> TcPluginM [Maybe ((EvTerm, Ct), Maybe Ct)]
-> TcPluginM ([(EvTerm, Ct)], [Maybe Ct])
forall a b. (a -> b) -> a -> b
$
    [FindWantedResult]
-> (FindWantedResult -> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct)))
-> TcPluginM [Maybe ((EvTerm, Ct), Maybe Ct)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (TcPluginNames -> Ct -> FindWantedResult
findShowLevWanted TcPluginNames
names (Ct -> FindWantedResult) -> [Ct] -> [FindWantedResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Ct]
wanted) ((FindWantedResult -> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct)))
 -> TcPluginM [Maybe ((EvTerm, Ct), Maybe Ct)])
-> (FindWantedResult -> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct)))
-> TcPluginM [Maybe ((EvTerm, Ct), Maybe Ct)]
forall a b. (a -> b) -> a -> b
$ \case
      FoundUnlifted Xi
ty Ct
ct -> do
        EvTerm
unshowableDict <- TcM EvTerm -> TcPluginM EvTerm
forall a. TcM a -> TcPluginM a
Ghc.unsafeTcPluginTcM (TcM EvTerm -> TcPluginM EvTerm) -> TcM EvTerm -> TcPluginM EvTerm
forall a b. (a -> b) -> a -> b
$ Xi -> TcM EvTerm
buildUnshowableDict Xi
ty
        Maybe ((EvTerm, Ct), Maybe Ct)
-> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct))
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ((EvTerm, Ct), Maybe Ct)
 -> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct)))
-> Maybe ((EvTerm, Ct), Maybe Ct)
-> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct))
forall a b. (a -> b) -> a -> b
$ ((EvTerm, Ct), Maybe Ct) -> Maybe ((EvTerm, Ct), Maybe Ct)
forall a. a -> Maybe a
Just ((EvTerm
unshowableDict, Ct
ct), Maybe Ct
forall a. Maybe a
Nothing)
      FoundLifted Xi
ty Ct
ct -> do
        (EvTerm
showDict, Ct
newWanted) <- TcPluginNames -> Ct -> Xi -> TcPluginM (EvTerm, Ct)
buildShowLevDict TcPluginNames
names Ct
ct Xi
ty
        let (ClsInst
succInst, [Xi]
_) = (ClsInst, [Xi]) -> Either SDoc (ClsInst, [Xi]) -> (ClsInst, [Xi])
forall b a. b -> Either a b -> b
fromRight (String -> (ClsInst, [Xi])
forall a. HasCallStack => String -> a
error String
"impossible: no Succeed instance") (Either SDoc (ClsInst, [Xi]) -> (ClsInst, [Xi]))
-> Either SDoc (ClsInst, [Xi]) -> (ClsInst, [Xi])
forall a b. (a -> b) -> a -> b
$
              InstEnvs -> Class -> [Xi] -> Either SDoc (ClsInst, [Xi])
Ghc.lookupUniqueInstEnv InstEnvs
instEnvs (TcPluginNames -> Class
succeedClass TcPluginNames
names) [Xi
ty]
        Maybe ((EvTerm, Ct), Maybe Ct)
-> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct))
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ((EvTerm, Ct), Maybe Ct)
 -> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct)))
-> Maybe ((EvTerm, Ct), Maybe Ct)
-> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct))
forall a b. (a -> b) -> a -> b
$ ((EvTerm, Ct), Maybe Ct) -> Maybe ((EvTerm, Ct), Maybe Ct)
forall a. a -> Maybe a
Just
          ((ClsInst -> Xi -> EvExpr -> EvTerm
liftDict ClsInst
succInst Xi
ty (EvTerm -> EvExpr
getEvExprFromDict EvTerm
showDict), Ct
ct)
          , Ct -> Maybe Ct
forall a. a -> Maybe a
Just Ct
newWanted
          )
      FindWantedResult
NotFound -> Maybe ((EvTerm, Ct), Maybe Ct)
-> TcPluginM (Maybe ((EvTerm, Ct), Maybe Ct))
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ((EvTerm, Ct), Maybe Ct)
forall a. Maybe a
Nothing

  -- Check if wanted is Show that arises from a use of showLev and create the
  -- missing Show dict if so.
  [(EvTerm, Ct)]
unshowableDicts <- [(Xi, Ct)]
-> ((Xi, Ct) -> TcPluginM (EvTerm, Ct)) -> TcPluginM [(EvTerm, Ct)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (TcPluginNames -> Ct -> Maybe (Xi, Ct)
findShowWithSuperclass TcPluginNames
names (Ct -> Maybe (Xi, Ct)) -> [Ct] -> [(Xi, Ct)]
forall a b. (a -> Maybe b) -> [a] -> [b]
`mapMaybe` [Ct]
wanted) (((Xi, Ct) -> TcPluginM (EvTerm, Ct)) -> TcPluginM [(EvTerm, Ct)])
-> ((Xi, Ct) -> TcPluginM (EvTerm, Ct)) -> TcPluginM [(EvTerm, Ct)]
forall a b. (a -> b) -> a -> b
$
    \(Xi
ty, Ct
ct) -> do
        EvTerm
dict <- TcPluginNames -> Xi -> TcPluginM EvTerm
lookupUnshowableDict TcPluginNames
names Xi
ty
        (EvTerm, Ct) -> TcPluginM (EvTerm, Ct)
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EvTerm
dict, Ct
ct)

  TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginSolveResult -> TcPluginM TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
Ghc.TcPluginOk
           ([(EvTerm, Ct)]
showLevDicts [(EvTerm, Ct)] -> [(EvTerm, Ct)] -> [(EvTerm, Ct)]
forall a. [a] -> [a] -> [a]
++ [(EvTerm, Ct)]
unshowableDicts)
           ([Maybe Ct] -> [Ct]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Ct]
mNewWanteds)

buildShowLevDict
  :: TcPluginNames
  -> Ghc.Ct
  -> Ghc.Type
  -> Ghc.TcPluginM (Ghc.EvTerm, Ghc.Ct)
buildShowLevDict :: TcPluginNames -> Ct -> Xi -> TcPluginM (EvTerm, Ct)
buildShowLevDict TcPluginNames
names Ct
showLevWanted Xi
ty = do
  CtEvidence
showWantedEv <-
    CtLoc -> Xi -> TcPluginM CtEvidence
Plugin.newWanted
      (Ct -> CtLoc
Ghc.ctLoc Ct
showLevWanted)
      (TyCon -> [Xi] -> Xi
Ghc.mkTyConApp (Class -> TyCon
Ghc.classTyCon (Class -> TyCon) -> Class -> TyCon
forall a b. (a -> b) -> a -> b
$ TcPluginNames -> Class
showClass TcPluginNames
names) [Xi
ty])
  let showCt :: Ct
showCt = CtEvidence -> Ct
Ghc.mkNonCanonical CtEvidence
showWantedEv
  (EvTerm, Ct) -> TcPluginM (EvTerm, Ct)
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CtEvidence -> EvTerm
Ghc.ctEvTerm CtEvidence
showWantedEv, Ct
showCt)

lookupUnshowableDict
  :: TcPluginNames
  -> Ghc.Type
  -> Ghc.TcPluginM Ghc.EvTerm
lookupUnshowableDict :: TcPluginNames -> Xi -> TcPluginM EvTerm
lookupUnshowableDict TcPluginNames
names Xi
ty = do
  InstEnvs
instEnvs <- TcPluginM InstEnvs
Plugin.getInstEnvs
  EvTerm
unshowableDict <- TcM EvTerm -> TcPluginM EvTerm
forall a. TcM a -> TcPluginM a
Ghc.unsafeTcPluginTcM (TcM EvTerm -> TcPluginM EvTerm) -> TcM EvTerm -> TcPluginM EvTerm
forall a b. (a -> b) -> a -> b
$ Xi -> TcM EvTerm
buildUnshowableDict Xi
ty
  let (ClsInst
inst, [Xi]
_) = (ClsInst, [Xi]) -> Either SDoc (ClsInst, [Xi]) -> (ClsInst, [Xi])
forall b a. b -> Either a b -> b
fromRight (String -> (ClsInst, [Xi])
forall a. HasCallStack => String -> a
error String
"impossible: no Show instance for ShowWrapper") (Either SDoc (ClsInst, [Xi]) -> (ClsInst, [Xi]))
-> Either SDoc (ClsInst, [Xi]) -> (ClsInst, [Xi])
forall a b. (a -> b) -> a -> b
$
        InstEnvs -> Class -> [Xi] -> Either SDoc (ClsInst, [Xi])
Ghc.lookupUniqueInstEnv
          InstEnvs
instEnvs
          (TcPluginNames -> Class
showClass TcPluginNames
names)
          [TyCon -> [Xi] -> Xi
Ghc.mkTyConApp (TcPluginNames -> TyCon
showWrapperTyCon TcPluginNames
names) [Xi
ty]]
  EvTerm -> TcPluginM EvTerm
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EvTerm -> TcPluginM EvTerm) -> EvTerm -> TcPluginM EvTerm
forall a b. (a -> b) -> a -> b
$ ClsInst -> Xi -> EvExpr -> EvTerm
liftDict ClsInst
inst Xi
ty (EvTerm -> EvExpr
getEvExprFromDict EvTerm
unshowableDict)

getEvExprFromDict :: Ghc.EvTerm -> Ghc.EvExpr
getEvExprFromDict :: EvTerm -> EvExpr
getEvExprFromDict = \case
  Ghc.EvExpr EvExpr
expr -> EvExpr
expr
  EvTerm
_ -> String -> EvExpr
forall a. HasCallStack => String -> a
error String
"invalid argument to getEvExprFromDict"

buildUnshowableDict :: Ghc.Type -> Ghc.TcM Ghc.EvTerm
buildUnshowableDict :: Xi -> TcM EvTerm
buildUnshowableDict Xi
ty = do
  let tyString :: String
tyString = SDoc -> String
Ghc.showSDocOneLine' (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ Xi -> SDoc
Ghc.pprTypeForUser' Xi
ty
  EvExpr
str <- String -> IOEnv (Env TcGblEnv TcLclEnv) EvExpr
forall (m :: * -> *). MonadThings m => String -> m EvExpr
Ghc.mkStringExpr (String -> IOEnv (Env TcGblEnv TcLclEnv) EvExpr)
-> String -> IOEnv (Env TcGblEnv TcLclEnv) EvExpr
forall a b. (a -> b) -> a -> b
$ String
"<" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
tyString String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
">"
  EvTerm -> TcM EvTerm
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EvTerm -> TcM EvTerm)
-> (EvExpr -> EvTerm) -> EvExpr -> TcM EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
Ghc.EvExpr (EvExpr -> TcM EvTerm) -> EvExpr -> TcM EvTerm
forall a b. (a -> b) -> a -> b
$
    [CoreBndr] -> EvExpr -> EvExpr
Ghc.mkCoreLams [Xi -> CoreBndr
Ghc.mkWildValBinder' Xi
ty] EvExpr
str

liftDict :: Ghc.ClsInst -> Ghc.Type -> Ghc.EvExpr -> Ghc.EvTerm
liftDict :: ClsInst -> Xi -> EvExpr -> EvTerm
liftDict ClsInst
succ_inst Xi
ty EvExpr
dict = CoreBndr -> [Xi] -> [EvExpr] -> EvTerm
Ghc.evDFunApp (ClsInst -> CoreBndr
Ghc.is_dfun ClsInst
succ_inst) [Xi
ty] [EvExpr
dict]