{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}
module Grisette.Core.TH
(
makeUnionWrapper,
makeUnionWrapper',
)
where
import Control.Monad
import Grisette.Core.THCompat
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
makeUnionWrapper' ::
[String] ->
Name ->
Q [Dec]
makeUnionWrapper' :: [String] -> Name -> Q [Dec]
makeUnionWrapper' [String]
names Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
[[Dec]]
ds <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds
occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name
getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use makeUnionWrapper' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
c
getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
Info
d <- Name -> Q Info
reify Name
typName
case Info
d of
TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
Info
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Info
d
makeUnionWrapper ::
String ->
Name ->
Q [Dec]
makeUnionWrapper :: String -> Name -> Q [Dec]
makeUnionWrapper String
prefix Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
[String]
constructorNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Con -> Q String
getConstructorName [Con]
constructors
[String] -> Name -> Q [Dec]
makeUnionWrapper' ((String
prefix forall a. [a] -> [a] -> [a]
++) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
[Name]
xs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
let args :: [Pat]
args = forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
Exp
mrgSingleFun <- [|mrgSingle|]
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
[Pat] -> Exp -> Exp
LamE
[Pat]
args
( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f (forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
)
augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
bndrs forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
tybinders) (Cxt
preds forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
_ Con
v = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
v