{-# LANGUAGE TemplateHaskell, CPP #-}
module Data.Express.Utils.TH
( reallyDeriveCascading
, deriveWhenNeeded
, deriveWhenNeededOrWarn
, typeConArgs
, typeConArgsThat
, typeConCascadingArgsThat
, normalizeType
, normalizeTypeUnits
, isInstanceOf
, isntInstanceOf
, typeArity
, typeConstructors
, isTypeSynonym
, typeSynonymType
, mergeIFns
, mergeI
, lookupValN
, showJustName
, typeConstructorsArgNames
, (|=>|)
, (|++|)
, whereI
, module Language.Haskell.TH
)
where
import Control.Monad
import Data.List
import Language.Haskell.TH
deriveWhenNeeded :: Name -> (Name -> DecsQ) -> Name -> DecsQ
deriveWhenNeeded = deriveWhenNeededX False
deriveWhenNeededOrWarn :: Name -> (Name -> DecsQ) -> Name -> DecsQ
deriveWhenNeededOrWarn = deriveWhenNeededX True
deriveWhenNeededX :: Bool -> Name -> (Name -> DecsQ) -> Name -> DecsQ
deriveWhenNeededX warnExisting cls reallyDerive t = do
is <- t `isInstanceOf` cls
if is
then do
unless (not warnExisting)
(reportWarning $ "Instance " ++ showJustName cls ++ " " ++ showJustName t
++ " already exists, skipping derivation")
return []
else
reallyDerive t
showJustName :: Name -> String
showJustName = reverse . takeWhile (/= '.') . reverse . show
reallyDeriveCascading :: Name -> (Name -> DecsQ) -> Name -> DecsQ
reallyDeriveCascading cls reallyDerive t =
return . concat
=<< mapM reallyDerive
=<< filterM (liftM not . isTypeSynonym)
=<< return . (t:) . delete t
=<< t `typeConCascadingArgsThat` (`isntInstanceOf` cls)
typeConArgs :: Name -> Q [Name]
typeConArgs t = do
is <- isTypeSynonym t
if is
then liftM typeConTs $ typeSynonymType t
else liftM (nubMerges . map typeConTs . concat . map snd) $ typeConstructors t
where
typeConTs :: Type -> [Name]
typeConTs (AppT t1 t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (SigT t _) = typeConTs t
typeConTs (VarT _) = []
typeConTs (ConT n) = [n]
#if __GLASGOW_HASKELL__ >= 800
typeConTs (InfixT t1 n t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (UInfixT t1 n t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (ParensT t) = typeConTs t
#endif
typeConTs _ = []
typeConArgsThat :: Name -> (Name -> Q Bool) -> Q [Name]
typeConArgsThat t p = do
targs <- typeConArgs t
tbs <- mapM (\t' -> do is <- p t'; return (t',is)) targs
return [t' | (t',p) <- tbs, p]
typeConCascadingArgsThat :: Name -> (Name -> Q Bool) -> Q [Name]
t `typeConCascadingArgsThat` p = do
ts <- t `typeConArgsThat` p
let p' t' = do is <- p t'; return $ t' `notElem` (t:ts) && is
tss <- mapM (`typeConCascadingArgsThat` p') ts
return $ nubMerges (ts:tss)
normalizeType :: Name -> Q (Type, [Type])
normalizeType t = do
ar <- typeArity t
vs <- newVarTs ar
return (foldl AppT (ConT t) vs, vs)
where
newNames :: [String] -> Q [Name]
newNames = mapM newName
newVarTs :: Int -> Q [Type]
newVarTs n = liftM (map VarT)
$ newNames (take n . map (:[]) $ cycle ['a'..'z'])
normalizeTypeUnits :: Name -> Q Type
normalizeTypeUnits t = do
ar <- typeArity t
return (foldl AppT (ConT t) (replicate ar (TupleT 0)))
isInstanceOf :: Name -> Name -> Q Bool
isInstanceOf tn cl = do
ty <- normalizeTypeUnits tn
isInstance cl [ty]
isntInstanceOf :: Name -> Name -> Q Bool
isntInstanceOf tn cl = liftM not (isInstanceOf tn cl)
typeArity :: Name -> Q Int
typeArity t = do
ti <- reify t
return . length $ case ti of
#if __GLASGOW_HASKELL__ < 800
TyConI (DataD _ _ ks _ _) -> ks
TyConI (NewtypeD _ _ ks _ _) -> ks
#else
TyConI (DataD _ _ ks _ _ _) -> ks
TyConI (NewtypeD _ _ ks _ _ _) -> ks
#endif
TyConI (TySynD _ ks _) -> ks
_ -> error $ "error (typeArity): symbol " ++ show t
++ " is not a newtype, data or type synonym"
typeConstructors :: Name -> Q [(Name,[Type])]
typeConstructors t = do
ti <- reify t
return . map simplify $ case ti of
#if __GLASGOW_HASKELL__ < 800
TyConI (DataD _ _ _ cs _) -> cs
TyConI (NewtypeD _ _ _ c _) -> [c]
#else
TyConI (DataD _ _ _ _ cs _) -> cs
TyConI (NewtypeD _ _ _ _ c _) -> [c]
#endif
_ -> error $ "error (typeConstructors): symbol " ++ show t
++ " is neither newtype nor data"
where
simplify (NormalC n ts) = (n,map snd ts)
simplify (RecC n ts) = (n,map trd ts)
simplify (InfixC t1 n t2) = (n,[snd t1,snd t2])
trd (x,y,z) = z
isTypeSynonym :: Name -> Q Bool
isTypeSynonym t = do
ti <- reify t
return $ case ti of
TyConI (TySynD _ _ _) -> True
_ -> False
typeSynonymType :: Name -> Q Type
typeSynonymType t = do
ti <- reify t
return $ case ti of
TyConI (TySynD _ _ t') -> t'
_ -> error $ "error (typeSynonymType): symbol " ++ show t
++ " is not a type synonym"
(|=>|) :: Cxt -> DecsQ -> DecsQ
c |=>| qds = do ds <- qds
return $ map (`ac` c) ds
#if __GLASGOW_HASKELL__ < 800
where ac (InstanceD c ts ds) c' = InstanceD (c++c') ts ds
ac d _ = d
#else
where ac (InstanceD o c ts ds) c' = InstanceD o (c++c') ts ds
ac d _ = d
#endif
(|++|) :: DecsQ -> DecsQ -> DecsQ
(|++|) = liftM2 (++)
mergeIFns :: DecsQ -> DecsQ
mergeIFns qds = do ds <- qds
return $ map m' ds
where
#if __GLASGOW_HASKELL__ < 800
m' (InstanceD c ts ds) = InstanceD c ts [foldr1 m ds]
#else
m' (InstanceD o c ts ds) = InstanceD o c ts [foldr1 m ds]
#endif
FunD n cs1 `m` FunD _ cs2 = FunD n (cs1 ++ cs2)
mergeI :: DecsQ -> DecsQ -> DecsQ
qds1 `mergeI` qds2 = do ds1 <- qds1
ds2 <- qds2
return $ ds1 `m` ds2
where
#if __GLASGOW_HASKELL__ < 800
[InstanceD c ts ds1] `m` [InstanceD _ _ ds2] = [InstanceD c ts (ds1 ++ ds2)]
#else
[InstanceD o c ts ds1] `m` [InstanceD _ _ _ ds2] = [InstanceD o c ts (ds1 ++ ds2)]
#endif
whereI :: DecsQ -> [Dec] -> DecsQ
qds `whereI` w = do ds <- qds
return $ map (`aw` w) ds
#if __GLASGOW_HASKELL__ < 800
where aw (InstanceD c ts ds) w' = InstanceD c ts (ds++w')
aw d _ = d
#else
where aw (InstanceD o c ts ds) w' = InstanceD o c ts (ds++w')
aw d _ = d
#endif
nubMerge :: Ord a => [a] -> [a] -> [a]
nubMerge [] ys = ys
nubMerge xs [] = xs
nubMerge (x:xs) (y:ys) | x < y = x : xs `nubMerge` (y:ys)
| x > y = y : (x:xs) `nubMerge` ys
| otherwise = x : xs `nubMerge` ys
nubMerges :: Ord a => [[a]] -> [a]
nubMerges = foldr nubMerge []
typeConstructorsArgNames :: Name -> Q [(Name,[Name])]
typeConstructorsArgNames t = do
cs <- typeConstructors t
sequence [ do ns <- sequence [newName "x" | _ <- ts]
return (c,ns)
| (c,ts) <- cs ]
lookupValN :: String -> Q Name
lookupValN s = do
mn <- lookupValueName s
case mn of
Just n -> return n
Nothing -> fail $ "lookupValN: cannot find " ++ s