{-# LANGUAGE DataKinds #-}

module Data.UnitsOfMeasure.Plugin.Unify
  ( SubstItem(..)
  , substsSubst
  , substsUnitEquality
  , UnitEquality(..)
  , toUnitEquality
  , fromUnitEquality
  , isUsefulUnitEquality
  , SimplifyState(..)
  , SimplifyResult(..)
  , simplifyUnits
  , initialState
  ) where

import GhcApi (text, (<+>), ($$), typeKind, ctEvPred, isGiven, mkSysTvName)

import GHC.TcPlugin.API as PluginAPI
import qualified GHC.TcPlugin.API.Internal as PluginAPI.Internal

import qualified GHC.Tc.Utils.Monad as GHC
import qualified GHC.Tc.Utils.TcMType as GHC

import Data.UnitsOfMeasure.Plugin.Convert
import Data.UnitsOfMeasure.Plugin.NormalForm


-- | A substitution is essentially a list of (variable, unit) pairs,
-- but we keep the original 'Ct' that lead to the substitution being
-- made, for use when turning the substitution back into constraints.
type TySubst = [SubstItem]

data SubstItem = SubstItem { SubstItem -> TyVar
siVar     :: TyVar
                           , SubstItem -> NormUnit
siUnit    :: NormUnit
                           , SubstItem -> Ct
siCt     ::  Ct
                           }

instance Outputable SubstItem where
  ppr :: SubstItem -> SDoc
ppr SubstItem
si = TyVar -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SubstItem -> TyVar
siVar SubstItem
si) SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
" := " SDoc -> SDoc -> SDoc
<+> NormUnit -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SubstItem -> NormUnit
siUnit SubstItem
si) SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"  {" SDoc -> SDoc -> SDoc
<+> Ct -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SubstItem -> Ct
siCt SubstItem
si) SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"}"

-- | Apply a substitution to a single normalised unit
substsUnit :: NormUnit -> TySubst -> NormUnit
substsUnit :: NormUnit -> TySubst -> NormUnit
substsUnit = (NormUnit -> SubstItem -> NormUnit)
-> NormUnit -> TySubst -> NormUnit
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\ NormUnit
u SubstItem
si -> TyVar -> NormUnit -> NormUnit -> NormUnit
substUnit (SubstItem -> TyVar
siVar SubstItem
si) (SubstItem -> NormUnit
siUnit SubstItem
si) NormUnit
u)

-- | Compose two substitutions
substsSubst :: TySubst -> TySubst -> TySubst
substsSubst :: TySubst -> TySubst -> TySubst
substsSubst TySubst
s = (SubstItem -> SubstItem) -> TySubst -> TySubst
forall a b. (a -> b) -> [a] -> [b]
map ((SubstItem -> SubstItem) -> TySubst -> TySubst)
-> (SubstItem -> SubstItem) -> TySubst -> TySubst
forall a b. (a -> b) -> a -> b
$ \ SubstItem
si -> SubstItem
si { siUnit :: NormUnit
siUnit = NormUnit -> TySubst -> NormUnit
substsUnit (SubstItem -> NormUnit
siUnit SubstItem
si) TySubst
s }

substsUnitEquality :: TySubst -> UnitEquality -> UnitEquality
substsUnitEquality :: TySubst -> UnitEquality -> UnitEquality
substsUnitEquality TySubst
s (UnitEquality Ct
ct NormUnit
u NormUnit
v) = Ct -> NormUnit -> NormUnit -> UnitEquality
UnitEquality Ct
ct (NormUnit -> TySubst -> NormUnit
substsUnit NormUnit
u TySubst
s) (NormUnit -> TySubst -> NormUnit
substsUnit NormUnit
v TySubst
s)

extendSubst :: SubstItem -> TySubst -> TySubst
extendSubst :: SubstItem -> TySubst -> TySubst
extendSubst SubstItem
si TySubst
s = SubstItem
si SubstItem -> TySubst -> TySubst
forall a. a -> [a] -> [a]
: TySubst -> TySubst -> TySubst
substsSubst [SubstItem
si] TySubst
s


