{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

-- | Defines splices that cut down on boilerplate associated with declaring new effects.
module Control.Effect.TH
  ( makeSmartConstructors,
  )
where

import Control.Algebra
import Control.Monad (join)
import Data.Char (toLower)
import Data.Foldable
import Data.Traversable
import Language.Haskell.TH as TH

data PerEffect = PerEffect
  { PerEffect -> Name
typeName :: TH.Name,
    PerEffect -> [TyVarBndr]
effectTypeVars :: [TH.TyVarBndr],
    PerEffect -> TyVarBndr
monadTypeVar :: TH.TyVarBndr,
    PerEffect -> Con
forallConstructor :: TH.Con
  }

data PerDecl = PerDecl
  { PerDecl -> Name
ctorName :: TH.Name,
    PerDecl -> Name
functionName :: TH.Name,
    PerDecl -> [Type]
ctorArgs :: [TH.Type],
    PerDecl -> Type
returnType :: TH.Type,
    PerDecl -> PerEffect
perEffect :: PerEffect,
    PerDecl -> [TyVarBndr]
extraTyVars :: [TyVarBndr]
  }

-- | Given an effect type, this splice generates functions that create per-constructor request functions.
--
-- That is to say, given the standard @State@ type
--
-- @
--   data State s m k where
--     Get :: State s m s
--     Put :: s -> State s m ()
-- @
--
-- an invocation of @makeSmartConstructors ''State@ will generate code that looks like
--
--
-- >   get ::
-- >     forall (s :: Type) sig (m :: Type -> Type).
-- >     Has (State s) sig m =>
-- >     m s
-- >   get = send Get
-- >   {-# INLINEABLE get #-}
-- >    put ::
-- >     forall (s :: Type) sig (m :: Type -> Type).
-- >     Has (State s) sig m =>
-- >     s ->
-- >     m ()
-- >   put a = send (Put a)
-- >   {-# INLINEABLE put #-}
--
--
-- The type variables in each declared function signature will appear in the order
-- they were defined in the effect type.
--
makeSmartConstructors :: Name -> TH.DecsQ
makeSmartConstructors :: Name -> DecsQ
makeSmartConstructors typ :: Name
typ =
  Name -> Q Info
TH.reify Name
typ Q Info -> (Info -> DecsQ) -> DecsQ
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TH.TyConI (TH.DataD _ctx :: [Type]
_ctx typeName :: Name
typeName tyvars :: [TyVarBndr]
tyvars _kind :: Maybe Type
_kind cons :: [Con]
cons _derive :: [DerivClause]
_derive) -> do
      -- Pick out the `m` argument. We can drop `k` on the floor.
      (effectTypeVarsWithoutSig :: [TyVarBndr]
effectTypeVarsWithoutSig, monadTypeVar :: TyVarBndr
monadTypeVar) <- case [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a]
reverse [TyVarBndr]
tyvars of
        _cont :: TyVarBndr
_cont : monad :: TyVarBndr
monad : rest :: [TyVarBndr]
rest -> ([TyVarBndr], TyVarBndr) -> Q ([TyVarBndr], TyVarBndr)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a]
reverse [TyVarBndr]
rest, TyVarBndr
monad)
        _ -> String -> Q ([TyVarBndr], TyVarBndr)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail ("Effect types need at least two type arguments: a monad `m` and continuation `k`.")
      -- Continue, recording the various relevant data from the type in question.
      let effectTypeVars :: [TyVarBndr]
effectTypeVars = [TyVarBndr]
effectTypeVarsWithoutSig [TyVarBndr] -> [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a] -> [a]
++ [Name -> TyVarBndr
TH.PlainTV (String -> Name
mkName "sig")]
      [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Con -> DecsQ) -> [Con] -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\forallConstructor :: Con
forallConstructor -> PerEffect -> DecsQ
makeDeclaration PerEffect :: Name -> [TyVarBndr] -> TyVarBndr -> Con -> PerEffect
PerEffect {..}) [Con]
cons
    other :: Info
other -> String -> DecsQ
forall (m :: * -> *) a. MonadFail m => String -> m a
fail ("Can't generate definitions for a non-data-constructor: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Info -> String
forall a. Ppr a => a -> String
pprint Info
other)

makeDeclaration :: PerEffect -> TH.DecsQ
makeDeclaration :: PerEffect -> DecsQ
makeDeclaration perEffect :: PerEffect
perEffect@PerEffect {..} = do
  (names :: [Name]
names, ctorArgs :: [Type]
ctorArgs, returnWithResult :: Type
returnWithResult, extraTyVars :: [TyVarBndr]
extraTyVars) <- case Con
forallConstructor of
    TH.ForallC vars :: [TyVarBndr]
vars _ctx :: [Type]
_ctx (TH.GadtC names :: [Name]
names bangtypes :: [BangType]
bangtypes returnType :: Type
returnType) ->
      ([Name], [Type], Type, [TyVarBndr])
-> Q ([Name], [Type], Type, [TyVarBndr])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Name]
names, (BangType -> Type) -> [BangType] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap BangType -> Type
forall a b. (a, b) -> b
snd [BangType]
bangtypes, Type
returnType, [TyVarBndr]
vars)
    _ ->
      String -> Q ([Name], [Type], Type, [TyVarBndr])
forall (m :: * -> *) a. MonadFail m => String -> m a
fail ("BUG: expected forall-qualified constructor, but didn't get one")
  Type
returnType <- case Type
returnWithResult of
    AppT _ final :: Type
