{-# LANGUAGE TemplateHaskell #-}
module Control.Effect.Machinery.TH
(
makeEffect
, makeHandler
, makeFinder
, makeLifter
, makeTaggedEffect
, makeTaggedEffectWith
, makeTagger
, makeTaggerWith
, liftL
, runL
, removeApostrophe
) where
import Control.Monad (forM, replicateM)
import Data.Coerce (coerce)
import Data.List (isSuffixOf)
import Data.Maybe (maybeToList)
import Control.Monad.Trans.Control (liftWith, restoreT)
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Syntax hiding (Lift, lift)
import Control.Monad.Trans.Class (lift)
import Control.Effect.Machinery.Tagger (Tagger(..), runTagger)
import Control.Effect.Machinery.Via (Control, EachVia(..), Find, G, Handle, Lift,
Via, runVia)
data ClassInfo = ClassInfo
{ clsCxt :: Cxt
, clsName :: Name
, clsTyVars :: [TyVarBndr]
, _clsFunDeps :: [FunDep]
, clsDecs :: [Dec]
}
data EffectInfo = EffectInfo
{ _effCxt :: Cxt
, effType :: Q Type
, effParams :: [TyVarBndr]
, effMonad :: TyVarBndr
, effName :: Name
, effTrafoName :: Name
, effSigs :: [Signature]
}
data TaggedInfo = TaggedInfo
{ tgTag :: TyVarBndr
, tgParams :: [TyVarBndr]
, tgMonad :: TyVarBndr
, tgEffName :: Name
, tgNameMap :: String -> Q String
, tgSigs :: [Signature]
}
data Signature = Signature
{ sigName :: Name
, sigType :: Type
}
synonymName :: TaggedInfo -> Q Name
synonymName info = mapName (tgNameMap info) (tgEffName info)
resultType :: Name -> Type -> Q Type
resultType m typ =
case typ of
VarT n `AppT` a | n == m -> pure a
ArrowT `AppT` _ `AppT` r -> resultType m r
ForallT _ _ t -> resultType m t
SigT t _ -> resultType m t
ParensT t -> resultType m t
other -> fail
$ "Expected a return type of the form 'm a', but encountered: "
++ show other
restorables :: Bool -> Name -> Type -> [Type]
restorables neg m typ =
case typ of
VarT n `AppT` a | n == m && neg -> [a]
ArrowT `AppT` a `AppT` r -> restorables (not neg) m a ++ restorables neg m r
ForallT _ _ t -> restorables neg m t
SigT t _ -> restorables neg m t
ParensT t -> restorables neg m t
other -> fail
$ "Encountered an unknown term when finding restorables: "
++ show other
isHigherType :: TyVarBndr -> Type -> Bool
isHigherType monad = go False
where
m = tyVarName monad
go negPos typ =
case typ of
VarT n `AppT` _ | n == m -> negPos
ArrowT `AppT` a `AppT` r ->
go (not negPos) a || go negPos r
ForallT _ _ t ->
go negPos t
_ ->
False
isHigherOrder :: TyVarBndr -> Signature -> Bool
isHigherOrder monad = isHigherType monad . sigType
signature :: Dec -> Q Signature
signature dec =
case dec of
SigD name typ ->
pure (Signature name typ)
other ->
fail
$ "The generation of the effect handling machinery currently supports"
++ " only signatures, but encountered: "
++ show other
unkindTyVar :: TyVarBndr -> TyVarBndr
unkindTyVar (KindedTV n _) = PlainTV n
unkindTyVar unkinded = unkinded
tyVarName :: TyVarBndr -> Name
tyVarName (PlainTV n ) = n
tyVarName (KindedTV n _) = n
tyVarType :: TyVarBndr -> Q Type
tyVarType (PlainTV n ) = varT n
tyVarType (KindedTV n k) = sigT (varT n) k
effectVars :: ClassInfo -> Q ([TyVarBndr], TyVarBndr)
effectVars info =
case clsTyVars info of
[] -> fail
$ "The specified effect type class `"
++ nameBase (clsName info)
++ "' has no monad type variable. "
++ "It is expected to be the last type variable."
vs ->
pure
(init vs, last vs)
effectInfo :: ClassInfo -> Q EffectInfo
effectInfo info = do
(params, clsM) <- effectVars info
t <- newName "t"
sigs <- mapM signature (clsDecs info)
pure $
EffectInfo
( clsCxt info )
( foldl appT (conT $ clsName info) (fmap tyVarType params) )
( params )
( clsM )
( clsName info )
( t )
( sigs )
extractTag :: [TyVarBndr] -> Q (TyVarBndr, [TyVarBndr])
extractTag [] = fail "The effect has no tag parameter."
extractTag (v:vs) = pure (v, vs)
removeApostrophe :: String -> Q String
removeApostrophe name =
if "'" `isSuffixOf` name then
pure $ init name
else
fail $ "Tagged effect and function names are expected to end with \"'\"."
mapName :: (String -> Q String) -> Name -> Q Name
mapName f = fmap mkName . f . nameBase
taggedInfo :: (String -> Q String) -> EffectInfo -> Q TaggedInfo
taggedInfo f info = do
(tag, params) <- extractTag (effParams info)
pure $
TaggedInfo
( tag )
( params )
( effMonad info )
( effName info )
( f )
( effSigs info )
classInfo :: Name -> Q ClassInfo
classInfo className = do
info <- reify className
case info of
ClassI (ClassD context name tyVars funDeps decs) _ ->
pure (ClassInfo context name tyVars funDeps decs)
other ->
fail
$ "The specified name `"
++ nameBase className
++ "' is not a type class, but the following instead: "
++ show other
instanceFinderCxt :: Name -> Name -> EffectInfo -> Q Cxt
instanceFinderCxt name effs info = cxt
[
conT name
`appT` effType info
`appT` varT effs
`appT` varT (effTrafoName info)
`appT` tyVarType (effMonad info)
]
instanceCxt :: Name -> EffectInfo -> Q Cxt
instanceCxt name info = cxt
[
conT name
`appT` effType info
`appT` varT (effTrafoName info)
`appT` tyVarType (effMonad info)
]
instanceHead :: Q Type -> EffectInfo -> Q Type
instanceHead effs info =
effType info
`appT` (
conT ''EachVia
`appT` effs
`appT` varT (effTrafoName info)
`appT` tyVarType (effMonad info)
)
makeEffect :: Name -> Q [Dec]
makeEffect className = do
clsInfo <- classInfo className
effInfo <- effectInfo clsInfo
hInstance <- handler effInfo
fInstance <- finder effInfo
lInstance <- lifter effInfo
pure [hInstance, fInstance, lInstance]
makeTagger :: Name -> Q [Dec]
makeTagger = makeTaggerWith removeApostrophe
makeTaggerWith :: (String -> Q String) -> Name -> Q [Dec]
makeTaggerWith f className = do
clsInfo <- classInfo className
effInfo <- effectInfo clsInfo
tagInfo <- taggedInfo f effInfo
tagger tagInfo
makeTaggedEffect :: Name -> Q [Dec]
makeTaggedEffect = makeTaggedEffectWith removeApostrophe
makeTaggedEffectWith :: (String -> Q String) -> Name -> Q [Dec]
makeTaggedEffectWith f className = do
clsInfo <- classInfo className
effInfo <- effectInfo clsInfo
tagInfo <- taggedInfo f effInfo
hInstance <- handler effInfo
fInstance <- finder effInfo
lInstance <- lifter effInfo
taggerDecs <- tagger tagInfo
pure (hInstance : fInstance : lInstance : taggerDecs)
makeHandler :: Name -> Q [Dec]
makeHandler className = do
clsInfo <- classInfo className
effInfo <- effectInfo clsInfo
hInstance <- handler effInfo
pure [hInstance]
makeFinder :: Name -> Q [Dec]
makeFinder className = do
clsInfo <- classInfo className
effInfo <- effectInfo clsInfo
fInstance <- finder effInfo
pure [fInstance]
makeLifter :: Name -> Q [Dec]
makeLifter className = do
clsInfo <- classInfo className
effInfo <- effectInfo clsInfo
lInstance <- lifter effInfo
pure [lInstance]
tagger :: TaggedInfo -> Q [Dec]
tagger info = do
taggerFuns <- taggerFunctions info
untaggedSyn <- untaggedSynonym info
untaggedFuns <- untaggedFunctions info
taggerInst <- taggerInstance info
pure
$ untaggedSyn
: taggerInst
: taggerFuns
++ untaggedFuns
handler :: EffectInfo -> Q Dec
handler info = do
funs <- handlerFunctions info
effs <- newName "effs"
instanceD
( instanceCxt ''Handle info )
( instanceHead (promotedConsT `appT` effType info `appT` varT effs) info )
( fmap pure funs )
finder :: EffectInfo -> Q Dec
finder info = do
funs <- finderFunctions info
other <- newName "other"
effs <- newName "effs"
instanceWithOverlapD
( Just Overlappable )
( instanceFinderCxt ''Find effs info )
( instanceHead (promotedConsT `appT` varT other `appT` varT effs) info )
( fmap pure funs )
lifter :: EffectInfo -> Q Dec
lifter info = do
let
monad = effMonad info
context =
if any (isHigherOrder monad) (effSigs info)
then ''Control
else ''Lift
funs <- lifterFunctions info
instanceD
( instanceCxt context info )
( instanceHead promotedNilT info )
( fmap pure funs )
taggerFunctions :: TaggedInfo -> Q [Dec]
taggerFunctions info = do
let params = tgParams info
tagVar = tgTag info
effectName = tgEffName info
nameString = nameBase effectName
tagFName = mkName ("tag" ++ nameString)
retagFName = mkName ("retag" ++ nameString)
untagFName = mkName ("untag" ++ nameString)
tag <- newName (nameBase $ tyVarName tagVar)
new <- newName "new"
tagF <- taggerFunction effectName tagFName Nothing (Just new) params
retagF <- taggerFunction effectName retagFName (Just tag) (Just new) params
untagF <- taggerFunction effectName untagFName (Just tag) Nothing params
pure $
tagF ++ retagF ++ untagF
taggerFunction :: Name -> Name -> Maybe Name -> Maybe Name -> [TyVarBndr] -> Q [Dec]
taggerFunction baseName funName tag new params = do
mName <- newName "m"
aName <- newName "a"
let m = varT mName
a = varT aName
tagParam = maybe [t| G |] varT tag
newParam = maybe [t| G |] varT new
tagNames = maybeToList tag ++ maybeToList new
paramNames = fmap tyVarName params
paramTypes = fmap (tyVarType . unkindTyVar) params
forallNames = tagNames ++ paramNames ++ [mName, aName]
forallTypes = fmap PlainTV forallNames
effectType = foldl appT (conT baseName) (tagParam : paramTypes)
funSigType <- [t| ($effectType `Via` Tagger $tagParam $newParam) $m $a -> $m $a |]
funSig <- sigD funName $ forallT forallTypes (cxt []) (pure funSigType)
funDef <- [d| $(varP funName) = runTagger . runVia |]
funInline <- pragInlD funName Inline FunLike AllPhases
pure (funSig : funInline : funDef)
untaggedSynonym :: TaggedInfo -> Q Dec
untaggedSynonym info = do
synName <- synonymName info
tySynD
( synName )
( params )
( foldl appT (conT effectName) (conT ''G : fmap tyVarType params) )
where
effectName = tgEffName info
params = fmap unkindTyVar (tgParams info)
untaggedFunctions :: TaggedInfo -> Q [Dec]
untaggedFunctions info = do
synName <- synonymName info
fmap concat $
forM (tgSigs info)
$ untaggedFunction (tgNameMap info)
$ foldl
( appT )
( conT synName )
( fmap (tyVarType . unkindTyVar) $ tgParams info ++ [tgMonad info] )
untaggedFunction :: (String -> Q String) -> Q Type -> Signature -> Q [Dec]
untaggedFunction f effectType sig = do
let originalName = sigName sig
signatureBody = pure (unkindType $ sigType sig)
funName <- mapName f originalName
funSig <- sigD funName [t| $effectType => $signatureBody |]
funDef <- [d| $(varP funName) = $(varE originalName) @G |]
funInline <- pragInlD funName Inline FunLike AllPhases
pure (funSig : funInline : funDef)
taggerInstance :: TaggedInfo -> Q Dec
taggerInstance info = do
newTagName <- newName "new"
let new = varT newTagName
monadName = tyVarName (tgMonad info)
m = varT monadName
tag = tyVarType (tgTag info)
effectType = conT $ tgEffName info
paramTypes = fmap tyVarType (tgParams info)
taggerType = [t| Tagger $tag $new $m |]
cxtParams = new : paramTypes ++ [m]
headParams = tag : paramTypes ++ [taggerType]
funs <-
fmap concat $
forM (tgSigs info) $ taggerInstanceFunction new monadName
instanceD
( cxt [foldl appT effectType cxtParams] )
( foldl appT effectType headParams )
( fmap pure funs )
taggerInstanceFunction :: Q Type -> Name -> Signature -> Q [Dec]
taggerInstanceFunction new monad sig = do
let typ = sigType sig
funName = sigName sig
expr = derive [] [| Tagger |] [| runTagger |] monad typ
typeAppliedName = varE funName `appTypeE` new
funDef <- [d| $(varP funName) = $expr $typeAppliedName |]
funInline <- pragInlD funName Inline FunLike AllPhases
pure (funInline : funDef)
paramCount :: Type -> Int
paramCount typ =
case typ of
ArrowT `AppT` _ `AppT` r -> 1 + paramCount r
ForallT _ _ t -> paramCount t
_ -> 0
invalid :: Q Exp
invalid = fail
$ "Could not generate effect instance because the operation is "
++ "invalid for higher-order effects."
handlerFunctions :: EffectInfo -> Q [Dec]
handlerFunctions info =
fmap concat $
mapM
( function [| EachVia |] [| runVia |] (effMonad info) (effParams info) )
( effSigs info )
liftL :: EachVia effs t m a -> EachVia (eff : effs) t m a
liftL = coerce
{-# INLINE liftL #-}
runL :: EachVia (eff : effs) t m a -> EachVia effs t m a
runL = coerce
{-# INLINE runL #-}
finderFunctions :: EffectInfo -> Q [Dec]
finderFunctions info =
fmap concat $
mapM
( function [| liftL |] [| runL |] (effMonad info) (effParams info) )
( effSigs info )
lifterFunctions :: EffectInfo -> Q [Dec]
lifterFunctions info =
let m = effMonad info
params = effParams info
in
fmap concat $
forM (effSigs info) $ \sig ->
if isHigherOrder m sig
then higherFunction m params sig
else function [| lift |] invalid m params sig
function :: Q Exp -> Q Exp -> TyVarBndr -> [TyVarBndr] -> Signature -> Q [Dec]
function f inv monad params sig = do
let m = tyVarName monad
funName = sigName sig
paramTypes = fmap tyVarType params
typeAppliedName = foldl appTypeE (varE funName) paramTypes
expr = derive [] f inv m (sigType sig)
funDef <- [d| $(varP funName) = $expr $typeAppliedName |]
funInline <- pragInlD funName Inline FunLike AllPhases
pure (funInline : funDef)
higherFunction :: TyVarBndr -> [TyVarBndr] -> Signature -> Q [Dec]
higherFunction monad params sig = do
let m = tyVarName monad
typ = sigType sig
funName = sigName sig
paramTypes = fmap tyVarType params
restores = restorables False m typ
expr = derive restores [| id |] [| run . runVia |] m typ
fParams <- replicateM (paramCount typ) (newName "x")
res <- resultType m typ
let typeAppliedName = foldl appTypeE (varE funName) paramTypes
appliedExp = foldl appE expr (typeAppliedName : fmap varE fParams)
body =
[| EachVia $
(liftWith $ \ $([p|run|]) -> $appliedExp)
>>= $(traverseExp res) (restoreT . pure)
|]
funDef <- funD funName [clause (fmap varP fParams) (normalB body) []]
funInline <- pragInlD funName Inline FunLike AllPhases
pure [funDef, funInline]
unkindType :: Type -> Type
unkindType typ =
case typ of
ForallT _ _ t -> unkindType t
AppT l r -> AppT (unkindType l) (unkindType r)
SigT t _ -> t
InfixT l n r -> InfixT (unkindType l) n (unkindType r)
UInfixT l n r -> UInfixT (unkindType l) n (unkindType r)
ParensT t -> ParensT (unkindType t)
other -> other
contains :: Name -> Type -> Bool
contains m typ =
case typ of
ForallT _ _ t -> contains m t
AppT l r -> contains m l || contains m r
SigT t _ -> contains m t
VarT n -> n == m
ConT n -> n == m
PromotedT n -> n == m
InfixT l n r -> n == m || contains m l || contains m r
UInfixT l n r -> n == m || contains m l || contains m r
ParensT t -> contains m t
_ -> False
derive :: [Type] -> Q Exp -> Q Exp -> Name -> Type -> Q Exp
derive rs f inv m typ =
case typ of
t | not (contains m t) ->
[| id |]
VarT n `AppT` _ | n == m ->
f
ArrowT `AppT` arg `AppT` res ->
let rf = derive rs f inv m res
af = derive rs inv f m arg
in if elem arg rs
then [| \x b -> $rf (((x =<<) . EachVia . restoreT . pure) b) |]
else [| \x b -> $rf (x ($af b)) |]
ForallT _ _ t ->
derive rs f inv m t
other -> fail
$ "Could not generate effect instance because an unknown structure "
++ "was encountered: "
++ show other
traverseExp :: Type -> Q Exp
traverseExp typ =
case typ of
ForallT _ _ t -> traverseExp t
AppT _ r -> traverseRec r
SigT t _ -> traverseExp t
InfixT _ _ r -> traverseRec r
UInfixT _ _ r -> traverseRec r
ParensT t -> traverseExp t
_ -> [| id |]
where
traverseRec t = [| traverse . $(traverseExp t) |]