{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE MonomorphismRestriction #-} {-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS -fwarn-missing-signatures #-} {-| Example usage: @ import Generics.MultiRec import Generics.MultiRec.TH.Alt import Data.Tree data TheFam :: (* -> *) where TreeIntPrf :: TheFam ('Tree' Int) ForestIntPrf :: TheFam ('Forest' Int) $('deriveEverything' ('DerivOptions' { 'familyTypes' = [ ( [t| 'Tree' Int |], \"TreeIntPrf\" ), ( [t| 'Forest' Int |], \"ForestIntPrf\" ) ], 'indexGadtName' = \"TheFam\", 'constructorNameModifier' = defaultConstructorNameModifier, 'patternFunctorName' = \"ThePF\", 'verbose' = True, 'sumMode' = 'Balanced' ) ) type instance 'PF' TheFam = ThePF @ -} module Generics.MultiRec.TH.Alt ( DerivOptions(..), SumMode(..), defaultConstructorNameModifier, deriveEverything, module Generics.MultiRec.Base ) where import Generics.MultiRec.TH.Alt.DerivOptions import THUtils(AppliedTyCon, (@@), (@@@), toAppliedTyCon, fromAppliedTyCon, atc2constructors, pprintUnqual, sMatch, sClause, cleanConstructorName) import BalancedFold(balancedFold, ascendFromLeaf) import MonadRQ import Generics.MultiRec.Base import Generics.MultiRec.Constructor(Associativity(..), Fixity(..), Constructor(..)) import Control.Monad.Reader(Monad(return, fail, (>>)), Functor(..), (=<<), mapM, sequence, liftM, zipWithM, asks) import Language.Haskell.TH.Syntax(Lift(..)) import Language.Haskell.TH(newName, mkName, wildP, clause, conE, appE, normalB, funD, dataD, instanceD, cxt, conT, appT, Exp(VarE, SigE, LamE, ConE, CaseE, AppE), Match, Clause, Q, Pat(WildP, VarP, ConP), TypeQ, Type(ConT), Dec(TySynD, InstanceD, FunD), Name, Con(RecC, NormalC, InfixC), FixityDirection(..), Info(DataConI), nameBase, reify, stringE) import Data.Map(lookup, elems) import Control.Applicative((<$>)) import qualified Data.Map as Map import qualified Language.Haskell.TH as TH -- | Main function. deriveEverything :: DerivOptions -> Q [Dec] deriveEverything opts = do -- let x | mkSanityChecks opts = makeSanityChecks -- | otherwise = return [] runRQ (concat <$> sequence [deriveConstructors, deriveFamily]) opts -- | Given a list of datatype names, derive datatypes and -- instances of class 'Constructor'. deriveConstructors :: RQ [Dec] deriveConstructors = concat <$> foreachType constrInstance -- | Given the name of the index GADT, the names of the -- types in the family, and the name (as string) for the -- pattern functor to derive, generate the 'Ix' and 'PF' -- instances. /IMPORTANT/: It is assumed that the constructors -- of the GADT have the same names as the datatypes in the -- family. deriveFamily :: RQ [Dec] deriveFamily = do pf <- derivePF el <- deriveEl fam <- deriveFam eq <- deriveEqS return $ pf ++ el ++ fam ++ eq -- | Derive only the 'PF' instance. Not needed if 'deriveFamily' -- is used. derivePF :: RQ [Dec] derivePF = do branches <- foreachType pfType pfn <- asks (patternFunctorName . derivOptions) sumT <- askSumT let pf = [TySynD (mkName pfn) [] (sumT branches)] famName <- asks (indexGadtName . derivOptions) -- message $ -- ( "*** The pattern functor is:\n" -- ++ pprint (cutNames pf) -- ++ "\n\n\n" -- ) messageReport ( "Reminder: Don't forget to add this line manually:\n" ++ " type instance PF "++famName++" = "++pfn ) return pf askSumT :: RQ( [Type] -> Type ) askSumT = do m <- askSumMode return (case m of Balanced -> balancedSumT RightNested -> rightSumT ) rightSumT :: [Type] -> Type rightSumT = foldr1 plusT balancedSumT :: [Type] -> Type balancedSumT = balancedFold plusT plusT :: Type -> Type -> Type plusT a b = ConT ''(:+:) @@ a @@ b prodT :: [Type] -> Type prodT = foldr1 timesT timesT :: Type -> Type -> Type timesT a b = ConT ''(:*:) @@ a @@ b -- | Derive only the 'El' instances. Not needed if 'deriveFamily' -- is used. deriveEl :: RQ [Dec] deriveEl = foreachType elInstance indexGadtType :: RQ Type indexGadtType = ConT . mkName <$> asks (indexGadtName . derivOptions) -- | Dervie only the 'Fam' instance. Not needed if 'deriveFamily' -- is used. deriveFam :: RQ [Dec] deriveFam = do fcs <- liftM concat $ foreachTypeNumbered mkFrom tcs <- foreachTypeNumbered mkTo s <- indexGadtType return [ InstanceD [] (ConT ''Fam @@ s) [FunD 'from fcs, FunD 'to tcs] ] -- | Derive only the 'EqS' instance. Not needed if 'deriveFamily' -- is used. deriveEqS :: RQ [Dec] deriveEqS = do s <- indexGadtType ns <- elems <$> asks familyTypesMap return [ InstanceD [] (ConT ''EqS @@ s) [FunD 'eqS (trues ns ++ falses ns)] ] where trueClause n = sClause [ConP (mkName n) [], ConP (mkName n) []] ((ConE 'Just `AppE` ConE 'Refl)) falseClause = sClause [WildP, WildP] ((ConE 'Nothing)) trues ns = fmap trueClause ns falses ns = if length (trues ns) == 1 then [] else [falseClause] constrInstance :: (AppliedTyCon,String) -> RQ [Dec] constrInstance (atc,s) = do cs <- liftq (atc2constructors atc) -- runIO (print i) ds <- mapM (mkData s) cs is <- mapM (mkInstance s) cs return $ ds ++ is stripRecordNames :: Con -> Con stripRecordNames (RecC n f) = NormalC n (fmap (\(_, s, t) -> (s, t)) f) stripRecordNames c = c -- TODO: Handle colons in the constructor name mkData :: String -> Con -> RQ Dec mkData s (NormalC n _) = do modifier <- askConstructorNameModifier liftq $ dataD (cxt []) (mkName . modifier s . cleanConstructorName . nameBase $ n) [] [] [] mkData s r@(RecC _ _) = mkData s (stripRecordNames r) mkData s (InfixC t1 n t2) = mkData s (NormalC n [t1,t2]) instance Lift Fixity where lift Prefix = conE 'Prefix lift (Infix a n) = conE 'Infix `appE` [| a |] `appE` [| n |] instance Lift Associativity where lift LeftAssociative = conE 'LeftAssociative lift RightAssociative = conE 'RightAssociative lift NotAssociative = conE 'NotAssociative mkInstance :: String -> Con -> RQ Dec mkInstance s (NormalC n _) = do modifier <- askConstructorNameModifier let n' = modifier s . cleanConstructorName . nameBase $ n liftq $ instanceD (cxt []) (appT (conT ''Constructor) (conT . mkName $ n')) [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []]] mkInstance s r@(RecC _ _) = mkInstance s (stripRecordNames r) mkInstance s (InfixC t1 n t2) = do modifier <- askConstructorNameModifier let n' = modifier s . cleanConstructorName . nameBase $ n i <- liftq (reify n) let fi = case i of DataConI _ _ _ f -> convertFixity f _ -> Prefix liftq $ instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName n')) [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []], funD 'conFixity [clause [wildP] (normalB [| fi |]) []]] where convertFixity (TH.Fixity n d) = Infix (convertDirection d) n convertDirection InfixL = LeftAssociative convertDirection InfixR = RightAssociative convertDirection InfixN = NotAssociative pfType :: (AppliedTyCon,String) -> RQ Type pfType (atc,s) = do -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) guardEmptyData cs atc sumT <- askSumT b <- sumT <$> mapM (pfCon s) cs return $ ConT ''(:>:) @@ b @@ fromAppliedTyCon atc pfCon :: String -> Con -> RQ Type pfCon s (NormalC n fs) = do modifier <- askConstructorNameModifier let n' = mkName . modifier s . cleanConstructorName . nameBase $ n fieldResults <- mapM (pfField . snd) fs let rest = case fs of [] -> ConT ''U _ -> prodT fieldResults return $ ConT ''C @@ ConT n' @@ rest pfCon s r@(RecC _ _) = pfCon s (stripRecordNames r) pfCon s (InfixC t1 n t2) = pfCon s (NormalC n [t1,t2]) pfField :: Type -> RQ Type pfField t = ifInFamily t (ConT ''I @@ t) (ConT ''K @@ t) lookupFam :: Type -> RQ (Maybe String) lookupFam t = do ts <- asks familyTypesMap t' <- liftq $ toAppliedTyCon t let res = case t' of Right t'' -> Map.lookup t'' ts Left _ -> Nothing -- message ("familyTypes = "++show ts) -- message ("lookupFam "++show t'++" = "++show res) return res ifInFamily :: Type -> a -> a -> RQ a ifInFamily n x y = ifInFamily' n (return x) (return y) ifInFamily' :: Type -> RQ a -> RQ a -> RQ a ifInFamily' t x y = maybe y (const x) =<< lookupFam t elInstance :: (AppliedTyCon,String) -> RQ Dec elInstance x@(atc,_) = do s <- indexGadtType prf <- mkProof x return $ InstanceD [] (ConT ''El @@ s @@ fromAppliedTyCon atc) [prf] mkFrom :: Int -> Int -> (AppliedTyCon,String) -> RQ [Clause] mkFrom m i (atc,s) = do -- ns <- fmap mkName . elems <$> asks familyTypes -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) lrE <- ask_lrE let wrapE = (\e -> lrE m i (ConE 'Tag @@@ e)) dn = mkName s -- (nameBase n) zipWithM (fromCon wrapE dn (length cs)) [0..] cs mkTo :: Int -> Int -> (AppliedTyCon,String) -> RQ Clause mkTo m i (atc,s) = do -- ns <- fmap mkName . elems <$> asks familyTypes -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) pfname <- mkName <$> asks (patternFunctorName . derivOptions) let -- typeOfLamE = ArrowT @@ -- (ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc) @@ -- (fromAppliedTyCon atc) matchesOfCons <- zipWithM (toCon (length cs) atc) [0..] cs xvar <- liftq (newName "x") convar <- liftq (newName "con") typeOfConvar <- do t0 <- pfType (atc,s) return (t0 @@ ConT ''I0 @@ fromAppliedTyCon atc) lrP <- ask_lrP let typeOfXvar = ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc body = LamE [VarP xvar] (CaseE (VarE xvar `SigE` typeOfXvar) [sMatch (lrP m i (VarP convar)) (CaseE (VarE convar `SigE` typeOfConvar) matchesOfCons) ] ) return (sClause [ConP (mkName s) []] body ) mkProof :: (AppliedTyCon,String) -> RQ Dec mkProof (_,s) = return $ FunD 'proof [sClause [] (ConE (mkName s)) ] fromCon :: (Exp -> Exp) -> Name -> Int -> Int -> Con -> RQ Clause fromCon wrap n m i (NormalC cn []) = do lrE <- ask_lrE -- Nullary constructor case return $ sClause [ConP n [], ConP cn []] (wrap . lrE m i $ ConE 'C @@@ ConE 'U) fromCon wrap n m i (NormalC cn fs) = do lrE <- ask_lrE rhs <- zipWithM fromField [0..] (snd <$> fs) return $ sClause [ ConP n [], ConP cn (fmap (VarP . field) [0..length fs - 1]) ] (wrap . lrE m i $ ConE 'C @@@ foldr1 prod rhs) where prod x y = ConE '(:*:) @@@ x @@@ y fromCon wrap n m i r@(RecC _ _) = fromCon wrap n m i (stripRecordNames r) fromCon wrap n m i (InfixC t1 cn t2) = fromCon wrap n m i (NormalC cn [t1,t2]) toCon :: Int -- ^ Number of constructors -> AppliedTyCon -> Int -- ^ Index of this constructor -> Con -> RQ Match toCon m atc i (NormalC cn []) = do lrP <- ask_lrP -- Nullary constructor case return $ sMatch (ConP 'Tag [lrP m i $ ConP 'C [ConP 'U []]]) ( ConE cn -- SigE (ConE cn) (fromAppliedTyCon atc) ) toCon m atc i (NormalC cn fs) = -- runIO (putStrLn ("constructor " ++ show ix)) >> do lrP <- ask_lrP lhs <- zipWithM toField [0..] (fmap snd fs) return $ sMatch (ConP 'Tag [lrP m i $ ConP 'C [foldr1 prod lhs]]) ( -- SigE ( foldl AppE (ConE cn) (fmap (VarE . field) [0..length fs - 1]) -- ) -- (fromAppliedTyCon atc) ) where prod x y = ConP '(:*:) [x,y] toCon m atc i r@(RecC _ _) = toCon m atc i (stripRecordNames r) toCon m atc i (InfixC t1 cn t2) = toCon m atc i (NormalC cn [t1,t2]) fromField :: Int -> Type -> RQ Exp fromField nr t = ifInFamily' t (return (ConE 'I @@@ (ConE 'I0 @@@ VarE (field nr)))) (message ("* Info: Type not in family: " ++ pprintUnqual t) >> -- helper t >> return (ConE 'K @@@ VarE (field nr))) toField :: Int -> Type -> RQ Pat toField nr t = ifInFamily t (ConP 'I [ConP 'I0 [VarP (field nr)]]) (ConP 'K [VarP (field nr)]) field :: Int -> Name field n = mkName $ "f" ++ show n ask_lrP :: RQ (Int -> Int -> ( Pat -> Pat)) ask_lrP = do m <- askSumMode return (case m of Balanced -> lrP_balanced RightNested -> lrP_rightNested ) lrP_balanced :: (Int -> Int -> ( Pat -> Pat)) lrP_balanced m i p = ascendFromLeaf (ConP 'L . (:[] {- robot monkey -})) (ConP 'R . (:[])) p m i lrP_rightNested :: (Int -> Int -> ( Pat -> Pat)) lrP_rightNested 1 0 p = p lrP_rightNested m 0 p = ConP 'L [p] lrP_rightNested m i p = ConP 'R [lrP_rightNested (m-1) (i-1) p] ask_lrE :: RQ (Int -> Int -> ( Exp -> Exp)) ask_lrE = do m <- askSumMode return (case m of Balanced -> lrE_balanced RightNested -> lrE_rightNested ) lrE_balanced :: (Int -> Int -> ( Exp -> Exp)) lrE_balanced m i e = ascendFromLeaf (ConE 'L @@@) (ConE 'R @@@) e m i lrE_rightNested :: (Int -> Int -> ( Exp -> Exp)) lrE_rightNested 1 0 e = e lrE_rightNested m 0 e = ConE 'L @@@ e lrE_rightNested m i e = ConE 'R @@@ lrE_rightNested (m-1) (i-1) e guardEmptyData :: [Con] -> AppliedTyCon -> RQ () guardEmptyData [] atc = fail ("Empty types not supported yet ("++ show (fromAppliedTyCon atc)) guardEmptyData _ atc = return () -- helper t = do -- Right (AppliedTyCon n args) <- liftq (toAppliedTyCon t) -- let prefix = "Prf_" -- str <- if n == ''[] -- then do -- Right (AppliedTyCon n1 _) <- liftq (toAppliedTyCon (head args)) -- return ("T("++prefix++"List"++nameBase n1 -- ++",["++pprintUnqual (head args)++"])") -- else -- return ("T("++prefix++nameBase n -- ++","++pprintUnqual t++")") -- liftq . runIO $ appendFile "dump.dump" (str++"\n") noSigE :: Exp -> Type -> Exp x `noSigE` y = x -- makeSanityChecks :: RQ [Dec] -- makeSanityChecks = concat <$> foreachType makeSanityCheck -- makeSanityCheck :: (AppliedTyCon,String) -> RQ [Dec] -- makeSanityCheck (atc,s) = do -- famname <- mkName <$> asks indexGadtName -- let -- chkName = mkName ("sanityCheck"++s) -- return [ -- SigD chkName (ConT famname @@ fromAppliedTyCon atc) -- , ValD (VarP chkName) -- (NormalB (ConE (mkName s))) -- [] -- ]