{-# LANGUAGE ScopedTypeVariables #-}
module Data.Singletons.Deriving.Functor where
import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Data.Singletons.Names
import Data.Singletons.Syntax
import Data.Singletons.Util
import Language.Haskell.TH.Desugar
mkFunctorInstance :: forall q. DsMonad q => DerivDesc q
mkFunctorInstance mb_ctxt ty dd@(DataDecl _ _ cons) = do
functorLikeValidityChecks False dd
f <- newUniqueName "_f"
z <- newUniqueName "_z"
let ft_fmap :: FFoldType (q DExp)
ft_fmap = FT { ft_triv = mkSimpleLam pure
, ft_var = pure $ DVarE f
, ft_ty_app = \_ g -> DAppE (DVarE fmapName) <$> g
, ft_forall = \_ g -> g
, ft_bad_app = error "in other argument in ft_fmap"
}
ft_replace :: FFoldType (q Replacer)
ft_replace = FT { ft_triv = fmap Nested $ mkSimpleLam pure
, ft_var = fmap Immediate $ mkSimpleLam $ \_ -> pure $ DVarE z
, ft_ty_app = \_ gm -> do
g <- gm
case g of
Nested g' -> pure . Nested $ DVarE fmapName `DAppE` g'
Immediate _ -> pure . Nested $ DVarE replaceName `DAppE` DVarE z
, ft_forall = \_ g -> g
, ft_bad_app = error "in other argument in ft_replace"
}
clause_for_con :: [DPat] -> DCon -> [DExp] -> q DClause
clause_for_con = mkSimpleConClause $ \con_name ->
foldExp (DConE con_name)
mk_fmap_clause :: DCon -> q DClause
mk_fmap_clause con = do
parts <- foldDataConArgs ft_fmap con
clause_for_con [DVarPa f] con =<< sequence parts
mk_replace_clause :: DCon -> q DClause
mk_replace_clause con = do
parts <- foldDataConArgs ft_replace con
clause_for_con [DVarPa z] con =<< traverse (fmap replace) parts
mk_fmap :: q [DClause]
mk_fmap = case cons of
[] -> do v <- newUniqueName "v"
pure [DClause [DWildPa, DVarPa v] (DCaseE (DVarE v) [])]
_ -> traverse mk_fmap_clause cons
mk_replace :: q [DClause]
mk_replace = case cons of
[] -> do v <- newUniqueName "v"
pure [DClause [DWildPa, DVarPa v] (DCaseE (DVarE v) [])]
_ -> traverse mk_replace_clause cons
fmap_clauses <- mk_fmap
replace_clauses <- mk_replace
constraints <- inferConstraintsDef mb_ctxt (DConPr functorName) ty cons
return $ InstDecl { id_cxt = constraints
, id_name = functorName
, id_arg_tys = [ty]
, id_sigs = mempty
, id_meths = [ (fmapName, UFunction fmap_clauses)
, (replaceName, UFunction replace_clauses)
] }
data Replacer = Immediate { replace :: DExp }
| Nested { replace :: DExp }