{-# LANGUAGE TypeSynonymInstances, FlexibleInstances, RankNTypes,
             TemplateHaskell, GeneralizedNewtypeDeriving,
             MultiParamTypeClasses, StandaloneDeriving,
             UndecidableInstances, MagicHash, UnboxedTuples,
             LambdaCase, NoMonomorphismRestriction #-}
module Data.Singletons.Util where
import Prelude hiding ( exp, foldl, concat, mapM, any, pred )
import Language.Haskell.TH.Syntax hiding ( lift )
import Language.Haskell.TH.Desugar
import Data.Char
import Control.Monad hiding ( mapM )
import Control.Monad.Writer hiding ( mapM )
import Control.Monad.Reader hiding ( mapM )
import qualified Data.Map as Map
import Data.List.NonEmpty (NonEmpty(..))
import Data.Map ( Map )
import qualified Data.Monoid as Monoid
import Data.Semigroup as Semigroup
import Data.Foldable
import Data.Functor.Identity
import Data.Traversable
import Data.Generics
import Data.Maybe
import Data.Void
import Control.Monad.Fail ( MonadFail )
basicTypes :: [Name]
basicTypes = [ ''Maybe
             , ''[]
             , ''Either
             , ''NonEmpty
             , ''Void
             ] ++ boundedBasicTypes
boundedBasicTypes :: [Name]
boundedBasicTypes =
            [ ''(,)
            , ''(,,)
            , ''(,,,)
            , ''(,,,,)
            , ''(,,,,,)
            , ''(,,,,,,)
            , ''Identity
            ] ++ enumBasicTypes
enumBasicTypes :: [Name]
enumBasicTypes = [ ''Bool, ''Ordering, ''() ]
semigroupBasicTypes :: [Name]
semigroupBasicTypes
  = [ ''Dual
    , ''All
    , ''Any
    , ''Sum
    , ''Product
    
    
    , ''Min
    , ''Max
    , ''Semigroup.First
    , ''Semigroup.Last
    , ''WrappedMonoid
    ]
monoidBasicTypes :: [Name]
monoidBasicTypes
  = [ ''Monoid.First
    , ''Monoid.Last
    ]
qReportWarning :: Quasi q => String -> q ()
qReportWarning = qReport False
qReportError :: Quasi q => String -> q ()
qReportError = qReport True
qNewUnique :: DsMonad q => q Int
qNewUnique = do
  Name _ flav <- qNewName "x"
  case flav of
    NameU n -> return n
    _       -> error "Internal error: `qNewName` didn't return a NameU"
checkForRep :: Quasi q => [Name] -> q ()
checkForRep names =
  when (any ((== "Rep") . nameBase) names)
    (fail $ "A data type named <<Rep>> is a special case.\n" ++
            "Promoting it will not work as expected.\n" ++
            "Please choose another name for your data type.")
checkForRepInDecls :: Quasi q => [DDec] -> q ()
checkForRepInDecls decls =
  checkForRep (allNamesIn decls)
tysOfConFields :: DConFields -> [DType]
tysOfConFields (DNormalC _ stys) = map snd stys
tysOfConFields (DRecC vstys)   = map (\(_,_,ty) -> ty) vstys
extractNameArgs :: DCon -> (Name, Int)
extractNameArgs = liftSnd length . extractNameTypes
extractNameTypes :: DCon -> (Name, [DType])
extractNameTypes (DCon _ _ n fields _) = (n, tysOfConFields fields)
extractName :: DCon -> Name
extractName (DCon _ _ n _ _) = n
isInfixDataCon :: String -> Bool
isInfixDataCon (':':_) = True
isInfixDataCon _       = False
isDataConName :: Name -> Bool
isDataConName n = let first = head (nameBase n) in isUpper first || first == ':'
isUpcase :: Name -> Bool
isUpcase n = let first = head (nameBase n) in isUpper first
upcase :: Name -> Name
upcase = mkName . toUpcaseStr noPrefix
toUpcaseStr :: (String, String)  
            -> Name -> String
toUpcaseStr (alpha, symb) n
  | isHsLetter first
  = upcase_alpha
  | otherwise
  = upcase_symb
  where
    str   = nameBase n
    first = head str
    upcase_alpha = alpha ++ (toUpper first) : tail str
    upcase_symb = symb ++ str
noPrefix :: (String, String)
noPrefix = ("", "")
prefixConName :: String -> String -> Name -> Name
prefixConName pre tyPre n = case (nameBase n) of
    (':' : rest) -> mkName (':' : tyPre ++ rest)
    alpha -> mkName (pre ++ alpha)
prefixName :: String -> String -> Name -> Name
prefixName pre tyPre n =
  let str = nameBase n
      first = head str in
    if isHsLetter first
     then mkName (pre ++ str)
     else mkName (tyPre ++ str)
suffixName :: String -> String -> Name -> Name
suffixName ident symb n =
  let str = nameBase n
      first = head str in
  if isHsLetter first
  then mkName (str ++ ident)
  else mkName (str ++ symb)
uniquePrefixes :: String   
               -> String   
               -> Int
               -> (String, String)  
uniquePrefixes alpha symb n = (alpha ++ n_str, symb ++ convert n_str)
  where
    n_str = show n
    convert [] = []
    convert (d : ds) =
      let d' = case d of
                 '0' -> '!'
                 '1' -> '#'
                 '2' -> '$'
                 '3' -> '%'
                 '4' -> '&'
                 '5' -> '*'
                 '6' -> '+'
                 '7' -> '.'
                 '8' -> '/'
                 '9' -> '>'
                 _   -> error "non-digit in show #"
      in d' : convert ds
extractTvbKind :: DTyVarBndr -> Maybe DKind
extractTvbKind (DPlainTV _) = Nothing
extractTvbKind (DKindedTV _ k) = Just k
extractTvbName :: DTyVarBndr -> Name
extractTvbName (DPlainTV n) = n
extractTvbName (DKindedTV n _) = n
tvbToType :: DTyVarBndr -> DType
tvbToType = DVarT . extractTvbName
inferMaybeKindTV :: Name -> Maybe DKind -> DTyVarBndr
inferMaybeKindTV n Nothing =  DPlainTV n
inferMaybeKindTV n (Just k) = DKindedTV n k
resultSigToMaybeKind :: DFamilyResultSig -> Maybe DKind
resultSigToMaybeKind DNoSig                      = Nothing
resultSigToMaybeKind (DKindSig k)                = Just k
resultSigToMaybeKind (DTyVarSig (DPlainTV _))    = Nothing
resultSigToMaybeKind (DTyVarSig (DKindedTV _ k)) = Just k
ravel :: [DType] -> DType -> DType
ravel []    res  = res
ravel (h:t) res = DAppT (DAppT DArrowT h) (ravel t res)
predToType :: DPred -> DType
predToType (DForallPr tvbs cxt p) = DForallT tvbs cxt (predToType p)
predToType (DAppPr p t)           = DAppT (predToType p) t
predToType (DSigPr p k)           = DSigT (predToType p) k
predToType (DVarPr n)             = DVarT n
predToType (DConPr n)             = DConT n
predToType DWildCardPr            = DWildCardT
countArgs :: DType -> Int
countArgs ty = length args
  where (_, _, args, _) = unravel ty
noExactTyVars :: Data a => a -> a
noExactTyVars = everywhere go
  where
    go :: Data a => a -> a
    go = mkT fix_tvb `extT` fix_ty `extT` fix_inj_ann
    no_exact_name :: Name -> Name
    no_exact_name (Name (OccName occ) (NameU unique)) = mkName (occ ++ show unique)
    no_exact_name n                                   = n
    fix_tvb (DPlainTV n)    = DPlainTV (no_exact_name n)
    fix_tvb (DKindedTV n k) = DKindedTV (no_exact_name n) k
    fix_ty (DVarT n)           = DVarT (no_exact_name n)
    fix_ty ty                  = ty
    fix_inj_ann (InjectivityAnn lhs rhs)
      = InjectivityAnn (no_exact_name lhs) (map no_exact_name rhs)
substKind :: Map Name DKind -> DKind -> DKind
substKind = substType
substType :: Map Name DType -> DType -> DType
substType subst ty | Map.null subst = ty
substType subst (DForallT tvbs cxt inner_ty)
  = DForallT tvbs' cxt' inner_ty'
  where
    (subst', tvbs') = mapAccumL subst_tvb subst tvbs
    cxt'            = map (substPred subst') cxt
    inner_ty'       = substType subst' inner_ty