-- | Possible results of unifying a single pair of units.  In the
-- non-failing cases, we return a substitution and a list of fresh
-- variables that were created.
data UnifyResult = Win [TyVar] TySubst TySubst
                 | Draw [TyVar] TySubst TySubst
                 | Lose

instance Outputable UnifyResult where
  ppr :: UnifyResult -> SDoc
ppr (Win  [TyVar]
tvs TySubst
subst TySubst
unsubst) = String -> SDoc
text String
"Win"  SDoc -> SDoc -> SDoc
<+> [TyVar] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [TyVar]
tvs SDoc -> SDoc -> SDoc
<+> TySubst -> SDoc
forall a. Outputable a => a -> SDoc
ppr TySubst
subst SDoc -> SDoc -> SDoc
<+> TySubst -> SDoc
forall a. Outputable a => a -> SDoc
ppr TySubst
unsubst
  ppr (Draw [TyVar]
tvs TySubst
subst TySubst
unsubst) = String -> SDoc
text String
"Draw" SDoc -> SDoc -> SDoc
<+> [TyVar] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [TyVar]
tvs SDoc -> SDoc -> SDoc
<+> TySubst -> SDoc
forall a. Outputable a => a -> SDoc
ppr TySubst
subst SDoc -> SDoc -> SDoc
<+> TySubst -> SDoc
forall a. Outputable a => a -> SDoc
ppr TySubst
unsubst
  ppr UnifyResult
Lose                     = String -> SDoc
text String
"Lose"


-- | Attempt to unify two normalised units to produce a unifying
-- substitution.  The 'Ct' is the equality between the non-normalised
-- (and perhaps less substituted) unit type expressions.
unifyUnits :: UnitDefs -> UnitEquality -> PluginAPI.TcPluginM PluginAPI.Solve UnifyResult
unifyUnits :: UnitDefs -> UnitEquality -> TcPluginM 'Solve UnifyResult
unifyUnits UnitDefs
uds (UnitEquality Ct
ct NormUnit
u0 NormUnit
v0) = do String -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
PluginAPI.tcPluginTrace String
"unifyUnits" (NormUnit -> SDoc
forall a. Outputable a => a -> SDoc
ppr NormUnit
u0 SDoc -> SDoc -> SDoc
$$ NormUnit -> SDoc
forall a. Outputable a => a -> SDoc
ppr NormUnit
v0)
                                            UnitDefs
-> Ct
-> [TyVar]
-> TySubst
-> TySubst
-> NormUnit
-> TcPluginM 'Solve UnifyResult
unifyOne UnitDefs
uds Ct
ct [] [] [] (NormUnit
u0 NormUnit -> NormUnit -> NormUnit
/: NormUnit
v0)

unifyOne :: UnitDefs -> Ct -> [TyVar] -> TySubst -> TySubst -> NormUnit -> PluginAPI.TcPluginM PluginAPI.Solve UnifyResult
unifyOne :: UnitDefs
-> Ct
-> [TyVar]
-> TySubst
-> TySubst
-> NormUnit
-> TcPluginM 'Solve UnifyResult
unifyOne UnitDefs
uds Ct
ct [TyVar]
tvs TySubst
subst TySubst
unsubst NormUnit
u
      | NormUnit -> Bool
isOne NormUnit
u           = UnifyResult -> TcPluginM 'Solve UnifyResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (UnifyResult -> TcPluginM 'Solve UnifyResult)
-> UnifyResult -> TcPluginM 'Solve UnifyResult
forall a b. (a -> b) -> a -> b
$ [TyVar] -> TySubst -> TySubst -> UnifyResult
Win [TyVar]
tvs TySubst
subst TySubst
unsubst
      | NormUnit -> Bool
isConstant NormUnit
u      = UnifyResult -> TcPluginM 'Solve UnifyResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return   UnifyResult
Lose
      | Bool
otherwise         = {- tcPluginTrace "unifyOne" (ppr u) >> -} [(Atom, Integer)]
-> [(Atom, Integer)] -> TcPluginM 'Solve UnifyResult
go [] (NormUnit -> [(Atom, Integer)]
ascending NormUnit
u)

      where
        go :: [(Atom, Integer)] -> [(Atom, Integer)] -> PluginAPI.TcPluginM PluginAPI.Solve UnifyResult
        go :: [(Atom, Integer)]
