{-|
Copyright  :  (C) 2015-2016, University of Twente
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>
-}
{-# LANGUAGE CPP             #-}
{-# LANGUAGE LambdaCase      #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module GHC.TcPluginM.Extra
  ( -- * Create new constraints
    newWanted
  , newGiven
  , newDerived
#if __GLASGOW_HASKELL__ < 711
  , newWantedWithProvenance
#endif
    -- * Creating evidence
  , evByFiat
#if __GLASGOW_HASKELL__ < 711
    -- * Report contractions
  , failWithProvenace
#endif
    -- * Lookup
  , lookupModule
  , lookupName
    -- * Trace state of the plugin
  , tracePlugin
    -- * Substitutions
  , flattenGivens
  , mkSubst
  , mkSubst'
  , substType
  , substCt
  )
where

-- External
import Data.Maybe (mapMaybe)
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List     (groupBy, partition, sortOn)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Core (Expr (..))
import GHC.Core.Coercion (Role (..), mkPrimEqPred, mkUnivCo)
import GHC.Core.Type  (PredType)
import GHC.Core.TyCo.Rep (Type (..), UnivCoProvenance (..))
import GHC.Data.FastString (FastString, fsLit)
import qualified GHC.Driver.Finder as Finder
import GHC.Unit.Module (Module, ModuleName)
import GHC.Tc.Plugin (FindResult (..), TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPluginM
import GHC.Tc.Utils.TcType (TcTyVar, TcType)
import GHC.Tc.Types (TcPlugin (..), TcPluginResult (..))
import GHC.Tc.Types.Constraint
  (Ct (..), CtLoc, CtEvidence (..), ctEvId, ctLoc, mkNonCanonical)
import GHC.Tc.Types.Evidence (EvTerm (..))
import GHC.Types.Name (Name)
import GHC.Types.Name.Occurrence (OccName)
import GHC.Utils.Outputable ((<+>), ($$), empty, ppr, text)
import GHC.Utils.Panic (panicDoc)
#else
#if __GLASGOW_HASKELL__ < 711
import BasicTypes (TopLevelFlag (..))
#endif
#if MIN_VERSION_ghc(8,5,0)
import CoreSyn    (Expr(..))
#endif
import Coercion   (Role (..), mkUnivCo)
import FastString (FastString, fsLit)
import Module     (Module, ModuleName)
import Name       (Name)
import OccName    (OccName)
import Outputable (($$), (<+>), empty, ppr, text)
import Panic      (panicDoc)
#if __GLASGOW_HASKELL__ >= 711
import TcEvidence (EvTerm (..))
#else
import TcEvidence (EvTerm (..), TcCoercion (..))
import TcMType    (newEvVar)
#endif
#if __GLASGOW_HASKELL__ < 711
import TcPluginM  (FindResult (..), TcPluginM, findImportedModule, lookupOrig,
                   tcPluginTrace, unsafeTcPluginTcM)
import TcRnTypes  (Ct, CtEvidence (..), CtLoc, TcIdBinder (..), TcLclEnv (..),
                   TcPlugin (..), TcPluginResult (..), ctEvLoc,
                   ctLocEnv, setCtLocEnv)
#else
import TcPluginM  (FindResult (..), TcPluginM, lookupOrig, tcPluginTrace)
import qualified  TcPluginM
import qualified  Finder
#if __GLASGOW_HASKELL__ < 809
import TcRnTypes  (CtEvidence (..), CtLoc,
                   TcPlugin (..), TcPluginResult (..))
#else
import TcRnTypes  (TcPlugin (..), TcPluginResult (..))
#endif
#endif
#if __GLASGOW_HASKELL__ < 802
import TcPluginM  (tcPluginIO)
#endif
#if __GLASGOW_HASKELL__ >= 711
import TyCoRep    (UnivCoProvenance (..))
import Type       (PredType, Type)
#else
import Type       (EqRel (..), PredTree (..), PredType, Type, classifyPredType)
import Var        (varType)
#endif
#if __GLASGOW_HASKELL__ < 809
import TcRnTypes     (Ct (..), ctLoc, ctEvId, mkNonCanonical)
#else
import Constraint
  (Ct (..), CtEvidence (..), CtLoc, ctLoc, ctEvId, mkNonCanonical)
#endif
import TcType        (TcTyVar, TcType)
#if __GLASGOW_HASKELL__ < 809
import Type          (mkPrimEqPred)
#else
import Predicate     (mkPrimEqPred)
#endif
#if __GLASGOW_HASKELL__ < 711
import TcRnTypes     (ctEvTerm)
import TypeRep       (Type (..))
#else
import TyCoRep       (Type (..))
#endif
#endif

-- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301
#if __GLASGOW_HASKELL__ < 802
import Data.IORef    (readIORef)
import Control.Monad (unless)
import StaticFlags   (initStaticOpts, v_opt_C_ready)
#endif

#if __GLASGOW_HASKELL__ >= 711
pattern FoundModule :: Module -> FindResult
pattern $mFoundModule :: forall r. FindResult -> (Module -> r) -> (Void# -> r) -> r
FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod :: a -> a
fr_mod = a -> a
forall a. a -> a
id
#endif


#if __GLASGOW_HASKELL__ < 711
{-# DEPRECATED newWantedWithProvenance "No longer available in GHC 8.0+" #-}
-- | Create a new [W]anted constraint that remembers from which wanted
-- constraint it was derived
newWantedWithProvenance :: CtEvidence -- ^ Constraint from which the new
                                      -- wanted is derived
                        -> PredType   -- ^ The type of the new constraint
                        -> TcPluginM CtEvidence
newWantedWithProvenance ev@(CtWanted {}) p = do
  let loc    = ctEvLoc ev
      env    = ctLocEnv loc
      id_    = ctEvId ev
      env'   = env {tcl_bndrs = (TcIdBndr id_ NotTopLevel):tcl_bndrs env}
      loc'   = setCtLocEnv loc env'
  evVar <- unsafeTcPluginTcM $ newEvVar p
  return CtWanted { ctev_pred = p
                  , ctev_evar = evVar
                  , ctev_loc  = loc'}

newWantedWithProvenance ev _ =
  panicDoc "newWantedWithProvenance: not a Wanted: " (ppr ev)
#endif

-- | Create a new [W]anted constraint.
newWanted  :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted = CtLoc -> PredType -> TcPluginM CtEvidence
TcPluginM.newWanted
#else
newWanted loc pty = do
    new_ev <- unsafeTcPluginTcM $ newEvVar pty
    return CtWanted { ctev_pred = pty
                    , ctev_evar = new_ev
                    , ctev_loc  = loc
                    }
#endif

-- | Create a new [G]iven constraint, with the supplied evidence. This must not
-- be invoked from 'tcPluginInit' or 'tcPluginStop', or it will panic.
newGiven :: CtLoc -> PredType -> EvTerm -> TcPluginM CtEvidence
#if MIN_VERSION_ghc(8,5,0)
newGiven :: CtLoc -> PredType -> EvTerm -> TcPluginM CtEvidence
newGiven CtLoc
loc PredType
pty (EvExpr EvExpr
ev) = CtLoc -> PredType -> EvExpr -> TcPluginM CtEvidence
TcPluginM.newGiven CtLoc
loc PredType
pty EvExpr
ev
newGiven CtLoc
_ PredType
_  EvTerm
ev = String -> SDoc -> TcPluginM CtEvidence
forall a. String -> SDoc -> a
panicDoc String
"newGiven: not an EvExpr: " (EvTerm -> SDoc
forall a. Outputable a => a -> SDoc
ppr EvTerm
ev)
#elif __GLASGOW_HASKELL__ >= 711
newGiven = TcPluginM.newGiven
#else
newGiven loc pty evtm = return
  CtGiven { ctev_pred = pty
          , ctev_evtm = evtm
          , ctev_loc  = loc
          }
#endif

-- | Create a new [D]erived constraint.
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived = CtLoc -> PredType -> TcPluginM CtEvidence
TcPluginM.newDerived
#else
newDerived loc pty = return
  CtDerived { ctev_pred = pty
            , ctev_loc  = loc
            }
#endif

-- | The 'EvTerm' equivalent for 'Unsafe.unsafeCoerce'
evByFiat :: String -- ^ Name the coercion should have
         -> Type   -- ^ The LHS of the equivalence relation (~)
         -> Type   -- ^ The RHS of the equivalence relation (~)
         -> EvTerm
evByFiat :: String -> PredType -> PredType -> EvTerm
evByFiat String
name PredType
t1 PredType
t2 =
#if MIN_VERSION_ghc(8,5,0)
                      EvExpr -> EvTerm
EvExpr
                    (EvExpr -> EvTerm) -> EvExpr -> EvTerm
forall a b. (a -> b) -> a -> b
$ Coercion -> EvExpr
forall b. Coercion -> Expr b
Coercion
#else
                      EvCoercion
#if __GLASGOW_HASKELL__ < 711
                    $ TcCoercion
#endif
#endif
                    (Coercion -> EvExpr) -> Coercion -> EvExpr
forall a b. (a -> b) -> a -> b
$ UnivCoProvenance -> Role -> PredType -> PredType -> Coercion
mkUnivCo
#if __GLASGOW_HASKELL__ >= 711
                        (String -> UnivCoProvenance
PluginProv String
name)
#else
                        (fsLit name)
#endif
                        Role
Nominal PredType
t1 PredType
t2

#if __GLASGOW_HASKELL__ < 711
{-# DEPRECATED failWithProvenace "No longer available in GHC 8.0+" #-}
-- | Mark the given constraint as insoluble.
--
-- If the [W]anted constraint was made by 'newWantedWithProvenance', it will
-- also mark the parent(s) from which the constraint was derived as insoluble.
-- Even if they were previously assumed to be solved.
failWithProvenace :: Ct -> TcPluginM TcPluginResult
failWithProvenace ct = return (TcPluginContradiction (ct : parents))
  where
    loc      = ctLoc ct
    lclbndrs = mapMaybe (\case {TcIdBndr id_ NotTopLevel -> Just id_
                               ;_ -> Nothing })
             $ tcl_bndrs (ctLocEnv loc)
    eqBndrs = filter ((\x -> case x of { EqPred NomEq _ _ -> True
                                       ; _ -> False })
                      . classifyPredType . snd)
            $ map (\ev -> (ev,varType ev)) lclbndrs
    parents = map (\(id_,p) -> mkNonCanonical $ CtWanted p id_ loc) eqBndrs
#endif

-- | Find a module
lookupModule :: ModuleName -- ^ Name of the module
             -> FastString -- ^ Name of the package containing the module.
                           -- NOTE: This value is ignored on ghc>=8.0.
             -> TcPluginM Module
lookupModule :: ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
mod_nm FastString
_pkg = do
#if __GLASGOW_HASKELL__ >= 711
  HscEnv
hsc_env <- TcPluginM HscEnv
TcPluginM.getTopEnv
  FindResult
found_module <- IO FindResult -> TcPluginM FindResult
forall a. IO a -> TcPluginM a
TcPluginM.tcPluginIO (IO FindResult -> TcPluginM FindResult)
-> IO FindResult -> TcPluginM FindResult
forall a b. (a -> b) -> a -> b
$ HscEnv -> ModuleName -> IO FindResult
Finder.findPluginModule HscEnv
hsc_env ModuleName
mod_nm
#else
  found_module <- findImportedModule mod_nm $ Just _pkg
#endif
  case FindResult
found_module of
#if __GLASGOW_HASKELL__ >= 711
    FoundModule Module
h -> Module -> TcPluginM Module
forall (m :: * -> *) a. Monad m => a -> m a
return (Module -> Module
forall a. a -> a
fr_mod Module
h)
#else
    Found _ md -> return md
#endif
    FindResult
_          -> do
      FindResult
found_module' <- ModuleName -> Maybe FastString -> TcPluginM FindResult
TcPluginM.findImportedModule ModuleName
mod_nm (Maybe FastString -> TcPluginM FindResult)
-> Maybe FastString -> TcPluginM FindResult
forall a b. (a -> b) -> a -> b
$ FastString -> Maybe FastString
forall a. a -> Maybe a
Just (FastString -> Maybe FastString) -> FastString -> Maybe FastString
forall a b. (a -> b) -> a -> b
$ String -> FastString
fsLit String
"this"
      case FindResult
found_module' of
#if __GLASGOW_HASKELL__ >= 711
        FoundModule Module
h -> Module -> TcPluginM Module
forall (m :: * -> *) a. Monad m => a -> m a
return (Module -> Module
forall a. a -> a
fr_mod Module
h)
#else
        Found _ md -> return md
#endif
        FindResult
_          -> String -> SDoc -> TcPluginM Module
forall a. String -> SDoc -> a
panicDoc String
"Unable to resolve module looked up by plugin: "
                               (ModuleName -> SDoc
forall a. Outputable a => a -> SDoc
ppr ModuleName
mod_nm)

-- | Find a 'Name' in a 'Module' given an 'OccName'
lookupName :: Module -> OccName -> TcPluginM Name
lookupName :: Module -> OccName -> TcPluginM Name
lookupName Module
md OccName
occ = Module -> OccName -> TcPluginM Name
lookupOrig Module
md OccName
occ

-- | Print out extra information about the initialisation, stop, and every run
-- of the plugin when @-ddump-tc-trace@ is enabled.
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin String
s TcPlugin{TcPluginM s
s -> TcPluginM ()
s -> TcPluginSolver
tcPluginInit :: ()
tcPluginSolve :: ()
tcPluginStop :: ()
tcPluginStop :: s -> TcPluginM ()
tcPluginSolve :: s -> TcPluginSolver
tcPluginInit :: TcPluginM s
..} = TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin { tcPluginInit :: TcPluginM s
tcPluginInit  = TcPluginM s
traceInit
                                      , tcPluginSolve :: s -> TcPluginSolver
tcPluginSolve = s -> TcPluginSolver
traceSolve
                                      , tcPluginStop :: s -> TcPluginM ()
tcPluginStop  = s -> TcPluginM ()
traceStop
                                      }
  where
    traceInit :: TcPluginM s
traceInit = do
      -- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301
      TcPluginM ()
initializeStaticFlags
      String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginInit " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty TcPluginM () -> TcPluginM s -> TcPluginM s
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TcPluginM s
tcPluginInit

    traceStop :: s -> TcPluginM ()
traceStop  s
z = String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginStop " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty TcPluginM () -> TcPluginM () -> TcPluginM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> s -> TcPluginM ()
tcPluginStop s
z

    traceSolve :: s -> TcPluginSolver
traceSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted = do
      String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve start " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
                        (String -> SDoc
text String
"given   =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
given
                      SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"derived =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
derived
                      SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"wanted  =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
wanted)
      TcPluginResult
r <- s -> TcPluginSolver
tcPluginSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted
      case TcPluginResult
r of
        TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new     -> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve ok " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
                                         (String -> SDoc
text String
"solved =" SDoc -> SDoc -> SDoc
<+> [(EvTerm, Ct)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
                                       SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"new    =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
        TcPluginContradiction [Ct]
bad -> String -> SDoc -> TcPluginM ()
tcPluginTrace
                                         (String
"tcPluginSolve contradiction " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
                                         (String -> SDoc
text String
"bad =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
bad)
      TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return TcPluginResult
r

-- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301
initializeStaticFlags :: TcPluginM ()
#if __GLASGOW_HASKELL__ < 802
initializeStaticFlags = tcPluginIO $ do
  r <- readIORef v_opt_C_ready
  unless r initStaticOpts
#else
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags = () -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

-- | Flattens evidence of constraints by substituting each others equalities.
--
-- __NB:__ Should only be used on /[G]iven/ constraints!
--
-- __NB:__ Doesn't flatten under binders
flattenGivens
  :: [Ct]
  -> [Ct]
flattenGivens :: [Ct] -> [Ct]
flattenGivens [Ct]
givens =
  ([((TcTyVar, PredType), Ct)] -> Maybe Ct)
-> [[((TcTyVar, PredType), Ct)]] -> [Ct]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [((TcTyVar, PredType), Ct)] -> Maybe Ct
flatToCt [[((TcTyVar, PredType), Ct)]]
flat [Ct] -> [Ct] -> [Ct]
forall a. [a] -> [a] -> [a]
++ (Ct -> Ct) -> [Ct] -> [Ct]
forall a b. (a -> b) -> [a] -> [b]
map ([(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst') [Ct]
givens
 where
  subst :: [((TcTyVar, PredType), Ct)]
subst = [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' [Ct]
givens
  ([[((TcTyVar, PredType), Ct)]]
flat,[(TcTyVar, PredType)]
subst')
    = ([[((TcTyVar, PredType), Ct)]] -> [(TcTyVar, PredType)])
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
-> ([[((TcTyVar, PredType), Ct)]], [(TcTyVar, PredType)])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> [((TcTyVar, PredType), Ct)] -> [(TcTyVar, PredType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst ([((TcTyVar, PredType), Ct)] -> [(TcTyVar, PredType)])
-> ([[((TcTyVar, PredType), Ct)]] -> [((TcTyVar, PredType), Ct)])
-> [[((TcTyVar, PredType), Ct)]]
-> [(TcTyVar, PredType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[((TcTyVar, PredType), Ct)]] -> [((TcTyVar, PredType), Ct)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat)
    (([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
 -> ([[((TcTyVar, PredType), Ct)]], [(TcTyVar, PredType)]))
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
-> ([[((TcTyVar, PredType), Ct)]], [(TcTyVar, PredType)])
forall a b. (a -> b) -> a -> b
$ ([((TcTyVar, PredType), Ct)] -> Bool)
-> [[((TcTyVar, PredType), Ct)]]
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2) (Int -> Bool)
-> ([((TcTyVar, PredType), Ct)] -> Int)
-> [((TcTyVar, PredType), Ct)]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [((TcTyVar, PredType), Ct)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length)
    ([[((TcTyVar, PredType), Ct)]]
 -> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]]))
-> [[((TcTyVar, PredType), Ct)]]
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, PredType), Ct) -> ((TcTyVar, PredType), Ct) -> Bool)
-> [((TcTyVar, PredType), Ct)] -> [[((TcTyVar, PredType), Ct)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (TcTyVar -> TcTyVar -> Bool
forall a. Eq a => a -> a -> Bool
(==) (TcTyVar -> TcTyVar -> Bool)
-> (((TcTyVar, PredType), Ct) -> TcTyVar)
-> ((TcTyVar, PredType), Ct)
-> ((TcTyVar, PredType), Ct)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ((TcTyVar, PredType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, PredType) -> TcTyVar)
-> (((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> ((TcTyVar, PredType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst))
    ([((TcTyVar, PredType), Ct)] -> [[((TcTyVar, PredType), Ct)]])
-> [((TcTyVar, PredType), Ct)] -> [[((TcTyVar, PredType), Ct)]]
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, PredType), Ct) -> TcTyVar)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn ((TcTyVar, PredType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, PredType) -> TcTyVar)
-> (((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> ((TcTyVar, PredType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst) [((TcTyVar, PredType), Ct)]
subst

  flatToCt :: [((TcTyVar,TcType),Ct)] -> Maybe Ct
  flatToCt :: [((TcTyVar, PredType), Ct)] -> Maybe Ct
flatToCt [((TcTyVar
_,PredType
lhs),Ct
ct),((TcTyVar
_,PredType
rhs),Ct
_)]
    = Ct -> Maybe Ct
forall a. a -> Maybe a
Just
    (Ct -> Maybe Ct) -> Ct -> Maybe Ct
forall a b. (a -> b) -> a -> b
$ CtEvidence -> Ct
mkNonCanonical
    (CtEvidence -> Ct) -> CtEvidence -> Ct
forall a b. (a -> b) -> a -> b
$ PredType -> TcTyVar -> CtLoc -> CtEvidence
CtGiven (PredType -> PredType -> PredType
mkPrimEqPred PredType
lhs PredType
rhs)
#if MIN_VERSION_ghc(8,4,0)
              (Ct -> TcTyVar
ctEvId Ct
ct)
#elif MIN_VERSION_ghc(8,0,0)
              (ctEvId (cc_ev ct))
#else
              (ctEvTerm (cc_ev ct))
#endif
              (Ct -> CtLoc
ctLoc Ct
ct)

  flatToCt [((TcTyVar, PredType), Ct)]
_ = Maybe Ct
forall a. Maybe a
Nothing


-- | Create flattened substitutions from type equalities, i.e. the substitutions
-- have been applied to each others right hand sides.
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' :: [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' = (((TcTyVar, PredType), Ct)
 -> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)])
-> [((TcTyVar, PredType), Ct)]
-> [((TcTyVar, PredType), Ct)]
-> [((TcTyVar, PredType), Ct)]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst [] ([((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)])
-> ([Ct] -> [((TcTyVar, PredType), Ct)])
-> [Ct]
-> [((TcTyVar, PredType), Ct)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ct -> Maybe ((TcTyVar, PredType), Ct))
-> [Ct] -> [((TcTyVar, PredType), Ct)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ct -> Maybe ((TcTyVar, PredType), Ct)
mkSubst
 where
  substSubst :: ((TcTyVar,TcType),Ct)
             -> [((TcTyVar,TcType),Ct)]
             -> [((TcTyVar,TcType),Ct)]
  substSubst :: ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst ((TcTyVar
tv,PredType
t),Ct
ct) [((TcTyVar, PredType), Ct)]
s = ((TcTyVar
tv,[(TcTyVar, PredType)] -> PredType -> PredType
substType ((((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> [((TcTyVar, PredType), Ct)] -> [(TcTyVar, PredType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst [((TcTyVar, PredType), Ct)]
s) PredType
t),Ct
ct)
                           ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
forall a. a -> [a] -> [a]
: (((TcTyVar, PredType), Ct) -> ((TcTyVar, PredType), Ct))
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
forall a b. (a -> b) -> [a] -> [b]
map (((TcTyVar, PredType) -> (TcTyVar, PredType))
-> ((TcTyVar, PredType), Ct) -> ((TcTyVar, PredType), Ct)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((PredType -> PredType)
-> (TcTyVar, PredType) -> (TcTyVar, PredType)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar
tv,PredType
t)]))) [((TcTyVar, PredType), Ct)]
s

-- | Create simple substitution from type equalities
mkSubst
  :: Ct
  -> Maybe ((TcTyVar, TcType),Ct)
mkSubst :: Ct -> Maybe ((TcTyVar, PredType), Ct)
mkSubst ct :: Ct
ct@(CTyEqCan {CtEvidence
EqRel
PredType
TcTyVar
cc_ev :: Ct -> CtEvidence
cc_tyvar :: Ct -> TcTyVar
cc_rhs :: Ct -> PredType
cc_eq_rel :: Ct -> EqRel
cc_eq_rel :: EqRel
cc_rhs :: PredType
cc_tyvar :: TcTyVar
cc_ev :: CtEvidence
..})  = ((TcTyVar, PredType), Ct) -> Maybe ((TcTyVar, PredType), Ct)
forall a. a -> Maybe a
Just ((TcTyVar
cc_tyvar,PredType
cc_rhs),Ct
ct)
mkSubst ct :: Ct
ct@(CFunEqCan {[PredType]
CtEvidence
TcTyVar
TyCon
cc_tyargs :: Ct -> [PredType]
cc_fun :: Ct -> TyCon
cc_fsk :: Ct -> TcTyVar
cc_fsk :: TcTyVar
cc_tyargs :: [PredType]
cc_fun :: TyCon
cc_ev :: CtEvidence
cc_ev :: Ct -> CtEvidence
..}) = ((TcTyVar, PredType), Ct) -> Maybe ((TcTyVar, PredType), Ct)
forall a. a -> Maybe a
Just ((TcTyVar
cc_fsk,TyCon -> [PredType] -> PredType
TyConApp TyCon
cc_fun [PredType]
cc_tyargs),Ct
ct)
mkSubst Ct
_                   = Maybe ((TcTyVar, PredType), Ct)
forall a. Maybe a
Nothing

-- | Apply substitution in the evidence of Cts
substCt
  :: [(TcTyVar, TcType)]
  -> Ct
  -> Ct
substCt :: [(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst Ct
ct =
  Ct
ct { cc_ev :: CtEvidence
cc_ev = (Ct -> CtEvidence
cc_ev Ct
ct) {ctev_pred :: PredType
ctev_pred = [(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst (CtEvidence -> PredType
ctev_pred (Ct -> CtEvidence
cc_ev Ct
ct))}
     }

-- | Apply substitutions in Types
--
-- __NB:__ Doesn't substitute under binders
substType
  :: [(TcTyVar, TcType)]
  -> TcType
  -> TcType
substType :: [(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst tv :: PredType
tv@(TyVarTy TcTyVar
v) = case TcTyVar -> [(TcTyVar, PredType)] -> Maybe PredType
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup TcTyVar
v [(TcTyVar, PredType)]
subst of
  Just PredType
t  -> PredType
t
  Maybe PredType
Nothing -> PredType
tv
substType [(TcTyVar, PredType)]
subst (AppTy PredType
t1 PredType
t2) =
  PredType -> PredType -> PredType
AppTy ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst PredType
t1) ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst PredType
t2)
substType [(TcTyVar, PredType)]
subst (TyConApp TyCon
tc [PredType]
xs) =
  TyCon -> [PredType] -> PredType
TyConApp TyCon
tc ((PredType -> PredType) -> [PredType] -> [PredType]
forall a b. (a -> b) -> [a] -> [b]
map ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst) [PredType]
xs)
substType [(TcTyVar, PredType)]
_subst t :: PredType
t@(ForAllTy TyCoVarBinder
_tv PredType
_ty) =
  -- TODO: Is it safe to do "dumb" substitution under binders?
  -- ForAllTy tv (substType subst ty)
  PredType
t
#if __GLASGOW_HASKELL__ >= 900
substType subst (FunTy k1 k2 t1 t2) =
  FunTy k1 k2 (substType subst t1) (substType subst t2)
#elif __GLASGOW_HASKELL__ >= 809
substType [(TcTyVar, PredType)]
subst (FunTy AnonArgFlag
af PredType
t1 PredType
t2) =
  AnonArgFlag -> PredType -> PredType -> PredType
FunTy AnonArgFlag
af ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst PredType
t1) ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst PredType
t2)
#elif __GLASGOW_HASKELL__ >= 802
substType subst (FunTy t1 t2) =
  FunTy (substType subst t1) (substType subst t2)
#elif __GLASGOW_HASKELL__ < 711
substType subst (FunTy t1 t2) =
  FunTy (substType subst t1) (substType subst t2)
#endif
substType [(TcTyVar, PredType)]
_ l :: PredType
l@(LitTy TyLit
_) = PredType
l
#if __GLASGOW_HASKELL__ > 711
substType [(TcTyVar, PredType)]
subst (CastTy PredType
ty Coercion
co) =
  PredType -> Coercion -> PredType
CastTy ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst PredType
ty) Coercion
co
substType [(TcTyVar, PredType)]
_ co :: PredType
co@(CoercionTy Coercion
_) = PredType
co
#endif