{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}

-- |
-- Module      :   Grisette.Core.TH
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.TH
  ( -- * Template Haskell procedures for building constructor wrappers
    makeUnionWrapper,
    makeUnionWrapper',
  )
where

import Control.Monad
import Grisette.Core.THCompat
import Language.Haskell.TH
import Language.Haskell.TH.Syntax

-- | Generate constructor wrappers that wraps the result in a union-like monad with provided names.
--
-- > $(makeUnionWrapper' ["mrgTuple2"] ''(,))
--
-- generates
--
-- > mrgTuple2 :: (SymBoolOp bool, Monad u, Mergeable bool t1, Mergeable bool t2, MonadUnion bool u) => t1 -> t2 -> u (t1, t2)
-- > mrgTuple2 = \v1 v2 -> mrgSingle (v1, v2)
makeUnionWrapper' ::
  -- | Names for generated wrappers
  [String] ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeUnionWrapper' :: [String] -> Name -> Q [Dec]
makeUnionWrapper' [String]
names Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
  [[Dec]]
ds <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds

occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name

getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
  forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use makeUnionWrapper' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
c

getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
  Info
d <- Name -> Q Info
reify Name
typName
  case Info
d of
    TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
    TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
    Info
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Info
d

-- | Generate constructor wrappers that wraps the result in a union-like monad.
--
-- > $(makeUnionWrapper "mrg" ''Maybe)
--
-- generates
--
-- > mrgNothing :: (SymBoolOp bool, Monad u, Mergeable bool t, MonadUnion bool u) => u (Maybe t)
-- > mrgNothing = mrgSingle Nothing
-- > mrgJust :: (SymBoolOp bool, Monad u, Mergeable bool t, MonadUnion bool u) => t -> u (Maybe t)
-- > mrgJust = \x -> mrgSingle (Just x)
makeUnionWrapper ::
  -- | Prefix for generated wrappers
  String ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeUnionWrapper :: String -> Name -> Q [Dec]
makeUnionWrapper String
prefix Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  [String]
constructorNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Con -> Q String
getConstructorName [Con]
constructors
  [String] -> Name -> Q [Dec]
makeUnionWrapper' ((String
prefix forall a. [a] -> [a] -> [a]
++) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName

augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
  [Name]
xs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
  let args :: [Pat]
args = forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
  Exp
mrgSingleFun <- [|mrgSingle|]
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
    [Pat] -> Exp -> Exp
LamE
      [Pat]
args
      ( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f (forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
      )

augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
bndrs forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
tybinders) (Cxt
preds forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp

mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
  forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
  forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
_ Con
v = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
v