{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}

-- |
-- Module      :   Grisette.Internal.Core.TH.MergedConstructor
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.TH.MergeConstructor
  ( mkMergeConstructor,
    mkMergeConstructor',
  )
where

import Control.Monad (join, replicateM, when, zipWithM)
import Data.Bifunctor (Bifunctor (second))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge)
import Language.Haskell.TH
  ( Body (NormalB),
    Clause (Clause),
    Con (ForallC, GadtC, InfixC, NormalC, RecC, RecGadtC),
    Dec (DataD, FunD, NewtypeD, SigD),
    Exp (AppE, ConE, LamE, VarE),
    Info (DataConI, TyConI),
    Name,
    Pat (VarP),
    Pred,
    Q,
    TyVarBndr (PlainTV),
    Type (AppT, ArrowT, ForallT, VarT),
    mkName,
    newName,
    pprint,
    reify,
  )
#if MIN_VERSION_template_haskell(2,17,0)
import Language.Haskell.TH.Syntax
  ( Name (Name),
    OccName (OccName),
    Specificity (SpecifiedSpec),
    Type (MulArrowT),
  )
#else
import Language.Haskell.TH.Syntax (Name (Name), OccName (OccName))
#endif

-- | Generate constructor wrappers that wraps the result in a container with `TryMerge` with provided names.
--
-- > mkMergeConstructor' ["mrgTuple2"] ''(,)
--
-- generates
--
-- > mrgTuple2 :: (Mergeable (a, b), Applicative m, TryMerge m) => a -> b -> u (a, b)
-- > mrgTuple2 = \v1 v2 -> mrgSingle (v1, v2)
mkMergeConstructor' ::
  -- | Names for generated wrappers
  [String] ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
mkMergeConstructor' :: [String] -> Name -> Q [Dec]
mkMergeConstructor' [String]
names Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Con] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
  [[Dec]]
ds <- (String -> Con -> Q [Dec]) -> [String] -> [Con] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds

occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name

getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
  String -> Q String
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use mkMergeConstructor' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = String -> Q String
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
c

getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
  Info
d <- Name -> Q Info
reify Name
typName
  case Info
d of
    TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> [Con] -> Q [Con]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
    TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> [Con] -> Q [Con]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
    Info
_ -> String -> Q [Con]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Con]) -> String -> Q [Con]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Ppr a => a -> String
pprint Info
d

-- | Generate constructor wrappers that wraps the result in a container with `TryMerge`.
--
-- > mkMergeConstructor "mrg" ''Maybe
--
-- generates
--
-- > mrgJust :: (Mergeable (Maybe a), Applicative m, TryMerge m) => m (Maybe a)
-- > mrgNothing = mrgSingle Nothing
-- > mrgJust :: (Mergeable (Maybe a), Applicative m, TryMerge m) => a -> m (Maybe a)
-- > mrgJust = \x -> mrgSingle (Just x)
mkMergeConstructor ::
  -- | Prefix for generated wrappers
  String ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
mkMergeConstructor :: String -> Name -> Q [Dec]
mkMergeConstructor String
prefix Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  [String]
constructorNames <- (Con -> Q String) -> [Con] -> Q [String]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Con -> Q String
getConstructorName [Con]
constructors
  [String] -> Name -> Q [Dec]
mkMergeConstructor' ((String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> [String] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName

augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
  [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
  let args :: [Pat]
args = (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
  Exp
mrgSingleFun <- [|mrgSingle|]
  Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$
    [Pat] -> Exp -> Exp
LamE
      [Pat]
args
      ( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
      )

#if MIN_VERSION_template_haskell(2,17,0)
augmentFinalType :: Type -> Q (([TyVarBndr Specificity], [Pred]), Type)
#else
augmentFinalType :: Type -> Q (([TyVarBndr], [Pred]), Type)
#endif
augmentFinalType :: Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType (AppT a :: Type
a@(AppT Type
ArrowT Type
_) Type
t) = do
  (([TyVarBndr Specificity], Cxt), Type)
tl <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
  (([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((([TyVarBndr Specificity], Cxt), Type)
 -> Q (([TyVarBndr Specificity], Cxt), Type))
-> (([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a b. (a -> b) -> a -> b
$ (Type -> Type)
-> (([TyVarBndr Specificity], Cxt), Type)
-> (([TyVarBndr Specificity], Cxt), Type)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Type -> Type -> Type
AppT Type
a) (([TyVarBndr Specificity], Cxt), Type)
tl
#if MIN_VERSION_template_haskell(2,17,0)
augmentFinalType (AppT (AppT (AppT Type
MulArrowT Type
_) Type
var) Type
t) = do
  (([TyVarBndr Specificity], Cxt), Type)
tl <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
  (([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((([TyVarBndr Specificity], Cxt), Type)
 -> Q (([TyVarBndr Specificity], Cxt), Type))
-> (([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a b. (a -> b) -> a -> b
$ (Type -> Type)
-> (([TyVarBndr Specificity], Cxt), Type)
-> (([TyVarBndr Specificity], Cxt), Type)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
var)) (([TyVarBndr Specificity], Cxt), Type)
tl
#endif
augmentFinalType Type
t = do
  Name
mName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"m"
  let mTy :: Type
mTy = Name -> Type
VarT Name
mName
  Type
mergeable <- [t|Mergeable|]
  Type
applicative <- [t|Applicative|]
  Type
tryMerge <- [t|TryMerge|]
#if MIN_VERSION_template_haskell(2,17,0)
  (([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( ( [ Name -> Specificity -> TyVarBndr Specificity
forall flag. Name -> flag -> TyVarBndr flag
PlainTV Name
mName Specificity
SpecifiedSpec ],
        [ Type -> Type -> Type
AppT Type
mergeable Type
t, Type -> Type -> Type
AppT Type
applicative Type
mTy, Type -> Type -> Type
AppT Type
tryMerge Type
mTy]
      ),
      Type -> Type -> Type
AppT Type
mTy Type
t
    )
#else
  return
    ( ( [ PlainTV mName ],
        [ AppT mergeable t, AppT applicative mTy, AppT tryMerge mTy]
      ),
      AppT mTy t
    )
#endif

augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
  Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
tybinders [TyVarBndr Specificity]
-> [TyVarBndr Specificity] -> [TyVarBndr Specificity]
forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
bndrs) (Cxt
preds Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
  Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp

mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([BangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([VarBangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
_ Con
v = String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
v