{-# LANGUAGE DataKinds #-}

-- | This module defines a typechecker plugin that solves equations
-- involving units of measure.  To use it, add
--
-- > {-# OPTIONS_GHC -fplugin Data.UnitsOfMeasure.Plugin #-}
--
-- above the module header of your source files, or in the
-- @ghc-options@ field of your @.cabal@ file.  You do not need to
-- import this module.
module Data.UnitsOfMeasure.Plugin
  ( plugin
  ) where

import GhcApi (TcCoercion, ctEvPred, ctEvTerm, typeKind, heqDataCon, evDFunApp, dataConName, dataConWrapId, occName, occNameFS, tyConDataCons, (<+>), isWanted, isGivenCt, isGiven, UnivCoProvenance(PluginProv), mkPrimEqPred, Type(TyConApp), heqTyCon)

import qualified GHC.Plugins as Plugins
import GHC.TcPlugin.API as PluginAPI

import Data.Either

import Data.UnitsOfMeasure.Plugin.Convert
import Data.UnitsOfMeasure.Plugin.NormalForm
import Data.UnitsOfMeasure.Plugin.Unify

-- | The plugin that GHC will load when this module is used with the
-- @-fplugin@ option.
plugin :: Plugins.Plugin
plugin :: Plugin
plugin =
    Plugin
Plugins.defaultPlugin
        { tcPlugin :: TcPlugin
Plugins.tcPlugin = Maybe TcPlugin -> TcPlugin
forall a b. a -> b -> a
const (Maybe TcPlugin -> TcPlugin) -> Maybe TcPlugin -> TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin -> TcPlugin
PluginAPI.mkTcPlugin TcPlugin
uomPlugin
        , pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
Plugins.pluginRecompile = IO PluginRecompile -> [CommandLineOption] -> IO PluginRecompile
forall a b. a -> b -> a
const (IO PluginRecompile -> [CommandLineOption] -> IO PluginRecompile)
-> IO PluginRecompile -> [CommandLineOption] -> IO PluginRecompile
forall a b. (a -> b) -> a -> b
$ PluginRecompile -> IO PluginRecompile
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PluginRecompile
Plugins.NoForceRecompile
        }

uomPlugin :: PluginAPI.TcPlugin
uomPlugin :: TcPlugin
uomPlugin =
    PluginAPI.TcPlugin
        { tcPluginInit :: TcPluginM 'Init UnitDefs
PluginAPI.tcPluginInit    = TcPluginM 'Init UnitDefs
lookupUnitDefs
        , tcPluginSolve :: UnitDefs -> TcPluginSolver
PluginAPI.tcPluginSolve   = UnitDefs -> TcPluginSolver
unitsOfMeasureSolver
        , tcPluginRewrite :: UnitDefs -> UniqFM TyCon TcPluginRewriter
PluginAPI.tcPluginRewrite = UnitDefs -> UniqFM TyCon TcPluginRewriter
unitsOfMeasureRewrite
        , tcPluginStop :: UnitDefs -> TcPluginM 'Stop ()
PluginAPI.tcPluginStop    = TcPluginM 'Stop () -> UnitDefs -> TcPluginM 'Stop ()
forall a b. a -> b -> a
const (TcPluginM 'Stop () -> UnitDefs -> TcPluginM 'Stop ())
-> TcPluginM 'Stop () -> UnitDefs -> TcPluginM 'Stop ()
forall a b. (a -> b) -> a -> b
$ () -> TcPluginM 'Stop ()
forall a. a -> TcPluginM 'Stop a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        }



unitsOfMeasureSolver :: UnitDefs -> [Ct] -> [Ct] -> PluginAPI.TcPluginM PluginAPI.Solve PluginAPI.TcPluginSolveResult
unitsOfMeasureSolver :: UnitDefs -> TcPluginSolver
unitsOfMeasureSolver UnitDefs
uds [Ct]
givens []      = do
    CommandLineOption -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *).
MonadTcPlugin m =>
CommandLineOption -> SDoc -> m ()
PluginAPI.tcPluginTrace CommandLineOption
"unitsOfMeasureSolver simplifying givens" (SDoc -> TcPluginM 'Solve ()) -> SDoc -> TcPluginM 'Solve ()
forall a b. (a -> b) -> a -> b
$ [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
givens
    let ([(Ct, UnitEquality)]
unit_givens0 , [Ct]
_) = [Either (Ct, UnitEquality) Ct] -> ([(Ct, UnitEquality)], [Ct])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Ct, UnitEquality) Ct] -> ([(Ct, UnitEquality)], [Ct]))
-> [Either (Ct, UnitEquality) Ct] -> ([(Ct, UnitEquality)], [Ct])
forall a b. (a -> b) -> a -> b
$ (Ct -> Either UnitEquality Ct -> Either (Ct, UnitEquality) Ct)
-> [Ct]
-> [Either UnitEquality Ct]
-> [Either (Ct, UnitEquality) Ct]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Ct -> Either UnitEquality Ct -> Either (Ct, UnitEquality) Ct
foo [Ct]
givens ([Either UnitEquality Ct] -> [Either (Ct, UnitEquality) Ct])
-> [Either UnitEquality Ct] -> [Either (Ct, UnitEquality) Ct]
forall a b. (a -> b) -> a -> b
$ (Ct -> Either UnitEquality Ct) -> [Ct] -> [Either UnitEquality Ct]
forall a b. (a -> b) -> [a] -> [b]
map (UnitDefs -> Ct -> Either UnitEquality Ct
toUnitEquality UnitDefs
uds) [Ct]
givens
    let unit_givens :: [(Ct, UnitEquality)]
unit_givens = ((Ct, UnitEquality) -> Bool)
-> [(Ct, UnitEquality)] -> [(Ct, UnitEquality)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Ct, UnitEquality) -> Bool
forall {a}. (a, UnitEquality) -> Bool
is_useful [(Ct, UnitEquality)]
unit_givens0
    case [(Ct, UnitEquality)]
unit_givens of
      []    -> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
PluginAPI.TcPluginOk [] []
      ((Ct, UnitEquality)
_:[(Ct, UnitEquality)]
_) -> do
        SimplifyResult
sr <- UnitDefs -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simplifyUnits UnitDefs
uds ([UnitEquality] -> TcPluginM 'Solve SimplifyResult)
-> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
forall a b. (a -> b) -> a -> b
$ ((Ct, UnitEquality) -> UnitEquality)
-> [(Ct, UnitEquality)] -> [UnitEquality]
forall a b. (a -> b) -> [a] -> [b]
map (Ct, UnitEquality) -> UnitEquality
forall a b. (a, b) -> b
snd [(Ct, UnitEquality)]
unit_givens
        CommandLineOption -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *).
MonadTcPlugin m =>
CommandLineOption -> SDoc -> m ()
PluginAPI.tcPluginTrace CommandLineOption
"unitsOfMeasureSolver simplified givens only" (SDoc -> TcPluginM 'Solve ()) -> SDoc -> TcPluginM 'Solve ()
forall a b. (a -> b) -> a -> b
$ SimplifyResult -> SDoc
forall a. Outputable a => a -> SDoc
ppr SimplifyResult
sr
        case SimplifyResult
sr of
          -- TODO: givens simplification is currently disabled, because if we emit a given
          -- constraint like x[sk] ~ Base "kg" then GHC will "simplify" all occurrences
          -- of the type family application Base "kg" to the skolem variable x[sk].
          -- This can then result in loops as the rewriter will turn the fam app into
          -- the variable, then the plugin will "solve" it again.
          Simplified SimplifyState
_ -> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
PluginAPI.TcPluginOk [] []
          Simplified SimplifyState
ss   -> do
              -- TODO: we ought to generate evidence that depends on the
              -- previous givens (and similarly when simplifying wanteds, the
              -- evidence we generate should depend on the new wanteds).
              -- Otherwise we could potentially have a soundness issue e.g. if a
              -- GADT pattern match brings a unit equality into scope, but we
              -- later float out something that depends on it.
              let usefuls :: TySubst
usefuls = SimplifyState -> TySubst
simplifySubst SimplifyState
ss
              [Ct]
xs <- (SubstItem -> TcPluginM 'Solve Ct)
-> TySubst -> TcPluginM 'Solve [Ct]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (UnitDefs -> SubstItem -> TcPluginM 'Solve Ct
substItemToCt UnitDefs
uds) TySubst
usefuls
              TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
PluginAPI.TcPluginOk ((SubstItem -> (EvTerm, Ct)) -> TySubst -> [(EvTerm, Ct)]
forall a b. (a -> b) -> [a] -> [b]
map (Ct -> (EvTerm, Ct)
solvedGiven (Ct -> (EvTerm, Ct))
-> (SubstItem -> Ct) -> SubstItem -> (EvTerm, Ct)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubstItem -> Ct
siCt) TySubst
usefuls) [Ct]
xs
          -- Simplified _    -> return $ PluginAPI.TcPluginOk [] []
          Impossible UnitEquality
eq [UnitEquality]
_ -> UnitDefs -> UnitEquality -> TcPluginM 'Solve TcPluginSolveResult
reportContradiction UnitDefs
uds UnitEquality
eq
  where
    foo :: Ct -> Either UnitEquality Ct -> Either (Ct, UnitEquality) Ct
    foo :: Ct -> Either UnitEquality Ct -> Either (Ct, UnitEquality) Ct
foo Ct
ct (Left UnitEquality
x)    = (Ct, UnitEquality) -> Either (Ct, UnitEquality) Ct
forall a b. a -> Either a b
Left (Ct
ct, UnitEquality
x)
    foo Ct
_  (Right Ct
ct') = Ct -> Either (Ct, UnitEquality) Ct
forall a b. b -> Either a b
Right Ct
ct'

    solvedGiven :: Ct -> (EvTerm, Ct)
solvedGiven Ct
ct = (CtEvidence -> EvTerm
ctEvTerm (Ct -> CtEvidence
ctEvidence Ct
ct), Ct
ct)

    -- TODO: if the simplify givens stage makes progress, we want to emit new
    -- givens in case GHC can substitute into constraints other than unit
    -- equalities.  However, we don't want to cause a loop by repeatedly
    -- re-simplifying the same givens.  We currently have a conservative check
    -- to see if it is useful to simplify a unit equality: if neither side of
    -- the original equality was a single variable.  There are "useful" cases
    -- this misses, however, e.g. v^2 ~ v.
    is_useful :: (a, UnitEquality) -> Bool
is_useful (a
_, UnitEquality
ue) = UnitEquality -> Bool
isUsefulUnitEquality UnitEquality
ue

unitsOfMeasureSolver UnitDefs
uds [Ct]
givens [Ct]
wanteds = do
    let ([UnitEquality]
unit_wanteds, [Ct]
_) = [Either UnitEquality Ct] -> ([UnitEquality], [Ct])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either UnitEquality Ct] -> ([UnitEquality], [Ct]))
-> [Either UnitEquality Ct] -> ([UnitEquality], [Ct])
forall a b. (a -> b) -> a -> b
$ (Ct -> Either UnitEquality Ct) -> [Ct] -> [Either UnitEquality Ct]
forall a b. (a -> b) -> [a] -> [b]
map (UnitDefs -> Ct -> Either UnitEquality Ct
toUnitEquality UnitDefs
uds) [Ct]
wanteds
    case [UnitEquality]
unit_wanteds of
      []    -> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
PluginAPI.TcPluginOk [] []
      (UnitEquality
_:[UnitEquality]
_) -> do
        let ([UnitEquality]
unit_givens , [Ct]
_) = [Either UnitEquality Ct] -> ([UnitEquality], [Ct])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either UnitEquality Ct] -> ([UnitEquality], [Ct]))
-> [Either UnitEquality Ct] -> ([UnitEquality], [Ct])
forall a b. (a -> b) -> a -> b
$ (Ct -> Either UnitEquality Ct) -> [Ct] -> [Either UnitEquality Ct]
forall a b. (a -> b) -> [a] -> [b]
map (UnitDefs -> Ct -> Either UnitEquality Ct
toUnitEquality UnitDefs
uds) [Ct]
givens
        SimplifyResult
