{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}


-- | This module implements a constraint solver plugin for the
-- 'Stable' type class.

module Rattus.Plugin.StableSolver (tcStable) where

import Rattus.Plugin.Utils

import Prelude hiding ((<>))

#if __GLASGOW_HASKELL__ >= 900
import GHC.Plugins
  (Type, Var, CommandLineOption,tyConSingleDataCon,
   mkCoreConApps,getTyVar_maybe)
import GHC.Core
import GHC.Tc.Types.Evidence
import GHC.Core.Class
import GHC.Tc.Types
#else
import GhcPlugins
  (Type, Var, CommandLineOption,tyConSingleDataCon,
   mkCoreConApps,getTyVar_maybe)
import CoreSyn
import TcEvidence
import Class
import TcRnTypes
#endif

#if __GLASGOW_HASKELL__ >= 900
import GHC.Tc.Types.Constraint
#elif __GLASGOW_HASKELL__ >= 810
import Constraint
#endif

import Data.Set (Set)
import qualified Data.Set as Set





-- | Constraint solver plugin for the 'Stable' type class.
tcStable :: [CommandLineOption] -> Maybe TcPlugin
tcStable :: [CommandLineOption] -> Maybe TcPlugin
tcStable [CommandLineOption]
_ = TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin
  { tcPluginInit :: TcPluginM ()
tcPluginInit = () -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , tcPluginSolve :: () -> TcPluginSolver
tcPluginSolve = \ () -> TcPluginSolver
stableSolver
  , tcPluginStop :: () -> TcPluginM ()
tcPluginStop = \ () -> () -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  }


wrap :: Class -> Type -> EvTerm
wrap :: Class -> Type -> EvTerm
wrap Class
cls Type
ty = EvExpr -> EvTerm
EvExpr EvExpr
appDc
  where
    tyCon :: TyCon
tyCon = Class -> TyCon
classTyCon Class
cls
    dc :: DataCon
dc = TyCon -> DataCon
tyConSingleDataCon TyCon
tyCon
    appDc :: EvExpr
appDc = DataCon -> [EvExpr] -> EvExpr
mkCoreConApps DataCon
dc [Type -> EvExpr
forall b. Type -> Expr b
Type Type
ty]

solveStable :: Set Var -> (Type, (Ct,Class)) -> Maybe (EvTerm, Ct)
solveStable :: Set Var -> (Type, (Ct, Class)) -> Maybe (EvTerm, Ct)
solveStable Set Var
c (Type
ty,(Ct
ct,Class
cl))
  | Set Var -> Type -> Bool
isStable Set Var
c Type
ty = (EvTerm, Ct) -> Maybe (EvTerm, Ct)
forall a. a -> Maybe a
Just (Class -> Type -> EvTerm
wrap Class
cl Type
ty, Ct
ct)
  | Bool
otherwise = Maybe (EvTerm, Ct)
forall a. Maybe a
Nothing

stableSolver :: [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
stableSolver :: TcPluginSolver
stableSolver [Ct]
given [Ct]
_derived [Ct]
wanted = do
  let chSt :: [(Type, (Ct, Class))]
chSt = (Ct -> [(Type, (Ct, Class))]) -> [Ct] -> [(Type, (Ct, Class))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Ct -> [(Type, (Ct, Class))]
filterCt [Ct]
wanted
  let haveSt :: Set Var
haveSt = [Var] -> Set Var
forall a. Ord a => [a] -> Set a
Set.fromList ([Var] -> Set Var) -> [Var] -> Set Var
forall a b. (a -> b) -> a -> b
$ ((Type, (Ct, Class)) -> [Var]) -> [(Type, (Ct, Class))] -> [Var]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Type -> [Var]
filterTypeVar (Type -> [Var])
-> ((Type, (Ct, Class)) -> Type) -> (Type, (Ct, Class)) -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, (Ct, Class)) -> Type
forall a b. (a, b) -> a
fst) ([(Type, (Ct, Class))] -> [Var]) -> [(Type, (Ct, Class))] -> [Var]
forall a b. (a -> b) -> a -> b
$ (Ct -> [(Type, (Ct, Class))]) -> [Ct] -> [(Type, (Ct, Class))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Ct -> [(Type, (Ct, Class))]
filterCt [Ct]
given
  case ((Type, (Ct, Class)) -> Maybe (EvTerm, Ct))
-> [(Type, (Ct, Class))] -> Maybe [(EvTerm, Ct)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Var -> (Type, (Ct, Class)) -> Maybe (EvTerm, Ct)
solveStable Set Var
haveSt) [(Type, (Ct, Class))]
chSt of
    Just [(EvTerm, Ct)]
evs -> TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [(EvTerm, Ct)]
evs []
    Maybe [(EvTerm, Ct)]
Nothing -> TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] []

  where filterCt :: Ct -> [(Type, (Ct, Class))]
filterCt ct :: Ct
ct@(CDictCan {cc_class :: Ct -> Class
cc_class = Class
cl, cc_tyargs :: Ct -> [Type]
cc_tyargs = [Type
ty]})
          = case Class -> Maybe (FastString, FastString)
forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Class
cl of
              Just (FastString
name,FastString
mod)
                | FastString -> Bool
isRattModule FastString
mod Bool -> Bool -> Bool
&& FastString
name FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"Stable" -> [(Type
ty,(Ct
ct,Class
cl))]
              Maybe (FastString, FastString)
_ -> []
        filterCt Ct
_ = []
        filterTypeVar :: Type -> [Var]
filterTypeVar Type
ty = case Type -> Maybe Var
getTyVar_maybe Type
ty of
          Just Var
v -> [Var
v]
          Maybe Var
Nothing -> []