{-# LANGUAGE TemplateHaskell, CPP #-}
module Test.LeanCheck.Derive
( deriveListable
, deriveListableIfNeeded
, deriveListableCascading
, deriveTiers
, deriveList
)
where
#ifdef __GLASGOW_HASKELL__
import Language.Haskell.TH
import Test.LeanCheck.Basic
import Control.Monad (unless, liftM, liftM2, filterM)
import Data.List (delete)
#if __GLASGOW_HASKELL__ < 706
reportWarning :: String -> Q ()
reportWarning = report False
#endif
deriveListable :: Name -> DecsQ
deriveListable = deriveListableX True False
deriveListableIfNeeded :: Name -> DecsQ
deriveListableIfNeeded = deriveListableX False False
deriveListableCascading :: Name -> DecsQ
deriveListableCascading = deriveListableX True True
deriveListableX :: Bool -> Bool -> Name -> DecsQ
deriveListableX warnExisting cascade t = do
is <- t `isInstanceOf` ''Listable
if is
then do
unless (not warnExisting)
(reportWarning $ "Instance Listable " ++ show t
++ " already exists, skipping derivation")
return []
else if cascade
then reallyDeriveListableCascading t
else reallyDeriveListable t
reallyDeriveListable :: Name -> DecsQ
reallyDeriveListable t = do
(nt,vs) <- normalizeType t
#if __GLASGOW_HASKELL__ >= 710
cxt <- sequence [[t| Listable $(return v) |] | v <- vs]
#else
cxt <- sequence [classP ''Listable [return v] | v <- vs]
#endif
#if __GLASGOW_HASKELL__ >= 708
cxt |=>| [d| instance Listable $(return nt)
where tiers = $(deriveTiers t) |]
#else
tiersE <- deriveTiers t
return [ InstanceD
cxt
(AppT (ConT ''Listable) nt)
[ValD (VarP 'tiers) (NormalB tiersE) []]
]
#endif
deriveTiers :: Name -> ExpQ
deriveTiers t = conse =<< typeConstructors t
where
cone n as = do
(Just consN) <- lookupValueName $ "cons" ++ show (length as)
[| $(varE consN) $(conE n) |]
conse = foldr1 (\e1 e2 -> [| $e1 \/ $e2 |]) . map (uncurry cone)
deriveList :: Name -> ExpQ
deriveList t = [| concat $(deriveTiers t) |]
reallyDeriveListableCascading :: Name -> DecsQ
reallyDeriveListableCascading t =
return . concat
=<< mapM reallyDeriveListable
=<< filterM (liftM not . isTypeSynonym)
=<< return . (t:) . delete t
=<< t `typeConCascadingArgsThat` (`isntInstanceOf` ''Listable)
typeConArgs :: Name -> Q [Name]
typeConArgs t = do
is <- isTypeSynonym t
if is
then liftM typeConTs $ typeSynonymType t
else liftM (nubMerges . map typeConTs . concat . map snd) $ typeConstructors t
where
typeConTs :: Type -> [Name]
typeConTs (AppT t1 t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (SigT t _) = typeConTs t
typeConTs (VarT _) = []
typeConTs (ConT n) = [n]
#if __GLASGOW_HASKELL__ >= 800
typeConTs (InfixT t1 n t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (UInfixT t1 n t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (ParensT t) = typeConTs t
#endif
typeConTs _ = []
typeConArgsThat :: Name -> (Name -> Q Bool) -> Q [Name]
typeConArgsThat t p = do
targs <- typeConArgs t
tbs <- mapM (\t' -> do is <- p t'; return (t',is)) targs
return [t' | (t',p) <- tbs, p]
typeConCascadingArgsThat :: Name -> (Name -> Q Bool) -> Q [Name]
t `typeConCascadingArgsThat` p = do
ts <- t `typeConArgsThat` p
let p' t' = do is <- p t'; return $ t' `notElem` (t:ts) && is
tss <- mapM (`typeConCascadingArgsThat` p') ts
return $ nubMerges (ts:tss)
normalizeType :: Name -> Q (Type, [Type])
normalizeType t = do
ar <- typeArity t
vs <- newVarTs ar
return (foldl AppT (ConT t) vs, vs)
where
newNames :: [String] -> Q [Name]
newNames = mapM newName
newVarTs :: Int -> Q [Type]
newVarTs n = liftM (map VarT)
$ newNames (take n . map (:[]) $ cycle ['a'..'z'])
normalizeTypeUnits :: Name -> Q Type
normalizeTypeUnits t = do
ar <- typeArity t
return (foldl AppT (ConT t) (replicate ar (TupleT 0)))
isInstanceOf :: Name -> Name -> Q Bool
isInstanceOf tn cl = do
ty <- normalizeTypeUnits tn
isInstance cl [ty]
isntInstanceOf :: Name -> Name -> Q Bool
isntInstanceOf tn cl = liftM not (isInstanceOf tn cl)
typeArity :: Name -> Q Int
typeArity t = do
ti <- reify t
return . length $ case ti of
#if __GLASGOW_HASKELL__ < 800
TyConI (DataD _ _ ks _ _) -> ks
TyConI (NewtypeD _ _ ks _ _) -> ks
#else
TyConI (DataD _ _ ks _ _ _) -> ks
TyConI (NewtypeD _ _ ks _ _ _) -> ks
#endif
TyConI (TySynD _ ks _) -> ks
_ -> error $ "error (typeArity): symbol " ++ show t
++ " is not a newtype, data or type synonym"
typeConstructors :: Name -> Q [(Name,[Type])]
typeConstructors t = do
ti <- reify t
return . map simplify $ case ti of
#if __GLASGOW_HASKELL__ < 800
TyConI (DataD _ _ _ cs _) -> cs
TyConI (NewtypeD _ _ _ c _) -> [c]
#else
TyConI (DataD _ _ _ _ cs _) -> cs
TyConI (NewtypeD _ _ _ _ c _) -> [c]
#endif
_ -> error $ "error (typeConstructors): symbol " ++ show t
++ " is neither newtype nor data"
where
simplify (NormalC n ts) = (n,map snd ts)
simplify (RecC n ts) = (n,map trd ts)
simplify (InfixC t1 n t2) = (n,[snd t1,snd t2])
trd (x,y,z) = z
isTypeSynonym :: Name -> Q Bool
isTypeSynonym t = do
ti <- reify t
return $ case ti of
TyConI (TySynD _ _ _) -> True
_ -> False
typeSynonymType :: Name -> Q Type
typeSynonymType t = do
ti <- reify t
return $ case ti of
TyConI (TySynD _ _ t') -> t'
_ -> error $ "error (typeSynonymType): symbol " ++ show t
++ " is not a type synonym"
(|=>|) :: Cxt -> DecsQ -> DecsQ
c |=>| qds = do ds <- qds
return $ map (`ac` c) ds
#if __GLASGOW_HASKELL__ < 800
where ac (InstanceD c ts ds) c' = InstanceD (c++c') ts ds
ac d _ = d
#else
where ac (InstanceD o c ts ds) c' = InstanceD o (c++c') ts ds
ac d _ = d
#endif
nubMerge :: Ord a => [a] -> [a] -> [a]
nubMerge [] ys = ys
nubMerge xs [] = xs
nubMerge (x:xs) (y:ys) | x < y = x : xs `nubMerge` (y:ys)
| x > y = y : (x:xs) `nubMerge` ys
| otherwise = x : xs `nubMerge` ys
nubMerges :: Ord a => [[a]] -> [a]
nubMerges = foldr nubMerge []
#else
errorNotGHC :: a
errorNotGHC = error "Only defined when using GHC"
deriveListable :: a
deriveListable = errorNotGHC
deriveListableIfNeeded :: a
deriveListableIfNeeded = errorNotGHC
deriveListableCascading :: a
deriveListableCascading = errorNotGHC
deriveTiers :: a
deriveTiers = errorNotGHC
deriveList :: a
deriveList = errorNotGHC
#endif