sr <- UnitDefs -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simplifyUnits UnitDefs
uds [UnitEquality]
unit_givens
        CommandLineOption -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *).
MonadTcPlugin m =>
CommandLineOption -> SDoc -> m ()
PluginAPI.tcPluginTrace CommandLineOption
"unitsOfMeasureSolver simplified givens" (SDoc -> TcPluginM 'Solve ()) -> SDoc -> TcPluginM 'Solve ()
forall a b. (a -> b) -> a -> b
$ SimplifyResult -> SDoc
forall a. Outputable a => a -> SDoc
ppr SimplifyResult
sr
        -- TODO: it is somewhat questionable to simplify the givens again
        -- here. In principle we should be able to simplify them at the
        -- simplify-givens stage, turn them into a substitution, and have GHC
        -- apply the substitution.
        case SimplifyResult
sr of
          Impossible UnitEquality
eq [UnitEquality]
_ -> UnitDefs -> UnitEquality -> TcPluginM 'Solve TcPluginSolveResult
reportContradiction UnitDefs
uds UnitEquality
eq
          Simplified SimplifyState
ss   -> do SimplifyResult
sr' <- UnitDefs -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simplifyUnits UnitDefs
uds ([UnitEquality] -> TcPluginM 'Solve SimplifyResult)
-> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
forall a b. (a -> b) -> a -> b
$ (UnitEquality -> UnitEquality) -> [UnitEquality] -> [UnitEquality]
forall a b. (a -> b) -> [a] -> [b]
map (TySubst -> UnitEquality -> UnitEquality
substsUnitEquality (SimplifyState -> TySubst
simplifySubst SimplifyState
ss)) [UnitEquality]
unit_wanteds
                                CommandLineOption -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *).
MonadTcPlugin m =>
CommandLineOption -> SDoc -> m ()
PluginAPI.tcPluginTrace CommandLineOption
"unitsOfMeasureSolver simplified wanteds" (SDoc -> TcPluginM 'Solve ()) -> SDoc -> TcPluginM 'Solve ()
forall a b. (a -> b) -> a -> b
$ SimplifyResult -> SDoc
forall a. Outputable a => a -> SDoc
ppr SimplifyResult
sr'
                                case SimplifyResult
sr' of
                                  Impossible UnitEquality
_eq [UnitEquality]
_ -> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
PluginAPI.TcPluginOk [] [] -- Don't report a contradiction, see #22
                                  Simplified SimplifyState
ss'  -> [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
PluginAPI.TcPluginOk [ (UnitDefs -> Ct -> EvTerm
evMagic UnitDefs
uds Ct
ct, Ct
ct) | UnitEquality
eq <- SimplifyState -> [UnitEquality]
simplifySolved SimplifyState
ss', let ct :: Ct
ct = UnitEquality -> Ct
fromUnitEquality UnitEquality
eq ]
                                                         ([Ct] -> TcPluginSolveResult)
-> TcPluginM 'Solve [Ct] -> TcPluginM 'Solve TcPluginSolveResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubstItem -> TcPluginM 'Solve Ct)
-> TySubst -> TcPluginM 'Solve [Ct]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (UnitDefs -> SubstItem -> TcPluginM 'Solve Ct
substItemToCt UnitDefs
uds) ((SubstItem -> Bool) -> TySubst -> TySubst
forall a. (a -> Bool) -> [a] -> [a]
filter (CtEvidence -> Bool
isWanted (CtEvidence -> Bool)
-> (SubstItem -> CtEvidence) -> SubstItem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ct -> CtEvidence
ctEvidence (Ct -> CtEvidence) -> (SubstItem -> Ct) -> SubstItem -> CtEvidence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubstItem -> Ct
siCt) (TySubst -> TySubst -> TySubst
substsSubst (SimplifyState -> TySubst
simplifyUnsubst SimplifyState
ss) (SimplifyState -> TySubst
simplifySubst SimplifyState
ss')))


reportContradiction :: UnitDefs -> UnitEquality -> PluginAPI.TcPluginM PluginAPI.Solve PluginAPI.TcPluginSolveResult
reportContradiction :: UnitDefs -> UnitEquality -> TcPluginM 'Solve TcPluginSolveResult
reportContradiction UnitDefs
uds UnitEquality
eq = [Ct] -> TcPluginSolveResult
PluginAPI.TcPluginContradiction ([Ct] -> TcPluginSolveResult)
-> (Ct -> [Ct]) -> Ct -> TcPluginSolveResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ct -> [Ct]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ct -> TcPluginSolveResult)
-> TcPluginM 'Solve Ct -> TcPluginM 'Solve TcPluginSolveResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UnitDefs -> UnitEquality -> TcPluginM 'Solve Ct
fromUnitEqualityForContradiction UnitDefs
uds UnitEquality
eq

-- See #22 for why we need this
fromUnitEqualityForContradiction :: UnitDefs -> UnitEquality -> PluginAPI.TcPluginM PluginAPI.Solve Ct
fromUnitEqualityForContradiction :: UnitDefs -> UnitEquality -> TcPluginM 'Solve Ct
fromUnitEqualityForContradiction UnitDefs
uds (UnitEquality Ct
ct NormUnit
u NormUnit
v) = case Type -> Pred
classifyPredType (Type -> Pred) -> Type -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred (CtEvidence -> Type) -> CtEvidence -> Type
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    EqPred EqRel
NomEq Type
_ Type
_ -> Ct -> TcPluginM 'Solve Ct
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return Ct
ct
    Pred
_ | Ct -> Bool
isGivenCt Ct
ct -> CtEvidence -> Ct
PluginAPI.mkNonCanonical (CtEvidence -> Ct)
-> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve Ct
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> Type -> EvExpr -> TcPluginM 'Solve CtEvidence
PluginAPI.newGiven  (Ct -> CtLoc
ctLoc Ct
ct) (Type -> Type -> Type
mkPrimEqPred Type
u' Type
v') (EvTerm -> EvExpr
evTermToExpr (Type -> Type -> Type -> EvTerm
mkFunnyEqEvidence (Ct -> Type
ctPred Ct
ct) Type
u' Type
v'))
      | Bool
otherwise    -> CtEvidence -> Ct
PluginAPI.mkNonCanonical (CtEvidence -> Ct)
-> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve Ct
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> Type -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
PluginAPI.newWanted (Ct -> CtLoc
ctLoc Ct
ct) (Type -> Type -> Type
mkPrimEqPred Type
u' Type
v')
  where
    u' :: Type
u' = UnitDefs -> NormUnit -> Type
reifyUnit UnitDefs
uds NormUnit
u
    v' :: Type
v' = UnitDefs -> NormUnit -> Type
reifyUnit UnitDefs
uds NormUnit
v


substItemToCt :: UnitDefs -> SubstItem -> PluginAPI.TcPluginM PluginAPI.Solve Ct
substItemToCt :: UnitDefs -> SubstItem -> TcPluginM 'Solve Ct
substItemToCt UnitDefs
uds SubstItem
si
      | CtEvidence -> Bool
isGiven (Ct -> CtEvidence
ctEvidence Ct
ct) = CtEvidence -> Ct
PluginAPI.mkNonCanonical (CtEvidence -> Ct)
-> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve Ct
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> Type -> EvExpr -> TcPluginM 'Solve CtEvidence
PluginAPI.newGiven CtLoc
loc Type
prd (CommandLineOption -> Type -> Type -> EvExpr
evByFiatExpr CommandLineOption
"units" Type
ty1 Type
ty2)
      | Bool
otherwise               = CtEvidence -> Ct
PluginAPI.mkNonCanonical (CtEvidence -> Ct)
-> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve Ct
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> Type -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
PluginAPI.newWanted CtLoc
loc Type
prd
      where
        prd :: Type
prd  = Type -> Type -> Type
mkPrimEqPred Type
ty1 Type
ty2
        ty1 :: Type
ty1  = TyVar -> Type
mkTyVarTy (SubstItem -> TyVar
siVar SubstItem
si)
        ty2 :: Type
ty2  = UnitDefs -> NormUnit -> Type
reifyUnit UnitDefs
uds (SubstItem -> NormUnit
siUnit SubstItem
si)
        ct :: Ct
ct   = SubstItem -> Ct
siCt SubstItem
si
        loc :: CtLoc
loc  = Ct -> CtLoc
ctLoc Ct
ct


{-
TODO: this leads to errors like this on GHC 9.2, but seems to work on 9.4?

*** Core Lint errors : in result of Desugar (before optimization) ***
src/Data/UnitsOfMeasure/Defs.hs:19:4: warning:
    Trans coercion mis-match: (IsCanonical
                                 Univ(nominal plugin "units"
                                      :: Unpack (Base "m"), '["m"] ':/ '[]))_N
                              ; Sym (D:R:IsCanonical[0] <'["m"]>_N <'[]>_N)
      IsCanonical (Unpack (Base "m")) ~ IsCanonical ('["m"] ':/ '[])
      (AllIsCanonical '["m"], AllIsCanonical '[]) ~ IsCanonical
                                                      ('["m"] ':/ '[])
    In the RHS of $cp1HasCanonicalBaseUnit_alno :: IsCanonical
                                                     (Unpack (CanonicalBaseUnit "m"))
    In the body of letrec with binders $d(%%)_alnP :: () :: Constraint
    In the body of letrec with binders $d(%%)_alnN :: () :: Constraint
    In the body of letrec with binders $d~_alnO :: Base "m" ~ Base "m"
    In the body of letrec with binders $d(%,%)_alnM :: (Base "m"
                                                        ~ Base "m",
                                                        () :: Constraint)
    In the body of letrec with binders $d(%,%)_alnL :: ((Base "m"
                                                         ~ Base "m",
                                                         () :: Constraint),
                                                        () :: Constraint)
    Substitution: [TCvSubst
                     In scope: InScope {}
                     Type env: []
                     Co env: []]
-}

unitsOfMeasureRewrite
  :: UnitDefs ->
    PluginAPI.UniqFM
        TyCon
        ([Ct] -> [Type] -> PluginAPI.TcPluginM PluginAPI.Rewrite PluginAPI.TcPluginRewriteResult)
unitsOfMeasureRewrite :: UnitDefs -> UniqFM TyCon TcPluginRewriter
unitsOfMeasureRewrite UnitDefs
uds = [(TyCon, TcPluginRewriter)] -> UniqFM TyCon TcPluginRewriter
forall key elt. Uniquable key => [(key, elt)] -> UniqFM key elt
PluginAPI.listToUFM [(UnitDefs -> TyCon
unpackTyCon UnitDefs
uds, UnitDefs -> TcPluginRewriter
unpackRewriter UnitDefs
uds)]

unpackRewriter :: UnitDefs -> [Ct] -> [Type] -> PluginAPI.TcPluginM PluginAPI.Rewrite PluginAPI.TcPluginRewriteResult
unpackRewriter :: UnitDefs -> TcPluginRewriter
unpackRewriter UnitDefs
uds [Ct]
_givens [Type
ty] = do
  case NormUnit -> Maybe [(BaseUnit, Integer)]
maybeConstant (NormUnit -> Maybe [(BaseUnit, Integer)])
-> Maybe NormUnit -> Maybe [(BaseUnit, Integer)]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnitDefs -> Type -> Maybe NormUnit
normaliseUnit UnitDefs
uds Type
ty of
    Maybe [(BaseUnit, Integer)]
Nothing -> do CommandLineOption -> SDoc -> TcPluginM 'Rewrite ()
forall (m :: * -> *).
MonadTcPlugin m =>
CommandLineOption -> SDoc -> m ()
PluginAPI.tcPluginTrace CommandLineOption
"unpackRewriter: no rewrite" (Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ty)
                  TcPluginRewriteResult -> TcPluginM 'Rewrite TcPluginRewriteResult
forall a. a -> TcPluginM 'Rewrite a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TcPluginRewriteResult
PluginAPI.TcPluginNoRewrite
    Just [(BaseUnit, Integer)]
u  -> do CommandLineOption -> SDoc -> TcPluginM 'Rewrite ()
forall (m :: * -> *).
MonadTcPlugin m =>
CommandLineOption -> SDoc -> m ()
PluginAPI.tcPluginTrace CommandLineOption
"unpackRewriter: rewrite" (Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ty SDoc -> SDoc -> SDoc
<+> [(BaseUnit, Integer)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(BaseUnit, Integer)]
u)
                  TcPluginRewriteResult -> TcPluginM 'Rewrite TcPluginRewriteResult
forall a. a -> TcPluginM 'Rewrite a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginRewriteResult -> TcPluginM 'Rewrite TcPluginRewriteResult)
-> TcPluginRewriteResult
-> TcPluginM 'Rewrite TcPluginRewriteResult
forall a b. (a -> b) -> a -> b
$ let reduct :: Type
reduct = UnitDefs -> [(BaseUnit, Integer)] -> Type
reifyUnitUnpacked UnitDefs
uds [(BaseUnit, Integer)]
u
                         in let co :: Coercion
co = CommandLineOption -> Role -> Type -> Type -> Coercion
PluginAPI.mkPluginUnivCo CommandLineOption
"units" Role
Nominal (TyCon -> [Type] -> Type
mkTyConApp (UnitDefs -> TyCon
unpackTyCon UnitDefs
uds) [Type
ty]) Type
reduct
                            in Reduction -> [Ct] -> TcPluginRewriteResult
PluginAPI.TcPluginRewriteTo (Coercion -> Type -> Reduction
PluginAPI.Reduction Coercion
co Type
reduct) []
unpackRewriter UnitDefs
_ [Ct]
_ [Type]
tys = do
    CommandLineOption -> SDoc -> TcPluginM 'Rewrite ()
forall (m :: * -> *).
MonadTcPlugin m =>
CommandLineOption -> SDoc -> m ()
PluginAPI.tcPluginTrace CommandLineOption
"unpackRewriter: wrong number of arguments?" ([Type] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Type]
tys)
    TcPluginRewriteResult -> TcPluginM 'Rewrite TcPluginRewriteResult
forall a. a -> TcPluginM 'Rewrite a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TcPluginRewriteResult
PluginAPI.TcPluginNoRewrite

-- TODO: the following is nonsense
lookupModule' :: PluginAPI.MonadTcPlugin m => PluginAPI.ModuleName -> p -> m PluginAPI.Module
lookupModule' :: forall (m :: * -> *) p.
MonadTcPlugin m =>
ModuleName -> p -> m Module
lookupModule' ModuleName
modname p
_pkg = do
  FindResult
r <- ModuleName -> PkgQual -> m FindResult
forall (m :: * -> *).
MonadTcPlugin m =>
ModuleName -> PkgQual -> m FindResult
PluginAPI.findImportedModule ModuleName
modname PkgQual
PluginAPI.NoPkgQual --  (PluginAPI.OtherPkg pkg)
  case FindResult
r of
    PluginAPI.Found ModLocation
_ Module
md -> Module -> m Module
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Module
md
    FindResult
_ -> do FindResult
r' <- ModuleName -> PkgQual -> m FindResult
forall (m :: * -> *).
MonadTcPlugin m =>
ModuleName -> PkgQual -> m FindResult
PluginAPI.findImportedModule ModuleName
modname PkgQual
PluginAPI.NoPkgQual
            case FindResult
r' of
              PluginAPI.Found ModLocation
_ Module
md -> Module -> m Module
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Module
md
              FindResult
_ -> CommandLineOption -> m Module
forall a. HasCallStack => CommandLineOption -> a
error CommandLineOption
"lookupModule: not Found"


lookupUnitDefs :: PluginAPI.TcPluginM PluginAPI.Init UnitDefs
lookupUnitDefs :: TcPluginM 'Init UnitDefs
lookupUnitDefs = do
    Module
md <- ModuleName -> BaseUnit -> TcPluginM 'Init Module
forall (m :: * -> *) p.
MonadTcPlugin m =>
ModuleName -> p -> m Module
lookupModule' ModuleName
myModule BaseUnit
myPackage
    TyCon
u <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"Unit"
    TyCon
b <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"Base"
    TyCon
o <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"One"
    TyCon
m <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"*:"
    TyCon
d <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"/:"
    TyCon
e <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"^:"
    TyCon
x <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"Unpack"
    TyCon
i <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"UnitSyntax"
    TyCon
c <- Module -> CommandLineOption -> TcPluginM 'Init TyCon
forall {m :: * -> *}.
MonadTcPlugin m =>
Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
"~~"
    UnitDefs -> TcPluginM 'Init UnitDefs
forall a. a -> TcPluginM 'Init a
forall (m :: * -> *) a. Monad m => a -> m a
return (UnitDefs -> TcPluginM 'Init UnitDefs)
-> UnitDefs -> TcPluginM 'Init UnitDefs
forall a b. (a -> b) -> a -> b
$ TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> UnitDefs
UnitDefs TyCon
u TyCon
b TyCon
o TyCon
m TyCon
d TyCon
e TyCon
x TyCon
i (TyCon -> CommandLineOption -> TyCon
getDataCon TyCon
i CommandLineOption
":/") TyCon
c
  where
    getDataCon :: TyCon -> CommandLineOption -> TyCon
getDataCon TyCon
u CommandLineOption
s = case [ DataCon
dc | DataCon
dc <- TyCon -> [DataCon]
tyConDataCons TyCon
u, OccName -> BaseUnit
occNameFS (Name -> OccName
forall name. HasOccName name => name -> OccName
occName (DataCon -> Name
dataConName DataCon
dc)) BaseUnit -> BaseUnit -> Bool
forall a. Eq a => a -> a -> Bool
== CommandLineOption -> BaseUnit
fsLit CommandLineOption
s ] of
                       [DataCon
d] -> DataCon -> TyCon
promoteDataCon DataCon
d
                       [DataCon]
_   -> CommandLineOption -> TyCon
forall a. HasCallStack => CommandLineOption -> a
error (CommandLineOption -> TyCon) -> CommandLineOption -> TyCon
forall a b. (a -> b) -> a -> b
$ CommandLineOption
"lookupUnitDefs/getDataCon: missing " CommandLineOption -> CommandLineOption -> CommandLineOption
forall a. [a] -> [a] -> [a]
++ CommandLineOption
s

    look :: Module -> CommandLineOption -> m TyCon
look Module
md CommandLineOption
s = Name -> m TyCon
forall (m :: * -> *). MonadTcPlugin m => Name -> m TyCon
PluginAPI.tcLookupTyCon (Name -> m TyCon) -> m Name -> m TyCon
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> m Name
forall (m :: * -> *).
MonadTcPlugin m =>
Module -> OccName -> m Name
PluginAPI.lookupOrig Module
md (CommandLineOption -> OccName
mkTcOcc CommandLineOption
s)
    myModule :: ModuleName
myModule  = CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"Data.UnitsOfMeasure.Internal"
    myPackage :: BaseUnit
myPackage = CommandLineOption -> BaseUnit
fsLit CommandLineOption
"uom-plugin"


-- | Make up evidence for a fake equality constraint @t1 ~~ t2@ by coercing
-- bogus evidence of type @t1 ~ t2@.
mkFunnyEqEvidence :: Type -> Type -> Type -> EvTerm
mkFunnyEqEvidence :: Type -> Type -> Type -> EvTerm
mkFunnyEqEvidence Type
t Type
t1 Type
t2 =
    EvTerm
castFrom EvTerm -> Coercion -> EvTerm
`evCast'` Coercion
castTo
    where
        castFrom :: EvTerm
        castFrom :: EvTerm
castFrom = TyVar -> [Type] -> [EvExpr] -> EvTerm
evDFunApp TyVar
funId [Type]
tys [EvExpr]
terms
            where
                funId :: Id
                funId :: TyVar
funId = DataCon -> TyVar
dataConWrapId DataCon
heqDataCon

                tys :: [Kind]
                tys :: [Type]
tys = [(() :: Constraint) => Type -> Type
Type -> Type
typeKind Type
t1, (() :: Constraint) => Type -> Type
Type -> Type
typeKind Type
t2, Type
t1, Type
t2]

                terms :: [EvExpr]
                terms :: [EvExpr]
terms = [CommandLineOption -> Type -> Type -> EvExpr
evByFiatExpr CommandLineOption
"units" Type
t1 Type
t2]

        castTo :: TcCoercion
        castTo :: Coercion
castTo =
            UnivCoProvenance -> Role -> Type -> Type -> Coercion
mkUnivCo UnivCoProvenance
from Role
Representational Type
tySource Type
t
            where
                from :: UnivCoProvenance
                from :: UnivCoProvenance
from = CommandLineOption -> UnivCoProvenance
PluginProv CommandLineOption
"units"

                tySource :: Type
                tySource :: Type
tySource = Type -> Type -> Type
mkHEqPred Type
t1 Type
t2

mkHEqPred :: Type -> Type -> Type
mkHEqPred :: Type -> Type -> Type
mkHEqPred Type
t1 Type
t2 = TyCon -> [Type] -> Type
TyConApp TyCon
heqTyCon [(() :: Constraint) => Type -> Type
Type -> Type
typeKind Type
t1, (() :: Constraint) => Type -> Type
Type -> Type
typeKind Type
t2, Type
t1, Type
t2]


-- | Produce bogus evidence for a constraint, including actual
-- equality constraints and our fake '(~~)' equality constraints.
evMagic :: UnitDefs -> Ct -> EvTerm
evMagic :: UnitDefs -> Ct -> EvTerm
evMagic UnitDefs
uds Ct
ct = case Type -> Pred
classifyPredType (Type -> Pred) -> Type -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred (CtEvidence -> Type) -> CtEvidence -> Type
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    EqPred EqRel
NomEq Type
t1 Type
t2   -> CommandLineOption -> Type -> Type -> EvTerm
evByFiat CommandLineOption
"units" Type
t1 Type
t2
    IrredPred Type
t
      | Just (TyCon
tc, [Type
t1,Type
t2]) <- (() :: Constraint) => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
t
      , TyCon
tc TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== UnitDefs -> TyCon
equivTyCon UnitDefs
uds -> Type -> Type -> Type -> EvTerm
mkFunnyEqEvidence Type
t Type
t1 Type
t2
    Pred
_                    -> CommandLineOption -> EvTerm
forall a. HasCallStack => CommandLineOption -> a
error CommandLineOption
"evMagic"

evByFiat :: String -> PluginAPI.TcType -> PluginAPI.TcType -> EvTerm
evByFiat :: CommandLineOption -> Type -> Type -> EvTerm
evByFiat CommandLineOption
s Type
t1 Type
t2 = CommandLineOption -> Role -> Type -> Type -> EvTerm
PluginAPI.mkPluginUnivEvTerm CommandLineOption
s Role
Nominal Type
t1 Type
t2

evByFiatExpr :: String -> PluginAPI.TcType -> PluginAPI.TcType -> EvExpr
evByFiatExpr :: CommandLineOption -> Type -> Type -> EvExpr
evByFiatExpr CommandLineOption
s Type
t1 Type
t2 = EvTerm -> EvExpr
evTermToExpr (EvTerm -> EvExpr) -> EvTerm -> EvExpr
forall a b. (a -> b) -> a -> b
$ CommandLineOption -> Role -> Type -> Type -> EvTerm
PluginAPI.mkPluginUnivEvTerm CommandLineOption
s Role
Nominal Type
t1 Type
t2

evTermToExpr :: EvTerm -> EvExpr
evTermToExpr :: EvTerm -> EvExpr
evTermToExpr (EvExpr EvExpr
e) = EvExpr
e
evTermToExpr EvTerm
_ = CommandLineOption -> EvExpr
forall a. HasCallStack => CommandLineOption -> a
error CommandLineOption
"evTermToExpr"

evCast' :: EvTerm -> TcCoercion -> EvTerm
evCast' :: EvTerm -> Coercion -> EvTerm
evCast' = EvExpr -> Coercion -> EvTerm
evCast (EvExpr -> Coercion -> EvTerm)
-> (EvTerm -> EvExpr) -> EvTerm -> Coercion -> EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvTerm -> EvExpr
evTermToExpr