{-# LANGUAGE LambdaCase, OverloadedStrings, RecordWildCards, ViewPatterns #-}
module TypeLevel.Rewrite (plugin) where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Foldable
import Data.Traversable

-- GHC API
import Coercion (Role(Representational), mkUnivCo)
import Constraint (CtEvidence(ctev_loc), Ct, ctEvExpr, ctLoc, mkNonCanonical)
import GhcPlugins (PredType, SDoc, eqType, fsep, ppr)
import Plugins (Plugin(pluginRecompile, tcPlugin), CommandLineOption, defaultPlugin, purePlugin)
import TcEvidence (EvExpr, EvTerm, evCast)
import TcPluginM (newWanted)
import TcRnTypes
import TyCoRep (UnivCoProvenance(PluginProv))
import TyCon (synTyConDefn_maybe)

import TypeLevel.Rewrite.Internal.ApplyRules
import TypeLevel.Rewrite.Internal.DecomposedConstraint
import TypeLevel.Rewrite.Internal.Lookup
import TypeLevel.Rewrite.Internal.PrettyPrint
import TypeLevel.Rewrite.Internal.TypeEq
import TypeLevel.Rewrite.Internal.TypeRule
import TypeLevel.Rewrite.Internal.TypeTerm

-- printf-debugging:
--import TcPluginM (tcPluginIO)
--import Outputable
----tcPluginIO $ print ("foo", showSDocUnsafe $ ppr foo)


data ReplaceCt = ReplaceCt
  { ReplaceCt -> EvTerm
evidenceOfCorrectness  :: EvTerm
  , ReplaceCt -> Ct
replacedConstraint     :: Ct
  , ReplaceCt -> [Ct]
replacementConstraints :: [Ct]
  }

combineReplaceCts
  :: [ReplaceCt]
  -> TcPluginResult
combineReplaceCts :: [ReplaceCt] -> TcPluginResult
combineReplaceCts [ReplaceCt]
replaceCts
  = [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk ((ReplaceCt -> (EvTerm, Ct)) -> [ReplaceCt] -> [(EvTerm, Ct)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ReplaceCt -> (EvTerm, Ct)
solvedConstraint [ReplaceCt]
replaceCts)
               ((ReplaceCt -> [Ct]) -> [ReplaceCt] -> [Ct]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ReplaceCt -> [Ct]
replacementConstraints [ReplaceCt]
replaceCts)
  where
    solvedConstraint :: ReplaceCt -> (EvTerm, Ct)
    solvedConstraint :: ReplaceCt -> (EvTerm, Ct)
solvedConstraint = (,) (EvTerm -> Ct -> (EvTerm, Ct))
-> (ReplaceCt -> EvTerm) -> ReplaceCt -> Ct -> (EvTerm, Ct)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReplaceCt -> EvTerm
evidenceOfCorrectness (ReplaceCt -> Ct -> (EvTerm, Ct))
-> (ReplaceCt -> Ct) -> ReplaceCt -> (EvTerm, Ct)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReplaceCt -> Ct
replacedConstraint


usage
  :: String  -- ^ expected
  -> String  -- ^ actual
  -> TcPluginM a
usage :: String -> String -> TcPluginM a
usage String
expected String
actual
  = String -> TcPluginM a
forall a. HasCallStack => String -> a
error (String -> TcPluginM a) -> String -> TcPluginM a
forall a b. (a -> b) -> a -> b
$ String
"usage:\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"  {-# OPTIONS_GHC -fplugin TypeLevel.Rewrite\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"                  -fplugin-opt=TypeLevel.Rewrite:TypeLevel.Append.RightIdentity\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"                  -fplugin-opt=TypeLevel.Rewrite:TypeLevel.Append.RightAssociative #-}\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Where 'TypeLevel.Append' is a module containing a type synonym named 'RightIdentity':\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"  type RightIdentity as = (as ++ '[]) ~ as\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Type expressions which match the left of the '~' will get rewritten to the type\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"expression on the right of the '~'. Be careful not to introduce cycles!\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"expected:\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"  " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
expected String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"got:\n"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"  " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
actual

lookupTypeRules
  :: [CommandLineOption]
  -> TcPluginM [TypeRule]
lookupTypeRules :: [String] -> TcPluginM [TypeRule]
lookupTypeRules [] = do
  String -> String -> TcPluginM [TypeRule]
forall a. String -> String -> TcPluginM a
usage ([String] -> String
forall a. Show a => a -> String
show [ String
"TypeLevel.Append.RightIdentity" :: String
              , String
"TypeLevel.Append.RightAssociative"
              ])
        String
"[]"
lookupTypeRules [String]
fullyQualifiedTypeSynonyms = do
  -- ["TypeLevel.Append.RightIdentity", "TypeLevel.Append.RightAssociative"]
  [String] -> (String -> TcPluginM TypeRule) -> TcPluginM [TypeRule]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [String]
fullyQualifiedTypeSynonyms ((String -> TcPluginM TypeRule) -> TcPluginM [TypeRule])
-> (String -> TcPluginM TypeRule) -> TcPluginM [TypeRule]
forall a b. (a -> b) -> a -> b
$ \String
fullyQualifiedTypeSynonym -> do
    -- "TypeLevel.Append.RightIdentity"
    case String -> Maybe (String, String)
splitLastDot String
fullyQualifiedTypeSynonym of
      Maybe (String, String)
Nothing -> do
        String -> String -> TcPluginM TypeRule
forall a. String -> String -> TcPluginM a
usage (String -> String
forall a. Show a => a -> String
show (String
"TypeLevel.Append.RightIdentity" :: String))
              (String -> String
forall a. Show a => a -> String
show String
fullyQualifiedTypeSynonym)
      Just (String
moduleNameStr, String
tyConNameStr) -> do
        -- ("TypeLevel.Append", "RightIdentity")
        TyCon
tyCon <- String -> String -> TcPluginM TyCon
lookupTyCon String
moduleNameStr String
tyConNameStr  -- FIXME: if tyConNameStr is not found in
                                                         -- the module, the error message is poor
        case TyCon -> Maybe ([TyVar], Type)
synTyConDefn_maybe TyCon
tyCon of
          Maybe ([TyVar], Type)
Nothing -> do
            String -> String -> TcPluginM TypeRule
forall a. String -> String -> TcPluginM a
usage (String
"type " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TyCon -> String
pprTyCon TyCon
tyCon String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" ... = ...")
                  (TyCon -> String
pprTyCon TyCon
tyCon String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not a type synonym")
          Just ([TyVar]
_tyVars, Type
definition) -> do
            -- ([TyVar "as"], Type "(as ++ '[]) ~ as")
            case Type -> Maybe TypeRule
toTypeRule_maybe Type
definition of
              Maybe TypeRule
Nothing -> do
                String -> String -> TcPluginM TypeRule
forall a. String -> String -> TcPluginM a
usage String
"... ~ ..."
                      (Type -> String
pprType Type
definition)
              Just TypeRule
typeRule -> do
                -- Rule (TypeTree "(as ++ '[])")
                --      (TypeTree "as")
                TypeRule -> TcPluginM TypeRule
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeRule
typeRule


plugin
  :: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin
  { tcPlugin :: TcPlugin
tcPlugin = \[String]
args -> TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin
    { tcPluginInit :: TcPluginM [TypeRule]
tcPluginInit  = [String] -> TcPluginM [TypeRule]
lookupTypeRules [String]
args
    , tcPluginSolve :: [TypeRule] -> TcPluginSolver
tcPluginSolve = [TypeRule] -> TcPluginSolver
solve
    , tcPluginStop :: [TypeRule] -> TcPluginM ()
tcPluginStop  = \[TypeRule]
_ -> () -> TcPluginM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    }
  , pluginRecompile :: [String] -> IO PluginRecompile
pluginRecompile = [String] -> IO PluginRecompile
purePlugin
  }


mkErrCtx
  :: SDoc
  -> ErrCtxt
mkErrCtx :: SDoc -> ErrCtxt
mkErrCtx SDoc
errDoc = (Bool
True, \TidyEnv
env -> (TidyEnv, SDoc) -> TcM (TidyEnv, SDoc)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TidyEnv
env, SDoc
errDoc))

newRuleInducedWanted
  :: Ct
  -> TypeRule
  -> PredType
  -> TcPluginM CtEvidence
newRuleInducedWanted :: Ct -> TypeRule -> Type -> TcPluginM CtEvidence
newRuleInducedWanted Ct
oldCt TypeRule
rule Type
newPredType = do
  let loc :: CtLoc
loc = Ct -> CtLoc
ctLoc Ct
oldCt

  -- include the rewrite rule in the error message, if any
  let errMsg :: SDoc
errMsg = [SDoc] -> SDoc
fsep [ SDoc
"From the typelevel rewrite rule:"
                    , Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr (TypeRule -> Type
fromTypeRule TypeRule
rule)
                    ]
  let loc' :: CtLoc
loc' = ErrCtxt -> CtLoc -> CtLoc
pushErrCtxtSameOrigin (SDoc -> ErrCtxt
mkErrCtx SDoc
errMsg) CtLoc
loc

  CtEvidence
wanted <- CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc' Type
newPredType

  -- ctLoc only copies the "arising from function X" part but not the location
  -- etc., so we need to copy the rest of it manually
  CtEvidence -> TcPluginM CtEvidence
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CtEvidence -> TcPluginM CtEvidence)
-> CtEvidence -> TcPluginM CtEvidence
forall a b. (a -> b) -> a -> b
$ CtEvidence
wanted { ctev_loc :: CtLoc
ctev_loc = CtLoc
loc' }

solve
  :: [TypeRule]
  -> [Ct]  -- ^ Given constraints
  -> [Ct]  -- ^ Derived constraints
  -> [Ct]  -- ^ Wanted constraints
  -> TcPluginM TcPluginResult
solve :: [TypeRule] -> TcPluginSolver
solve [TypeRule]
_ [Ct]
_ [Ct]
_ [] = do
  TcPluginResult -> TcPluginM TcPluginResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] []
solve [TypeRule]
rules [Ct]
givens [Ct]
_ [Ct]
wanteds = do
  [(TypeEq, TypeTerm)]
typeSubst <- WriterT [(TypeEq, TypeTerm)] TcPluginM ()
-> TcPluginM [(TypeEq, TypeTerm)]
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (WriterT [(TypeEq, TypeTerm)] TcPluginM ()
 -> TcPluginM [(TypeEq, TypeTerm)])
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
-> TcPluginM [(TypeEq, TypeTerm)]
forall a b. (a -> b) -> a -> b
$ do
    [Ct]
-> (Ct -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Ct]
givens ((Ct -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
 -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> (Ct -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \Ct
given -> do
      Maybe (Type, Type)
-> ((Type, Type) -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (Ct -> Maybe (Type, Type)
asEqualityConstraint Ct
given) (((Type, Type) -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
 -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> ((Type, Type) -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \(Type
lhs, Type
rhs) -> do
        -- lhs ~ rhs
        -- where lhs is typically an expression and rhs is typically a variable
        let var :: TypeEq
var = Type -> TypeEq
TypeEq Type
rhs
        let val :: TypeTerm
val = Type -> TypeTerm
toTypeTerm Type
lhs
        [(TypeEq, TypeTerm)] -> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell [(TypeEq
var, TypeTerm
val)]

  [ReplaceCt]
replaceCts <- WriterT [ReplaceCt] TcPluginM () -> TcPluginM [ReplaceCt]
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (WriterT [ReplaceCt] TcPluginM () -> TcPluginM [ReplaceCt])
-> WriterT [ReplaceCt] TcPluginM () -> TcPluginM [ReplaceCt]
forall a b. (a -> b) -> a -> b
$ do
    [Ct]
-> (Ct -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Ct]
wanteds ((Ct -> WriterT [ReplaceCt] TcPluginM ())
 -> WriterT [ReplaceCt] TcPluginM ())
-> (Ct -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \Ct
wanted -> do
      -- wanted => ...
      Maybe (DecomposedConstraint Type)
-> (DecomposedConstraint Type -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (Ct -> Maybe (DecomposedConstraint Type)
asDecomposedConstraint Ct
wanted) ((DecomposedConstraint Type -> WriterT [ReplaceCt] TcPluginM ())
 -> WriterT [ReplaceCt] TcPluginM ())
-> (DecomposedConstraint Type -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \DecomposedConstraint Type
types -> do
        -- C a b c => ...

        -- C a b c
        let typeTerms :: DecomposedConstraint TypeTerm
typeTerms = (Type -> TypeTerm)
-> DecomposedConstraint Type -> DecomposedConstraint TypeTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> TypeTerm
toTypeTerm DecomposedConstraint Type
types
        let predType :: Type
predType = DecomposedConstraint Type -> Type
fromDecomposeConstraint DecomposedConstraint Type
types

        Maybe (TypeRule, DecomposedConstraint TypeTerm)
-> ((TypeRule, DecomposedConstraint TypeTerm)
    -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ ([(TypeEq, TypeTerm)]
-> [TypeRule]
-> DecomposedConstraint TypeTerm
-> Maybe (TypeRule, DecomposedConstraint TypeTerm)
forall (t :: * -> *).
Traversable t =>
[(TypeEq, TypeTerm)]
-> [TypeRule] -> t TypeTerm -> Maybe (TypeRule, t TypeTerm)
applyRules [(TypeEq, TypeTerm)]
typeSubst [TypeRule]
rules DecomposedConstraint TypeTerm
typeTerms) (((TypeRule, DecomposedConstraint TypeTerm)
  -> WriterT [ReplaceCt] TcPluginM ())
 -> WriterT [ReplaceCt] TcPluginM ())
-> ((TypeRule, DecomposedConstraint TypeTerm)
    -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \(TypeRule
rule, DecomposedConstraint TypeTerm
typeTerms') -> do
          -- C a' b' c'
          let types' :: DecomposedConstraint Type
types' = (TypeTerm -> Type)
-> DecomposedConstraint TypeTerm -> DecomposedConstraint Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeTerm -> Type
fromTypeTerm DecomposedConstraint TypeTerm
typeTerms'
          let predType' :: Type
predType' = DecomposedConstraint Type -> Type
fromDecomposeConstraint DecomposedConstraint Type
types'

          Bool
-> WriterT [ReplaceCt] TcPluginM ()
-> WriterT [ReplaceCt] TcPluginM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Type -> Bool
eqType Type
predType' Type
predType) (WriterT [ReplaceCt] TcPluginM ()
 -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ do
            -- co :: C a' b' c'  ~R  C a b c
            let co :: Coercion
co = UnivCoProvenance -> Role -> Type -> Type -> Coercion
mkUnivCo
                       (String -> UnivCoProvenance
PluginProv String
"TypeLevel.Rewrite")
                       Role
Representational
                       Type
predType'
                       Type
predType
            CtEvidence
evWanted' <- TcPluginM CtEvidence -> WriterT [ReplaceCt] TcPluginM CtEvidence
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TcPluginM CtEvidence -> WriterT [ReplaceCt] TcPluginM CtEvidence)
-> TcPluginM CtEvidence -> WriterT [ReplaceCt] TcPluginM CtEvidence
forall a b. (a -> b) -> a -> b
$ Ct -> TypeRule -> Type -> TcPluginM CtEvidence
newRuleInducedWanted Ct
wanted TypeRule
rule Type
predType'
            let wanted' :: Ct
wanted' = CtEvidence -> Ct
mkNonCanonical CtEvidence
evWanted'
            let futureDict :: EvExpr
                futureDict :: EvExpr
futureDict = CtEvidence -> EvExpr
ctEvExpr CtEvidence
evWanted'
            let replaceCt :: ReplaceCt
                replaceCt :: ReplaceCt
replaceCt = ReplaceCt :: EvTerm -> Ct -> [Ct] -> ReplaceCt
ReplaceCt
                  { evidenceOfCorrectness :: EvTerm
evidenceOfCorrectness  = EvExpr -> Coercion -> EvTerm
evCast EvExpr
futureDict Coercion
co
                  , replacedConstraint :: Ct
replacedConstraint     = Ct
wanted
                  , replacementConstraints :: [Ct]
replacementConstraints = [Ct
wanted']
                  }
            [ReplaceCt] -> WriterT [ReplaceCt] TcPluginM ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell [ReplaceCt
replaceCt]
  TcPluginResult -> TcPluginM TcPluginResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [ReplaceCt] -> TcPluginResult
combineReplaceCts [ReplaceCt]
replaceCts