-> [(Atom, Integer)] -> TcPluginM 'Solve UnifyResult
go [(Atom, Integer)]
_  []                       = UnifyResult -> TcPluginM 'Solve UnifyResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (UnifyResult -> TcPluginM 'Solve UnifyResult)
-> UnifyResult -> TcPluginM 'Solve UnifyResult
forall a b. (a -> b) -> a -> b
$ [TyVar] -> TySubst -> TySubst -> UnifyResult
Draw [TyVar]
tvs TySubst
subst TySubst
unsubst
        go [(Atom, Integer)]
ls (at :: (Atom, Integer)
at@(VarAtom TyVar
a, Integer
i) : [(Atom, Integer)]
xs) = do
            Bool
tch <- if Bool
given_mode then Bool -> TcPluginM 'Solve Bool
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True else TyVar -> TcPluginM 'Solve Bool
forall (m :: * -> *). MonadTcPlugin m => TyVar -> m Bool
PluginAPI.isTouchableTcPluginM TyVar
a
            let r :: NormUnit
r = Integer -> NormUnit -> NormUnit
divideExponents (-Integer
i) (NormUnit -> NormUnit) -> NormUnit -> NormUnit
forall a b. (a -> b) -> a -> b
$ TyVar -> NormUnit -> NormUnit
leftover TyVar
a NormUnit
u
            case () of
                () | Bool
tch Bool -> Bool -> Bool
&& Integer -> NormUnit -> Bool
divisible Integer
i NormUnit
u -> UnifyResult -> TcPluginM 'Solve UnifyResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (UnifyResult -> TcPluginM 'Solve UnifyResult)
-> UnifyResult -> TcPluginM 'Solve UnifyResult
forall a b. (a -> b) -> a -> b
$ if TyVar -> NormUnit -> Bool
occurs TyVar
a NormUnit
r then [TyVar] -> TySubst -> TySubst -> UnifyResult
Draw [TyVar]
tvs TySubst
subst TySubst
unsubst
                                                                    else [TyVar] -> TySubst -> TySubst -> UnifyResult
Win [TyVar]
tvs (SubstItem -> TySubst -> TySubst
extendSubst (TyVar -> NormUnit -> Ct -> SubstItem
SubstItem TyVar
a NormUnit
r Ct
ct) TySubst
subst) TySubst
unsubst
                   | Bool
tch Bool -> Bool -> Bool
&& Bool -> Bool
not (((Atom, Integer) -> Bool) -> [(Atom, Integer)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Atom -> Bool
isBase (Atom -> Bool)
-> ((Atom, Integer) -> Atom) -> (Atom, Integer) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Atom, Integer) -> Atom
forall a b. (a, b) -> a
fst) [(Atom, Integer)]
xs) -> do TyVar
beta <- TcPluginM 'Solve TyVar
newUnitVar
                                                              let subst' :: TySubst
subst'   = SubstItem -> TySubst -> TySubst
extendSubst (TyVar -> NormUnit -> Ct -> SubstItem
SubstItem TyVar
a    (TyVar -> NormUnit
varUnit TyVar
beta NormUnit -> NormUnit -> NormUnit
*: NormUnit
r) Ct
ct) TySubst
subst
                                                                  unsubst' :: TySubst
unsubst' = SubstItem -> TySubst -> TySubst
extendSubst (TyVar -> NormUnit -> Ct -> SubstItem
SubstItem TyVar
beta (TyVar -> NormUnit
varUnit TyVar
a    NormUnit -> NormUnit -> NormUnit
/: NormUnit
r) Ct
ct) TySubst
unsubst
                                                              UnitDefs
-> Ct
-> [TyVar]
-> TySubst
-> TySubst
-> NormUnit
-> TcPluginM 'Solve UnifyResult
unifyOne UnitDefs
uds Ct
ct (TyVar
betaTyVar -> [TyVar] -> [TyVar]
forall a. a -> [a] -> [a]
:[TyVar]
tvs) TySubst
subst' TySubst
unsubst' (NormUnit -> TcPluginM 'Solve UnifyResult)
-> NormUnit -> TcPluginM 'Solve UnifyResult
forall a b. (a -> b) -> a -> b
$ TyVar -> NormUnit -> NormUnit -> NormUnit
substUnit TyVar
a (TyVar -> NormUnit
varUnit TyVar
beta NormUnit -> NormUnit -> NormUnit
*: NormUnit
r) NormUnit
u
                   | Bool
