{- Language/Haskell/TH/Desugar/Match.hs

(c) Richard Eisenberg 2013

Simplifies case statements in desugared TH. After this pass, there are no
more nested patterns.

This code is directly based on the analogous operation as written in GHC.

{-# LANGUAGE CPP, TemplateHaskell #-}

module Language.Haskell.TH.Desugar.Match (scExp, scLetDec) where

import Prelude hiding ( fail, exp )

#if __GLASGOW_HASKELL__ < 709
import Control.Applicative
import Control.Monad hiding ( fail )
import qualified Control.Monad as Monad
import Data.Data
import qualified Data.Foldable as F
import Data.Generics
import qualified Data.Set as S
import qualified Data.Map as Map
import Language.Haskell.TH.Instances ()
import Language.Haskell.TH.Syntax

import Language.Haskell.TH.Desugar.AST
import Language.Haskell.TH.Desugar.Core
import Language.Haskell.TH.Desugar.FV
import qualified Language.Haskell.TH.Desugar.OSet as OS
import Language.Haskell.TH.Desugar.Util
import Language.Haskell.TH.Desugar.Reify

-- | Remove all nested pattern-matches within this expression. This also
-- removes all 'DTildePa's and 'DBangPa's. After this is run, every pattern
-- is guaranteed to be either a 'DConPa' with bare variables as arguments,
-- a 'DLitPa', or a 'DWildPa'.
scExp :: DsMonad q => DExp -> q DExp
scExp (DAppE e1 e2) = DAppE <$> scExp e1 <*> scExp e2
scExp (DLamE names exp) = DLamE names <$> scExp exp
scExp (DCaseE scrut matches)
  | DVarE name <- scrut
  = simplCaseExp [name] clauses
  | otherwise
  = do scrut_name <- newUniqueName "scrut"
       case_exp <- simplCaseExp [scrut_name] clauses
       return $ DLetE [DValD (DVarP scrut_name) scrut] case_exp
    clauses = map match_to_clause matches
    match_to_clause (DMatch pat exp) = DClause [pat] exp

scExp (DLetE decs body) = DLetE <$> mapM scLetDec decs <*> scExp body
scExp (DSigE exp ty) = DSigE <$> scExp exp <*> pure ty
scExp (DAppTypeE exp ty) = DAppTypeE <$> scExp exp <*> pure ty
scExp e@(DVarE {}) = return e
scExp e@(DConE {}) = return e
scExp e@(DLitE {}) = return e
scExp e@(DStaticE {}) = return e

-- | Like 'scExp', but for a 'DLetDec'.
scLetDec :: DsMonad q => DLetDec -> q DLetDec
scLetDec (DFunD name clauses@(DClause pats1 _ : _)) = do
  arg_names <- mapM (const (newUniqueName "_arg")) pats1
  clauses' <- mapM sc_clause_rhs clauses
  case_exp <- simplCaseExp arg_names clauses'
  return $ DFunD name [DClause (map DVarP arg_names) case_exp]
    sc_clause_rhs (DClause pats exp) = DClause pats <$> scExp exp
scLetDec (DValD pat exp) = DValD pat <$> scExp exp
scLetDec (DPragmaD prag) = DPragmaD <$> scLetPragma prag
scLetDec dec@(DSigD {}) = return dec
scLetDec dec@(DInfixD {}) = return dec
scLetDec dec@(DFunD _ []) = return dec

scLetPragma :: DsMonad q => DPragma -> q DPragma
scLetPragma = topEverywhereM scExp -- Only topEverywhereM because scExp already recurses on its own

type MatchResult = DExp -> DExp

matchResultToDExp :: MatchResult -> DExp
matchResultToDExp mr = mr failed_pattern_match
    failed_pattern_match = DAppE (DVarE 'error)
                                 (DLitE $ StringL "Pattern-match failure")

simplCaseExp :: DsMonad q
             => [Name]
             -> [DClause]
             -> q DExp
simplCaseExp vars clauses =
  do let eis = [ EquationInfo pats (\_ -> rhs) |
                 DClause pats rhs <- clauses ]
     matchResultToDExp `liftM` simplCase vars eis

data EquationInfo = EquationInfo [DPat] MatchResult  -- like DClause, but with a hole

-- analogous to GHC's match (in deSugar/Match.lhs)
simplCase :: DsMonad q
          => [Name]         -- the names of the scrutinees
          -> [EquationInfo] -- the matches (where the # of pats == length (1st arg))
          -> q MatchResult
simplCase [] clauses = return (foldr1 (.) match_results)
    match_results = [ mr | EquationInfo _ mr <- clauses ]
simplCase vars@(v:_) clauses = do
  (aux_binds, tidy_clauses) <- mapAndUnzipM (tidyClause v) clauses
  let grouped = groupClauses tidy_clauses
  match_results <- match_groups grouped
  return (adjustMatchResult (foldr (.) id aux_binds) $
          foldr1 (.) match_results)
    match_groups :: DsMonad q => [[(PatGroup, EquationInfo)]] -> q [MatchResult]
    match_groups [] = matchEmpty v
    match_groups gs = mapM match_group gs

    match_group :: DsMonad q => [(PatGroup, EquationInfo)] -> q MatchResult
    match_group [] = error "Internal error in th-desugar (match_group)"
    match_group eqns@((group,_) : _) =
      case group of
        PgCon _ -> matchConFamily vars (subGroup [(c,e) | (PgCon c, e) <- eqns])
        PgLit _ -> matchLiterals  vars (subGroup [(l,e) | (PgLit l, e) <- eqns])
        PgBang  -> matchBangs     vars (drop_group eqns)
        PgAny   -> matchVariables vars (drop_group eqns)

    drop_group = map snd

-- analogous to GHC's tidyEqnInfo
tidyClause :: DsMonad q => Name -> EquationInfo -> q (DExp -> DExp, EquationInfo)
tidyClause _ (EquationInfo [] _) =
  error "Internal error in th-desugar: no patterns in tidyClause."
tidyClause v (EquationInfo (pat : pats) body) = do
  (wrap, pat') <- tidy1 v pat
  return (wrap, EquationInfo (pat' : pats) body)

tidy1 :: DsMonad q
      => Name   -- the name of the variable that ...
      -> DPat   -- ... this pattern is matching against
      -> q (DExp -> DExp, DPat)   -- a wrapper and tidied pattern
tidy1 _ p@(DLitP {}) = return (id, p)
tidy1 v (DVarP var) = return (wrapBind var v, DWildP)
tidy1 _ p@(DConP {}) = return (id, p)
tidy1 v (DTildeP pat) = do
  sel_decs <- mkSelectorDecs pat v
  return (maybeDLetE sel_decs, DWildP)
tidy1 v (DBangP pat) =
  case pat of
    DLitP _   -> tidy1 v pat   -- already strict
    DVarP _   -> return (id, DBangP pat)  -- no change
    DConP _ _ -> tidy1 v pat   -- already strict
    DTildeP p -> tidy1 v (DBangP p) -- discard ~ under !
    DBangP p  -> tidy1 v (DBangP p) -- discard ! under !
    DSigP p _ -> tidy1 v (DBangP p) -- discard sig under !
    DWildP    -> return (id, DBangP pat)  -- no change
tidy1 v (DSigP pat ty)
  | no_tyvars_ty ty = tidy1 v pat
  -- The match-flattener doesn't know how to deal with patterns that mention
  -- type variables properly, so we give up if we encounter one.
  -- See https://github.com/goldfirere/th-desugar/pull/48#issuecomment-266778976
  -- for further discussion.
  | otherwise = Monad.fail
    "Match-flattening patterns that mention type variables is not supported."
    no_tyvars_ty :: Data a => a -> Bool
    no_tyvars_ty = everything (&&) (mkQ True no_tyvar_ty)

    no_tyvar_ty :: DType -> Bool
    no_tyvar_ty (DVarT{}) = False
    no_tyvar_ty t         = gmapQl (&&) True no_tyvars_ty t
tidy1 _ DWildP = return (id, DWildP)

wrapBind :: Name -> Name -> DExp -> DExp
wrapBind new old
  | new == old = id
  | otherwise  = DLetE [DValD (DVarP new) (DVarE old)]

-- like GHC's mkSelectorBinds
mkSelectorDecs :: DsMonad q
               => DPat      -- pattern to deconstruct
               -> Name      -- variable being matched against
               -> q [DLetDec]
mkSelectorDecs (DVarP v) name = return [DValD (DVarP v) (DVarE name)]
mkSelectorDecs pat name
  | OS.null binders
  = return []

  | OS.size binders == 1
  = do val_var <- newUniqueName "var"
       err_var <- newUniqueName "err"
       bind    <- mk_bind val_var err_var (head $ F.toList binders)
       return [DValD (DVarP val_var) (DVarE name),
               DValD (DVarP err_var) (DVarE 'error `DAppE`
                                       (DLitE $ StringL "Irrefutable match failed")),

  | otherwise
  = do tuple_expr <- simplCaseExp [name] [DClause [pat] local_tuple]
       tuple_var <- newUniqueName "tuple"
       projections <- mapM (mk_projection tuple_var) [0 .. tuple_size-1]
       return (DValD (DVarP tuple_var) tuple_expr :
               zipWith DValD (map DVarP binders_list) projections)

    binders = extractBoundNamesDPat pat
    binders_list = F.toList binders
    tuple_size = length binders_list
    local_tuple = mkTupleDExp (map DVarE binders_list)

    mk_projection :: DsMonad q
                  => Name   -- of the tuple
                  -> Int    -- which element to get (0-indexed)
                  -> q DExp
    mk_projection tup_name i = do
      var_name <- newUniqueName "proj"
      return $ DCaseE (DVarE tup_name) [DMatch (DConP (tupleDataName tuple_size) (mk_tuple_pats var_name i))
                                               (DVarE var_name)]

    mk_tuple_pats :: Name   -- of the projected element
                  -> Int    -- which element to get (0-indexed)
                  -> [DPat]
    mk_tuple_pats elt_name i = replicate i DWildP ++ DVarP elt_name : replicate (tuple_size - i - 1) DWildP

    mk_bind scrut_var err_var bndr_var = do
      rhs_mr <- simplCase [scrut_var] [EquationInfo [pat] (\_ -> DVarE bndr_var)]
      return (DValD (DVarP bndr_var) (rhs_mr (DVarE err_var)))

data PatGroup
  = PgAny         -- immediate match (wilds, vars, lazies)
  | PgCon Name
  | PgLit Lit
  | PgBang

-- like GHC's groupEquations
groupClauses :: [EquationInfo] -> [[(PatGroup, EquationInfo)]]
groupClauses clauses
  = runs same_gp [(patGroup (firstPat clause), clause) | clause <- clauses]
    same_gp :: (PatGroup, EquationInfo) -> (PatGroup, EquationInfo) -> Bool
    (pg1,_) `same_gp` (pg2,_) = pg1 `sameGroup` pg2

patGroup :: DPat -> PatGroup
patGroup (DLitP l)     = PgLit l
patGroup (DVarP {})    = error "Internal error in th-desugar (patGroup DVarP)"
patGroup (DConP con _) = PgCon con
patGroup (DTildeP {})  = error "Internal error in th-desugar (patGroup DTildeP)"
patGroup (DBangP {})   = PgBang
patGroup (DSigP{})     = error "Internal error in th-desugar (patGroup DSigP)"
patGroup DWildP        = PgAny

sameGroup :: PatGroup -> PatGroup -> Bool
sameGroup PgAny     PgAny     = True
sameGroup PgBang    PgBang    = True
sameGroup (PgCon _) (PgCon _) = True
sameGroup (PgLit _) (PgLit _) = True
sameGroup _         _         = False

subGroup :: Ord a => [(a, EquationInfo)] -> [[EquationInfo]]
subGroup group
  = map reverse $ Map.elems $ foldl accumulate Map.empty group
    accumulate pg_map (pg, eqn)
      = case Map.lookup pg pg_map of
          Just eqns -> Map.insert pg (eqn:eqns) pg_map
          Nothing   -> Map.insert pg [eqn]      pg_map

firstPat :: EquationInfo -> DPat
firstPat (EquationInfo (pat : _) _) = pat
firstPat _ = error "Clause encountered with no patterns -- should never happen"

data CaseAlt = CaseAlt { alt_con  :: Name         -- con name
                       , _alt_args :: [Name]       -- bound var names
                       , _alt_rhs  :: MatchResult  -- RHS

-- from GHC's MatchCon.lhs
matchConFamily :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult
matchConFamily (var:vars) groups
  = do alts <- mapM (matchOneCon vars) groups
       mkDataConCase var alts
matchConFamily [] _ = error "Internal error in th-desugar (matchConFamily)"

-- like matchOneConLike from MatchCon
matchOneCon :: DsMonad q => [Name] -> [EquationInfo] -> q CaseAlt
matchOneCon vars eqns@(eqn1 : _)
  = do arg_vars <- selectMatchVars (pat_args pat1)
       match_result <- match_group arg_vars

       return $ CaseAlt (pat_con pat1) arg_vars match_result
    pat1 = firstPat eqn1

    pat_args (DConP _ pats) = pats
    pat_args _              = error "Internal error in th-desugar (pat_args)"

    pat_con (DConP con _) = con
    pat_con _             = error "Internal error in th-desugar (pat_con)"

    match_group :: DsMonad q => [Name] -> q MatchResult
    match_group arg_vars
      = simplCase (arg_vars ++ vars) (map shift eqns)

    shift (EquationInfo (DConP _ args : pats) exp) = EquationInfo (args ++ pats) exp
    shift _ = error "Internal error in th-desugar (shift)"
matchOneCon _ _ = error "Internal error in th-desugar (matchOneCon)"

mkDataConCase :: DsMonad q => Name -> [CaseAlt] -> q MatchResult
mkDataConCase var case_alts = do
  all_ctors <- get_all_ctors (alt_con $ head case_alts)
  return $ \fail ->
    let matches = map (mk_alt fail) case_alts in
    DCaseE (DVarE var) (matches ++ mk_default all_ctors fail)
    mk_alt fail (CaseAlt con args body_fn)
      = let body = body_fn fail in
        DMatch (DConP con (map DVarP args)) body

    mk_default all_ctors fail | exhaustive_case all_ctors = []
                              | otherwise       = [DMatch DWildP fail]

    mentioned_ctors = S.fromList $ map alt_con case_alts
    exhaustive_case all_ctors = all_ctors `S.isSubsetOf` mentioned_ctors

    get_all_ctors :: DsMonad q => Name -> q (S.Set Name)
    get_all_ctors con_name = do
      ty_name <- dataConNameToDataName con_name
      Just (DTyConI tycon_dec _) <- dsReify ty_name
      return $ S.fromList $ map get_con_name $ get_cons tycon_dec

    get_cons (DDataD _ _ _ _ _ cons _)     = cons
    get_cons (DDataInstD _ _ _ _ _ cons _) = cons
    get_cons _                             = []

    get_con_name (DCon _ _ n _ _) = n

matchEmpty :: DsMonad q => Name -> q [MatchResult]
matchEmpty var = return [mk_seq]
    mk_seq fail = DCaseE (DVarE var) [DMatch DWildP fail]

matchLiterals :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult
matchLiterals (var:vars) sub_groups
  = do alts <- mapM match_group sub_groups
       return (mkCoPrimCaseMatchResult var alts)
    match_group :: DsMonad q => [EquationInfo] -> q (Lit, MatchResult)
    match_group eqns
      = do let lit = case firstPat (head eqns) of
                       DLitP lit' -> lit'
                       _          -> error $ "Internal error in th-desugar "
                                          ++ "(matchLiterals.match_group)"
           match_result <- simplCase vars (shiftEqns eqns)
           return (lit, match_result)
matchLiterals [] _ = error "Internal error in th-desugar (matchLiterals)"

mkCoPrimCaseMatchResult :: Name -- Scrutinee
                        -> [(Lit, MatchResult)]
                        -> MatchResult
mkCoPrimCaseMatchResult var match_alts = mk_case
    mk_case fail = let alts = map (mk_alt fail) match_alts in
                   DCaseE (DVarE var) (alts ++ [DMatch DWildP fail])
    mk_alt fail (lit, body_fn)
      = DMatch (DLitP lit) (body_fn fail)

matchBangs :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult
matchBangs (var:vars) eqns
  = do match_result <- simplCase (var:vars) $
                       map (decomposeFirstPat getBangPat) eqns
       return (mkEvalMatchResult var match_result)
matchBangs [] _ = error "Internal error in th-desugar (matchBangs)"

decomposeFirstPat :: (DPat -> DPat) -> EquationInfo -> EquationInfo
decomposeFirstPat extractpat (EquationInfo (pat:pats) body)
  = EquationInfo (extractpat pat : pats) body
decomposeFirstPat _ _ = error "Internal error in th-desugar (decomposeFirstPat)"

getBangPat :: DPat -> DPat
getBangPat (DBangP p) = p
getBangPat _          = error "Internal error in th-desugar (getBangPat)"

mkEvalMatchResult :: Name -> MatchResult -> MatchResult
mkEvalMatchResult var body_fn fail
  = foldl DAppE (DVarE 'seq) [DVarE var, body_fn fail]

matchVariables :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult
matchVariables (_:vars) eqns = simplCase vars (shiftEqns eqns)
matchVariables _ _ = error "Internal error in th-desugar (matchVariables)"

shiftEqns :: [EquationInfo] -> [EquationInfo]
shiftEqns = map shift
    shift (EquationInfo pats rhs) = EquationInfo (tail pats) rhs

adjustMatchResult :: (DExp -> DExp) -> MatchResult -> MatchResult
adjustMatchResult wrap mr fail = wrap $ mr fail

-- from DsUtils
selectMatchVars :: DsMonad q => [DPat] -> q [Name]
selectMatchVars = mapM selectMatchVar

-- from DsUtils
selectMatchVar :: DsMonad q => DPat -> q Name
selectMatchVar (DBangP pat)  = selectMatchVar pat
selectMatchVar (DTildeP pat) = selectMatchVar pat
selectMatchVar (DVarP var)   = newUniqueName ('_' : nameBase var)
selectMatchVar _             = newUniqueName "_pat"

-- like GHC's runs
runs :: (a -> a -> Bool) -> [a] -> [[a]]
runs _ [] = []
runs p (x:xs) = case span (p x) xs of
                  (first, rest) -> (x:first) : (runs p rest)