{-# LANGUAGE CPP             #-}
{-# LANGUAGE RecordWildCards #-}
module Overloaded.Plugin.TcPlugin (
    tcPlugin,
) where

import Data.Maybe    (mapMaybe)

import qualified GHC.Compat.All  as GHC

#if MIN_VERSION_ghc(9,0,0)
import qualified GHC.Tc.Plugin as Plugins
#else
import qualified TcPluginM as Plugins
#endif

import Overloaded.Plugin.TcPlugin.Ctx
import Overloaded.Plugin.HasField
import Overloaded.Plugin.HasConstructor

-- TODO: take argument which options to enable.
tcPlugin :: GHC.TcPlugin
tcPlugin :: TcPlugin
tcPlugin = TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
GHC.TcPlugin
    { tcPluginInit :: TcPluginM PluginCtx
GHC.tcPluginInit  = TcPluginM PluginCtx
tcPluginInit
    , tcPluginSolve :: PluginCtx -> TcPluginSolver
GHC.tcPluginSolve = PluginCtx -> TcPluginSolver
tcPluginSolve
    , tcPluginStop :: PluginCtx -> TcPluginM ()
GHC.tcPluginStop  = TcPluginM () -> PluginCtx -> TcPluginM ()
forall a b. a -> b -> a
const (() -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    }

-- HasPolyField "petName" Pet Pet [Char] [Char]
tcPluginSolve :: PluginCtx -> GHC.TcPluginSolver
tcPluginSolve :: PluginCtx -> TcPluginSolver
tcPluginSolve PluginCtx
ctx [Ct]
_ [Ct]
_ [Ct]
wanteds = do
    -- acquire context
    DynFlags
dflags      <- TcM DynFlags -> TcPluginM DynFlags
forall a. TcM a -> TcPluginM a
Plugins.unsafeTcPluginTcM TcM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
GHC.getDynFlags
    (FamInstEnv, FamInstEnv)
famInstEnvs <- TcPluginM (FamInstEnv, FamInstEnv)
Plugins.getFamInstEnvs
    GlobalRdrEnv
rdrEnv      <- TcM GlobalRdrEnv -> TcPluginM GlobalRdrEnv
forall a. TcM a -> TcPluginM a
Plugins.unsafeTcPluginTcM TcM GlobalRdrEnv
GHC.getGlobalRdrEnv

    [(Maybe (EvTerm, [Ct]), Ct)]
solvedHasField       <- PluginCtx
-> DynFlags
-> (FamInstEnv, FamInstEnv)
-> GlobalRdrEnv
-> [Ct]
-> TcPluginM [(Maybe (EvTerm, [Ct]), Ct)]
solveHasField       PluginCtx
ctx DynFlags
dflags (FamInstEnv, FamInstEnv)
famInstEnvs GlobalRdrEnv
rdrEnv [Ct]
wanteds
    [(Maybe (EvTerm, [Ct]), Ct)]
solvedHasConstructor <- PluginCtx
-> DynFlags
-> (FamInstEnv, FamInstEnv)
-> GlobalRdrEnv
-> [Ct]
-> TcPluginM [(Maybe (EvTerm, [Ct]), Ct)]
solveHasConstructor PluginCtx
ctx DynFlags
dflags (FamInstEnv, FamInstEnv)
famInstEnvs GlobalRdrEnv
rdrEnv [Ct]
wanteds

    let solved :: [(Maybe (EvTerm, [Ct]), Ct)]
solved = [(Maybe (EvTerm, [Ct]), Ct)]
solvedHasField [(Maybe (EvTerm, [Ct]), Ct)]
-> [(Maybe (EvTerm, [Ct]), Ct)] -> [(Maybe (EvTerm, [Ct]), Ct)]
forall a. [a] -> [a] -> [a]
++ [(Maybe (EvTerm, [Ct]), Ct)]
solvedHasConstructor

    TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
GHC.TcPluginOk (((Maybe (EvTerm, [Ct]), Ct) -> Maybe (EvTerm, Ct))
-> [(Maybe (EvTerm, [Ct]), Ct)] -> [(EvTerm, Ct)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Maybe (EvTerm, [Ct]), Ct) -> Maybe (EvTerm, Ct)
forall a b b. (Maybe (a, b), b) -> Maybe (a, b)
extractA [(Maybe (EvTerm, [Ct]), Ct)]
solved) ([[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Ct]] -> [Ct]) -> [[Ct]] -> [Ct]
forall a b. (a -> b) -> a -> b
$ ((Maybe (EvTerm, [Ct]), Ct) -> Maybe [Ct])
-> [(Maybe (EvTerm, [Ct]), Ct)] -> [[Ct]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Maybe (EvTerm, [Ct]), Ct) -> Maybe [Ct]
forall a a b. (Maybe (a, a), b) -> Maybe a
extractB [(Maybe (EvTerm, [Ct]), Ct)]
solved)
  where
    extractA :: (Maybe (a, b), b) -> Maybe (a, b)
extractA (Maybe (a, b)
Nothing, b
_)     = Maybe (a, b)
forall a. Maybe a
Nothing
    extractA (Just (a
a, b
_), b
b) = (a, b) -> Maybe (a, b)
forall a. a -> Maybe a
Just (a
a, b
b)

    extractB :: (Maybe (a, a), b) -> Maybe a
extractB (Maybe (a, a)
Nothing, b
_)      = Maybe a
forall a. Maybe a
Nothing
    extractB (Just (a
_, a
ct), b
_) = a -> Maybe a
forall a. a -> Maybe a
Just a
ct