otherwise            -> [(Atom, Integer)]
-> [(Atom, Integer)] -> TcPluginM 'Solve UnifyResult
go ((Atom, Integer)
at(Atom, Integer) -> [(Atom, Integer)] -> [(Atom, Integer)]
forall a. a -> [a] -> [a]
:[(Atom, Integer)]
ls) [(Atom, Integer)]
xs

        go [(Atom, Integer)]
ls (at :: (Atom, Integer)
at@(FamAtom TyCon
f [Type]
tys, Integer
i) : [(Atom, Integer)]
xs) = do
          Maybe Reduction
mb <- TyCon -> [Type] -> TcPluginM 'Solve (Maybe Reduction)
forall (m :: * -> *).
MonadTcPlugin m =>
TyCon -> [Type] -> m (Maybe Reduction)
PluginAPI.matchFam TyCon
f [Type]
tys
          case UnitDefs -> Type -> Maybe NormUnit
normaliseUnit UnitDefs
uds (Type -> Maybe NormUnit)
-> (Reduction -> Type) -> Reduction -> Maybe NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reduction -> Type
PluginAPI.reductionReducedType (Reduction -> Maybe NormUnit) -> Maybe Reduction -> Maybe NormUnit
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe Reduction
mb of
            Just NormUnit
v  -> UnitDefs
-> Ct
-> [TyVar]
-> TySubst
-> TySubst
-> NormUnit
-> TcPluginM 'Solve UnifyResult
unifyOne UnitDefs
uds Ct
ct [TyVar]
tvs TySubst
subst TySubst
unsubst (NormUnit -> TcPluginM 'Solve UnifyResult)
-> NormUnit -> TcPluginM 'Solve UnifyResult
forall a b. (a -> b) -> a -> b
$ [(Atom, Integer)] -> NormUnit
mkNormUnit ([(Atom, Integer)]
ls [(Atom, Integer)] -> [(Atom, Integer)] -> [(Atom, Integer)]
forall a. [a] -> [a] -> [a]
++ [(Atom, Integer)]
xs) NormUnit -> NormUnit -> NormUnit
*: NormUnit
v NormUnit -> Integer -> NormUnit
^: Integer
i
            Maybe NormUnit
Nothing -> [(Atom, Integer)]
-> [(Atom, Integer)] -> TcPluginM 'Solve UnifyResult
go ((Atom, Integer)
at(Atom, Integer) -> [(Atom, Integer)] -> [(Atom, Integer)]
forall a. a -> [a] -> [a]
:[(Atom, Integer)]
ls) [(Atom, Integer)]
xs
        go [(Atom, Integer)]
ls (at :: (Atom, Integer)
at@(BaseAtom  Type
_, Integer
_) : [(Atom, Integer)]
xs) = [(Atom, Integer)]
-> [(Atom, Integer)] -> TcPluginM 'Solve UnifyResult
go ((Atom, Integer)
at(Atom, Integer) -> [(Atom, Integer)] -> [(Atom, Integer)]
forall a. a -> [a] -> [a]
:[(Atom, Integer)]
ls) [(Atom, Integer)]
xs


        given_mode :: Bool
given_mode = CtEvidence -> Bool
isGiven (Ct -> CtEvidence
ctEvidence Ct
ct)

        newUnitVar :: TcPluginM 'Solve TyVar
newUnitVar | Bool
given_mode = TcM TyVar -> TcPluginM 'Solve TyVar
forall a. TcM a -> TcPluginM 'Solve a
forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
PluginAPI.Internal.unsafeLiftTcM (TcM TyVar -> TcPluginM 'Solve TyVar)
-> TcM TyVar -> TcPluginM 'Solve TyVar
forall a b. (a -> b) -> a -> b
$ Type -> TcM TyVar
forall {gbl} {lcl}. Type -> IOEnv (Env gbl lcl) TyVar
newSkolemTyVar (Type -> TcM TyVar) -> Type -> TcM TyVar
forall a b. (a -> b) -> a -> b
$ UnitDefs -> Type
unitKind UnitDefs
uds
                   | Bool
otherwise  = TcM TyVar -> TcPluginM 'Solve TyVar
forall a. TcM a -> TcPluginM 'Solve a
forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
PluginAPI.Internal.unsafeLiftTcM (TcM TyVar -> TcPluginM 'Solve TyVar)
-> TcM TyVar -> TcPluginM 'Solve TyVar
forall a b. (a -> b) -> a -> b
$ Type -> TcM TyVar
GHC.newFlexiTyVar  (Type -> TcM TyVar) -> Type -> TcM TyVar
forall a b. (a -> b) -> a -> b
$ UnitDefs -> Type
unitKind UnitDefs
uds

        newSkolemTyVar :: Type -> IOEnv (Env gbl lcl) TyVar
newSkolemTyVar Type
kind = do
            Unique
x <- TcRnIf gbl lcl Unique
forall gbl lcl. TcRnIf gbl lcl Unique
GHC.newUnique
            let name :: Name
name = Unique -> FastString -> Name
mkSysTvName Unique
x (String -> FastString
fsLit String
"beta")
            TyVar -> IOEnv (Env gbl lcl) TyVar
forall a. a -> IOEnv (Env gbl lcl) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TyVar -> IOEnv (Env gbl lcl) TyVar)
-> TyVar -> IOEnv (Env gbl lcl) TyVar
forall a b. (a -> b) -> a -> b
$ Name -> Type -> TyVar
PluginAPI.mkTyVar Name
name Type
kind -- mkTcTyVar name kind vanillaSkolemTv



data UnitEquality = UnitEquality Ct NormUnit NormUnit

instance Outputable UnitEquality where
  ppr :: UnitEquality -> SDoc
ppr (UnitEquality Ct
ct NormUnit
u NormUnit
v) = String -> SDoc
text String
"UnitEquality" SDoc -> SDoc -> SDoc
$$ Ct -> SDoc
forall a. Outputable a => a -> SDoc
ppr Ct
ct SDoc -> SDoc -> SDoc
$$ NormUnit -> SDoc
forall a. Outputable a => a -> SDoc
ppr NormUnit
u SDoc -> SDoc -> SDoc
$$ NormUnit -> SDoc
forall a. Outputable a => a -> SDoc
ppr NormUnit
v

-- Extract the unit equality constraints
toUnitEquality :: UnitDefs -> Ct -> Either UnitEquality Ct
toUnitEquality :: UnitDefs -> Ct -> Either UnitEquality Ct
toUnitEquality UnitDefs
uds Ct
ct = case Type -> Pred
classifyPredType (Type -> Pred) -> Type -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred (CtEvidence -> Type) -> CtEvidence -> Type
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    EqPred EqRel
NomEq Type
t1 Type
t2
      | UnitDefs -> Type -> Bool
isUnitKind UnitDefs
uds ((() :: Constraint) => Type -> Type
Type -> Type
typeKind Type
t1) Bool -> Bool -> Bool
|| UnitDefs -> Type -> Bool
isUnitKind UnitDefs
uds ((() :: Constraint) => Type -> Type
Type -> Type
typeKind Type
t1)
      , Just NormUnit
u1 <- UnitDefs -> Type -> Maybe NormUnit
normaliseUnit UnitDefs
uds Type
t1
      , Just NormUnit
u2 <- UnitDefs -> Type -> Maybe NormUnit
normaliseUnit UnitDefs
uds Type
t2 -> UnitEquality -> Either UnitEquality Ct
forall a b. a -> Either a b
Left (Ct -> NormUnit -> NormUnit -> UnitEquality
UnitEquality Ct
ct NormUnit
u1 NormUnit
u2)
    IrredPred Type
t
      | Just (TyCon
tc, [Type
t1,Type
t2]) <- (() :: Constraint) => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
t
      , TyCon
tc TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== UnitDefs -> TyCon
equivTyCon UnitDefs
uds
      , Just NormUnit
u1 <- UnitDefs -> Type -> Maybe NormUnit
normaliseUnit UnitDefs
uds Type
t1
      , Just NormUnit
u2 <- UnitDefs -> Type -> Maybe NormUnit
normaliseUnit UnitDefs
uds Type
t2 -> UnitEquality -> Either UnitEquality Ct
forall a b. a -> Either a b
Left (Ct -> NormUnit -> NormUnit -> UnitEquality
UnitEquality Ct
ct NormUnit
u1 NormUnit
u2)
    Pred
_                                   -> Ct -> Either UnitEquality Ct
forall a b. b -> Either a b
Right Ct
ct

fromUnitEquality :: UnitEquality -> Ct
fromUnitEquality :: UnitEquality -> Ct
fromUnitEquality (UnitEquality Ct
ct NormUnit
_ NormUnit
_) = Ct
ct


isUsefulUnitEquality :: UnitEquality -> Bool
isUsefulUnitEquality :: UnitEquality -> Bool
isUsefulUnitEquality (UnitEquality Ct
_ NormUnit
lhs NormUnit
rhs) =
    case (NormUnit -> Maybe TyVar
maybeSingleVariable NormUnit
lhs, NormUnit -> Maybe TyVar
maybeSingleVariable NormUnit
rhs) of
        (Maybe TyVar
Nothing, Maybe TyVar
Nothing) -> Bool
True
        (Just TyVar
v, Maybe TyVar
_)        -> TyVar -> NormUnit -> Bool
occurs TyVar
v NormUnit
rhs
        (Maybe TyVar
_, Just TyVar
v)        -> TyVar -> NormUnit -> Bool
occurs TyVar
v NormUnit
lhs


data SimplifyState
  = SimplifyState { SimplifyState -> [TyVar]
simplifyFreshVars :: [TyVar]
                  , SimplifyState -> TySubst
simplifySubst     :: TySubst
                  , SimplifyState -> TySubst
simplifyUnsubst   :: TySubst
                  , SimplifyState -> [UnitEquality]
simplifySolved    :: [UnitEquality]
                  , SimplifyState -> [UnitEquality]
simplifyStuck     :: [UnitEquality]
                  }

instance Outputable SimplifyState where
  ppr :: SimplifyState -> SDoc
ppr SimplifyState
ss = String -> SDoc
text String
"fresh   = " SDoc -> SDoc -> SDoc
<+> [TyVar] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplifyState -> [TyVar]
simplifyFreshVars SimplifyState
ss)
        SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"subst   = " SDoc -> SDoc -> SDoc
<+> TySubst -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplifyState -> TySubst
simplifySubst     SimplifyState
ss)
        SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"unsubst = " SDoc -> SDoc -> SDoc
