module TypeLet.Plugin.Constraints (
    -- * Constraints recognized by the plugin
    CEqual(..)
  , CLet(..)
    -- * Parsing
    -- ** Infrastructure
  , ParseResult(..)
  , parseAll
  , parseAll'
  , withOrig
    -- ** SPecific parsers
  , InvalidLet(..)
  , parseEqual
  , parseLet
    -- * Evidence construction
  , evidenceEqual
    -- * Formatting errors
  , formatCLet
  , formatInvalidLet
  ) where

import Data.Bifunctor
import Data.Void

import TypeLet.Plugin.GhcTcPluginAPI
import TypeLet.Plugin.NameResolution

{-------------------------------------------------------------------------------
  Constraints recognized by the plugin
-------------------------------------------------------------------------------}

data CLet = CLet {
      CLet -> Type
letKind :: Type
    , CLet -> TyVar
letLHS  :: TyVar
    , CLet -> Type
letRHS  :: Type
    }

data CEqual = CEqual {
      CEqual -> Type
equalKind :: Type
    , CEqual -> Type
equalLHS  :: Type
    , CEqual -> Type
equalRHS  :: Type
    }

instance Outputable CLet where
  ppr :: CLet -> SDoc
ppr (CLet Type
k TyVar
a Type
b) = SDoc -> SDoc
parens forall a b. (a -> b) -> a -> b
$ String -> SDoc
text String
"Let" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
k SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr TyVar
a SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
b

instance Outputable CEqual where
  ppr :: CEqual -> SDoc
ppr (CEqual Type
k Type
a Type
b) = SDoc -> SDoc
parens forall a b. (a -> b) -> a -> b
$ String -> SDoc
text String
"Equal" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
k SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
a SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
b

{-------------------------------------------------------------------------------
  Parsing infrastructure
-------------------------------------------------------------------------------}

data ParseResult e a =
    -- | Parse successful
    ParseOk a

    -- | Different constraint than we're looking for (does not imply an error)
  | ParseNoMatch

    -- | Constraint of the shape we're looking for, but something is wrong
  | ParseError e
  deriving (forall a b. a -> ParseResult e b -> ParseResult e a
forall a b. (a -> b) -> ParseResult e a -> ParseResult e b
forall e a b. a -> ParseResult e b -> ParseResult e a
forall e a b. (a -> b) -> ParseResult e a -> ParseResult e b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ParseResult e b -> ParseResult e a
$c<$ :: forall e a b. a -> ParseResult e b -> ParseResult e a
fmap :: forall a b. (a -> b) -> ParseResult e a -> ParseResult e b
$cfmap :: forall e a b. (a -> b) -> ParseResult e a -> ParseResult e b
Functor)

instance Bifunctor ParseResult where
  bimap :: forall a b c d.
(a -> b) -> (c -> d) -> ParseResult a c -> ParseResult b d
bimap a -> b
_ c -> d
g (ParseOk c
a)    = forall e a. a -> ParseResult e a
ParseOk (c -> d
g c
a)
  bimap a -> b
_ c -> d
_ ParseResult a c
ParseNoMatch   = forall e a. ParseResult e a
ParseNoMatch
  bimap a -> b
f c -> d
_ (ParseError a
e) = forall e a. e -> ParseResult e a
ParseError (a -> b
f a
e)

-- | Apply parser to each value in turn, bailing at the first error
parseAll :: forall e a b. (a -> ParseResult e b) -> [a] -> Either e [b]
parseAll :: forall e a b. (a -> ParseResult e b) -> [a] -> Either e [b]
parseAll a -> ParseResult e b
f = [b] -> [a] -> Either e [b]
go []
  where
    go :: [b] -> [a] -> Either e [b]
    go :: [b] -> [a] -> Either e [b]
go [b]
acc []     = forall a b. b -> Either a b
Right (forall a. [a] -> [a]
reverse [b]
acc)
    go [b]
