{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}
module Grisette.Internal.Core.TH.MergeConstructor
( mkMergeConstructor,
mkMergeConstructor',
)
where
import Control.Monad (join, replicateM, when, zipWithM)
import Data.Bifunctor (Bifunctor (second))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge)
import Language.Haskell.TH
( Body (NormalB),
Clause (Clause),
Con (ForallC, GadtC, InfixC, NormalC, RecC, RecGadtC),
Dec (DataD, FunD, NewtypeD, SigD),
Exp (AppE, ConE, LamE, VarE),
Info (DataConI, TyConI),
Name,
Pat (VarP),
Pred,
Q,
TyVarBndr (PlainTV),
Type (AppT, ArrowT, ForallT, VarT),
mkName,
newName,
pprint,
reify,
)
#if MIN_VERSION_template_haskell(2,17,0)
import Language.Haskell.TH.Syntax
( Name (Name),
OccName (OccName),
Specificity (SpecifiedSpec),
Type (MulArrowT),
)
#else
import Language.Haskell.TH.Syntax (Name (Name), OccName (OccName))
#endif
mkMergeConstructor' ::
[String] ->
Name ->
Q [Dec]
mkMergeConstructor' :: [String] -> Name -> Q [Dec]
mkMergeConstructor' [String]
names Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Con] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
[[Dec]]
ds <- (String -> Con -> Q [Dec]) -> [String] -> [Con] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
[Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds
occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name
getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
String -> Q String
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use mkMergeConstructor' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = String -> Q String
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
c
getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
Info
d <- Name -> Q Info
reify Name
typName
case Info
d of
TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> [Con] -> Q [Con]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> [Con] -> Q [Con]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
Info
_ -> String -> Q [Con]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Con]) -> String -> Q [Con]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Ppr a => a -> String
pprint Info
d
mkMergeConstructor ::
String ->
Name ->
Q [Dec]
mkMergeConstructor :: String -> Name -> Q [Dec]
mkMergeConstructor String
prefix Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
[String]
constructorNames <- (Con -> Q String) -> [Con] -> Q [String]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Con -> Q String
getConstructorName [Con]
constructors
[String] -> Name -> Q [Dec]
mkMergeConstructor' ((String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> [String] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
[Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
let args :: [Pat]
args = (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
Exp
mrgSingleFun <- [|mrgSingle|]
Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$
[Pat] -> Exp -> Exp
LamE
[Pat]
args
( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
(Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
)
#if MIN_VERSION_template_haskell(2,17,0)
augmentFinalType :: Type -> Q (([TyVarBndr Specificity], [Pred]), Type)
#else
augmentFinalType :: Type -> Q (([TyVarBndr], [Pred]), Type)
#endif
augmentFinalType :: Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType (AppT a :: Type
a@(AppT Type
ArrowT Type
_) Type
t) = do
(([TyVarBndr Specificity], Cxt), Type)
tl <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
(([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type))
-> (([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a b. (a -> b) -> a -> b
$ (Type -> Type)
-> (([TyVarBndr Specificity], Cxt), Type)
-> (([TyVarBndr Specificity], Cxt), Type)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Type -> Type -> Type
AppT Type
a) (([TyVarBndr Specificity], Cxt), Type)
tl
#if MIN_VERSION_template_haskell(2,17,0)
augmentFinalType (AppT (AppT (AppT Type
MulArrowT Type
_) Type
var) Type
t) = do
(([TyVarBndr Specificity], Cxt), Type)
tl <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
(([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type))
-> (([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a b. (a -> b) -> a -> b
$ (Type -> Type)
-> (([TyVarBndr Specificity], Cxt), Type)
-> (([TyVarBndr Specificity], Cxt), Type)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
var)) (([TyVarBndr Specificity], Cxt), Type)
tl
#endif
augmentFinalType Type
t = do
Name
mName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"m"
let mTy :: Type
mTy = Name -> Type
VarT Name
mName
Type
mergeable <- [t|Mergeable|]
Type
applicative <- [t|Applicative|]
Type
tryMerge <- [t|TryMerge|]
#if MIN_VERSION_template_haskell(2,17,0)
(([TyVarBndr Specificity], Cxt), Type)
-> Q (([TyVarBndr Specificity], Cxt), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
( ( [ Name -> Specificity -> TyVarBndr Specificity
forall flag. Name -> flag -> TyVarBndr flag
PlainTV Name
mName Specificity
SpecifiedSpec ],
[ Type -> Type -> Type
AppT Type
mergeable Type
t, Type -> Type -> Type
AppT Type
applicative Type
mTy, Type -> Type -> Type
AppT Type
tryMerge Type
mTy]
),
Type -> Type -> Type
AppT Type
mTy Type
t
)
#else
return
( ( [ PlainTV mName ],
[ AppT mergeable t, AppT applicative mTy, AppT tryMerge mTy]
),
AppT mTy t
)
#endif
augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
tybinders [TyVarBndr Specificity]
-> [TyVarBndr Specificity] -> [TyVarBndr Specificity]
forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
bndrs) (Cxt
preds Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([BangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
[Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([VarBangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
[Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
_ Con
v = String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
v