{-# LANGUAGE ScopedTypeVariables #-}
module Data.Singletons.Deriving.Functor where
import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Data.Singletons.Names
import Data.Singletons.Syntax
import Data.Singletons.Util
import Language.Haskell.TH.Desugar
mkFunctorInstance :: forall q. DsMonad q => DerivDesc q
mkFunctorInstance mb_ctxt ty dd@(DataDecl _ _ cons) = do
  functorLikeValidityChecks False dd
  f <- newUniqueName "_f"
  z <- newUniqueName "_z"
  let ft_fmap :: FFoldType (q DExp)
      ft_fmap = FT { ft_triv = mkSimpleLam pure
                     
                   , ft_var = pure $ DVarE f
                     
                   , ft_ty_app = \_ g -> DAppE (DVarE fmapName) <$> g
                     
                   , ft_forall = \_ g -> g
                   , ft_bad_app = error "in other argument in ft_fmap"
                   }
      ft_replace :: FFoldType (q Replacer)
      ft_replace = FT { ft_triv = fmap Nested    $ mkSimpleLam pure
                        
                      , ft_var  = fmap Immediate $ mkSimpleLam $ \_ -> pure $ DVarE z
                        
                      , ft_ty_app = \_ gm -> do
                          g <- gm
                          case g of
                            Nested g'   -> pure . Nested $ DVarE fmapName    `DAppE` g'
                            Immediate _ -> pure . Nested $ DVarE replaceName `DAppE` DVarE z
                        
                      , ft_forall  = \_ g -> g
                      , ft_bad_app = error "in other argument in ft_replace"
                      }
      
      clause_for_con :: [DPat] -> DCon -> [DExp] -> q DClause
      clause_for_con = mkSimpleConClause $ \con_name ->
        foldExp (DConE con_name) 
      mk_fmap_clause :: DCon -> q DClause
      mk_fmap_clause con = do
        parts <- foldDataConArgs ft_fmap con
        clause_for_con [DVarPa f] con =<< sequence parts
      mk_replace_clause :: DCon -> q DClause
      mk_replace_clause con = do
        parts <- foldDataConArgs ft_replace con
        clause_for_con [DVarPa z] con =<< traverse (fmap replace) parts
      mk_fmap :: q [DClause]
      mk_fmap = case cons of
                  [] -> do v <- newUniqueName "v"
                           pure [DClause [DWildPa, DVarPa v] (DCaseE (DVarE v) [])]
                  _  -> traverse mk_fmap_clause cons
      mk_replace :: q [DClause]
      mk_replace = case cons of
                     [] -> do v <- newUniqueName "v"
                              pure [DClause [DWildPa, DVarPa v] (DCaseE (DVarE v) [])]
                     _  -> traverse mk_replace_clause cons
  fmap_clauses    <- mk_fmap
  replace_clauses <- mk_replace
  constraints <- inferConstraintsDef mb_ctxt (DConPr functorName) ty cons
  return $ InstDecl { id_cxt = constraints
                    , id_name = functorName
                    , id_arg_tys = [ty]
                    , id_sigs  = mempty
                    , id_meths = [ (fmapName,    UFunction fmap_clauses)
                                 , (replaceName, UFunction replace_clauses)
                                 ] }
data Replacer = Immediate { replace :: DExp }
              | Nested    { replace :: DExp }