acc (a
a:[a]
as) = case a -> ParseResult e b
f a
a of
                      ParseOk b
b    -> [b] -> [a] -> Either e [b]
go (b
bforall a. a -> [a] -> [a]
:[b]
acc) [a]
as
                      ParseResult e b
ParseNoMatch -> [b] -> [a] -> Either e [b]
go    [b]
acc  [a]
as
                      ParseError e
e -> forall a b. a -> Either a b
Left e
e

-- | Variation on 'parseAll' which rules out the error case
parseAll' :: (a -> ParseResult Void b) -> [a] -> [b]
parseAll' :: forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' a -> ParseResult Void b
f = forall b. Either Void [b] -> [b]
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a b. (a -> ParseResult e b) -> [a] -> Either e [b]
parseAll a -> ParseResult Void b
f
  where
    aux :: Either Void [b] -> [b]
    aux :: forall b. Either Void [b] -> [b]
aux (Left  Void
v)  = forall a. Void -> a
absurd Void
v
    aux (Right [b]
bs) = [b]
bs

-- | Bundle the parse result with the original value
withOrig :: (a -> ParseResult e b) -> (a -> ParseResult e (a, b))
withOrig :: forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig a -> ParseResult e b
f a
x = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
x, ) forall a b. (a -> b) -> a -> b
$ a -> ParseResult e b
f a
x

{-------------------------------------------------------------------------------
  Parser for specific constraints

  We can assume here that the constraint is kind correct, so if the class
  matches, we know how many arguments
-------------------------------------------------------------------------------}

data InvalidLet =
    -- | LHS should always be a variable
    NonVariableLHS Type Type Type

    -- | The LHS should be a /skolem/ variable
    --
    -- As for as ghc is concerned, the LHS should be an opaque type variable
    -- with unknown value (only the plugin knows); certainly, ghc should not
    -- try to unify it with anything.
  | NonSkolemLHS Type TyVar Type

parseLet ::
     ResolvedNames
  -> Ct
  -> ParseResult (GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet)
parseLet :: ResolvedNames
-> Ct
-> ParseResult
     (GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet)
parseLet ResolvedNames{Class
clsLet :: ResolvedNames -> Class
clsEqual :: ResolvedNames -> Class
clsLet :: Class
clsEqual :: Class
..} Ct
ct = forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall l e. l -> e -> GenLocated l e
L forall a b. (a -> b) -> a -> b
$ Ct -> CtLoc
ctLoc Ct
ct) (forall l e. l -> e -> GenLocated l e
L forall a b. (a -> b) -> a -> b
$ Ct -> CtLoc
ctLoc Ct
ct) forall a b. (a -> b) -> a -> b
$
    case Type -> Pred
classifyPredType (Ct -> Type
ctPred Ct
ct) of
      ClassPred Class
cls [Type
k, Type
a, Type
b] | Class
cls forall a. Eq a => a -> a -> Bool
== Class
clsLet ->
        case Type -> Maybe TyVar
getTyVar_maybe Type
a of
          Maybe TyVar
Nothing ->
            forall e a. e -> ParseResult e a
ParseError forall a b. (a -> b) -> a -> b
$ Type -> Type -> Type -> InvalidLet
NonVariableLHS Type
k Type
a Type
b
          Just TyVar
x  ->
            if TyVar -> Bool
isSkolemTyVar TyVar
x
              then forall e a. a -> ParseResult e a
ParseOk    forall a b. (a -> b) -> a -> b
$ Type -> TyVar -> Type -> CLet
CLet         Type
k TyVar
x Type
b
              else forall e a. e -> ParseResult e a
ParseError forall a b. (a -> b) -> a -> b
$ Type -> TyVar -> Type -> InvalidLet
NonSkolemLHS Type
k TyVar
x Type
b
      Pred
_otherwise ->
        forall e a. ParseResult e a