final -> Type -> Q Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
final
    _ -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail ("BUG: Couldn't get a return type out of " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Type -> String
forall a. Ppr a => a -> String
pprint Type
returnWithResult)
  ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Q [[Dec]] -> DecsQ)
-> ((Name -> DecsQ) -> Q [[Dec]]) -> (Name -> DecsQ) -> DecsQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name] -> (Name -> DecsQ) -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Name]
names ((Name -> DecsQ) -> DecsQ) -> (Name -> DecsQ) -> DecsQ
forall a b. (a -> b) -> a -> b
$ \ctorName :: Name
ctorName -> do
    let downcase :: String -> String
downcase = \case
          x :: Char
x : xs :: String
xs -> Char -> Char
toLower Char
x Char -> String -> String
forall a. a -> [a] -> [a]
: String
xs
          [] -> []
        functionName :: Name
functionName = String -> Name
TH.mkName (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
downcase (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
TH.nameBase (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ Name
ctorName
    let decl :: PerDecl
decl = PerDecl :: Name
-> Name -> [Type] -> Type -> PerEffect -> [TyVarBndr] -> PerDecl
PerDecl {..}
    Dec
sign <- PerDecl -> DecQ
makeSignature PerDecl
decl
    Dec
func <- PerDecl -> DecQ
makeFunction PerDecl
decl
    Dec
prag <- PerDecl -> DecQ
makePragma PerDecl
decl
    [Dec] -> DecsQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Dec
sign, Dec
func, Dec
prag]

makePragma :: PerDecl -> TH.DecQ
makePragma :: PerDecl -> DecQ
makePragma PerDecl {..} =
  Name -> Inline -> RuleMatch -> Phases -> DecQ
TH.pragInlD Name
functionName Inline
TH.Inlinable RuleMatch
TH.FunLike Phases
TH.AllPhases

makeFunction :: PerDecl -> Q Dec
makeFunction :: PerDecl -> DecQ
makeFunction d :: PerDecl
d =
  Name -> [ClauseQ] -> DecQ
TH.funD (PerDecl -> Name
functionName PerDecl
d) [PerDecl -> ClauseQ
makeClause PerDecl
d]

makeClause :: PerDecl -> ClauseQ
makeClause :: PerDecl -> ClauseQ
makeClause PerDecl {..} = [PatQ] -> BodyQ -> [DecQ] -> ClauseQ
TH.clause [PatQ]
pats BodyQ
body []
  where
    body :: BodyQ
body = ExpQ -> BodyQ
TH.normalB [e|send ($(applies))|]
    pats :: [PatQ]
pats = (Name -> PatQ) -> [Name] -> [PatQ]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> PatQ
TH.varP [Name]
names
    applies :: ExpQ
applies = (ExpQ -> Name -> ExpQ) -> ExpQ -> [Name] -> ExpQ
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\e :: ExpQ
e n :: Name
n -> ExpQ
e ExpQ -> ExpQ -> ExpQ
`appE` Name -> ExpQ
varE Name
n) (Name -> ExpQ
conE Name
ctorName) [Name]
names
    names :: [Name]
names = (Char -> Name) -> String -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> Name
mkName (String -> Name) -> (Char -> String) -> Char -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> String
forall (f :: * -> *) a. Applicative f => a -> f a
pure) (Int -> String -> String
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ctorArgs) ['a' .. 'z'])

makeSignature :: PerDecl -> TH.DecQ
makeSignature :: PerDecl -> DecQ
makeSignature PerDecl {perEffect :: PerDecl -> PerEffect
perEffect = PerEffect {..}, ..} =
  let sigVar :: TyVarBndr
sigVar = [TyVarBndr] -> TyVarBndr
forall a. [a] -> a
last [TyVarBndr]
effectTypeVars
      rest :: [TyVarBndr]
rest = [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a]
init [TyVarBndr]
effectTypeVars
      getTyVar :: TyVarBndr -> Name
getTyVar = \case
        TH.PlainTV t :: Name
t -> Name
t
        TH.KindedTV t :: Name
t _ -> Name
t
      monadName :: Q Type
monadName = Name -> Q Type
varT (TyVarBndr -> Name
getTyVar TyVarBndr
monadTypeVar)
      invocation :: Q Type
invocation = (Q Type -> Q Type -> Q Type) -> Q Type -> [Q Type] -> Q Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Q Type -> Q Type -> Q Type
appT (Name -> Q Type
conT Name
typeName) ((TyVarBndr -> Q Type) -> [TyVarBndr] -> [Q Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name -> Q Type
varT (Name -> Q Type) -> (TyVarBndr -> Name) -> TyVarBndr -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr -> Name
getTyVar) [TyVarBndr]
rest)
      hasConstraint :: Q Type
hasConstraint = [t|Has $(parensT invocation) $(varT (mkName "sig")) $(monadName)|]
      folded :: Q Type
folded = (Type -> Q Type -> Q Type) -> Q Type -> [Type] -> Q Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\a :: Type
a b :: Q Type
b -> Q Type
arrowT Q Type -> Q Type -> Q Type
`appT` Type -> Q Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
a Q Type -> Q Type -> Q Type
`appT` Q Type
b) (Q Type
monadName Q Type -> Q Type -> Q Type
`appT` Type -> Q Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
returnType) [Type]
ctorArgs
   in Name -> Q Type -> DecQ
TH.sigD Name
functionName ([TyVarBndr] -> CxtQ -> Q Type -> Q Type
TH.forallT ([TyVarBndr]
extraTyVars [TyVarBndr] -> [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a] -> [a]
++ [TyVarBndr
sigVar]) ([Q Type] -> CxtQ
TH.cxt [Q Type
hasConstraint]) Q Type
folded)