{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

-- | Defines splices that cut down on boilerplate associated with declaring new effects.
module Control.Effect.TH
  ( makeSmartConstructors,
  )
where

import Control.Algebra
import Control.Monad (join)
import Data.Char (toLower)
import Data.Foldable
import Data.Monoid (Ap (..))
import Data.Traversable
import Language.Haskell.TH (appT, arrowT, mkName, varT)
import qualified Language.Haskell.TH as TH

data PerEffect = PerEffect
  { PerEffect -> TypeQ
effectType :: TH.TypeQ,
    PerEffect -> Int
effectTyVarCount :: Int,
    PerEffect -> Con
forallConstructor :: TH.Con
  }

-- Hideous hacks to deal with kindedness changes in newer TH versions.
#if MIN_VERSION_template_haskell(2,17,0)
type TyVarBinder = TH.TyVarBndrSpec

makeTV :: TH.Name -> TyVarBinder
makeTV :: Name -> TyVarBinder
makeTV Name
n = Name -> Specificity -> TyVarBinder
forall flag. Name -> flag -> TyVarBndr flag
TH.PlainTV Name
n Specificity
TH.inferredSpec

tvName :: TyVarBinder -> TH.Name
tvName :: TyVarBinder -> Name
tvName = \case
  TH.PlainTV Name
n Specificity
_ -> Name
n
  TH.KindedTV Name
n Specificity
_ Kind
_ -> Name
n
#else
type TyVarBinder = TH.TyVarBndr

makeTV :: TH.Name -> TyVarBinder
makeTV = TH.plainTV

tvName :: TyVarBinder -> TH.Name
tvName = \case
  TH.PlainTV n -> n
  TH.KindedTV n _ -> n
#endif


data PerDecl = PerDecl
  { PerDecl -> [TypeQ]
ctorArgs :: [TH.TypeQ],
    PerDecl -> [TypeQ]
ctorConstraints :: [TH.TypeQ],
    PerDecl -> Name
ctorName :: TH.Name,
    PerDecl -> [TyVarBinder]
ctorTyVars :: [TyVarBinder],
    PerDecl -> Name
functionName :: TH.Name,
    PerDecl -> TypeQ
gadtReturnType :: TH.TypeQ,
    PerDecl -> PerEffect
perEffect :: PerEffect
  }

-- | Given an effect type, this splice generates functions that create per-constructor request functions.
--
-- That is to say, given the standard @State@ type
--
-- @
--   data State s m k where
--     Get :: State s m s
--     Put :: s -> State s m ()
-- @
--
-- an invocation of @makeSmartConstructors ''State@ will generate code that looks like
--
--
-- >   get ::
-- >     forall (s :: Type) sig (m :: Type -> Type).
-- >     Has (State s) sig m =>
-- >     m s
-- >   get = send Get
-- >   {-# INLINEABLE get #-}
-- >    put ::
-- >     forall (s :: Type) sig (m :: Type -> Type).
-- >     Has (State s) sig m =>
-- >     s ->
-- >     m ()
-- >   put a = send (Put a)
-- >   {-# INLINEABLE put #-}
--
--
-- The type variables in each declared function signature will appear in the order
-- they were defined in the effect type.
makeSmartConstructors :: TH.Name -> TH.DecsQ
makeSmartConstructors :: Name -> DecsQ
makeSmartConstructors Name
typ =
  -- Lookup the provided type name.
  Name -> Q Info
TH.reify Name
typ Q Info -> (Info -> DecsQ) -> DecsQ
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    -- If it's a type constructor, record its type name.
    TH.TyConI (TH.DataD Cxt
_ctx Name
tn [TyVarBndr ()]
tvs Maybe Kind
_kind [Con]
constructors [DerivClause]
_derive) ->
      let perEffect :: Con -> PerEffect
perEffect = TypeQ -> Int -> Con -> PerEffect
PerEffect (Name -> TypeQ
forall (m :: * -> *). Quote m => Name -> m Kind
TH.conT Name
tn) ([TyVarBndr ()] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarBndr ()]
tvs)
       in Ap Q [Dec] -> DecsQ
forall {k} (f :: k -> *) (a :: k). Ap f a -> f a
getAp ((Con -> Ap Q [Dec]) -> [Con] -> Ap Q [Dec]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (DecsQ -> Ap Q [Dec]
forall {k} (f :: k -> *) (a :: k). f a -> Ap f a
Ap (DecsQ -> Ap Q [Dec]) -> (Con -> DecsQ) -> Con -> Ap Q [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerEffect -> DecsQ
makeDeclaration (PerEffect -> DecsQ) -> (Con -> PerEffect) -> Con -> DecsQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Con -> PerEffect
perEffect) [Con]
constructors)
    -- Die otherwise.
    Info
other ->
      String -> DecsQ
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Can't generate definitions for a non-data-constructor: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Info -> String
forall a. Ppr a => a -> String
TH.pprint Info
other)

makeDeclaration :: PerEffect -> TH.DecsQ
makeDeclaration :: PerEffect -> DecsQ
makeDeclaration perEffect :: PerEffect
perEffect@PerEffect {Int
TypeQ
Con
forallConstructor :: Con
effectTyVarCount :: Int
effectType :: TypeQ
forallConstructor :: PerEffect -> Con
effectTyVarCount :: PerEffect -> Int
effectType :: PerEffect -> TypeQ
..} = do
  -- Start by extracting the relevant parts of this particular constructor.
  ([Name]
names, Cxt
ctorArgs, Cxt
constraints, Kind
returnType, [TyVarBinder]
ctorTyVars) <- case Con
forallConstructor of
    TH.ForallC [TyVarBinder]
vars Cxt
ctx (TH.GadtC [Name]
names [BangType]
bangtypes (TH.AppT Kind
_ Kind
final)) ->
      ([Name], Cxt, Cxt, Kind, [TyVarBinder])
-> Q ([Name], Cxt, Cxt, Kind, [TyVarBinder])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Name]
names, (BangType -> Kind) -> [BangType] -> Cxt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap BangType -> Kind
forall a b. (a, b) -> b
snd [BangType]
bangtypes, Cxt
ctx, Kind
final, [TyVarBinder]
vars)
    Con
