{-# LANGUAGE GeneralizedNewtypeDeriving, NoMonomorphismRestriction, FlexibleInstances, FlexibleContexts #-} module Language.Haskell.TH.Universe ( get_universe, -- ** Utils ... Not sure how to hide these. get_type_names, filter_dups', collect_new_dec_names, collect_dec_type_names, eval_state) where import Language.Haskell.TH import Language.Haskell.TH.Syntax hiding (lift) import Control.Monad.State import Control.Monad.Error import Control.Monad.Trans import Data.Generics.Uniplate.Data import Data.List import Data.Tuple.Select import Control.Applicative import Data.Composition import Control.Monad type Universe = [(Name, Dec)] type ErrorStateType m e s a = ErrorT e (StateT s m) a newtype ErrorStateT e s m a = ErrorStateT { runErrorStateT :: ErrorStateType m e s a } deriving (Monad, MonadState s, MonadError e, Functor, MonadPlus) instance MonadTrans (ErrorStateT String Universe) where lift = ErrorStateT . lift . lift type UniverseState a = ErrorStateT String Universe Q a type Result a = Either String a -- | Collect all the ancestor Dec's for whatever is passed in by name. -- For instance if we have -- -- > data Otherthing = Otherthing Float -- -- > data Thing = Thing OtherThing Int -- -- then -- -- > get_universe ''Thing -- -- would return the Dec's for Thing, OtherThing, Int and Float get_universe :: Name -> Q (Universe) get_universe = exec_state . create_universe_name create_universe_name :: Name -> UniverseState () create_universe_name name = do reify_result <- lift $ reify name case reify_result of ClassI dec _ -> create_universe_dec dec ClassOpI _ _ dec_name _ -> create_universe_name dec_name TyConI dec -> create_universe_dec dec (PrimTyConI _ _ _) -> return () DataConI _ _ dec_name _ -> create_universe_name dec_name VarI _ _ m_dec _ -> maybe (return ()) create_universe_dec m_dec TyVarI _ _ -> error "Don't know what a TyVarI is" -- | Collect all the ancestor Dec's for the given Dec create_universe_dec :: Dec -> UniverseState () create_universe_dec dec = mapM_ create_universe_name =<< collect_new_dec_names dec collect_new_dec_names :: (Monad m, MonadState Universe m, MonadError String m) => Dec -> m [Name] collect_new_dec_names dec = do dec_name <- get_dec_name dec modify ((dec_name, dec):) let type_names = collect_dec_type_names dec --look them up if haven't looked them up before filter_dups type_names collect_dec_type_names :: Dec -> [Name] collect_dec_type_names = concatMap get_type_names . nub . concatMap get_con_types . get_cons filter_dups' :: Eq a => [a] -> [(a, b)] -> [a] filter_dups' names uni = names \\ (fst $ unzip uni) filter_dups :: (Eq a, MonadState [(a, b)] m) =>[a] -> m [a] filter_dups names = gets (filter_dups' names) -------------------------------------------------------------------------------- --utility functions without a home get_cons :: Dec -> [Con] get_cons (NewtypeD _ _ _ con _) = [con] get_cons (DataD _ _ _ cons _) = cons get_cons (DataInstD _ _ _ cons _) = cons get_cons (NewtypeInstD _ _ _ con _) = [con] get_cons _ = [] get_con_types :: Con -> [Type] get_con_types (NormalC _ st) = map snd st get_con_types (RecC _ st) = map sel3 st get_con_types (InfixC x _ y) = map snd [x, y] get_con_types (ForallC _ _ con) = get_con_types con get_type_names :: Type -> [Name] get_type_names = map (from_right . get_type_name) . get_constr_types from_right :: Either a b -> b from_right (Right x) = x from_right _ = error "from_right" get_type_name' :: Type -> Name get_type_name' = from_right . get_type_name get_type_name :: Type -> Result Name get_type_name (ForallT _ _ typ) = get_type_name typ get_type_name (VarT n) = Right n get_type_name (ConT n) = Right n get_type_name x = Left ("No name for " ++ show x) get_constr_types :: Type -> [Type] get_constr_types = filter is_cont . universe is_cont :: Type -> Bool is_cont (ConT _) = True is_cont _ = False get_dec_name :: (MonadError String m, Monad m) => Dec -> m Name get_dec_name (FunD name _) = return name get_dec_name (ValD _ _ _) = throwError "InstanceD does not have a name" get_dec_name (DataD _ name _ _ _) = return name get_dec_name (NewtypeD _ name _ _ _) = return name get_dec_name (TySynD name _ _) = return name get_dec_name (ClassD _ name _ _ _ ) = return name get_dec_name (InstanceD _ _ _) = throwError "InstanceD does not have a name" get_dec_name (SigD name _) = return name get_dec_name (ForeignD _) = throwError "ForeignD does not have a name" get_dec_name (PragmaD _) = throwError "PragmaD does not have a name" get_dec_name (FamilyD _ name _ _) = return name get_dec_name (DataInstD _ name _ _ _) = return name exec_state :: Monad m => ErrorStateT e [a1] m a -> m [a1] exec_state x = execStateT (runErrorT (runErrorStateT x)) [] eval_state :: Monad m => ErrorStateT e [a1] m a -> m (Either e a) eval_state x = evalStateT (runErrorT (runErrorStateT x)) []