{-# LANGUAGE TemplateHaskell #-}
module Test.StrictCheck.TH
( derivePatternSynonyms
) where
import Generics.SOP (NP(..), NS(..))
import Test.StrictCheck.Demand
import Test.StrictCheck.Shaped
import Control.Monad (when)
import Language.Haskell.TH
patternTypeDec :: [Type] -> Type -> Type
patternTypeDec [] ty = AppT (ConT ''Demand) ty
patternTypeDec (arg:args) ty = AppT (AppT ArrowT $ AppT (ConT ''Demand) arg)
(patternTypeDec args ty)
prefixPatternDec :: Int -> Name -> [Name] -> Pat -> Dec
prefixPatternDec idx patName binderNames npPat =
PatSynD patName
(PrefixPatSyn binderNames)
ImplBidir
(ConP 'Wrap [ConP 'Eval [ConP 'GS [sumPattern idx npPat]]])
infixPatternDec :: Int
-> Name
-> Name -> Name
-> Pat
-> Dec
infixPatternDec idx patName lhsBinder rhsBinder npPat =
PatSynD patName
(InfixPatSyn lhsBinder rhsBinder)
ImplBidir
(ConP 'Wrap [ConP 'Eval [ConP 'GS [sumPattern idx npPat]]])
sumPattern :: Int -> Pat -> Pat
sumPattern idx p | idx <= 0 = ConP 'Z [p]
| otherwise = ConP 'S [sumPattern (idx-1) p]
productPattern :: [Type] -> Q (Pat, [Name])
productPattern [] = return (ConP 'Nil [], [])
productPattern (_:args) = do
(tailPat, names) <- productPattern args
freshName <- newName "x"
return (InfixP (VarP freshName) '(:*) tailPat, freshName : names)
constructor2PatternDec :: Type -> Int -> Con -> Q (Dec, Dec)
constructor2PatternDec ty idx (NormalC conName argTypes) = do
(npPat, names) <- productPattern (map snd argTypes)
return (PatSynSigD patDecName (patternTypeDec (map snd argTypes) ty),
prefixPatternDec idx patDecName names npPat)
where patDecName = mkName (nameBase conName ++ "'")
constructor2PatternDec ty idx (InfixC argType1 conName argType2) = do
let argTypes = [argType1, argType2]
(npPat, names) <- productPattern (map snd argTypes)
when (length names /= 2) $
reportError "The impossible happened: Infix Pattern have more than 2 binders"
let nm1:nm2:_ = names
return (PatSynSigD patDecName (patternTypeDec (map snd argTypes) ty),
infixPatternDec idx patDecName nm1 nm2 npPat)
where patDecName = mkName (nameBase conName ++ "%")
constructor2PatternDec _ _ _ =
fail "Test.StrictCheck.TH cannot derive pattern synonyms for fancy types"
derivePatternSynonyms :: Name -> Q [Dec]
derivePatternSynonyms name = do
nameInfo <- reify name
case nameInfo of
TyConI (DataD _ tyName tyVars _ constrs _) -> do
let tyVarTypes = map (\tyVar -> case tyVar of
PlainTV nm -> VarT nm
KindedTV nm kd -> SigT (VarT nm) kd
)
tyVars
ty = foldl AppT (ConT tyName) tyVarTypes
decs <- mapM (uncurry (constructor2PatternDec ty)) (zip [0..] constrs)
return $ (map fst decs) ++ (map snd decs)
_ -> do
reportError (show name ++ " is not a data type name")
return []