{-# LANGUAGE RecordWildCards, ViewPatterns #-}
module TypeLevel.Rewrite (plugin) where
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Foldable
import Data.Traversable
import GHC.TcPluginM.Extra (evByFiat)
import Plugins (Plugin(pluginRecompile, tcPlugin), CommandLineOption, defaultPlugin, purePlugin)
import TcEvidence (EvTerm)
import TcPluginM (TcPluginM, newCoercionHole)
import TcRnTypes
import TcType (TcPredType)
import TyCon (synTyConDefn_maybe)
import Type (EqRel(NomEq), PredTree(EqPred), Type, classifyPredType, mkPrimEqPred)
import TypeLevel.Rewrite.Internal.Lookup
import TypeLevel.Rewrite.Internal.PrettyPrint
import TypeLevel.Rewrite.Internal.TypeRule
import TypeLevel.Rewrite.Internal.TypeTerm
data ReplaceCt = ReplaceCt
{ evidenceOfCorrectness :: EvTerm
, replacedConstraint :: Ct
, replacementConstraints :: [Ct]
}
combineReplaceCts
:: [ReplaceCt]
-> TcPluginResult
combineReplaceCts replaceCts
= TcPluginOk (fmap solvedConstraint replaceCts)
(foldMap replacementConstraints replaceCts)
where
solvedConstraint :: ReplaceCt -> (EvTerm, Ct)
solvedConstraint = (,) <$> evidenceOfCorrectness <*> replacedConstraint
usage
:: String
-> String
-> TcPluginM a
usage expected actual
= error $ "usage:\n"
++ " {-# OPTIONS_GHC -fplugin TypeLevel.Rewrite\n"
++ " -fplugin-opt=TypeLevel.Rewrite:TypeLevel.Append.RightIdentity\n"
++ " -fplugin-opt=TypeLevel.Rewrite:TypeLevel.Append.RightAssociative #-}\n"
++ "Where 'TypeLevel.Append' is a module containing a type synonym named 'RightIdentity':\n"
++ " type RightIdentity as = (as ++ '[]) ~ as\n"
++ "Type expressions which match the left of the '~' will get rewritten to the type\n"
++ "expression on the right of the '~'. Be careful not to introduce cycles!\n"
++ "\n"
++ "expected:\n"
++ " " ++ expected ++ "\n"
++ "got:\n"
++ " " ++ actual
lookupTypeRules
:: [CommandLineOption]
-> TcPluginM [TypeRule]
lookupTypeRules [] = do
usage (show ["TypeLevel.Append.RightIdentity", "TypeLevel.Append.RightAssociative"])
"[]"
lookupTypeRules fullyQualifiedTypeSynonyms = do
for fullyQualifiedTypeSynonyms $ \fullyQualifiedTypeSynonym -> do
case splitLastDot fullyQualifiedTypeSynonym of
Nothing -> do
usage (show "TypeLevel.Append.RightIdentity")
(show fullyQualifiedTypeSynonym)
Just (moduleNameStr, tyConNameStr) -> do
tyCon <- lookupTyCon moduleNameStr tyConNameStr
case synTyConDefn_maybe tyCon of
Nothing -> do
usage ("type " ++ pprTyCon tyCon ++ " ... = ...")
(pprTyCon tyCon ++ " is not a type synonym")
Just (_tyVars, definition) -> do
case toTypeRule_maybe definition of
Nothing -> do
usage "... ~ ..."
(pprType definition)
Just typeRule -> do
pure typeRule
plugin
:: Plugin
plugin = defaultPlugin
{ tcPlugin = \args -> Just $ TcPlugin
{ tcPluginInit = lookupTypeRules args
, tcPluginSolve = solve
, tcPluginStop = \_ -> pure ()
}
, pluginRecompile = purePlugin
}
asEqualityConstraint
:: Ct
-> Maybe (Type, Type)
asEqualityConstraint ct = do
let predTree
= classifyPredType
$ ctEvPred
$ ctEvidence
$ ct
case predTree of
EqPred NomEq lhs rhs
-> pure (lhs, rhs)
_ -> Nothing
toEqualityConstraint
:: Type -> Type -> CtLoc -> TcPluginM Ct
toEqualityConstraint lhs rhs loc = do
let tcPredType :: TcPredType
tcPredType = mkPrimEqPred lhs rhs
hole <- newCoercionHole tcPredType
pure $ mkNonCanonical
$ CtWanted tcPredType (HoleDest hole) WDeriv loc
solve
:: [TypeRule]
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
solve _ _ _ [] = do
pure $ TcPluginOk [] []
solve rules _ _ cts = do
replaceCts <- execWriterT $ do
for_ cts $ \ct -> do
for_ (asEqualityConstraint ct) $ \(lhs, rhs) -> do
let lhsTypeTerm = toTypeTerm lhs
let rhsTypeTerm = toTypeTerm rhs
let lhsTypeTerm' = applyRules rules lhsTypeTerm
let rhsTypeTerm' = applyRules rules rhsTypeTerm
unless (lhsTypeTerm' == lhsTypeTerm && rhsTypeTerm' == rhsTypeTerm) $ do
let lhs' = fromTypeTerm lhsTypeTerm'
let rhs' = fromTypeTerm rhsTypeTerm'
ct' <- lift $ toEqualityConstraint lhs' rhs' (ctLoc ct)
let replaceCt :: ReplaceCt
replaceCt = ReplaceCt
{ evidenceOfCorrectness = evByFiat "TypeLevel.Rewrite" lhs' rhs'
, replacedConstraint = ct
, replacementConstraints = [ct']
}
tell [replaceCt]
pure $ combineReplaceCts replaceCts