<+> TySubst -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplifyState -> TySubst
simplifyUnsubst   SimplifyState
ss)
        SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"solved  = " SDoc -> SDoc -> SDoc
<+> [UnitEquality] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplifyState -> [UnitEquality]
simplifySolved    SimplifyState
ss)
        SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"stuck   = " SDoc -> SDoc -> SDoc
<+> [UnitEquality] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplifyState -> [UnitEquality]
simplifyStuck     SimplifyState
ss)

initialState :: SimplifyState
initialState :: SimplifyState
initialState = [TyVar]
-> TySubst
-> TySubst
-> [UnitEquality]
-> [UnitEquality]
-> SimplifyState
SimplifyState [] [] [] [] []

data SimplifyResult
  = Simplified SimplifyState
  | Impossible { SimplifyResult -> UnitEquality
simplifyImpossible :: UnitEquality
               , SimplifyResult -> [UnitEquality]
simplifyRemaining  :: [UnitEquality]
               }

instance Outputable SimplifyResult where
  ppr :: SimplifyResult -> SDoc
ppr (Simplified SimplifyState
ss)     = String -> SDoc
text String
"Simplified" SDoc -> SDoc -> SDoc
$$ SimplifyState -> SDoc
forall a. Outputable a => a -> SDoc
ppr SimplifyState
ss
  ppr (Impossible UnitEquality
eq [UnitEquality]
eqs) = String -> SDoc
text String
"Impossible" SDoc -> SDoc -> SDoc
<+> UnitEquality -> SDoc
forall a. Outputable a => a -> SDoc
ppr UnitEquality
eq SDoc -> SDoc -> SDoc
<+> [UnitEquality] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [UnitEquality]
eqs