substType subst (DAppT ty1 ty2) = substType subst ty1 `DAppT` substType subst ty2
substType subst (DSigT ty ki) = substType subst ty `DSigT` substType subst ki
substType subst (DVarT n) =
  case Map.lookup n subst of
    Just ki -> ki
    Nothing -> DVarT n
substType _ ty@(DConT {}) = ty
substType _ ty@(DArrowT)  = ty
substType _ ty@(DLitT {}) = ty
substType _ ty@DWildCardT = ty
substPred :: Map Name DType -> DPred -> DPred
substPred subst pred | Map.null subst = pred
substPred subst (DForallPr tvbs cxt inner_pred)
  = DForallPr tvbs' cxt' inner_pred'
  where
    (subst', tvbs') = mapAccumL subst_tvb subst tvbs
    cxt'            = map (substPred subst') cxt
    inner_pred'     = substPred subst' inner_pred
substPred subst (DAppPr pred ty) =
  DAppPr (substPred subst pred) (substType subst ty)
substPred subst (DSigPr pred ki) =
  DSigPr (substPred subst pred) (substKind subst ki)
substPred _ pred@(DVarPr {}) = pred
substPred _ pred@(DConPr {}) = pred
substPred _ pred@DWildCardPr = pred
subst_tvb :: Map Name DKind -> DTyVarBndr -> (Map Name DKind, DTyVarBndr)
subst_tvb s tvb@(DPlainTV n) = (Map.delete n s, tvb)
subst_tvb s (DKindedTV n k)  = (Map.delete n s, DKindedTV n (substKind s k))
cuskify :: DTyVarBndr -> DTyVarBndr
cuskify (DPlainTV tvname) = DKindedTV tvname $ DConT typeKindName
cuskify tvb               = tvb
foldType :: DType -> [DType] -> DType
foldType = foldl DAppT
foldTypeTvbs :: DType -> [DTyVarBndr] -> DType
foldTypeTvbs ty = foldType ty . map tvbToType
foldPred :: DPred -> [DType] -> DPred
foldPred = foldl DAppPr
foldPredTvbs :: DPred -> [DTyVarBndr] -> DPred
foldPredTvbs pr = foldPred pr . map tvbToType
unfoldType :: DType -> NonEmpty DType
unfoldType = go []
  where
    go :: [DType] -> DType -> NonEmpty DType
    go acc (DAppT t1 t2)    = go (t2:acc) t1
    go acc (DSigT t _)      = go acc t
    go acc (DForallT _ _ t) = go acc t
    go acc t                = t :| acc
