{-# LANGUAGE TemplateHaskell, CPP, NamedFieldPuns #-}
module Data.Acid.TemplateHaskell where
import Language.Haskell.TH
import Language.Haskell.TH.Ppr
import Language.Haskell.TH.ExpandSyns
import Data.Acid.Core
import Data.Acid.Common
import Data.List ((\\), nub, delete)
import Data.SafeCopy
import Data.Typeable
import Data.Char
import Data.Monoid ((<>))
import Control.Applicative
import Control.Monad
import Control.Monad.State (MonadState)
import Control.Monad.Reader (MonadReader)
makeAcidic :: Name -> [Name] -> Q [Dec]
makeAcidic = makeAcidicWithSerialiser safeCopySerialiserSpec
data SerialiserSpec =
SerialiserSpec
{ serialisationClassName :: Name
, methodSerialiserName :: Name
, makeEventSerialiser :: Name -> Type -> DecQ
}
safeCopySerialiserSpec :: SerialiserSpec
safeCopySerialiserSpec =
SerialiserSpec { serialisationClassName = ''SafeCopy
, methodSerialiserName = 'safeCopyMethodSerialiser
, makeEventSerialiser = makeSafeCopyInstance
}
makeAcidicWithSerialiser :: SerialiserSpec -> Name -> [Name] -> Q [Dec]
makeAcidicWithSerialiser ss stateName eventNames
= do stateInfo <- reify stateName
case stateInfo of
TyConI tycon
->case tycon of
#if MIN_VERSION_template_haskell(2,11,0)
DataD _cxt _name tyvars _kind constructors _derivs
#else
DataD _cxt _name tyvars constructors _derivs
#endif
-> makeAcidic' ss eventNames stateName tyvars constructors
#if MIN_VERSION_template_haskell(2,11,0)
NewtypeD _cxt _name tyvars _kind constructor _derivs
#else
NewtypeD _cxt _name tyvars constructor _derivs
#endif
-> makeAcidic' ss eventNames stateName tyvars [constructor]
TySynD _name tyvars _ty
-> makeAcidic' ss eventNames stateName tyvars []
_ -> error "Data.Acid.TemplateHaskell: Unsupported state type. Only 'data', 'newtype' and 'type' are supported."
_ -> error "Data.Acid.TemplateHaskell: Given state is not a type."
makeAcidic' :: SerialiserSpec -> [Name] -> Name -> [TyVarBndr] -> [Con] -> Q [Dec]
makeAcidic' ss eventNames stateName tyvars constructors
= do events <- sequence [ makeEvent ss eventName | eventName <- eventNames ]
acidic <- makeIsAcidic ss eventNames stateName tyvars constructors
return $ acidic : concat events
makeEvent :: SerialiserSpec -> Name -> Q [Dec]
makeEvent ss eventName
= do exists <- recover (return False) (reify (toStructName eventName) >> return True)
eventType <- getEventType eventName
if exists
then do b <- makeEventSerialiser ss eventName eventType
return [b]
else do d <- makeEventDataType eventName eventType
b <- makeEventSerialiser ss eventName eventType
i <- makeMethodInstance eventName eventType
e <- makeEventInstance eventName eventType
return [d,b,i,e]
getEventType :: Name -> Q Type
getEventType eventName
= do eventInfo <- reify eventName
case eventInfo of
#if MIN_VERSION_template_haskell(2,11,0)
VarI _name eventType _decl
#else
VarI _name eventType _decl _fixity
#endif
-> expandSyns eventType
_ -> error $ "Data.Acid.TemplateHaskell: Events must be functions: " ++ show eventName
makeIsAcidic ss eventNames stateName tyvars constructors
= do types <- mapM getEventType eventNames
stateType' <- stateType
let preds = [ serialisationClassName ss, ''Typeable ]
ty = appT (conT ''IsAcidic) stateType
handlers = zipWith (makeEventHandler ss) eventNames types
cxtFromEvents = nub $ concat $ zipWith (eventCxts stateType' tyvars) eventNames types
cxts' <- mkCxtFromTyVars preds tyvars cxtFromEvents
instanceD (return cxts') ty
[ valD (varP 'acidEvents) (normalB (listE handlers)) []
]
where stateType = foldl appT (conT stateName) (map varT (allTyVarBndrNames tyvars))
eventCxts :: Type
-> [TyVarBndr]
-> Name
-> Type
-> [Pred]
eventCxts targetStateType targetTyVars eventName eventType =
let TypeAnalysis { context = cxt, stateType }
= analyseType eventName eventType
eventTyVars = findTyVars stateType
table = zip eventTyVars (map tyVarBndrName targetTyVars)
in map (unify table)
(renameState stateType targetStateType cxt)
where
unify :: [(Name, Name)] -> Pred -> Pred
#if MIN_VERSION_template_haskell(2,10,0)
unify table p = rename p table p
#else
unify table p@(ClassP n tys) = ClassP n (map (rename p table) tys)
unify table p@(EqualP a b) = EqualP (rename p table a) (rename p table b)
#endif
rename :: Pred -> [(Name, Name)] -> Type -> Type
rename pred table t@(ForallT tyvarbndrs cxt typ) =
ForallT (map renameTyVar tyvarbndrs) (map (unify table) cxt) (rename pred table typ)
where
renameTyVar (PlainTV name) = PlainTV (renameName pred table name)
renameTyVar (KindedTV name k) = KindedTV (renameName pred table name) k
rename pred table (VarT n) = VarT $ renameName pred table n
rename pred table (AppT a b) = AppT (rename pred table a) (rename pred table b)
rename pred table (SigT a k) = SigT (rename pred table a) k
rename _ _ typ = typ
renameName :: Pred -> [(Name, Name)] -> Name -> Name
renameName pred table n =
case lookup n table of
Nothing -> error $ unlines [ "Data.Acid.TemplateHaskell: "
, ""
, show $ ppr_sig eventName eventType
, ""
, "can not be used as an UpdateEvent because the class context: "
, ""
, pprint pred
, ""
, "contains a type variable which is not found in the state type: "
, ""
, pprint targetStateType
, ""
, "You may be able to fix this by providing a type signature that fixes these type variable(s)"
]
(Just n') -> n'
renameState :: Type -> Type -> Cxt -> Cxt
renameState tfrom tto cxt = map renamePred cxt
where
#if MIN_VERSION_template_haskell(2,10,0)
renamePred p = renameType p
#else
renamePred (ClassP n tys) = ClassP n (map renameType tys)
renamePred (EqualP a b) = EqualP (renameType a) (renameType b)
#endif
renameType n | n == tfrom = tto
renameType (AppT a b) = AppT (renameType a) (renameType b)
renameType (SigT a k) = SigT (renameType a) k
renameType typ = typ
makeEventHandler :: SerialiserSpec -> Name -> Type -> ExpQ
makeEventHandler ss eventName eventType
= do assertTyVarsOk
vars <- replicateM (length args) (newName "arg")
let lamClause = conP eventStructName [varP var | var <- vars ]
conE constr `appE` lamE [lamClause] (foldl appE (varE eventName) (map varE vars))
`appE` varE (methodSerialiserName ss)
where constr = if isUpdate then 'UpdateEvent else 'QueryEvent
TypeAnalysis { tyvars, argumentTypes = args, stateType, isUpdate } = analyseType eventName eventType
eventStructName = toStructName eventName
stateTypeTyVars = findTyVars stateType
tyVarNames = map tyVarBndrName tyvars
assertTyVarsOk =
case tyVarNames \\ stateTypeTyVars of
[] -> return ()
ns -> error $ "Data.Acid.TemplateHaskell: " <> unlines
[show $ ppr_sig eventName eventType
, ""
, "can not be used as an UpdateEvent because it contains the type variables: "
, ""
, pprint ns
, ""
, "which do not appear in the state type:"
, ""
, pprint stateType
]
makeEventDataType :: Name -> Type -> DecQ
makeEventDataType eventName eventType
= do let con = normalC eventStructName [ strictType notStrict (return arg) | arg <- args ]
#if MIN_VERSION_template_haskell(2,12,0)
cxt = [derivClause Nothing [conT ''Typeable]]
#elif MIN_VERSION_template_haskell(2,11,0)
cxt = mapM conT [''Typeable]
#else
cxt = [''Typeable]
#endif
case args of
#if MIN_VERSION_template_haskell(2,11,0)
[_] -> newtypeD (return []) eventStructName tyvars Nothing con cxt
_ -> dataD (return []) eventStructName tyvars Nothing [con] cxt
#else
[_] -> newtypeD (return []) eventStructName tyvars con cxt
_ -> dataD (return []) eventStructName tyvars [con] cxt
#endif
where TypeAnalysis { tyvars, argumentTypes = args } = analyseType eventName eventType
eventStructName = toStructName eventName
makeSafeCopyInstance :: Name -> Type -> DecQ
makeSafeCopyInstance eventName eventType
= do let preds = [ ''SafeCopy ]
ty = AppT (ConT ''SafeCopy) (foldl AppT (ConT eventStructName) (map VarT (allTyVarBndrNames tyvars)))
getBase = appE (varE 'return) (conE eventStructName)
getArgs = foldl (\a b -> infixE (Just a) (varE '(<*>)) (Just (varE 'safeGet))) getBase args
contained val = varE 'contain `appE` val
putVars <- replicateM (length args) (newName "arg")
let putClause = conP eventStructName [varP var | var <- putVars ]
putExp = doE $ [ noBindS $ appE (varE 'safePut) (varE var) | var <- putVars ] ++
[ noBindS $ appE (varE 'return) (tupE []) ]
instanceD (mkCxtFromTyVars preds tyvars context)
(return ty)
[ funD 'putCopy [clause [putClause] (normalB (contained putExp)) []]
, valD (varP 'getCopy) (normalB (contained getArgs)) []
, funD 'errorTypeName [clause [wildP] (normalB (litE (stringL (pprint ty)))) []]
]
where TypeAnalysis { tyvars, context, argumentTypes = args } = analyseType eventName eventType
eventStructName = toStructName eventName
mkCxtFromTyVars preds tyvars extraContext
= cxt $ [ classP classPred [varT tyvar] | tyvar <- allTyVarBndrNames tyvars, classPred <- preds ] ++
map return extraContext
makeMethodInstance :: Name -> Type -> DecQ
makeMethodInstance eventName eventType = do
let preds =
[ ''Typeable ]
ty =
AppT (ConT ''Method) (foldl AppT (ConT eventStructName) (map VarT (allTyVarBndrNames tyvars)))
structType =
foldl appT (conT eventStructName) (map varT (allTyVarBndrNames tyvars))
instanceContext =
cxt $
[ classP classPred [varT tyvar]
| tyvar <- allTyVarBndrNames tyvars
, classPred <- preds
]
++ map return context
instanceD
instanceContext
(return ty)
#if MIN_VERSION_template_haskell(2,15,0)
[ tySynInstD $ tySynEqn Nothing (conT ''MethodResult `appT` structType) (return resultType)
, tySynInstD $ tySynEqn Nothing (conT ''MethodState `appT` structType) (return stateType)
#elif __GLASGOW_HASKELL__ >= 707
[ tySynInstD ''MethodResult (tySynEqn [structType] (return resultType))
, tySynInstD ''MethodState (tySynEqn [structType] (return stateType))
#else
[ tySynInstD ''MethodResult [structType] (return resultType)
, tySynInstD ''MethodState [structType] (return stateType)
#endif
]
where TypeAnalysis { tyvars, context, stateType, resultType } = analyseType eventName eventType
eventStructName = toStructName eventName
makeEventInstance :: Name -> Type -> DecQ
makeEventInstance eventName eventType
= do let preds = [ ''Typeable ]
eventClass = if isUpdate then ''UpdateEvent else ''QueryEvent
ty = AppT (ConT eventClass) (foldl AppT (ConT eventStructName) (map VarT (allTyVarBndrNames tyvars)))
instanceD (cxt $ [ classP classPred [varT tyvar] | tyvar <- allTyVarBndrNames tyvars, classPred <- preds ] ++ map return context)
(return ty)
[]
where TypeAnalysis { tyvars, context, isUpdate } = analyseType eventName eventType
eventStructName = toStructName eventName
data TypeAnalysis = TypeAnalysis
{ tyvars :: [TyVarBndr]
, context :: Cxt
, argumentTypes :: [Type]
, stateType :: Type
, resultType :: Type
, isUpdate :: Bool
} deriving (Eq, Show)
analyseType :: Name -> Type -> TypeAnalysis
analyseType eventName t = go [] [] [] t
where
#if MIN_VERSION_template_haskell(2,10,0)
getMonadReader :: Cxt -> Name -> [(Type, Type)]
getMonadReader cxt m = do
constraint@(AppT (AppT (ConT c) x) m') <- cxt
guard (c == ''MonadReader && m' == VarT m)
return (constraint, x)
getMonadState :: Cxt -> Name -> [(Type, Type)]
getMonadState cxt m = do
constraint@(AppT (AppT (ConT c) x) m') <- cxt
guard (c == ''MonadState && m' == VarT m)
return (constraint, x)
#else
getMonadReader :: Cxt -> Name -> [(Pred, Type)]
getMonadReader cxt m = do
constraint@(ClassP c [x, m']) <- cxt
guard (c == ''MonadReader && m' == VarT m)
return (constraint, x)
getMonadState :: Cxt -> Name -> [(Pred, Type)]
getMonadState cxt m = do
constraint@(ClassP c [x, m']) <- cxt
guard (c == ''MonadState && m' == VarT m)
return (constraint, x)
#endif
go tyvars cxt args (AppT (AppT ArrowT a) b)
= go tyvars cxt (args ++ [a]) b
go tyvars context argumentTypes (AppT (AppT (ConT con) stateType) resultType)
| con == ''Update =
TypeAnalysis
{ tyvars, context, argumentTypes, stateType, resultType
, isUpdate = True
}
| con == ''Query =
TypeAnalysis
{ tyvars, context, argumentTypes, stateType, resultType
, isUpdate = False
}
go tyvars cxt args (ForallT tyvars2 cxt2 a)
= go (tyvars ++ tyvars2) (cxt ++ cxt2) args a
go tyvars' cxt argumentTypes (AppT (VarT m) resultType)
| [] <- queries, [(cx, stateType)] <- updates
= TypeAnalysis
{ tyvars, argumentTypes , stateType, resultType
, isUpdate = True
, context = delete cx cxt
}
| [(cx, stateType)] <- queries, [] <- updates
= TypeAnalysis
{ tyvars, argumentTypes , stateType, resultType
, isUpdate = False
, context = delete cx cxt
}
where
queries = getMonadReader cxt m
updates = getMonadState cxt m
tyvars = filter ((/= m) . tyVarBndrName) tyvars'
go _ _ _ _ = error $ "Data.Acid.TemplateHaskell: Event has an invalid type signature: Not an Update, Query, MonadState, or MonadReader: " ++ show eventName
findTyVars :: Type -> [Name]
findTyVars (ForallT _ _ a) = findTyVars a
findTyVars (VarT n) = [n]
findTyVars (AppT a b) = findTyVars a ++ findTyVars b
findTyVars (SigT a _) = findTyVars a
findTyVars _ = []
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV n) = n
tyVarBndrName (KindedTV n _) = n
allTyVarBndrNames :: [TyVarBndr] -> [Name]
allTyVarBndrNames tyvars = map tyVarBndrName tyvars
toStructName :: Name -> Name
toStructName eventName = mkName (structName (nameBase eventName))
where
structName [] = []
structName (x:xs) = toUpper x : xs