{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module AsyncRattus.Plugin.StableSolver (tcStable) where
import AsyncRattus.Plugin.Utils
( getNameModule, isRattModule, isStable )
import Prelude hiding ((<>))
import GHC.Plugins
(Type, Var, CommandLineOption,tyConSingleDataCon,
mkCoreConApps,getTyVar_maybe)
import GHC.Core
import GHC.Tc.Types.Evidence
import GHC.Core.Class
import GHC.Tc.Types
import GHC.Tc.Types.Constraint
import Data.Set (Set)
import qualified Data.Set as Set
#if __GLASGOW_HASKELL__ >= 904
import GHC.Types.Unique.FM
#endif
tcStable :: [CommandLineOption] -> Maybe TcPlugin
tcStable :: [CommandLineOption] -> Maybe TcPlugin
tcStable [CommandLineOption]
_ = TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin
{ tcPluginInit :: TcPluginM ()
tcPluginInit = () -> TcPluginM ()
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
, tcPluginSolve :: () -> TcPluginSolver
tcPluginSolve = \ () -> TcPluginSolver
stableSolver
, tcPluginStop :: () -> TcPluginM ()
tcPluginStop = \ () -> () -> TcPluginM ()
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if __GLASGOW_HASKELL__ >= 904
, tcPluginRewrite :: () -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = \ () -> UniqFM TyCon TcPluginRewriter
forall key elt. UniqFM key elt
emptyUFM
#endif
}
wrap :: Class -> Type -> EvTerm
wrap :: Class -> Type -> EvTerm
wrap Class
cls Type
ty = EvExpr -> EvTerm
EvExpr EvExpr
appDc
where
tyCon :: TyCon
tyCon = Class -> TyCon
classTyCon Class
cls
dc :: DataCon
dc = TyCon -> DataCon
tyConSingleDataCon TyCon
tyCon
appDc :: EvExpr
appDc = DataCon -> [EvExpr] -> EvExpr
mkCoreConApps DataCon
dc [Type -> EvExpr
forall b. Type -> Expr b
Type Type
ty]
solveStable :: Set Var -> (Type, (Ct,Class)) -> Maybe (EvTerm, Ct)
solveStable :: Set Var -> (Type, (Ct, Class)) -> Maybe (EvTerm, Ct)
solveStable Set Var
c (Type
ty,(Ct
ct,Class
cl))
| Set Var -> Type -> Bool
isStable Set Var
c Type
ty = (EvTerm, Ct) -> Maybe (EvTerm, Ct)
forall a. a -> Maybe a
Just (Class -> Type -> EvTerm
wrap Class
cl Type
ty, Ct
ct)
| Bool
otherwise = Maybe (EvTerm, Ct)
forall a. Maybe a
Nothing
#if __GLASGOW_HASKELL__ >= 904
stableSolver :: EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
stableSolver :: TcPluginSolver
stableSolver EvBindsVar
_ [Ct]
given [Ct]
wanted = do
#else
stableSolver :: [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
stableSolver given _derived wanted = do
#endif
let chSt :: [(Type, (Ct, Class))]
chSt = (Ct -> [(Type, (Ct, Class))]) -> [Ct] -> [(Type, (Ct, Class))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Ct -> [(Type, (Ct, Class))]
filterCt [Ct]
wanted
let haveSt :: Set Var
haveSt = [Var] -> Set Var
forall a. Ord a => [a] -> Set a
Set.fromList ([Var] -> Set Var) -> [Var] -> Set Var
forall a b. (a -> b) -> a -> b
$ ((Type, (Ct, Class)) -> [Var]) -> [(Type, (Ct, Class))] -> [Var]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Type -> [Var]
filterTypeVar (Type -> [Var])
-> ((Type, (Ct, Class)) -> Type) -> (Type, (Ct, Class)) -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, (Ct, Class)) -> Type
forall a b. (a, b) -> a
fst) ([(Type, (Ct, Class))] -> [Var]) -> [(Type, (Ct, Class))] -> [Var]
forall a b. (a -> b) -> a -> b
$ (Ct -> [(Type, (Ct, Class))]) -> [Ct] -> [(Type, (Ct, Class))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Ct -> [(Type, (Ct, Class))]
filterCt [Ct]
given
case ((Type, (Ct, Class)) -> Maybe (EvTerm, Ct))
-> [(Type, (Ct, Class))] -> Maybe [(EvTerm, Ct)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Set Var -> (Type, (Ct, Class)) -> Maybe (EvTerm, Ct)
solveStable Set Var
haveSt) [(Type, (Ct, Class))]
chSt of
Just [(EvTerm, Ct)]
evs -> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
evs []
Maybe [(EvTerm, Ct)]
Nothing -> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []
where
#if __GLASGOW_HASKELL__ >= 908
filterCt ct@(CDictCan (DictCt {di_cls = cl, di_tys = [ty]}))
#else
filterCt :: Ct -> [(Type, (Ct, Class))]
filterCt ct :: Ct
ct@(CDictCan {cc_class :: Ct -> Class
cc_class = Class
cl, cc_tyargs :: Ct -> [Type]
cc_tyargs = [Type
ty]})
#endif
= case Class -> Maybe (FastString, FastString)
forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Class
cl of
Just (FastString
name,FastString
mod)
| FastString -> Bool
isRattModule FastString
mod Bool -> Bool -> Bool
&& FastString
name FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"Stable" -> [(Type
ty,(Ct
ct,Class
cl))]
Maybe (FastString, FastString)
_ -> []
filterCt Ct
_ = []
filterTypeVar :: Type -> [Var]
filterTypeVar Type
ty = case Type -> Maybe Var
getTyVar_maybe Type
ty of
Just Var
v -> [Var
v]
Maybe Var
Nothing -> []