simplifyUnits :: UnitDefs -> [UnitEquality] -> PluginAPI.TcPluginM PluginAPI.Solve SimplifyResult
simplifyUnits :: UnitDefs -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simplifyUnits UnitDefs
uds [UnitEquality]
eqs0 = String -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
PluginAPI.tcPluginTrace String
"simplifyUnits" ([UnitEquality] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [UnitEquality]
eqs0) TcPluginM 'Solve ()
-> TcPluginM 'Solve SimplifyResult
-> TcPluginM 'Solve SimplifyResult
forall a b.
TcPluginM 'Solve a -> TcPluginM 'Solve b -> TcPluginM 'Solve b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SimplifyState -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simples SimplifyState
initialState [UnitEquality]
eqs0
  where
    simples :: SimplifyState -> [UnitEquality] -> PluginAPI.TcPluginM PluginAPI.Solve SimplifyResult
    simples :: SimplifyState -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simples SimplifyState
ss [] = SimplifyResult -> TcPluginM 'Solve SimplifyResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplifyResult -> TcPluginM 'Solve SimplifyResult)
-> SimplifyResult -> TcPluginM 'Solve SimplifyResult
forall a b. (a -> b) -> a -> b
$ SimplifyState -> SimplifyResult
Simplified SimplifyState
ss
    simples SimplifyState
