{-# LANGUAGE ScopedTypeVariables #-}
module Data.Singletons.Deriving.Util where
import Control.Monad
import Data.List
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.Set as Set
import Data.Singletons.Names
import Data.Singletons.Syntax
import Data.Singletons.Util
import Language.Haskell.TH.Desugar
import Language.Haskell.TH.Syntax
type DerivDesc q
= Maybe DCxt
-> DType
-> DataDecl
-> q UInstDecl
isNonVanillaDataType :: forall q. DsMonad q => DType -> [DCon] -> q Bool
isNonVanillaDataType data_ty = anyM $ \con@(DCon _ ctxt _ _ _) -> do
ex_tvbs <- conExistentialTvbs data_ty con
return $ not $ null ex_tvbs && null ctxt
where
anyM :: (a -> q Bool) -> [a] -> q Bool
anyM _ [] = return False
anyM p (x:xs) = do
b <- p x
if b then return True else anyM p xs
data FFoldType a
= FT { ft_triv :: a
, ft_var :: a
, ft_ty_app :: DType -> a -> a
, ft_bad_app :: a
, ft_forall :: [DTyVarBndr] -> a -> a
}
functorLikeTraverse :: forall q a.
DsMonad q
=> Name
-> FFoldType a
-> DType
-> q a
functorLikeTraverse var (FT { ft_triv = caseTrivial, ft_var = caseVar
, ft_ty_app = caseTyApp, ft_bad_app = caseWrongArg
, ft_forall = caseForAll })
ty
= do ty' <- expandType ty
(res, _) <- go ty'
pure res
where
go :: DType
-> q (a, Bool)
go (DAppT f x) = do
(_, fc) <- go f
if fc
then pure (caseWrongArg, True)
else do (xr, xc) <- go x
if xc
then let tyApp :: q (a, Bool)
tyApp = pure (caseTyApp f xr, True)
inspect :: DType -> q (a, Bool)
inspect (DConT n) = do
itf <- isTyFamilyName n
if itf
then pure (caseWrongArg, True)
else tyApp
inspect (DForallT _ _ t) = inspect t
inspect (DSigT t _) = inspect t
inspect (DAppT t _) = inspect t
inspect (DVarT {}) = tyApp
inspect DArrowT = tyApp
inspect (DLitT {}) = tyApp
inspect DWildCardT = tyApp
in case unfoldType f of
f_head :| _ -> inspect f_head
else trivial
go (DSigT t k) = do
(_, kc) <- go k
if kc
then pure (caseWrongArg, True)
else go t
go (DVarT v)
| v == var = pure (caseVar, True)
| otherwise = trivial
go (DForallT tvbs _ t) = do
(tr, tc) <- go t
if var `notElem` map extractTvbName tvbs && tc
then pure (caseForAll tvbs tr, True)
else trivial
go (DConT {}) = trivial
go DArrowT = trivial
go (DLitT {}) = trivial
go DWildCardT = trivial
trivial :: q (a, Bool)
trivial = pure (caseTrivial, False)
isTyFamilyName :: DsMonad q => Name -> q Bool
isTyFamilyName n = do
info <- dsReify n
pure $ case info of
Just (DTyConI dec _)
| DOpenTypeFamilyD{} <- dec -> True
| DClosedTypeFamilyD{} <- dec -> True
_ -> False
functorLikeValidityChecks :: forall q. DsMonad q => Bool -> DataDecl -> q ()
functorLikeValidityChecks allowConstrainedLastTyVar (DataDecl n data_tvbs cons)
| null data_tvbs
= fail $ "Data type " ++ nameBase n ++ " must have some type parameters"
| otherwise
= mapM_ check_con cons
where
check_con :: DCon -> q ()
check_con con = do
check_universal con
checks <- foldDataConArgs (ft_check (extractName con)) con
sequence_ checks
check_universal :: DCon -> q ()
check_universal con@(DCon con_tvbs con_theta con_name _ res_ty)
| allowConstrainedLastTyVar
= pure ()
| _ :| res_ty_args <- unfoldType res_ty
, (_, last_res_ty_arg) <- snocView res_ty_args
, Just last_tv <- getDVarTName_maybe last_res_ty_arg
= do ex_tvbs <- conExistentialTvbs (foldTypeTvbs (DConT n) data_tvbs) con
let univ_tvb_names = map extractTvbName con_tvbs \\ map extractTvbName ex_tvbs
if last_tv `elem` univ_tvb_names
&& last_tv `Set.notMember` foldMap (fvDType . predToType) con_theta
then pure ()
else fail $ badCon con_name existential
| otherwise
= fail $ badCon con_name existential
ft_check :: Name -> FFoldType (q ())
ft_check con_name =
FT { ft_triv = pure ()
, ft_var = pure ()
, ft_ty_app = \_ x -> x
, ft_bad_app = fail $ badCon con_name wrong_arg
, ft_forall = \_ x -> x
}
badCon :: Name -> String -> String
badCon con_name msg = "Constructor " ++ nameBase con_name ++ " " ++ msg
existential, wrong_arg :: String
existential = "must be truly polymorphic in the last argument of the data type"
wrong_arg = "must use the type variable only as the last argument of a data type"
deepSubtypesContaining :: DsMonad q => Name -> DType -> q [DType]
deepSubtypesContaining tv
= functorLikeTraverse tv
(FT { ft_triv = []
, ft_var = []
, ft_ty_app = (:)
, ft_bad_app = error "in other argument in deepSubtypesContaining"
, ft_forall = \tvbs xs -> filter (\x -> all (not_in_ty x) tvbs) xs })
where
not_in_ty :: DType -> DTyVarBndr -> Bool
not_in_ty ty tvb = extractTvbName tvb `Set.notMember` fvDType ty
foldDataConArgs :: forall q a. DsMonad q => FFoldType a -> DCon -> q [a]
foldDataConArgs ft (DCon _ _ _ fields res_ty) = do
field_tys <- traverse expandType $ tysOfConFields fields
traverse foldArg field_tys
where
foldArg :: DType -> q a
foldArg
| _ :| res_ty_args <- unfoldType res_ty
, (_, last_res_ty_arg) <- snocView res_ty_args
, Just last_tv <- getDVarTName_maybe last_res_ty_arg
= functorLikeTraverse last_tv ft
| otherwise
= const (return (ft_triv ft))
getDVarTName_maybe :: DType -> Maybe Name
getDVarTName_maybe (DSigT t _) = getDVarTName_maybe t
getDVarTName_maybe (DVarT n) = Just n
getDVarTName_maybe _ = Nothing
mkSimpleLam :: Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleLam lam = do
n <- newUniqueName "n"
body <- lam (DVarE n)
return $ DLamE [n] body
mkSimpleLam2 :: Quasi q => (DExp -> DExp -> q DExp) -> q DExp
mkSimpleLam2 lam = do
n1 <- newUniqueName "n1"
n2 <- newUniqueName "n2"
body <- lam (DVarE n1) (DVarE n2)
return $ DLamE [n1, n2] body
mkSimpleConClause :: Quasi q
=> (Name -> [DExp] -> DExp)
-> [DPat]
-> DCon
-> [DExp]
-> q DClause
mkSimpleConClause fold extra_pats (DCon _ _ con_name _ _) insides = do
vars_needed <- replicateM (length insides) $ newUniqueName "a"
let pat = DConPa con_name (map DVarPa vars_needed)
rhs = fold con_name (zipWith (\i v -> i `DAppE` DVarE v) insides vars_needed)
pure $ DClause (extra_pats ++ [pat]) rhs
isFunctorLikeClassName :: Name -> Bool
isFunctorLikeClassName class_name
= class_name `elem` [functorName, foldableName, traversableName]