{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}


-- | This module implements a constraint solver plugin for the
-- 'Stable' type class.

module AsyncRattus.Plugin.StableSolver (tcStable) where

import AsyncRattus.Plugin.Utils
    ( getNameModule, isRattModule, isStable )

import Prelude hiding ((<>))

import GHC.Plugins
  (Type, Var, CommandLineOption,tyConSingleDataCon,
   mkCoreConApps,getTyVar_maybe)
import GHC.Core
import GHC.Tc.Types.Evidence
import GHC.Core.Class
import GHC.Tc.Types
import GHC.Tc.Types.Constraint

import Data.Set (Set)
import qualified Data.Set as Set
#if __GLASGOW_HASKELL__ >= 904
import GHC.Types.Unique.FM
#endif



-- | Constraint solver plugin for the 'Stable' type class.
tcStable :: [CommandLineOption] -> Maybe TcPlugin
tcStable :: [CommandLineOption] -> Maybe TcPlugin
tcStable [CommandLineOption]
_ = TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin
  { tcPluginInit :: TcPluginM ()
tcPluginInit = () -> TcPluginM ()
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , tcPluginSolve :: () -> TcPluginSolver
tcPluginSolve = \ () -> TcPluginSolver
stableSolver
  , tcPluginStop :: () -> TcPluginM ()
tcPluginStop = \ () -> () -> TcPluginM ()
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if __GLASGOW_HASKELL__ >= 904
  , tcPluginRewrite :: () -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = \ () -> UniqFM TyCon TcPluginRewriter
forall key elt. UniqFM key elt
emptyUFM
#endif
  }


wrap :: Class -> Type -> EvTerm
wrap :: Class -> Type -> EvTerm
wrap Class
cls Type
ty = EvExpr -> EvTerm
EvExpr EvExpr
appDc
  where
    tyCon :: TyCon
tyCon = Class -> TyCon
classTyCon Class
cls
    dc :: DataCon
dc = TyCon -> DataCon
tyConSingleDataCon TyCon
tyCon
    appDc :: EvExpr
appDc = DataCon -> [EvExpr] -> EvExpr
mkCoreConApps DataCon
dc [Type -> EvExpr
forall b. Type -> Expr b
Type Type
ty]

solveStable :: Set Var -> (Type, (Ct,Class)) -> Maybe (EvTerm, Ct)
solveStable :: Set Var -> (Type, (Ct, Class)) -> Maybe (EvTerm, Ct)
solveStable Set Var
c (Type
ty,(Ct
ct,Class
cl))
  | Set Var -> Type -> Bool
isStable Set Var
c Type
ty = (EvTerm, Ct) -> Maybe (EvTerm, Ct)
forall a. a -> Maybe a
Just (Class -> Type -> EvTerm
wrap Class
cl Type
ty, Ct
ct)
  | Bool
otherwise = Maybe (EvTerm, Ct)
forall a. Maybe a
Nothing

#if __GLASGOW_HASKELL__ >= 904
stableSolver :: EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
stableSolver :: TcPluginSolver
stableSolver EvBindsVar
_ [Ct]
given [Ct]
wanted = do
#else
stableSolver :: [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
stableSolver given _derived wanted = do
#endif

  let chSt :: [(Type, (Ct, Class))]
chSt = (Ct -> [(Type, (Ct, Class))]) -> [Ct] -> [(Type, (Ct, Class))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Ct -> [(Type, (Ct, Class))]
filterCt [Ct]
wanted
  let haveSt :: Set Var
haveSt = [Var] -> Set Var
forall a. Ord a => [a] -> Set a
Set.fromList ([Var] -> Set Var) -> [Var] -> Set Var
forall a b. (a -> b) -> a -> b
$ ((Type, (Ct, Class)) -> [Var]) -> [(Type, (Ct, Class))] -> [Var]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Type -> [Var]
filterTypeVar (Type -> [Var])
-> ((Type, (Ct, Class)) -> Type) -> (Type, (Ct, Class)) -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, (Ct, Class)) -> Type
forall a b. (a, b) -> a
fst) ([(Type, (Ct, Class))] -> [Var]) -> [(Type, (Ct, Class))] -> [Var]
forall a b. (a -> b) -> a -> b
$ (Ct -> [(Type, (Ct, Class))]) -> [Ct] -> [(Type, (Ct, Class))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Ct -> [(Type, (Ct, Class))]
filterCt [Ct]
given
  case ((Type, (Ct, Class)) -> Maybe (EvTerm, Ct))
-> [(Type, (Ct, Class))] -> Maybe [(EvTerm, Ct)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Set Var -> (Type, (Ct, Class)) -> Maybe (EvTerm, Ct)
solveStable Set Var
haveSt) [(Type, (Ct, Class))]
chSt of
    Just [(EvTerm, Ct)]
evs -> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
evs []
    Maybe [(EvTerm, Ct)]
Nothing -> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []

  where
#if __GLASGOW_HASKELL__ >= 908
        filterCt ct@(CDictCan (DictCt {di_cls = cl, di_tys = [ty]}))
#else
        filterCt :: Ct -> [(Type, (Ct, Class))]
filterCt ct :: Ct
ct@(CDictCan {cc_class :: Ct -> Class
cc_class = Class
cl, cc_tyargs :: Ct -> [Type]
cc_tyargs = [Type
ty]})
#endif
          = case Class -> Maybe (FastString, FastString)
forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Class
cl of
                Just (FastString
name,FastString
mod)
                  | FastString -> Bool
isRattModule FastString
mod Bool -> Bool -> Bool
&& FastString
name FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"Stable" -> [(Type
ty,(Ct
ct,Class
cl))]
                Maybe (FastString, FastString)
_ -> []
        filterCt Ct
_ = []
        filterTypeVar :: Type -> [Var]
filterTypeVar Type
ty = case Type -> Maybe Var
getTyVar_maybe Type
ty of
          Just Var
v -> [Var
v]
          Maybe Var
Nothing -> []