ss (UnitEquality
eq:[UnitEquality]
eqs) = do
        UnifyResult
ur <- UnitDefs -> UnitEquality -> TcPluginM 'Solve UnifyResult
unifyUnits UnitDefs
uds (TySubst -> UnitEquality -> UnitEquality
substsUnitEquality (SimplifyState -> TySubst
simplifySubst SimplifyState
ss) UnitEquality
eq)
        String -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
PluginAPI.tcPluginTrace String
"unifyUnits result" (UnifyResult -> SDoc
forall a. Outputable a => a -> SDoc
ppr UnifyResult
ur)
        case UnifyResult
ur of
          Win  [TyVar]
tvs TySubst
subst TySubst
unsubst -> let (SimplifyState
ss', [UnitEquality]
xs) = UnitEquality
-> [TyVar]
-> TySubst
-> TySubst
-> SimplifyState
-> (SimplifyState, [UnitEquality])
win UnitEquality
eq [TyVar]
tvs TySubst
subst TySubst
unsubst SimplifyState
ss
                                    in SimplifyState -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simples SimplifyState
ss' ([UnitEquality]
xs [UnitEquality] -> [UnitEquality] -> [UnitEquality]
forall a. [a] -> [a] -> [a]
++ [UnitEquality]
eqs)
          Draw [TyVar]
_   []    TySubst
_       -> SimplifyState -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simples (UnitEquality -> SimplifyState -> SimplifyState
addStuck UnitEquality
eq SimplifyState
ss) [UnitEquality]
eqs
          Draw [TyVar]
tvs TySubst
subst TySubst
unsubst -> let (SimplifyState
ss', [UnitEquality]
xs) = UnitEquality
-> [TyVar]
-> TySubst
-> TySubst
-> SimplifyState
-> (SimplifyState, [UnitEquality])
draw UnitEquality
eq [TyVar]
tvs TySubst
subst TySubst
unsubst SimplifyState
ss
                                    in SimplifyState -> [UnitEquality] -> TcPluginM 'Solve SimplifyResult
simples SimplifyState
ss' ([UnitEquality]
xs [UnitEquality] -> [UnitEquality] -> [UnitEquality]
forall a. [a] -> [a] -> [a]
++ [UnitEquality]
eqs)
          UnifyResult
Lose                   -> SimplifyResult -> TcPluginM 'Solve SimplifyResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return Impossible { simplifyImpossible :: UnitEquality
simplifyImpossible = UnitEquality
eq
                                                      , simplifyRemaining :: [UnitEquality]
simplifyRemaining  = SimplifyState -> [UnitEquality]
simplifyStuck SimplifyState
ss [UnitEquality] -> [UnitEquality] -> [UnitEquality]
forall a. [a] -> [a] -> [a]
++ [UnitEquality]
eqs }