ParseNoMatch

-- | Parse 'Equal' constraints
--
-- Kind-correct 'Equal' constraints of any form are ok, so this cannot return
-- errors.
parseEqual :: ResolvedNames -> Ct -> ParseResult Void (GenLocated CtLoc CEqual)
parseEqual :: ResolvedNames -> Ct -> ParseResult Void (GenLocated CtLoc CEqual)
parseEqual ResolvedNames{Class
clsLet :: Class
clsEqual :: Class
clsLet :: ResolvedNames -> Class
clsEqual :: ResolvedNames -> Class
..} Ct
ct = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall l e. l -> e -> GenLocated l e
L forall a b. (a -> b) -> a -> b
$ Ct -> CtLoc
ctLoc Ct
ct) forall a b. (a -> b) -> a -> b
$
    case Type -> Pred
classifyPredType (Ct -> Type
ctPred Ct
ct) of
      ClassPred Class
cls [Type
k, Type
a, Type
b] | Class
cls forall a. Eq a => a -> a -> Bool
== Class
clsEqual ->
        forall e a. a -> ParseResult e a
ParseOk forall a b. (a -> b) -> a -> b
$ Type -> Type -> Type -> CEqual
CEqual Type
k Type
a Type
b
      Pred
_otherwise ->
        forall e a. ParseResult e a
ParseNoMatch

{-------------------------------------------------------------------------------
  Evidence construction
-------------------------------------------------------------------------------}

-- | Evidence for an 'Equal' constraint
--
-- TODO: should we worry about producing an evidence term that prevents floating
-- stuff out of scope...? (the whole "coercions cannot simply be zapped" thing)
-- See also https://gitlab.haskell.org/ghc/ghc/-/issues/8095#note_108189 .
evidenceEqual :: ResolvedNames -> CEqual -> EvTerm
evidenceEqual :: ResolvedNames -> CEqual -> EvTerm
evidenceEqual ResolvedNames{Class
clsLet :: Class
clsEqual :: Class
clsLet :: ResolvedNames -> Class
clsEqual :: ResolvedNames -> Class
..} (CEqual Type
k Type
a Type
b) =
    DataCon -> [Type] -> [EvExpr] -> EvTerm
evDataConApp
      (Class -> DataCon
classDataCon Class
clsEqual)
      [Type
k, Type
a, Type
b]
      []

{-------------------------------------------------------------------------------
  Formatting errors
-------------------------------------------------------------------------------}

formatCLet :: CLet -> TcPluginErrorMessage
formatCLet :: CLet -> TcPluginErrorMessage
formatCLet (CLet Type
_ TyVar
a Type
b) =
        Type -> TcPluginErrorMessage
PrintType (TyVar -> Type
mkTyVarTy TyVar
a)
    TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: String -> TcPluginErrorMessage
Txt String
" := "
    TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: Type -> TcPluginErrorMessage
PrintType Type
b

formatInvalidLet :: InvalidLet -> TcPluginErrorMessage
formatInvalidLet :: InvalidLet -> TcPluginErrorMessage
formatInvalidLet (NonVariableLHS Type
_k Type
a Type
b) =
        String -> TcPluginErrorMessage
Txt String
"Let with non-variable LHS: "
    TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: Type -> TcPluginErrorMessage
PrintType Type
a TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: String -> TcPluginErrorMessage
Txt String
" := " TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: Type -> TcPluginErrorMessage
PrintType Type
b
formatInvalidLet (NonSkolemLHS Type
_k TyVar
a Type
b) =
        String -> TcPluginErrorMessage
Txt String
"Let with non-skolem LHS: "
    TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: Type -> TcPluginErrorMessage
PrintType (TyVar -> Type
mkTyVarTy TyVar
a) TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: String -> TcPluginErrorMessage
Txt String
" := " TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: Type -> TcPluginErrorMessage
PrintType Type
b