_ ->
      String -> Q ([Name], Cxt, Cxt, Kind, [TyVarBinder])
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"BUG: expected forall-qualified constructor, but didn't get one")
  -- Then iterate over the names of the constructors, emitting an injected
  -- method per name.
  ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Q [[Dec]] -> DecsQ)
-> ((Name -> DecsQ) -> Q [[Dec]]) -> (Name -> DecsQ) -> DecsQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name] -> (Name -> DecsQ) -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Name]
names ((Name -> DecsQ) -> DecsQ) -> (Name -> DecsQ) -> DecsQ
forall a b. (a -> b) -> a -> b
$ \Name
ctorName -> do
    let downcase :: String -> Name
downcase (Char
x : String
xs) = String -> Name
mkName (Char -> Char
toLower Char
x Char -> String -> String
forall a. a -> [a] -> [a]
: String
xs)
        downcase [] = String -> Name
forall a. HasCallStack => String -> a
error String
"attempted to downcase empty name"
        decl :: PerDecl
decl =
          PerDecl :: [TypeQ]
-> [TypeQ]
-> Name
-> [TyVarBinder]
-> Name
-> TypeQ
-> PerEffect
-> PerDecl
PerDecl
            { ctorName :: Name
ctorName = Name
ctorName,
              functionName :: Name
functionName = String -> Name
downcase (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
TH.nameBase (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ Name
ctorName,
              ctorArgs :: [TypeQ]
ctorArgs = (Kind -> TypeQ) -> Cxt -> [TypeQ]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Kind -> TypeQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure Cxt
ctorArgs,
              gadtReturnType :: TypeQ
gadtReturnType = Kind -> TypeQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure Kind
returnType,
              perEffect :: PerEffect
perEffect = PerEffect
perEffect,
              ctorTyVars :: [TyVarBinder]
ctorTyVars = [TyVarBinder]
ctorTyVars,
              ctorConstraints :: [TypeQ]
ctorConstraints = (Kind -> TypeQ) -> Cxt -> [TypeQ]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Kind -> TypeQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure Cxt
constraints
            }
    Dec
sign <- PerDecl -> DecQ
makeSignature PerDecl
decl
    Dec
func <- PerDecl -> DecQ
makeFunction PerDecl
decl
    Dec
prag <- PerDecl -> DecQ
makePragma PerDecl
decl
    [Dec] -> DecsQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Dec
sign, Dec
func, Dec
prag]

makePragma :: PerDecl -> TH.DecQ
makePragma :: PerDecl -> DecQ
makePragma PerDecl {[TypeQ]
[TyVarBinder]
TypeQ
Name
PerEffect
perEffect :: PerEffect
gadtReturnType :: TypeQ
functionName :: Name
ctorTyVars :: [TyVarBinder]
ctorName :: Name
ctorConstraints :: [TypeQ]
ctorArgs :: [TypeQ]
perEffect :: PerDecl -> PerEffect
gadtReturnType :: PerDecl -> TypeQ
functionName :: PerDecl -> Name
ctorTyVars :: PerDecl -> [TyVarBinder]
ctorName :: PerDecl -> Name
ctorConstraints :: PerDecl -> [TypeQ]
ctorArgs :: PerDecl -> [TypeQ]
..} =
  Name -> Inline -> RuleMatch -> Phases -> DecQ
forall (m :: * -> *).
Quote m =>
Name -> Inline -> RuleMatch -> Phases -> m Dec
TH.pragInlD Name
functionName Inline
TH.Inlinable RuleMatch
TH.FunLike Phases
TH.AllPhases

makeFunction :: PerDecl -> TH.DecQ
makeFunction :: PerDecl -> DecQ
makeFunction PerDecl
d =
  Name -> [Q Clause] -> DecQ
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
TH.funD (PerDecl -> Name
functionName PerDecl
d) [PerDecl -> Q Clause
makeClause PerDecl
d]

makeClause :: PerDecl -> TH.ClauseQ
makeClause :: PerDecl -> Q Clause
makeClause PerDecl {[TypeQ]
[TyVarBinder]
TypeQ
Name
PerEffect
perEffect :: PerEffect
gadtReturnType :: TypeQ
functionName :: Name
ctorTyVars :: [TyVarBinder]
ctorName :: Name
ctorConstraints :: [TypeQ]
ctorArgs :: [TypeQ]
perEffect :: PerDecl -> PerEffect
gadtReturnType :: PerDecl -> TypeQ
functionName :: PerDecl -> Name
ctorTyVars :: PerDecl -> [TyVarBinder]
ctorName :: PerDecl -> Name
ctorConstraints :: PerDecl -> [TypeQ]
ctorArgs :: PerDecl -> [TypeQ]
..} = [Q Pat] -> Q Body -> [DecQ] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
TH.clause [Q Pat]
pats Q Body
body []
  where
    body :: Q Body
body = Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
TH.normalB [e|send ($(applies))|]
    pats :: [Q Pat]
pats = (Name -> Q Pat) -> [Name] -> [Q Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
TH.varP [Name]
names
    -- Glue together the parameter to 'send', fully applied
    applies :: Q Exp
applies = (Q Exp -> Name -> Q Exp) -> Q Exp -> [Name] -> Q Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Q Exp
e Name
n -> Q Exp
e Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`TH.appE` Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
TH.varE Name
n) (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
TH.conE Name
ctorName) [Name]
names
    -- A source of a, b, c... names for function parameters.
    names :: [Name]
names = (Char -> Name) -> String -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> Name
mkName (String -> Name) -> (Char -> String) -> Char -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> String
forall (f :: * -> *) a. Applicative f => a -> f a
pure) (Int -> String -> String
forall a. Int -> [a] -> [a]
take ([TypeQ] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeQ]
ctorArgs) [Char
'a' .. Char
'z'])

