{-# LANGUAGE TypeOperators      #-}
{-# LANGUAGE DataKinds          #-}
{-# LANGUAGE DeriveFunctor      #-}
{-# LANGUAGE TupleSections      #-}
{-# LANGUAGE TemplateHaskell    #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE CPP                #-}
-- |This module provides some Template Haskell functionality to
-- help out the declaration of 'Deep' instances.
--
-- Note that we chose to not automate the whole process on purpose.
-- Sometimes the user will need to define standalone 'Generic'
-- instances for some select types in the family, some other times
-- the user might want better control over naming, for example.
-- Consequently, the most adaptable option is to provide
-- two TH utilities:
--
-- 1. Unfolding a family into a list of types until a fixpoint is reached,
-- given in 'unfoldFamilyInto'
-- 2. Declaring 'Deep' for a list of types, given in 'declareDeepFor'
--
-- The stepts in between unfolding the family and declaring 'Deep' vary
-- too much from case to case and hence, must be manually executed.
-- Let us run through  a simple example, which involves mutual
-- recursion and type synonyms in the AST of a pseudo-language.
--
-- > data Stmt var
-- >   = SAssign var (Exp var)
-- >   | SIf     (Exp var) (Stmt var) (Stmt var)
-- >   | SSeq    (Stmt var) (Stmt var)
-- >   | SReturn (Exp var)
-- >   | SDecl (Decl var)
-- >   | SSkip
-- >   deriving (Show, Generic)
-- >
-- > data ODecl var
-- >   = DVar var
-- >   | DFun var var (Stmt var)
-- >   deriving (Show, Generic)
-- >
-- > type Decl x = TDecl x
-- > type TDecl x = ODecl x
-- >
-- > data Exp var
-- >   = EVar  var
-- >   | ECall var (Exp var)
-- >   | EAdd (Exp var) (Exp var)
-- >   | ESub (Exp var) (Exp var)
-- >   | ELit Int
-- >   deriving (Show, Generic)
--
-- Now say we want to use some code written with /generics-simplistic/
-- over these datatypes above. We must declare the 'Deep'
-- instances for the types in the family and "GHC.Generics"
-- takes care of the rest.
--
-- The first step is in defining @Prim@ and @Fam@, which
-- will be type-level lists with the primitive types and the non-primitive,
-- or compound, types.
--
-- An easy way to gather /all/ types involved in the family is with
-- 'unfoldFamilyInto', like:
--
-- > unfoldFamilyInto "stmtFam" [t| Stmt Int |]
--
-- The call above will be expanded into:
--
-- > stmtFam :: [String]
-- > stmtFam = ["Generics.Simplistic.Example.Exp Int"
-- >           ,"Generics.Simplistic.Example.ODecl Int"
-- >           ,"Generics.Simplistic.Example.Stmt Int"
-- >           ,"Int"
-- >           ]
--
-- Which can then be inspected with GHCi and, with
-- some elbow-grease (or test-editting macros!) we can
-- easily generate the necessary type-level lists:
--
-- > type Fam = '[Generics.Simplistic.Example.Exp Int
-- >             ,Generics.Simplistic.Example.ODecl Int
-- >             ,Generics.Simplistic.Example.Stmt Int
-- >             ]
-- >
-- > type Prim = '[Int]
--
-- Finally, we are ready to call 'deriveDeepFor' and get
-- the instances declared.
--
-- > deriveDeepFor ''Prim ''Fam
--
-- The TH code above expands to:
--
-- > instance Deep Prim Fam (Exp Int)
-- > instance Deep Prim Fam (ODecl Int)
-- > instance Deep Prim Fam (Stmt Int)
--
-- This workflow is crucial to be able to work
-- with large mutually recursive families, and it becomes
-- especially easy if coupled with
-- a text editor with good macro support (read emacs and vim).
--
module Generics.Simplistic.Deep.TH
  ( unfoldFamilyInto
  , deriveDeepFor
  , deriveInstancesWith
  ) where

import Control.Monad.State
import Control.Arrow ((***))

import Language.Haskell.TH hiding (match)
import Language.Haskell.TH.Syntax hiding (lift)

import qualified Data.Set as S

import Generics.Simplistic.Deep

-- |Lists all the necessary types that should
-- have 'Generic' and 'Deep' instances. For example,
--
-- > data Rose2 a b = Fork (Either a b) [Rose2 a b]
-- > unfoldFamilyInto 'rose2tys [t| Rose2 Int Char |]
--
-- Will yield the following code:
--
-- > rose2tys :: String
-- > rose2tys = [ "Rose2 Int Char"
-- >            , "Either Int Char"
-- >            , "[Rose2 Int Char]"
-- >            , "Int"
-- >            , "Char"
-- >            ]
--
-- You should then use some elbow grease or your favorite text editor
-- and its provided macro functionality to produce:
--
-- > type Rose2Prim = '[Int , Char]
-- > type Rose2Fam  = '[Rose2 Int Char , Either Int Char , [Rose2 Int Char]]
-- > deriving instance Generic (Rose2 Int Char)
-- > deriving instance Generic (Either Int Char)
-- > instance Deep Rose2Prim Rose2Fam (Rose2 Int Char)
-- > instance Deep Rose2Prim Rose2Fam (Either Int Char)
-- > instance Deep Rose2Prim Rose2Fam [Rose2 Int Char]
--
-- Note that types like @Int@ will appear fully qualified,
-- this will need some renaming.
unfoldFamilyInto :: String -> Q Type -> Q [Dec]
unfoldFamilyInto n first = do
  ty <- first >>= convertType
  allTys <- S.toList <$> execStateT (process ty) S.empty
  listStr <- [t| [String] |]
  return [ SigD (mkName n) listStr
         , FunD (mkName n) [Clause [] (NormalB $ mkExp allTys) []]
         ]
 where
   mkExp :: [STy] -> Exp
   mkExp = ListE . map (LitE . StringL . show . ppr . trevnocType)

-- |Given two type-level lists @Prims@ and @Fam@, will generate
-- @instance Deep Prim Fam f@ for every @f@ in @Fam@.
deriveDeepFor :: Name -> Name -> Q [Dec]
deriveDeepFor pr fam =
  let qprim = return $ ConT pr
      qfam  = return $ ConT fam
   in deriveInstancesWith (\t -> [t| Deep $(qprim) $(qfam) $(return t) |]) fam

-- |Given a function @f@ and a type level stored in @fam@,
-- 'deriveInstacesWith' will generate:
--
-- > instance f x
--
-- for each @x@ in @fam@. This function is mostly internal,
-- please check 'deriveDeepFor' and 'deriveGenericFor'.
deriveInstancesWith :: (Type -> Q Type) -- ^ Instance to derive
                    -> Name -- ^ fam
                    -> Q [Dec]
deriveInstancesWith f fam = do
  tys <- getTypeLevelList fam
  forM tys $ \ty -> do
    instTy <- f ty
    return $ InstanceD Nothing [] instTy []


getTypeLevelList :: Name -> Q [Type]
getTypeLevelList x = do
  mtyDecl <- reifyDec x
  case mtyDecl of
    Nothing              -> fail ("Not a type declaration: " ++ show (ppr x))
    Just (TySynD _ _ ty) -> getTyLL ty
    Just d -> fail ("Not a type-level list: " ++ show (ppr x) ++ show (ppr d))
 where
   getTyLL :: Type -> Q [Type]
   getTyLL (SigT t _) = getTyLL t
   getTyLL PromotedNilT = return []
   getTyLL (AppT (AppT PromotedConsT a) as) = (a:) <$> getTyLL as
   getTyLL t = fail ("Not a type-level list: " ++ show (ppr x) ++ "; " ++ show t)

process :: STy -> StateT (S.Set STy) Q ()
process ty = do
  tys <- get
  if ty `S.member` tys
  then return ()
  else do
    let (tyHd , args) = styFlatten ty
    case tyHd of
      ConST tyName -> do
        tyDecl <- lift (reifyDec tyName)
        case tyDecl of
          Just dec -> processDecl dec args
          Nothing  -> return ()
      _ -> fail "Invalid type"

processDecl :: Dec -> [STy] -> StateT (S.Set STy) Q ()
processDecl (DataD _ tyName vars _ cons _) args = do
  modify (S.insert (styApp tyName args))
  let argVal = zip (map tyvarName vars) args
  mapM_ (processCon argVal) cons
processDecl (NewtypeD _ tyName vars _ con _) args = do
  modify (S.insert (styApp tyName args))
  let argVal = zip (map tyvarName vars) args
  processCon argVal con
processDecl (TySynD _ vars ty) args = do
  sty <- convertType ty
  let argVal = zip (map tyvarName vars) args
  process (styReduce argVal sty)
processDecl _ _
  = fail "unknown decl"

processCon :: [(Name , STy)] -> Con -> StateT (S.Set STy) Q ()
processCon argVal con = do
  fields <- mapM (fmap (styReduce argVal) . convertType) (conType con)
  mapM_ process fields

tyvarName :: TyVarBndr -> Name
tyvarName (PlainTV n) = n
tyvarName (KindedTV n _) = n

vbtyTy :: VarBangType -> Type
vbtyTy (_ , _ , t) = t

btyTy :: BangType -> Type
btyTy (_ , t) = t

conType :: Con -> [Type]
conType (NormalC _ btys)     = map btyTy btys
conType (RecC _ vbtys)       = map vbtyTy vbtys
conType (InfixC tyl _ tyr)   = map btyTy [tyl , tyr]
conType (ForallC _ _ c)      = conType c
conType (GadtC _ btys _)     = map btyTy btys
conType (RecGadtC _ vbtys _) = map vbtyTy vbtys

----------------------

data STy
  = AppST STy STy
  | VarST Name
  | ConST Name
  deriving (Eq , Show, Ord)

#if __GLASGOW_HASKELL__ >= 808
convertType :: (MonadFail m) => Type -> m STy
#else
convertType :: (Monad m) => Type -> m STy
#endif
convertType (AppT a b)  = AppST <$> convertType a <*> convertType b
convertType (SigT t _)  = convertType t
convertType (VarT n)    = return (VarST n)
convertType (ConT n)    = return (ConST n)
convertType (ParensT t) = convertType t
convertType ListT       = return (ConST (mkName "[]"))
convertType (TupleT n)  = return (ConST (mkName $ '(':replicate (n-1) ',' ++ ")"))
convertType t           = fail ("convertType: Unsupported Type: " ++ show t)

trevnocType :: STy -> Type
trevnocType (AppST a b) = AppT (trevnocType a) (trevnocType b)
trevnocType (VarST n)   = VarT n
trevnocType (ConST n)
  | n == mkName "[]" = ListT
  | isTupleN n       = TupleT $ length (show n) - 1
  | otherwise        = ConT n
  where isTupleN n0 = take 2 (show n0) == "(,"

-- |Handy substitution function.
--
--  @stySubst t m n@ substitutes m for n within t, that is: t[m/n]
stySubst :: STy -> Name -> STy -> STy
stySubst (AppST a b) m n = AppST (stySubst a m n) (stySubst b m n)
stySubst (ConST a)   _ _ = ConST a
stySubst (VarST x)   m n
  | x == m    = n
  | otherwise = VarST x

-- |Just like subst, but applies a list of substitutions
styReduce :: [(Name , STy)] -> STy -> STy
styReduce parms t = foldr (\(n , m) ty -> stySubst ty n m) t parms

-- |Flattens an application into a list of arguments;
--
--  @styFlatten (AppST (AppST Tree A) B) == (Tree , [A , B])@
styFlatten :: STy -> (STy , [STy])
styFlatten (AppST a b) = id *** (++ [b]) $ styFlatten a
styFlatten sty         = (sty , [])

styApp :: Name -> [STy] -> STy
styApp name args = go (ConST name) (reverse args)
  where go t [] = t
        go t (x:xs) = AppST (go t xs) x

-- * Parsing Haskell's AST

reifyDec :: Name -> Q (Maybe Dec)
reifyDec name =
  do info <- reify name
     case info of TyConI dec -> return (Just dec)
                  _          -> return Nothing