{-# LANGUAGE CPP             #-}
{-# LANGUAGE RecordWildCards #-}
module Language.Haskell.TH.CodeT.Plugin (
    plugin,
) where

import Control.Monad (forM)
import Data.Maybe    (catMaybes)
import Data.String   (fromString)

import qualified GHC.Plugins as Plugins

import Plugin.GHC

-- | A GHC type-checker plugin which solves 'Language.Haskell.TH.CodeT.LiftT' instances.
--
-- At the moment plugin solves only type constructor instances,
-- e.g. for a data type
--
-- @
-- data Foo a = MkFoo a
-- @
--
-- the plugin will solve needed instances in @'Language.Hsakell.TH.CodeT.codeT' \@Foo@ and @'Language.Hsakell.TH.CodeT.codeT' \@('MkFoo)@.
-- (There is @(LiftT f, LiftT x) => LiftT (f x)@ existing instance, so plugin doesn't need to).
--
-- Noteably, the plugin solves only for algebraic type constructors (@data@, @newtype@, @class@) and
-- promoted data constructors. Specifically it doesn't solve for type-family type constructors.
--
-- Enable plugin with:
--
-- @
-- \{-# OPTIONS_GHC -fplugin=Language.Haskell.TH.CodeT.Plugin #-\}
-- @
--
plugin :: Plugins.Plugin
plugin :: Plugin
plugin = Plugin
Plugins.defaultPlugin
    { Plugins.tcPlugin = Just . tcPlugin
    }

tcPlugin :: a -> TcPlugin
tcPlugin :: forall a. a -> TcPlugin
tcPlugin a
_ = TcPlugin
    { tcPluginInit :: TcPluginM PluginCtx
tcPluginInit    = TcPluginM PluginCtx
tcPluginInit_
    , tcPluginSolve :: PluginCtx -> TcPluginSolver
tcPluginSolve   = PluginCtx -> TcPluginSolver
tcPluginSolve_
    , tcPluginStop :: PluginCtx -> TcPluginM ()
tcPluginStop    = TcPluginM () -> PluginCtx -> TcPluginM ()
forall a b. a -> b -> a
const (() -> TcPluginM ()
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
#if __GLASGOW_HASKELL__ >=904
    , tcPluginRewrite :: PluginCtx -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = \PluginCtx
_ -> UniqFM TyCon TcPluginRewriter
forall key elt. UniqFM key elt
Plugins.emptyUFM
#endif
    }

data PluginCtx = PluginCtx
    { PluginCtx -> Class
liftTClass        :: Class
    , PluginCtx -> Id
unsafeCodeTNameD  :: Id
    , PluginCtx -> Id
unsafeCodeTNameTC :: Id
    }

tcPluginInit_ :: TcPluginM PluginCtx
tcPluginInit_ :: TcPluginM PluginCtx
tcPluginInit_ = do
    let codet :: FastString
        codet :: FastString
codet = CommandLineOption -> FastString
forall a. IsString a => CommandLineOption -> a
fromString CommandLineOption
"codet"

    let codeTModuleName :: ModuleName
        codeTModuleName :: ModuleName
codeTModuleName =  CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"Language.Haskell.TH.CodeT.Unsafe"

    Class
liftTClass <- do
        Module
md <- ModuleName -> FastString -> TcPluginM Module
findModulePluginM ModuleName
codeTModuleName FastString
codet
        Name -> TcPluginM Class
tcLookupClass (Name -> TcPluginM Class) -> TcPluginM Name -> TcPluginM Class
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
lookupOrig Module
md (CommandLineOption -> OccName
mkTcOcc CommandLineOption
"LiftT")

    Id
unsafeCodeTNameD <- do
        Module
md <- ModuleName -> FastString -> TcPluginM Module
findModulePluginM ModuleName
codeTModuleName FastString
codet
        Module -> OccName -> TcPluginM Name
lookupOrig Module
md (CommandLineOption -> OccName
mkVarOcc CommandLineOption
"unsafeCodeTNameD")  TcPluginM Name -> (Name -> TcPluginM Id) -> TcPluginM Id
forall a b. TcPluginM a -> (a -> TcPluginM b) -> TcPluginM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Name -> TcPluginM Id
tcLookupId

    Id
unsafeCodeTNameTC <- do
        Module
md <- ModuleName -> FastString -> TcPluginM Module
findModulePluginM ModuleName
codeTModuleName FastString
codet
        Module -> OccName -> TcPluginM Name
lookupOrig Module
md (CommandLineOption -> OccName
mkVarOcc CommandLineOption
"unsafeCodeTNameTC")  TcPluginM Name -> (Name -> TcPluginM Id) -> TcPluginM Id
forall a b. TcPluginM a -> (a -> TcPluginM b) -> TcPluginM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Name -> TcPluginM Id
tcLookupId

    PluginCtx -> TcPluginM PluginCtx
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return PluginCtx {Id
Class
liftTClass :: Class
unsafeCodeTNameD :: Id
unsafeCodeTNameTC :: Id
liftTClass :: Class
unsafeCodeTNameD :: Id
unsafeCodeTNameTC :: Id
..}

tcPluginSolve_ :: PluginCtx -> TcPluginSolver
tcPluginSolve_ :: PluginCtx -> TcPluginSolver
tcPluginSolve_ PluginCtx
ctx EvBindsVar
_evBindsVar [Ct]
_given [Ct]
wanteds = do
    [Maybe (EvTerm, Ct)]
solved' <- TcM [Maybe (EvTerm, Ct)] -> TcPluginM [Maybe (EvTerm, Ct)]
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM (TcM [Maybe (EvTerm, Ct)] -> TcPluginM [Maybe (EvTerm, Ct)])
-> TcM [Maybe (EvTerm, Ct)] -> TcPluginM [Maybe (EvTerm, Ct)]
forall a b. (a -> b) -> a -> b
$ [Ct]
-> (Ct -> IOEnv (Env TcGblEnv TcLclEnv) (Maybe (EvTerm, Ct)))
-> TcM [Maybe (EvTerm, Ct)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Ct]
wanteds ((Ct -> IOEnv (Env TcGblEnv TcLclEnv) (Maybe (EvTerm, Ct)))
 -> TcM [Maybe (EvTerm, Ct)])
-> (Ct -> IOEnv (Env TcGblEnv TcLclEnv) (Maybe (EvTerm, Ct)))
-> TcM [Maybe (EvTerm, Ct)]
forall a b. (a -> b) -> a -> b
$ \Ct
wanted -> PluginCtx
-> Ct -> IOEnv (Env TcGblEnv TcLclEnv) (Maybe (EvTerm, Ct))
solveLiftT PluginCtx
ctx Ct
wanted

    let solved :: [(EvTerm, Ct)]
        solved :: [(EvTerm, Ct)]
solved = [Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (EvTerm, Ct)]
solved'

    let new :: [Ct]
        new :: [Ct]
new = []
    TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new

solveLiftT :: PluginCtx -> Ct -> TcM (Maybe (EvTerm, Ct))
solveLiftT :: PluginCtx
-> Ct -> IOEnv (Env TcGblEnv TcLclEnv) (Maybe (EvTerm, Ct))
solveLiftT PluginCtx
ctx Ct
wanted
    | Just (Ct
ct, [Type
k, Type
x]) <- Class -> Ct -> Maybe (Ct, [Type])
findClassConstraint (PluginCtx -> Class
liftTClass PluginCtx
ctx) Ct
wanted
    , Just (TyCon
xTyCon, [Type]
_args) <- (() :: Constraint) => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
x
    , TyCon -> Bool
isAlgTyCon TyCon
xTyCon Bool -> Bool -> Bool
|| TyCon -> Bool
isPromotedDataCon TyCon
xTyCon
    -- , let ki = tyConKind xTyCon
    -- in 9.0 splitPiTysInvisible
    -- , (_invPis, _) <- splitInvisPiTys ki
    , let xTyConName :: Name
xTyConName = TyCon -> Name
forall a. NamedThing a => a -> Name
getName TyCon
xTyCon
    , Just Module
tcMod <- Name -> Maybe Module
nameModule_maybe Name
xTyConName
    -- TODO: check that 'args' count matches 'invPis'?

    = do
        let occ :: OccName
occ = Name -> OccName
nameOccName Name
xTyConName

        -- tcPluginIO $ logOutput logger $ text "wanted:" <+> ppr x
        let pkg_str :: CommandLineOption
pkg_str    = Unit -> CommandLineOption
forall u. IsUnitId u => u -> CommandLineOption
unitString (Module -> Unit
forall unit. GenModule unit -> unit
moduleUnit Module
tcMod)
            mod_str :: CommandLineOption
mod_str    = ModuleName -> CommandLineOption
moduleNameString (Module -> ModuleName
forall unit. GenModule unit -> ModuleName
moduleName Module
tcMod)
            occ_str :: CommandLineOption
occ_str    = OccName -> CommandLineOption
occNameString OccName
occ

        CoreExpr
pkg_str' <- CommandLineOption -> IOEnv (Env TcGblEnv TcLclEnv) CoreExpr
forall (m :: * -> *).
MonadThings m =>
CommandLineOption -> m CoreExpr
mkStringExpr CommandLineOption
pkg_str
        CoreExpr
mod_str' <- CommandLineOption -> IOEnv (Env TcGblEnv TcLclEnv) CoreExpr
forall (m :: * -> *).
MonadThings m =>
CommandLineOption -> m CoreExpr
mkStringExpr CommandLineOption
mod_str
        CoreExpr
occ_str' <- CommandLineOption -> IOEnv (Env TcGblEnv TcLclEnv) CoreExpr
forall (m :: * -> *).
MonadThings m =>
CommandLineOption -> m CoreExpr
mkStringExpr CommandLineOption
occ_str

        let fun :: Id
fun | OccName -> Bool
isDataOcc OccName
occ = PluginCtx -> Id
unsafeCodeTNameD PluginCtx
ctx
                | OccName -> Bool
isTcOcc OccName
occ   = PluginCtx -> Id
unsafeCodeTNameTC PluginCtx
ctx
                | Bool
otherwise     = CommandLineOption -> SDoc -> Id
forall a. HasCallStack => CommandLineOption -> SDoc -> a
Plugins.pprPanic CommandLineOption
"solveLiftT" (Name -> SDoc
forall a. Outputable a => a -> SDoc
ppr Name
xTyConName)

        let ev :: CoreExpr
ev = CoreExpr -> [CoreExpr] -> CoreExpr
mkCoreApps (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
fun) [Type -> CoreExpr
forall b. Type -> Expr b
Type Type
k, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
x, CoreExpr
pkg_str', CoreExpr
mod_str', CoreExpr
occ_str']
        let evidence :: EvTerm
evidence = Class -> [Type] -> CoreExpr -> EvTerm
makeClassEvidence (PluginCtx -> Class
liftTClass PluginCtx
ctx) [Type
k, Type
x] CoreExpr
ev

        Maybe (EvTerm, Ct)
-> IOEnv (Env TcGblEnv TcLclEnv) (Maybe (EvTerm, Ct))
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((EvTerm, Ct) -> Maybe (EvTerm, Ct)
forall a. a -> Maybe a
Just (EvTerm
evidence, Ct
ct))

    | Bool
otherwise
    = Maybe (EvTerm, Ct)
-> IOEnv (Env TcGblEnv TcLclEnv) (Maybe (EvTerm, Ct))
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, Ct)
forall a. Maybe a
Nothing