makeSignature :: PerDecl -> TH.DecQ
makeSignature :: PerDecl -> DecQ
makeSignature PerDecl {perEffect :: PerDecl -> PerEffect
perEffect = PerEffect {Int
TypeQ
Con
forallConstructor :: Con
effectTyVarCount :: Int
effectType :: TypeQ
forallConstructor :: PerEffect -> Con
effectTyVarCount :: PerEffect -> Int
effectType :: PerEffect -> TypeQ
..}, [TypeQ]
[TyVarBinder]
TypeQ
Name
gadtReturnType :: TypeQ
functionName :: Name
ctorTyVars :: [TyVarBinder]
ctorName :: Name
ctorConstraints :: [TypeQ]
ctorArgs :: [TypeQ]
gadtReturnType :: PerDecl -> TypeQ
functionName :: PerDecl -> Name
ctorTyVars :: PerDecl -> [TyVarBinder]
ctorName :: PerDecl -> Name
ctorConstraints :: PerDecl -> [TypeQ]
ctorArgs :: PerDecl -> [TypeQ]
..} =
  let sigVar :: Name
sigVar = String -> Name
mkName String
"sig"
      ([TyVarBinder]
rest, TyVarBinder
monadTV) = ([TyVarBinder] -> [TyVarBinder]
forall a. [a] -> [a]
init [TyVarBinder]
ctorTyVars, [TyVarBinder] -> TyVarBinder
forall a. [a] -> a
last [TyVarBinder]
ctorTyVars)
      getTyVar :: TyVarBinder -> TypeQ
