module Test.FitSpec.Derive
( deriveMutable
, deriveMutableE
, module Test.FitSpec.Mutable
, module Test.FitSpec.ShowMutable
, module Test.LeanCheck
)
where
import Test.FitSpec.Mutable
import Test.FitSpec.ShowMutable
import Test.LeanCheck
import Language.Haskell.TH
import Control.Monad (when, unless, liftM, liftM2)
#if __GLASGOW_HASKELL__ < 706
reportWarning :: String -> Q ()
reportWarning = report False
#endif
deriveListableIfNeeded :: Name -> DecsQ
deriveListableIfNeeded t = do
is <- t `isInstanceOf` ''Listable
if is
then return []
else deriveListable t
deriveMutable :: Name -> DecsQ
deriveMutable = deriveMutableE []
deriveMutableE :: [Name] -> Name -> DecsQ
deriveMutableE cs t = do
is <- t `isInstanceOf` ''Mutable
if is
then do
reportWarning $ "Instance Mutable " ++ show t
++ " already exists, skipping derivation"
return []
else do
cd <- canDeriveMutable t
unless cd (fail $ "Unable to derive Mutable " ++ show t)
liftM2 (++) (deriveListableIfNeeded t) (reallyDeriveMutable cs t)
canDeriveMutable :: Name -> Q Bool
canDeriveMutable t = (t `isInstanceOf` ''Eq)
&&& (t `isInstanceOf` ''Show)
where (&&&) = liftM2 (&&)
reallyDeriveMutable :: [Name] -> Name -> DecsQ
reallyDeriveMutable cs t = do
(nt,vs) <- normalizeType t
#if __GLASGOW_HASKELL__ >= 710
cxt <- sequence [ [t| $(conT c) $(return v) |]
#else
cxt <- sequence [ classP c [return v]
#endif
| v <- vs, c <- ''Eq:''Listable:''Show:cs ]
#if __GLASGOW_HASKELL__ >= 708
cxt |=>| [d| instance Mutable $(return nt)
where mutiers = mutiersEq
instance ShowMutable $(return nt)
where mutantS = mutantSEq |]
#else
return [ InstanceD
cxt
(AppT (ConT ''Mutable) nt)
[ValD (VarP 'mutiers) (NormalB (VarE 'mutiersEq)) []]
, InstanceD
cxt
(AppT (ConT ''ShowMutable) nt)
[ValD (VarP 'mutantS) (NormalB (VarE 'mutantSEq)) []]
]
#endif
normalizeType :: Name -> Q (Type, [Type])
normalizeType t = do
ar <- typeArity t
vs <- newVarTs ar
return (foldl AppT (ConT t) vs, vs)
where
newNames :: [String] -> Q [Name]
newNames = mapM newName
newVarTs :: Int -> Q [Type]
newVarTs n = liftM (map VarT)
$ newNames (take n . map (:[]) $ cycle ['a'..'z'])
normalizeTypeUnits :: Name -> Q Type
normalizeTypeUnits t = do
ar <- typeArity t
return (foldl AppT (ConT t) (replicate ar (TupleT 0)))
isInstanceOf :: Name -> Name -> Q Bool
isInstanceOf tn cl = do
ty <- normalizeTypeUnits tn
isInstance cl [ty]
typeArity :: Name -> Q Int
typeArity t = do
ti <- reify t
return . length $ case ti of
#if __GLASGOW_HASKELL__ < 800
TyConI (DataD _ _ ks _ _) -> ks
TyConI (NewtypeD _ _ ks _ _) -> ks
#else
TyConI (DataD _ _ ks _ _ _) -> ks
TyConI (NewtypeD _ _ ks _ _ _) -> ks
#endif
_ -> error $ "error (arity): symbol "
++ show t
++ " is not a newtype or data"
(|=>|) :: Cxt -> DecsQ -> DecsQ
c |=>| qds = do ds <- qds
return $ map (`ac` c) ds
#if __GLASGOW_HASKELL__ < 800
where ac (InstanceD c ts ds) c' = InstanceD (c++c') ts ds
ac d _ = d
#else
where ac (InstanceD o c ts ds) c' = InstanceD o (c++c') ts ds
ac d _ = d
#endif