win :: UnitEquality -> [TyVar] -> TySubst -> TySubst -> SimplifyState -> (SimplifyState, [UnitEquality])
win :: UnitEquality
-> [TyVar]
-> TySubst
-> TySubst
-> SimplifyState
-> (SimplifyState, [UnitEquality])
win UnitEquality
eq [TyVar]
tvs TySubst
subst TySubst
unsubst SimplifyState
ss =
  ( SimplifyState { simplifyFreshVars :: [TyVar]
simplifyFreshVars = SimplifyState -> [TyVar]
simplifyFreshVars SimplifyState
ss [TyVar] -> [TyVar] -> [TyVar]
forall a. [a] -> [a] -> [a]
++ [TyVar]
tvs
                  , simplifySubst :: TySubst
simplifySubst     = TySubst -> TySubst -> TySubst
substsSubst TySubst
subst (SimplifyState -> TySubst
simplifySubst SimplifyState
ss) TySubst -> TySubst -> TySubst
forall a. [a] -> [a] -> [a]
++ TySubst
subst
                  , simplifyUnsubst :: TySubst
simplifyUnsubst   = TySubst -> TySubst -> TySubst
substsSubst TySubst
unsubst (SimplifyState -> TySubst
simplifyUnsubst SimplifyState
ss) TySubst -> TySubst -> TySubst
forall a. [a] -> [a] -> [a]
++ TySubst
unsubst
                  , simplifySolved :: [UnitEquality]
simplifySolved    = UnitEquality
eq UnitEquality -> [UnitEquality] -> [UnitEquality]
forall a. a -> [a] -> [a]
: SimplifyState -> [UnitEquality]
simplifySolved SimplifyState
ss
                  , simplifyStuck :: [UnitEquality]
simplifyStuck     = []
                  }
  , SimplifyState -> [UnitEquality]
simplifyStuck SimplifyState
ss )

draw :: UnitEquality -> [TyVar] -> TySubst -> TySubst -> SimplifyState -> (SimplifyState, [UnitEquality])
draw :: UnitEquality
-> [TyVar]
-> TySubst
-> TySubst
-> SimplifyState
-> (SimplifyState, [UnitEquality])
draw UnitEquality
eq [TyVar]
tvs TySubst
subst TySubst
unsubst SimplifyState
ss =
  ( SimplifyState { simplifyFreshVars :: [TyVar]
simplifyFreshVars = SimplifyState -> [TyVar]
simplifyFreshVars SimplifyState
ss [TyVar] -> [TyVar] -> [TyVar]
forall a. [a] -> [a] -> [a]
++ [TyVar]
tvs
                  , simplifySubst :: TySubst
simplifySubst     = TySubst -> TySubst -> TySubst
substsSubst TySubst
subst (SimplifyState -> TySubst
simplifySubst SimplifyState
ss) TySubst -> TySubst -> TySubst
forall a. [a] -> [a] -> [a]
++ TySubst
subst
                  , simplifyUnsubst :: TySubst
simplifyUnsubst   = TySubst -> TySubst -> TySubst
substsSubst TySubst
unsubst (SimplifyState -> TySubst
simplifyUnsubst SimplifyState
ss) TySubst -> TySubst -> TySubst
forall a. [a] -> [a] -> [a]
++ TySubst
unsubst
                  , simplifySolved :: [UnitEquality]
simplifySolved    = SimplifyState -> [UnitEquality]
simplifySolved SimplifyState
ss
                  , simplifyStuck :: [UnitEquality]
simplifyStuck     = [UnitEquality
eq]
                  }
  , SimplifyState -> [UnitEquality]
simplifyStuck SimplifyState
ss )

addStuck :: UnitEquality -> SimplifyState -> SimplifyState
addStuck :: UnitEquality -> SimplifyState -> SimplifyState
addStuck UnitEquality
eq SimplifyState
ss = SimplifyState
ss { simplifyStuck :: [UnitEquality]
simplifyStuck = UnitEquality
eq UnitEquality -> [UnitEquality] -> [UnitEquality]
forall a. a -> [a] -> [a]
: SimplifyState -> [UnitEquality]
simplifyStuck SimplifyState
ss }