{-# LANGUAGE ScopedTypeVariables #-}
module Data.Singletons.Deriving.Traversable where
import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Data.Singletons.Names
import Data.Singletons.Syntax
import Language.Haskell.TH.Desugar
mkTraversableInstance :: forall q. DsMonad q => DerivDesc q
mkTraversableInstance mb_ctxt ty dd@(DataDecl _ _ cons) = do
  functorLikeValidityChecks False dd
  f <- newUniqueName "_f"
  let ft_trav :: FFoldType (q DExp)
      ft_trav = FT { ft_triv = pure $ DVarE pureName
                     
                   , ft_var = pure $ DVarE f
                     
                   , ft_ty_app = \_ g -> DAppE (DVarE traverseName) <$> g
                     
                   , ft_forall = \_ g -> g
                   , ft_bad_app = error "in other argument in ft_trav"
                   }
      
      clause_for_con :: [DPat] -> DCon -> [DExp] -> q DClause
      clause_for_con = mkSimpleConClause $ \con_name -> mkApCon (DConE con_name)
        where
          
          mkApCon :: DExp -> [DExp] -> DExp
          mkApCon con []  = DVarE pureName `DAppE` con
          mkApCon con [x] = DVarE fmapName `DAppE` con `DAppE` x
          mkApCon con (x1:x2:xs) =
              foldl appAp (DVarE liftA2Name `DAppE` con `DAppE` x1 `DAppE` x2) xs
            where appAp x y = DVarE apName `DAppE` x `DAppE` y
      mk_trav_clause :: DCon -> q DClause
      mk_trav_clause con = do
        parts <- foldDataConArgs ft_trav con
        clause_for_con [DVarPa f] con =<< sequence parts
      mk_trav :: q [DClause]
      mk_trav = case cons of
                  [] -> do v <- newUniqueName "v"
                           pure [DClause [DWildPa, DVarPa v]
                                         (DVarE pureName `DAppE` DCaseE (DVarE v) [])]
                  _  -> traverse mk_trav_clause cons
  trav_clauses <- mk_trav
  constraints <- inferConstraintsDef mb_ctxt (DConPr traversableName) ty cons
  return $ InstDecl { id_cxt = constraints
                    , id_name = traversableName
                    , id_arg_tys = [ty]
                    , id_sigs  = mempty
                    , id_meths = [ (traverseName, UFunction trav_clauses) ] }