buildDataDTvbs :: DsMonad q => [DTyVarBndr] -> Maybe DKind -> q [DTyVarBndr]
buildDataDTvbs tvbs mk = do
  extra_tvbs <- mkExtraDKindBinders $ fromMaybe (DConT typeKindName) mk
  pure $ tvbs ++ extra_tvbs
foldExp :: DExp -> [DExp] -> DExp
foldExp = foldl DAppE
isFunTy :: DType -> Bool
isFunTy (DAppT (DAppT DArrowT _) _) = True
isFunTy (DForallT _ _ _)            = True
isFunTy _                           = False
orIfEmpty :: [a] -> [a] -> [a]
orIfEmpty [] x = x
orIfEmpty x  _ = x
multiCase :: [DExp] -> [DPat] -> DExp -> DExp
multiCase [] [] body = body
multiCase scruts pats body =
  DCaseE (mkTupleDExp scruts) [DMatch (mkTupleDPat pats) body]
wrapDesugar :: (Desugar th ds, DsMonad q) => (th -> ds -> q ds) -> th -> q th
wrapDesugar f th = do
  ds <- desugar th
  fmap sweeten $ f th ds
newtype QWithAux m q a = QWA { runQWA :: WriterT m q a }
  deriving ( Functor, Applicative, Monad, MonadTrans
           , MonadWriter m, MonadReader r
           , MonadFail, MonadIO )
instance (Quasi q, Monoid m) => Quasi (QWithAux m q) where
  qNewName          = lift `comp1` qNewName
  qReport           = lift `comp2` qReport
  qLookupName       = lift `comp2` qLookupName
  qReify            = lift `comp1` qReify
  qReifyInstances   = lift `comp2` qReifyInstances
  qLocation         = lift qLocation
  qRunIO            = lift `comp1` qRunIO
  qAddDependentFile = lift `comp1` qAddDependentFile
  qReifyRoles       = lift `comp1` qReifyRoles
  qReifyAnnotations = lift `comp1` qReifyAnnotations
  qReifyModule      = lift `comp1` qReifyModule
  qAddTopDecls      = lift `comp1` qAddTopDecls
  qAddModFinalizer  = lift `comp1` qAddModFinalizer
  qGetQ             = lift qGetQ
  qPutQ             = lift `comp1` qPutQ
  qReifyFixity        = lift `comp1` qReifyFixity
  qReifyConStrictness = lift `comp1` qReifyConStrictness
  qIsExtEnabled       = lift `comp1` qIsExtEnabled
  qExtsEnabled        = lift qExtsEnabled
  qAddForeignFilePath = lift `comp2` qAddForeignFilePath
  qAddTempFile        = lift `comp1` qAddTempFile
  qAddCorePlugin      = lift `comp1` qAddCorePlugin
  qRecover exp handler = do
    (result, aux) <- lift $ qRecover (evalForPair exp) (evalForPair handler)
    tell aux
    return result