getTyVar = Name -> TypeQ
forall (m :: * -> *). Quote m => Name -> m Kind
varT (Name -> TypeQ) -> (TyVarBinder -> Name) -> TyVarBinder -> TypeQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBinder -> Name
tvName
      monadName :: TypeQ
monadName = TyVarBinder -> TypeQ
getTyVar TyVarBinder
monadTV
      -- Build the parameter to Has by consulting the number of required type parameters.
      invocation :: TypeQ
invocation = (TypeQ -> TypeQ -> TypeQ) -> TypeQ -> [TypeQ] -> TypeQ
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TypeQ -> TypeQ -> TypeQ
forall (m :: * -> *). Quote m => m Kind -> m Kind -> m Kind
appT TypeQ
effectType ((TyVarBinder -> TypeQ) -> [TyVarBinder] -> [TypeQ]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TyVarBinder -> TypeQ
getTyVar (Int -> [TyVarBinder] -> [TyVarBinder]
forall a. Int -> [a] -> [a]
take (Int
effectTyVarCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) [TyVarBinder]
rest))
      hasConstraint :: TypeQ
hasConstraint = [t|Has ($(invocation)) $(varT sigVar) $(monadName)|]
      -- Build the type signature by folding with (->) over the function arguments as needed.
      foldedSig :: TypeQ
foldedSig = (TypeQ -> TypeQ -> TypeQ) -> TypeQ -> [TypeQ] -> TypeQ
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\TypeQ
a TypeQ
b -> TypeQ
forall (m :: * -> *). Quote m => m Kind
arrowT TypeQ -> TypeQ -> TypeQ
forall (m :: * -> *). Quote m => m Kind -> m Kind -> m Kind
`appT` TypeQ
a TypeQ -> TypeQ -> TypeQ
forall (m :: * -> *). Quote m => m Kind -> m Kind -> m Kind
`appT` TypeQ
b) (TypeQ
monadName TypeQ -> TypeQ -> TypeQ
forall (m :: * -> *). Quote m => m Kind -> m Kind -> m Kind
`appT` TypeQ
gadtReturnType) [TypeQ]
ctorArgs
      -- Glue together the Has and the per-constructor constraints.
      allConstraints :: Q Cxt
allConstraints = [TypeQ] -> Q Cxt
forall (m :: * -> *). Quote m => [m Kind] -> m Cxt
TH.cxt (TypeQ
hasConstraint TypeQ -> [TypeQ] -> [TypeQ]
forall a. a -> [a] -> [a]
: [TypeQ]
ctorConstraints)
   in Name -> TypeQ -> DecQ
forall (m :: * -> *). Quote m => Name -> m Kind -> m Dec
TH.sigD Name
functionName ([TyVarBinder] -> Q Cxt -> TypeQ -> TypeQ
forall (m :: * -> *).
Quote m =>
[TyVarBinder] -> m Cxt -> m Kind -> m Kind
TH.forallT ([TyVarBinder]
rest [TyVarBinder] -> [TyVarBinder] -> [TyVarBinder]
forall a. [a] -> [a] -> [a]
++ [TyVarBinder
monadTV, Name -> TyVarBinder
makeTV Name
sigVar]) Q Cxt
allConstraints TypeQ
foldedSig)