instance (DsMonad q, Monoid m) => DsMonad (QWithAux m q) where
  localDeclarations = lift localDeclarations
comp1 :: (b -> c) -> (a -> b) -> a -> c
comp1 = (.)
comp2 :: (c -> d) -> (a -> b -> c) -> a -> b -> d
comp2 f g a b = f (g a b)
evalWithoutAux :: Quasi q => QWithAux m q a -> q a
evalWithoutAux = liftM fst . runWriterT . runQWA
evalForAux :: Quasi q => QWithAux m q a -> q m
evalForAux = execWriterT . runQWA
evalForPair :: QWithAux m q a -> q (a, m)
evalForPair = runWriterT . runQWA
addBinding :: (Quasi q, Ord k) => k -> v -> QWithAux (Map.Map k v) q ()
addBinding k v = tell (Map.singleton k v)
addElement :: Quasi q => elt -> QWithAux [elt] q ()
addElement elt = tell [elt]
dsReifyTypeNameInfo :: DsMonad q => Name -> q (Maybe DInfo)
dsReifyTypeNameInfo ty_name = do
  mb_name <- lookupTypeNameWithLocals (nameBase ty_name)
  case mb_name of
    Just n  -> dsReify n
    Nothing -> pure Nothing
concatMapM :: (Monad monad, Monoid monoid, Traversable t)
           => (a -> monad monoid) -> t a -> monad monoid
concatMapM fn list = do
  bss <- mapM fn list
  return $ fold bss
listify :: a -> [a]
listify = (:[])
fstOf3 :: (a,b,c) -> a
fstOf3 (a,_,_) = a
liftFst :: (a -> b) -> (a, c) -> (b, c)
liftFst f (a, c) = (f a, c)
liftSnd :: (a -> b) -> (c, a) -> (c, b)
liftSnd f (c, a) = (c, f a)
snocView :: [a] -> ([a], a)
snocView [] = error "snocView nil"
snocView [x] = ([], x)
snocView (x : xs) = liftFst (x:) (snocView xs)
partitionWith :: (a -> Either b c) -> [a] -> ([b], [c])
partitionWith f = go [] []
  where go bs cs []     = (reverse bs, reverse cs)
        go bs cs (a:as) =
          case f a of
            Left b  -> go (b:bs) cs as
            Right c -> go bs (c:cs) as
partitionWithM :: Monad m => (a -> m (Either b c)) -> [a] -> m ([b], [c])
partitionWithM f = go [] []
  where go bs cs []     = return (reverse bs, reverse cs)
        go bs cs (a:as) = do
          fa <- f a
          case fa of
            Left b  -> go (b:bs) cs as
            Right c -> go bs (c:cs) as
partitionLetDecs :: [DDec] -> ([DLetDec], [DDec])
partitionLetDecs = partitionWith (\case DLetDec ld -> Left ld
                                        dec        -> Right dec)
{-# INLINEABLE zipWith3M #-}
zipWith3M :: Monad m => (a -> b -> m c) -> [a] -> [b] -> m [c]
zipWith3M f (a:as) (b:bs) = (:) <$> f a b <*> zipWith3M f as bs
zipWith3M _ _ _ = return []
mapAndUnzip3M :: Monad m => (a -> m (b,c,d)) -> [a] -> m ([b],[c],[d])
mapAndUnzip3M _ []     = return ([],[],[])
mapAndUnzip3M f (x:xs) = do
    (r1,  r2,  r3)  <- f x
    (rs1, rs2, rs3) <- mapAndUnzip3M f xs
    return (r1:rs1, r2:rs2, r3:rs3)
isHsLetter :: Char -> Bool
isHsLetter c = isLetter c || c == '_'