{- Language/Haskell/TH/Desugar/Match.hs

(c) Richard Eisenberg 2013
rae@cs.brynmawr.edu

Simplifies case statements in desugared TH. After this pass, there are no
more nested patterns.

This code is directly based on the analogous operation as written in GHC.
-}

{-# LANGUAGE CPP, TemplateHaskell #-}

module Language.Haskell.TH.Desugar.Match (scExp, scLetDec) where

import Prelude hiding ( fail, exp )

#if __GLASGOW_HASKELL__ < 709
import Control.Applicative
#endif
import Control.Monad hiding ( fail )
import qualified Control.Monad as Monad
import Data.Data
import qualified Data.Foldable as F
import Data.Generics
import qualified Data.Set as S
import qualified Data.Map as Map
import Language.Haskell.TH.Instances ()
import Language.Haskell.TH.Syntax

import Language.Haskell.TH.Desugar.AST
import Language.Haskell.TH.Desugar.Core
import Language.Haskell.TH.Desugar.FV
import qualified Language.Haskell.TH.Desugar.OSet as OS
import Language.Haskell.TH.Desugar.Util
import Language.Haskell.TH.Desugar.Reify

-- | Remove all nested pattern-matches within this expression. This also
-- removes all 'DTildePa's and 'DBangPa's. After this is run, every pattern
-- is guaranteed to be either a 'DConPa' with bare variables as arguments,
-- a 'DLitPa', or a 'DWildPa'.
scExp :: DsMonad q => DExp -> q DExp
scExp :: DExp -> q DExp
scExp (DAppE DExp
e1 DExp
e2) = DExp -> DExp -> DExp
DAppE (DExp -> DExp -> DExp) -> q DExp -> q (DExp -> DExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
e1 q (DExp -> DExp) -> q DExp -> q DExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
e2
scExp (DLamE [Name]
names DExp
exp) = [Name] -> DExp -> DExp
DLamE [Name]
names (DExp -> DExp) -> q DExp -> q DExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
exp
scExp (DCaseE DExp
scrut [DMatch]
matches)
  | DVarE Name
name <- DExp
scrut
  = [Name] -> [DClause] -> q DExp
forall (q :: * -> *). DsMonad q => [Name] -> [DClause] -> q DExp
simplCaseExp [Name
name] [DClause]
clauses
  | Bool
otherwise
  = do Name
scrut_name <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"scrut"
       DExp
case_exp <- [Name] -> [DClause] -> q DExp
forall (q :: * -> *). DsMonad q => [Name] -> [DClause] -> q DExp
simplCaseExp [Name
scrut_name] [DClause]
clauses
       DExp -> q DExp
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> q DExp) -> DExp -> q DExp
forall a b. (a -> b) -> a -> b
$ [DLetDec] -> DExp -> DExp
DLetE [DPat -> DExp -> DLetDec
DValD (Name -> DPat
DVarP Name
scrut_name) DExp
scrut] DExp
case_exp
  where
    clauses :: [DClause]
clauses = (DMatch -> DClause) -> [DMatch] -> [DClause]
forall a b. (a -> b) -> [a] -> [b]
map DMatch -> DClause
match_to_clause [DMatch]
matches
    match_to_clause :: DMatch -> DClause
match_to_clause (DMatch DPat
pat DExp
exp) = [DPat] -> DExp -> DClause
DClause [DPat
pat] DExp
exp

scExp (DLetE [DLetDec]
decs DExp
body) = [DLetDec] -> DExp -> DExp
DLetE ([DLetDec] -> DExp -> DExp) -> q [DLetDec] -> q (DExp -> DExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (DLetDec -> q DLetDec) -> [DLetDec] -> q [DLetDec]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DLetDec -> q DLetDec
forall (q :: * -> *). DsMonad q => DLetDec -> q DLetDec
scLetDec [DLetDec]
decs q (DExp -> DExp) -> q DExp -> q DExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
body
scExp (DSigE DExp
exp DType
ty) = DExp -> DType -> DExp
DSigE (DExp -> DType -> DExp) -> q DExp -> q (DType -> DExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
exp q (DType -> DExp) -> q DType -> q DExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DType -> q DType
forall (f :: * -> *) a. Applicative f => a -> f a
pure DType
ty
scExp (DAppTypeE DExp
exp DType
ty) = DExp -> DType -> DExp
DAppTypeE (DExp -> DType -> DExp) -> q DExp -> q (DType -> DExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
exp q (DType -> DExp) -> q DType -> q DExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DType -> q DType
forall (f :: * -> *) a. Applicative f => a -> f a
pure DType
ty
scExp e :: DExp
e@(DVarE {}) = DExp -> q DExp
forall (m :: * -> *) a. Monad m => a -> m a
return DExp
e
scExp e :: DExp
e@(DConE {}) = DExp -> q DExp
forall (m :: * -> *) a. Monad m => a -> m a
return DExp
e
scExp e :: DExp
e@(DLitE {}) = DExp -> q DExp
forall (m :: * -> *) a. Monad m => a -> m a
return DExp
e
scExp e :: DExp
e@(DStaticE {}) = DExp -> q DExp
forall (m :: * -> *) a. Monad m => a -> m a
return DExp
e

-- | Like 'scExp', but for a 'DLetDec'.
scLetDec :: DsMonad q => DLetDec -> q DLetDec
scLetDec :: DLetDec -> q DLetDec
scLetDec (DFunD Name
name clauses :: [DClause]
clauses@(DClause [DPat]
pats1 DExp
_ : [DClause]
_)) = do
  [Name]
arg_names <- (DPat -> q Name) -> [DPat] -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (q Name -> DPat -> q Name
forall a b. a -> b -> a
const (String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"_arg")) [DPat]
pats1
  [DClause]
clauses' <- (DClause -> q DClause) -> [DClause] -> q [DClause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DClause -> q DClause
forall (f :: * -> *). DsMonad f => DClause -> f DClause
sc_clause_rhs [DClause]
clauses
  DExp
case_exp <- [Name] -> [DClause] -> q DExp
forall (q :: * -> *). DsMonad q => [Name] -> [DClause] -> q DExp
simplCaseExp [Name]
arg_names [DClause]
clauses'
  DLetDec -> q DLetDec
forall (m :: * -> *) a. Monad m => a -> m a
return (DLetDec -> q DLetDec) -> DLetDec -> q DLetDec
forall a b. (a -> b) -> a -> b
$ Name -> [DClause] -> DLetDec
DFunD Name
name [[DPat] -> DExp -> DClause
DClause ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
arg_names) DExp
case_exp]
  where
    sc_clause_rhs :: DClause -> f DClause
sc_clause_rhs (DClause [DPat]
pats DExp
exp) = [DPat] -> DExp -> DClause
DClause [DPat]
pats (DExp -> DClause) -> f DExp -> f DClause
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DExp -> f DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
exp
scLetDec (DValD DPat
pat DExp
exp) = DPat -> DExp -> DLetDec
DValD DPat
pat (DExp -> DLetDec) -> q DExp -> q DLetDec
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp DExp
exp
scLetDec (DPragmaD DPragma
prag) = DPragma -> DLetDec
DPragmaD (DPragma -> DLetDec) -> q DPragma -> q DLetDec
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DPragma -> q DPragma
forall (q :: * -> *). DsMonad q => DPragma -> q DPragma
scLetPragma DPragma
prag
scLetDec dec :: DLetDec
dec@(DSigD {}) = DLetDec -> q DLetDec
forall (m :: * -> *) a. Monad m => a -> m a
return DLetDec
dec
scLetDec dec :: DLetDec
dec@(DInfixD {}) = DLetDec -> q DLetDec
forall (m :: * -> *) a. Monad m => a -> m a
return DLetDec
dec
scLetDec dec :: DLetDec
dec@(DFunD Name
_ []) = DLetDec -> q DLetDec
forall (m :: * -> *) a. Monad m => a -> m a
return DLetDec
dec

scLetPragma :: DsMonad q => DPragma -> q DPragma
scLetPragma :: DPragma -> q DPragma
scLetPragma = (DExp -> q DExp) -> DPragma -> q DPragma
forall a b (m :: * -> *).
(Typeable a, Data b, Monad m) =>
(a -> m a) -> b -> m b
topEverywhereM DExp -> q DExp
forall (q :: * -> *). DsMonad q => DExp -> q DExp
scExp -- Only topEverywhereM because scExp already recurses on its own

type MatchResult = DExp -> DExp

matchResultToDExp :: MatchResult -> DExp
matchResultToDExp :: (DExp -> DExp) -> DExp
matchResultToDExp DExp -> DExp
mr = DExp -> DExp
mr DExp
failed_pattern_match
  where
    failed_pattern_match :: DExp
failed_pattern_match = DExp -> DExp -> DExp
DAppE (Name -> DExp
DVarE 'error)
                                 (Lit -> DExp
DLitE (Lit -> DExp) -> Lit -> DExp
forall a b. (a -> b) -> a -> b
$ String -> Lit
StringL String
"Pattern-match failure")

simplCaseExp :: DsMonad q
             => [Name]
             -> [DClause]
             -> q DExp
simplCaseExp :: [Name] -> [DClause] -> q DExp
simplCaseExp [Name]
vars [DClause]
clauses =
  do let eis :: [EquationInfo]
eis = [ [DPat] -> (DExp -> DExp) -> EquationInfo
EquationInfo [DPat]
pats (\DExp
_ -> DExp
rhs) |
                 DClause [DPat]
pats DExp
rhs <- [DClause]
clauses ]
     (DExp -> DExp) -> DExp
matchResultToDExp ((DExp -> DExp) -> DExp) -> q (DExp -> DExp) -> q DExp
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` [Name] -> [EquationInfo] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
simplCase [Name]
vars [EquationInfo]
eis

data EquationInfo = EquationInfo [DPat] MatchResult  -- like DClause, but with a hole

-- analogous to GHC's match (in deSugar/Match.lhs)
simplCase :: DsMonad q
          => [Name]         -- the names of the scrutinees
          -> [EquationInfo] -- the matches (where the # of pats == length (1st arg))
          -> q MatchResult
simplCase :: [Name] -> [EquationInfo] -> q (DExp -> DExp)
simplCase [] [EquationInfo]
clauses = (DExp -> DExp) -> q (DExp -> DExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (((DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp)
-> [DExp -> DExp] -> DExp -> DExp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) [DExp -> DExp]
match_results)
  where
    match_results :: [DExp -> DExp]
match_results = [ DExp -> DExp
mr | EquationInfo [DPat]
_ DExp -> DExp
mr <- [EquationInfo]
clauses ]
simplCase vars :: [Name]
vars@(Name
v:[Name]
_) [EquationInfo]
clauses = do
  ([DExp -> DExp]
aux_binds, [EquationInfo]
tidy_clauses) <- (EquationInfo -> q (DExp -> DExp, EquationInfo))
-> [EquationInfo] -> q ([DExp -> DExp], [EquationInfo])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (Name -> EquationInfo -> q (DExp -> DExp, EquationInfo)
forall (q :: * -> *).
DsMonad q =>
Name -> EquationInfo -> q (DExp -> DExp, EquationInfo)
tidyClause Name
v) [EquationInfo]
clauses
  let grouped :: [[(PatGroup, EquationInfo)]]
grouped = [EquationInfo] -> [[(PatGroup, EquationInfo)]]
groupClauses [EquationInfo]
tidy_clauses
  [DExp -> DExp]
match_results <- [[(PatGroup, EquationInfo)]] -> q [DExp -> DExp]
forall (q :: * -> *).
DsMonad q =>
[[(PatGroup, EquationInfo)]] -> q [DExp -> DExp]
match_groups [[(PatGroup, EquationInfo)]]
grouped
  (DExp -> DExp) -> q (DExp -> DExp)
forall (m :: * -> *) a. Monad m => a -> m a
return ((DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp
adjustMatchResult (((DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp)
-> (DExp -> DExp) -> [DExp -> DExp] -> DExp -> DExp
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) DExp -> DExp
forall a. a -> a
id [DExp -> DExp]
aux_binds) ((DExp -> DExp) -> DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp
forall a b. (a -> b) -> a -> b
$
          ((DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp)
-> [DExp -> DExp] -> DExp -> DExp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) [DExp -> DExp]
match_results)
  where
    match_groups :: DsMonad q => [[(PatGroup, EquationInfo)]] -> q [MatchResult]
    match_groups :: [[(PatGroup, EquationInfo)]] -> q [DExp -> DExp]
match_groups [] = Name -> q [DExp -> DExp]
forall (q :: * -> *). DsMonad q => Name -> q [DExp -> DExp]
matchEmpty Name
v
    match_groups [[(PatGroup, EquationInfo)]]
gs = ([(PatGroup, EquationInfo)] -> q (DExp -> DExp))
-> [[(PatGroup, EquationInfo)]] -> q [DExp -> DExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [(PatGroup, EquationInfo)] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[(PatGroup, EquationInfo)] -> q (DExp -> DExp)
match_group [[(PatGroup, EquationInfo)]]
gs

    match_group :: DsMonad q => [(PatGroup, EquationInfo)] -> q MatchResult
    match_group :: [(PatGroup, EquationInfo)] -> q (DExp -> DExp)
match_group [] = String -> q (DExp -> DExp)
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (match_group)"
    match_group eqns :: [(PatGroup, EquationInfo)]
eqns@((PatGroup
group,EquationInfo
_) : [(PatGroup, EquationInfo)]
_) =
      case PatGroup
group of
        PgCon Name
_ -> [Name] -> [[EquationInfo]] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [[EquationInfo]] -> q (DExp -> DExp)
matchConFamily [Name]
vars ([(Name, EquationInfo)] -> [[EquationInfo]]
forall a. Ord a => [(a, EquationInfo)] -> [[EquationInfo]]
subGroup [(Name
c,EquationInfo
e) | (PgCon Name
c, EquationInfo
e) <- [(PatGroup, EquationInfo)]
eqns])
        PgLit Lit
_ -> [Name] -> [[EquationInfo]] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [[EquationInfo]] -> q (DExp -> DExp)
matchLiterals  [Name]
vars ([(Lit, EquationInfo)] -> [[EquationInfo]]
forall a. Ord a => [(a, EquationInfo)] -> [[EquationInfo]]
subGroup [(Lit
l,EquationInfo
e) | (PgLit Lit
l, EquationInfo
e) <- [(PatGroup, EquationInfo)]
eqns])
        PatGroup
PgBang  -> [Name] -> [EquationInfo] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
matchBangs     [Name]
vars ([(PatGroup, EquationInfo)] -> [EquationInfo]
forall a b. [(a, b)] -> [b]
drop_group [(PatGroup, EquationInfo)]
eqns)
        PatGroup
PgAny   -> [Name] -> [EquationInfo] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
matchVariables [Name]
vars ([(PatGroup, EquationInfo)] -> [EquationInfo]
forall a b. [(a, b)] -> [b]
drop_group [(PatGroup, EquationInfo)]
eqns)

    drop_group :: [(a, b)] -> [b]
drop_group = ((a, b) -> b) -> [(a, b)] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (a, b) -> b
forall a b. (a, b) -> b
snd

-- analogous to GHC's tidyEqnInfo
tidyClause :: DsMonad q => Name -> EquationInfo -> q (DExp -> DExp, EquationInfo)
tidyClause :: Name -> EquationInfo -> q (DExp -> DExp, EquationInfo)
tidyClause Name
_ (EquationInfo [] DExp -> DExp
_) =
  String -> q (DExp -> DExp, EquationInfo)
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar: no patterns in tidyClause."
tidyClause Name
v (EquationInfo (DPat
pat : [DPat]
pats) DExp -> DExp
body) = do
  (DExp -> DExp
wrap, DPat
pat') <- Name -> DPat -> q (DExp -> DExp, DPat)
forall (q :: * -> *).
DsMonad q =>
Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
v DPat
pat
  (DExp -> DExp, EquationInfo) -> q (DExp -> DExp, EquationInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> DExp
wrap, [DPat] -> (DExp -> DExp) -> EquationInfo
EquationInfo (DPat
pat' DPat -> [DPat] -> [DPat]
forall a. a -> [a] -> [a]
: [DPat]
pats) DExp -> DExp
body)

tidy1 :: DsMonad q
      => Name   -- the name of the variable that ...
      -> DPat   -- ... this pattern is matching against
      -> q (DExp -> DExp, DPat)   -- a wrapper and tidied pattern
tidy1 :: Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
_ p :: DPat
p@(DLitP {}) = (DExp -> DExp, DPat) -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> DExp
forall a. a -> a
id, DPat
p)
tidy1 Name
v (DVarP Name
var) = (DExp -> DExp, DPat) -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Name -> DExp -> DExp
wrapBind Name
var Name
v, DPat
DWildP)
tidy1 Name
_ p :: DPat
p@(DConP {}) = (DExp -> DExp, DPat) -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> DExp
forall a. a -> a
id, DPat
p)
tidy1 Name
v (DTildeP DPat
pat) = do
  [DLetDec]
sel_decs <- DPat -> Name -> q [DLetDec]
forall (q :: * -> *). DsMonad q => DPat -> Name -> q [DLetDec]
mkSelectorDecs DPat
pat Name
v
  (DExp -> DExp, DPat) -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. Monad m => a -> m a
return ([DLetDec] -> DExp -> DExp
maybeDLetE [DLetDec]
sel_decs, DPat
DWildP)
tidy1 Name
v (DBangP DPat
pat) =
  case DPat
pat of
    DLitP Lit
_   -> Name -> DPat -> q (DExp -> DExp, DPat)
forall (q :: * -> *).
DsMonad q =>
Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
v DPat
pat   -- already strict
    DVarP Name
_   -> (DExp -> DExp, DPat) -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> DExp
forall a. a -> a
id, DPat -> DPat
DBangP DPat
pat)  -- no change
    DConP{}   -> Name -> DPat -> q (DExp -> DExp, DPat)
forall (q :: * -> *).
DsMonad q =>
Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
v DPat
pat   -- already strict
    DTildeP DPat
p -> Name -> DPat -> q (DExp -> DExp, DPat)
forall (q :: * -> *).
DsMonad q =>
Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
v (DPat -> DPat
DBangP DPat
p) -- discard ~ under !
    DBangP DPat
p  -> Name -> DPat -> q (DExp -> DExp, DPat)
forall (q :: * -> *).
DsMonad q =>
Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
v (DPat -> DPat
DBangP DPat
p) -- discard ! under !
    DSigP DPat
p DType
_ -> Name -> DPat -> q (DExp -> DExp, DPat)
forall (q :: * -> *).
DsMonad q =>
Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
v (DPat -> DPat
DBangP DPat
p) -- discard sig under !
    DPat
DWildP    -> (DExp -> DExp, DPat) -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> DExp
forall a. a -> a
id, DPat -> DPat
DBangP DPat
pat)  -- no change
tidy1 Name
v (DSigP DPat
pat DType
ty)
  | DType -> Bool
forall a. Data a => a -> Bool
no_tyvars_ty DType
ty = Name -> DPat -> q (DExp -> DExp, DPat)
forall (q :: * -> *).
DsMonad q =>
Name -> DPat -> q (DExp -> DExp, DPat)
tidy1 Name
v DPat
pat
  -- The match-flattener doesn't know how to deal with patterns that mention
  -- type variables properly, so we give up if we encounter one.
  -- See https://github.com/goldfirere/th-desugar/pull/48#issuecomment-266778976
  -- for further discussion.
  | Bool
otherwise = String -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. MonadFail m => String -> m a
Monad.fail
    String
"Match-flattening patterns that mention type variables is not supported."
  where
    no_tyvars_ty :: Data a => a -> Bool
    no_tyvars_ty :: a -> Bool
no_tyvars_ty = (Bool -> Bool -> Bool)
-> (forall a. Data a => a -> Bool) -> forall a. Data a => a -> Bool
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Bool -> Bool -> Bool
(&&) (Bool -> (DType -> Bool) -> a -> Bool
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Bool
True DType -> Bool
no_tyvar_ty)

    no_tyvar_ty :: DType -> Bool
    no_tyvar_ty :: DType -> Bool
no_tyvar_ty (DVarT{}) = Bool
False
    no_tyvar_ty DType
t         = (Bool -> Bool -> Bool)
-> Bool -> (forall a. Data a => a -> Bool) -> DType -> Bool
forall a r r'.
Data a =>
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r
gmapQl Bool -> Bool -> Bool
(&&) Bool
True forall a. Data a => a -> Bool
no_tyvars_ty DType
t
tidy1 Name
_ DPat
DWildP = (DExp -> DExp, DPat) -> q (DExp -> DExp, DPat)
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> DExp
forall a. a -> a
id, DPat
DWildP)

wrapBind :: Name -> Name -> DExp -> DExp
wrapBind :: Name -> Name -> DExp -> DExp
wrapBind Name
new Name
old
  | Name
new Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
old = DExp -> DExp
forall a. a -> a
id
  | Bool
otherwise  = [DLetDec] -> DExp -> DExp
DLetE [DPat -> DExp -> DLetDec
DValD (Name -> DPat
DVarP Name
new) (Name -> DExp
DVarE Name
old)]

-- like GHC's mkSelectorBinds
mkSelectorDecs :: DsMonad q
               => DPat      -- pattern to deconstruct
               -> Name      -- variable being matched against
               -> q [DLetDec]
mkSelectorDecs :: DPat -> Name -> q [DLetDec]
mkSelectorDecs (DVarP Name
v) Name
name = [DLetDec] -> q [DLetDec]
forall (m :: * -> *) a. Monad m => a -> m a
return [DPat -> DExp -> DLetDec
DValD (Name -> DPat
DVarP Name
v) (Name -> DExp
DVarE Name
name)]
mkSelectorDecs DPat
pat Name
name
  | OSet Name -> Bool
forall a. OSet a -> Bool
OS.null OSet Name
binders
  = [DLetDec] -> q [DLetDec]
forall (m :: * -> *) a. Monad m => a -> m a
return []

  | OSet Name -> Int
forall a. OSet a -> Int
OS.size OSet Name
binders Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
  = do Name
val_var <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"var"
       Name
err_var <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"err"
       DLetDec
bind    <- Name -> Name -> Name -> q DLetDec
forall (m :: * -> *).
DsMonad m =>
Name -> Name -> Name -> m DLetDec
mk_bind Name
val_var Name
err_var ([Name] -> Name
forall a. [a] -> a
head ([Name] -> Name) -> [Name] -> Name
forall a b. (a -> b) -> a -> b
$ OSet Name -> [Name]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList OSet Name
binders)
       [DLetDec] -> q [DLetDec]
forall (m :: * -> *) a. Monad m => a -> m a
return [DPat -> DExp -> DLetDec
DValD (Name -> DPat
DVarP Name
val_var) (Name -> DExp
DVarE Name
name),
               DPat -> DExp -> DLetDec
DValD (Name -> DPat
DVarP Name
err_var) (Name -> DExp
DVarE 'error DExp -> DExp -> DExp
`DAppE`
                                       (Lit -> DExp
DLitE (Lit -> DExp) -> Lit -> DExp
forall a b. (a -> b) -> a -> b
$ String -> Lit
StringL String
"Irrefutable match failed")),
               DLetDec
bind]

  | Bool
otherwise
  = do DExp
tuple_expr <- [Name] -> [DClause] -> q DExp
forall (q :: * -> *). DsMonad q => [Name] -> [DClause] -> q DExp
simplCaseExp [Name
name] [[DPat] -> DExp -> DClause
DClause [DPat
pat] DExp
local_tuple]
       Name
tuple_var <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"tuple"
       [DExp]
projections <- (Int -> q DExp) -> [Int] -> q [DExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Name -> Int -> q DExp
forall (q :: * -> *). DsMonad q => Name -> Int -> q DExp
mk_projection Name
tuple_var) [Int
0 .. Int
tuple_sizeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
       [DLetDec] -> q [DLetDec]
forall (m :: * -> *) a. Monad m => a -> m a
return (DPat -> DExp -> DLetDec
DValD (Name -> DPat
DVarP Name
tuple_var) DExp
tuple_expr DLetDec -> [DLetDec] -> [DLetDec]
forall a. a -> [a] -> [a]
:
               (DPat -> DExp -> DLetDec) -> [DPat] -> [DExp] -> [DLetDec]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DPat -> DExp -> DLetDec
DValD ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
binders_list) [DExp]
projections)

  where
    binders :: OSet Name
binders = DPat -> OSet Name
extractBoundNamesDPat DPat
pat
    binders_list :: [Name]
binders_list = OSet Name -> [Name]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList OSet Name
binders
    tuple_size :: Int
tuple_size = [Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
binders_list
    local_tuple :: DExp
local_tuple = [DExp] -> DExp
mkTupleDExp ((Name -> DExp) -> [Name] -> [DExp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DExp
DVarE [Name]
binders_list)

    mk_projection :: DsMonad q
                  => Name   -- of the tuple
                  -> Int    -- which element to get (0-indexed)
                  -> q DExp
    mk_projection :: Name -> Int -> q DExp
mk_projection Name
tup_name Int
i = do
      Name
var_name <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"proj"
      DExp -> q DExp
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> q DExp) -> DExp -> q DExp
forall a b. (a -> b) -> a -> b
$ DExp -> [DMatch] -> DExp
DCaseE (Name -> DExp
DVarE Name
tup_name) [DPat -> DExp -> DMatch
DMatch (Name -> [DType] -> [DPat] -> DPat
DConP (Int -> Name
tupleDataName Int
tuple_size) [] (Name -> Int -> [DPat]
mk_tuple_pats Name
var_name Int
i))
                                               (Name -> DExp
DVarE Name
var_name)]

    mk_tuple_pats :: Name   -- of the projected element
                  -> Int    -- which element to get (0-indexed)
                  -> [DPat]
    mk_tuple_pats :: Name -> Int -> [DPat]
mk_tuple_pats Name
elt_name Int
i = Int -> DPat -> [DPat]
forall a. Int -> a -> [a]
replicate Int
i DPat
DWildP [DPat] -> [DPat] -> [DPat]
forall a. [a] -> [a] -> [a]
++ Name -> DPat
DVarP Name
elt_name DPat -> [DPat] -> [DPat]
forall a. a -> [a] -> [a]
: Int -> DPat -> [DPat]
forall a. Int -> a -> [a]
replicate (Int
tuple_size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) DPat
DWildP

    mk_bind :: Name -> Name -> Name -> m DLetDec
mk_bind Name
scrut_var Name
err_var Name
bndr_var = do
      DExp -> DExp
rhs_mr <- [Name] -> [EquationInfo] -> m (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
simplCase [Name
scrut_var] [[DPat] -> (DExp -> DExp) -> EquationInfo
EquationInfo [DPat
pat] (\DExp
_ -> Name -> DExp
DVarE Name
bndr_var)]
      DLetDec -> m DLetDec
forall (m :: * -> *) a. Monad m => a -> m a
return (DPat -> DExp -> DLetDec
DValD (Name -> DPat
DVarP Name
bndr_var) (DExp -> DExp
rhs_mr (Name -> DExp
DVarE Name
err_var)))

data PatGroup
  = PgAny         -- immediate match (wilds, vars, lazies)
  | PgCon Name
  | PgLit Lit
  | PgBang

-- like GHC's groupEquations
groupClauses :: [EquationInfo] -> [[(PatGroup, EquationInfo)]]
groupClauses :: [EquationInfo] -> [[(PatGroup, EquationInfo)]]
groupClauses [EquationInfo]
clauses
  = ((PatGroup, EquationInfo) -> (PatGroup, EquationInfo) -> Bool)
-> [(PatGroup, EquationInfo)] -> [[(PatGroup, EquationInfo)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
runs (PatGroup, EquationInfo) -> (PatGroup, EquationInfo) -> Bool
same_gp [(DPat -> PatGroup
patGroup (EquationInfo -> DPat
firstPat EquationInfo
clause), EquationInfo
clause) | EquationInfo
clause <- [EquationInfo]
clauses]
  where
    same_gp :: (PatGroup, EquationInfo) -> (PatGroup, EquationInfo) -> Bool
    (PatGroup
pg1,EquationInfo
_) same_gp :: (PatGroup, EquationInfo) -> (PatGroup, EquationInfo) -> Bool
`same_gp` (PatGroup
pg2,EquationInfo
_) = PatGroup
pg1 PatGroup -> PatGroup -> Bool
`sameGroup` PatGroup
pg2

patGroup :: DPat -> PatGroup
patGroup :: DPat -> PatGroup
patGroup (DLitP Lit
l)       = Lit -> PatGroup
PgLit Lit
l
patGroup (DVarP {})      = String -> PatGroup
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (patGroup DVarP)"
patGroup (DConP Name
con [DType]
_ [DPat]
_) = Name -> PatGroup
PgCon Name
con
patGroup (DTildeP {})    = String -> PatGroup
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (patGroup DTildeP)"
patGroup (DBangP {})     = PatGroup
PgBang
patGroup (DSigP{})       = String -> PatGroup
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (patGroup DSigP)"
patGroup DPat
DWildP          = PatGroup
PgAny

sameGroup :: PatGroup -> PatGroup -> Bool
sameGroup :: PatGroup -> PatGroup -> Bool
sameGroup PatGroup
PgAny     PatGroup
PgAny     = Bool
True
sameGroup PatGroup
PgBang    PatGroup
PgBang    = Bool
True
sameGroup (PgCon Name
_) (PgCon Name
_) = Bool
True
sameGroup (PgLit Lit
_) (PgLit Lit
_) = Bool
True
sameGroup PatGroup
_         PatGroup
_         = Bool
False

subGroup :: Ord a => [(a, EquationInfo)] -> [[EquationInfo]]
subGroup :: [(a, EquationInfo)] -> [[EquationInfo]]
subGroup [(a, EquationInfo)]
group
  = ([EquationInfo] -> [EquationInfo])
-> [[EquationInfo]] -> [[EquationInfo]]
forall a b. (a -> b) -> [a] -> [b]
map [EquationInfo] -> [EquationInfo]
forall a. [a] -> [a]
reverse ([[EquationInfo]] -> [[EquationInfo]])
-> [[EquationInfo]] -> [[EquationInfo]]
forall a b. (a -> b) -> a -> b
$ Map a [EquationInfo] -> [[EquationInfo]]
forall k a. Map k a -> [a]
Map.elems (Map a [EquationInfo] -> [[EquationInfo]])
-> Map a [EquationInfo] -> [[EquationInfo]]
forall a b. (a -> b) -> a -> b
$ (Map a [EquationInfo] -> (a, EquationInfo) -> Map a [EquationInfo])
-> Map a [EquationInfo]
-> [(a, EquationInfo)]
-> Map a [EquationInfo]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map a [EquationInfo] -> (a, EquationInfo) -> Map a [EquationInfo]
forall k a. Ord k => Map k [a] -> (k, a) -> Map k [a]
accumulate Map a [EquationInfo]
forall k a. Map k a
Map.empty [(a, EquationInfo)]
group
  where
    accumulate :: Map k [a] -> (k, a) -> Map k [a]
accumulate Map k [a]
pg_map (k
pg, a
eqn)
      = case k -> Map k [a] -> Maybe [a]
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
pg Map k [a]
pg_map of
          Just [a]
eqns -> k -> [a] -> Map k [a] -> Map k [a]
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k
pg (a
eqna -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
eqns) Map k [a]
pg_map
          Maybe [a]
Nothing   -> k -> [a] -> Map k [a] -> Map k [a]
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k
pg [a
eqn]      Map k [a]
pg_map

firstPat :: EquationInfo -> DPat
firstPat :: EquationInfo -> DPat
firstPat (EquationInfo (DPat
pat : [DPat]
_) DExp -> DExp
_) = DPat
pat
firstPat EquationInfo
_ = String -> DPat
forall a. HasCallStack => String -> a
error String
"Clause encountered with no patterns -- should never happen"

data CaseAlt = CaseAlt { CaseAlt -> Name
alt_con  :: Name         -- con name
                       , CaseAlt -> [Name]
_alt_args :: [Name]       -- bound var names
                       , CaseAlt -> DExp -> DExp
_alt_rhs  :: MatchResult  -- RHS
                       }

-- from GHC's MatchCon.lhs
matchConFamily :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult
matchConFamily :: [Name] -> [[EquationInfo]] -> q (DExp -> DExp)
matchConFamily (Name
var:[Name]
vars) [[EquationInfo]]
groups
  = do [CaseAlt]
alts <- ([EquationInfo] -> q CaseAlt) -> [[EquationInfo]] -> q [CaseAlt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Name] -> [EquationInfo] -> q CaseAlt
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q CaseAlt
matchOneCon [Name]
vars) [[EquationInfo]]
groups
       Name -> [CaseAlt] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
Name -> [CaseAlt] -> q (DExp -> DExp)
mkDataConCase Name
var [CaseAlt]
alts
matchConFamily [] [[EquationInfo]]
_ = String -> q (DExp -> DExp)
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (matchConFamily)"

-- like matchOneConLike from MatchCon
matchOneCon :: DsMonad q => [Name] -> [EquationInfo] -> q CaseAlt
matchOneCon :: [Name] -> [EquationInfo] -> q CaseAlt
matchOneCon [Name]
vars eqns :: [EquationInfo]
eqns@(EquationInfo
eqn1 : [EquationInfo]
_)
  = do [Name]
arg_vars <- [DPat] -> q [Name]
forall (q :: * -> *). DsMonad q => [DPat] -> q [Name]
selectMatchVars (DPat -> [DPat]
pat_args DPat
pat1)
       DExp -> DExp
match_result <- [Name] -> q (DExp -> DExp)
forall (q :: * -> *). DsMonad q => [Name] -> q (DExp -> DExp)
match_group [Name]
arg_vars

       CaseAlt -> q CaseAlt
forall (m :: * -> *) a. Monad m => a -> m a
return (CaseAlt -> q CaseAlt) -> CaseAlt -> q CaseAlt
forall a b. (a -> b) -> a -> b
$ Name -> [Name] -> (DExp -> DExp) -> CaseAlt
CaseAlt (DPat -> Name
pat_con DPat
pat1) [Name]
arg_vars DExp -> DExp
match_result
  where
    pat1 :: DPat
pat1 = EquationInfo -> DPat
firstPat EquationInfo
eqn1

    pat_args :: DPat -> [DPat]
pat_args (DConP Name
_ [DType]
_ [DPat]
pats) = [DPat]
pats
    pat_args DPat
_                = String -> [DPat]
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (pat_args)"

    pat_con :: DPat -> Name
pat_con (DConP Name
con [DType]
_ [DPat]
_) = Name
con
    pat_con DPat
_               = String -> Name
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (pat_con)"

    match_group :: DsMonad q => [Name] -> q MatchResult
    match_group :: [Name] -> q (DExp -> DExp)
match_group [Name]
arg_vars
      = [Name] -> [EquationInfo] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
simplCase ([Name]
arg_vars [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ [Name]
vars) ((EquationInfo -> EquationInfo) -> [EquationInfo] -> [EquationInfo]
forall a b. (a -> b) -> [a] -> [b]
map EquationInfo -> EquationInfo
shift [EquationInfo]
eqns)

    shift :: EquationInfo -> EquationInfo
shift (EquationInfo (DConP Name
_ [DType]
_ [DPat]
args : [DPat]
pats) DExp -> DExp
exp) = [DPat] -> (DExp -> DExp) -> EquationInfo
EquationInfo ([DPat]
args [DPat] -> [DPat] -> [DPat]
forall a. [a] -> [a] -> [a]
++ [DPat]
pats) DExp -> DExp
exp
    shift EquationInfo
_ = String -> EquationInfo
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (shift)"
matchOneCon [Name]
_ [EquationInfo]
_ = String -> q CaseAlt
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (matchOneCon)"

mkDataConCase :: DsMonad q => Name -> [CaseAlt] -> q MatchResult
mkDataConCase :: Name -> [CaseAlt] -> q (DExp -> DExp)
mkDataConCase Name
var [CaseAlt]
case_alts = do
  Set Name
all_ctors <- Name -> q (Set Name)
forall (q :: * -> *). DsMonad q => Name -> q (Set Name)
get_all_ctors (CaseAlt -> Name
alt_con (CaseAlt -> Name) -> CaseAlt -> Name
forall a b. (a -> b) -> a -> b
$ [CaseAlt] -> CaseAlt
forall a. [a] -> a
head [CaseAlt]
case_alts)
  (DExp -> DExp) -> q (DExp -> DExp)
forall (m :: * -> *) a. Monad m => a -> m a
return ((DExp -> DExp) -> q (DExp -> DExp))
-> (DExp -> DExp) -> q (DExp -> DExp)
forall a b. (a -> b) -> a -> b
$ \DExp
fail ->
    let matches :: [DMatch]
matches = (CaseAlt -> DMatch) -> [CaseAlt] -> [DMatch]
forall a b. (a -> b) -> [a] -> [b]
map (DExp -> CaseAlt -> DMatch
mk_alt DExp
fail) [CaseAlt]
case_alts in
    DExp -> [DMatch] -> DExp
DCaseE (Name -> DExp
DVarE Name
var) ([DMatch]
matches [DMatch] -> [DMatch] -> [DMatch]
forall a. [a] -> [a] -> [a]
++ Set Name -> DExp -> [DMatch]
mk_default Set Name
all_ctors DExp
fail)
  where
    mk_alt :: DExp -> CaseAlt -> DMatch
mk_alt DExp
fail (CaseAlt Name
con [Name]
args DExp -> DExp
body_fn)
      = let body :: DExp
body = DExp -> DExp
body_fn DExp
fail in
        DPat -> DExp -> DMatch
DMatch (Name -> [DType] -> [DPat] -> DPat
DConP Name
con [] ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
args)) DExp
body

    mk_default :: Set Name -> DExp -> [DMatch]
mk_default Set Name
all_ctors DExp
fail | Set Name -> Bool
exhaustive_case Set Name
all_ctors = []
                              | Bool
otherwise       = [DPat -> DExp -> DMatch
DMatch DPat
DWildP DExp
fail]

    mentioned_ctors :: Set Name
mentioned_ctors = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ (CaseAlt -> Name) -> [CaseAlt] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map CaseAlt -> Name
alt_con [CaseAlt]
case_alts
    exhaustive_case :: Set Name -> Bool
exhaustive_case Set Name
all_ctors = Set Name
all_ctors Set Name -> Set Name -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.isSubsetOf` Set Name
mentioned_ctors

    get_all_ctors :: DsMonad q => Name -> q (S.Set Name)
    get_all_ctors :: Name -> q (Set Name)
get_all_ctors Name
con_name = do
      Name
ty_name <- Name -> q Name
forall (q :: * -> *). DsMonad q => Name -> q Name
dataConNameToDataName Name
con_name
      Just (DTyConI DDec
tycon_dec Maybe [DDec]
_) <- Name -> q (Maybe DInfo)
forall (q :: * -> *). DsMonad q => Name -> q (Maybe DInfo)
dsReify Name
ty_name
      Set Name -> q (Set Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (Set Name -> q (Set Name)) -> Set Name -> q (Set Name)
forall a b. (a -> b) -> a -> b
$ [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ (DCon -> Name) -> [DCon] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map DCon -> Name
get_con_name ([DCon] -> [Name]) -> [DCon] -> [Name]
forall a b. (a -> b) -> a -> b
$ DDec -> [DCon]
get_cons DDec
tycon_dec

    get_cons :: DDec -> [DCon]
get_cons (DDataD NewOrData
_ [DType]
_ Name
_ [DTyVarBndrUnit]
_ Maybe DType
_ [DCon]
cons [DDerivClause]
_)     = [DCon]
cons
    get_cons (DDataInstD NewOrData
_ [DType]
_ Maybe [DTyVarBndrUnit]
_ DType
_ Maybe DType
_ [DCon]
cons [DDerivClause]
_) = [DCon]
cons
    get_cons DDec
_                             = []

    get_con_name :: DCon -> Name
get_con_name (DCon [DTyVarBndrSpec]
_ [DType]
_ Name
n DConFields
_ DType
_) = Name
n

matchEmpty :: DsMonad q => Name -> q [MatchResult]
matchEmpty :: Name -> q [DExp -> DExp]
matchEmpty Name
var = [DExp -> DExp] -> q [DExp -> DExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [DExp -> DExp
mk_seq]
  where
    mk_seq :: DExp -> DExp
mk_seq DExp
fail = DExp -> [DMatch] -> DExp
DCaseE (Name -> DExp
DVarE Name
var) [DPat -> DExp -> DMatch
DMatch DPat
DWildP DExp
fail]

matchLiterals :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult
matchLiterals :: [Name] -> [[EquationInfo]] -> q (DExp -> DExp)
matchLiterals (Name
var:[Name]
vars) [[EquationInfo]]
sub_groups
  = do [(Lit, DExp -> DExp)]
alts <- ([EquationInfo] -> q (Lit, DExp -> DExp))
-> [[EquationInfo]] -> q [(Lit, DExp -> DExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [EquationInfo] -> q (Lit, DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[EquationInfo] -> q (Lit, DExp -> DExp)
match_group [[EquationInfo]]
sub_groups
       (DExp -> DExp) -> q (DExp -> DExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> [(Lit, DExp -> DExp)] -> DExp -> DExp
mkCoPrimCaseMatchResult Name
var [(Lit, DExp -> DExp)]
alts)
  where
    match_group :: DsMonad q => [EquationInfo] -> q (Lit, MatchResult)
    match_group :: [EquationInfo] -> q (Lit, DExp -> DExp)
match_group [EquationInfo]
eqns
      = do let lit :: Lit
lit = case EquationInfo -> DPat
firstPat ([EquationInfo] -> EquationInfo
forall a. [a] -> a
head [EquationInfo]
eqns) of
                       DLitP Lit
lit' -> Lit
lit'
                       DPat
_          -> String -> Lit
forall a. HasCallStack => String -> a
error (String -> Lit) -> String -> Lit
forall a b. (a -> b) -> a -> b
$ String
"Internal error in th-desugar "
                                          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"(matchLiterals.match_group)"
           DExp -> DExp
match_result <- [Name] -> [EquationInfo] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
simplCase [Name]
vars ([EquationInfo] -> [EquationInfo]
shiftEqns [EquationInfo]
eqns)
           (Lit, DExp -> DExp) -> q (Lit, DExp -> DExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lit
lit, DExp -> DExp
match_result)
matchLiterals [] [[EquationInfo]]
_ = String -> q (DExp -> DExp)
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (matchLiterals)"

mkCoPrimCaseMatchResult :: Name -- Scrutinee
                        -> [(Lit, MatchResult)]
                        -> MatchResult
mkCoPrimCaseMatchResult :: Name -> [(Lit, DExp -> DExp)] -> DExp -> DExp
mkCoPrimCaseMatchResult Name
var [(Lit, DExp -> DExp)]
match_alts = DExp -> DExp
mk_case
  where
    mk_case :: DExp -> DExp
mk_case DExp
fail = let alts :: [DMatch]
alts = ((Lit, DExp -> DExp) -> DMatch)
-> [(Lit, DExp -> DExp)] -> [DMatch]
forall a b. (a -> b) -> [a] -> [b]
map (DExp -> (Lit, DExp -> DExp) -> DMatch
forall t. t -> (Lit, t -> DExp) -> DMatch
mk_alt DExp
fail) [(Lit, DExp -> DExp)]
match_alts in
                   DExp -> [DMatch] -> DExp
DCaseE (Name -> DExp
DVarE Name
var) ([DMatch]
alts [DMatch] -> [DMatch] -> [DMatch]
forall a. [a] -> [a] -> [a]
++ [DPat -> DExp -> DMatch
DMatch DPat
DWildP DExp
fail])
    mk_alt :: t -> (Lit, t -> DExp) -> DMatch
mk_alt t
fail (Lit
lit, t -> DExp
body_fn)
      = DPat -> DExp -> DMatch
DMatch (Lit -> DPat
DLitP Lit
lit) (t -> DExp
body_fn t
fail)

matchBangs :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult
matchBangs :: [Name] -> [EquationInfo] -> q (DExp -> DExp)
matchBangs (Name
var:[Name]
vars) [EquationInfo]
eqns
  = do DExp -> DExp
match_result <- [Name] -> [EquationInfo] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
simplCase (Name
varName -> [Name] -> [Name]
forall a. a -> [a] -> [a]
:[Name]
vars) ([EquationInfo] -> q (DExp -> DExp))
-> [EquationInfo] -> q (DExp -> DExp)
forall a b. (a -> b) -> a -> b
$
                       (EquationInfo -> EquationInfo) -> [EquationInfo] -> [EquationInfo]
forall a b. (a -> b) -> [a] -> [b]
map ((DPat -> DPat) -> EquationInfo -> EquationInfo
decomposeFirstPat DPat -> DPat
getBangPat) [EquationInfo]
eqns
       (DExp -> DExp) -> q (DExp -> DExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> (DExp -> DExp) -> DExp -> DExp
mkEvalMatchResult Name
var DExp -> DExp
match_result)
matchBangs [] [EquationInfo]
_ = String -> q (DExp -> DExp)
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (matchBangs)"

decomposeFirstPat :: (DPat -> DPat) -> EquationInfo -> EquationInfo
decomposeFirstPat :: (DPat -> DPat) -> EquationInfo -> EquationInfo
decomposeFirstPat DPat -> DPat
extractpat (EquationInfo (DPat
pat:[DPat]
pats) DExp -> DExp
body)
  = [DPat] -> (DExp -> DExp) -> EquationInfo
EquationInfo (DPat -> DPat
extractpat DPat
pat DPat -> [DPat] -> [DPat]
forall a. a -> [a] -> [a]
: [DPat]
pats) DExp -> DExp
body
decomposeFirstPat DPat -> DPat
_ EquationInfo
_ = String -> EquationInfo
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (decomposeFirstPat)"

getBangPat :: DPat -> DPat
getBangPat :: DPat -> DPat
getBangPat (DBangP DPat
p) = DPat
p
getBangPat DPat
_          = String -> DPat
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (getBangPat)"

mkEvalMatchResult :: Name -> MatchResult -> MatchResult
mkEvalMatchResult :: Name -> (DExp -> DExp) -> DExp -> DExp
mkEvalMatchResult Name
var DExp -> DExp
body_fn DExp
fail
  = (DExp -> DExp -> DExp) -> DExp -> [DExp] -> DExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl DExp -> DExp -> DExp
DAppE (Name -> DExp
DVarE 'seq) [Name -> DExp
DVarE Name
var, DExp -> DExp
body_fn DExp
fail]

matchVariables :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult
matchVariables :: [Name] -> [EquationInfo] -> q (DExp -> DExp)
matchVariables (Name
_:[Name]
vars) [EquationInfo]
eqns = [Name] -> [EquationInfo] -> q (DExp -> DExp)
forall (q :: * -> *).
DsMonad q =>
[Name] -> [EquationInfo] -> q (DExp -> DExp)
simplCase [Name]
vars ([EquationInfo] -> [EquationInfo]
shiftEqns [EquationInfo]
eqns)
matchVariables [Name]
_ [EquationInfo]
_ = String -> q (DExp -> DExp)
forall a. HasCallStack => String -> a
error String
"Internal error in th-desugar (matchVariables)"

shiftEqns :: [EquationInfo] -> [EquationInfo]
shiftEqns :: [EquationInfo] -> [EquationInfo]
shiftEqns = (EquationInfo -> EquationInfo) -> [EquationInfo] -> [EquationInfo]
forall a b. (a -> b) -> [a] -> [b]
map EquationInfo -> EquationInfo
shift
  where
    shift :: EquationInfo -> EquationInfo
shift (EquationInfo [DPat]
pats DExp -> DExp
rhs) = [DPat] -> (DExp -> DExp) -> EquationInfo
EquationInfo ([DPat] -> [DPat]
forall a. [a] -> [a]
tail [DPat]
pats) DExp -> DExp
rhs


adjustMatchResult :: (DExp -> DExp) -> MatchResult -> MatchResult
adjustMatchResult :: (DExp -> DExp) -> (DExp -> DExp) -> DExp -> DExp
adjustMatchResult DExp -> DExp
wrap DExp -> DExp
mr DExp
fail = DExp -> DExp
wrap (DExp -> DExp) -> DExp -> DExp
forall a b. (a -> b) -> a -> b
$ DExp -> DExp
mr DExp
fail

-- from DsUtils
selectMatchVars :: DsMonad q => [DPat] -> q [Name]
selectMatchVars :: [DPat] -> q [Name]
selectMatchVars = (DPat -> q Name) -> [DPat] -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DPat -> q Name
forall (q :: * -> *). DsMonad q => DPat -> q Name
selectMatchVar

-- from DsUtils
selectMatchVar :: DsMonad q => DPat -> q Name
selectMatchVar :: DPat -> q Name
selectMatchVar (DBangP DPat
pat)  = DPat -> q Name
forall (q :: * -> *). DsMonad q => DPat -> q Name
selectMatchVar DPat
pat
selectMatchVar (DTildeP DPat
pat) = DPat -> q Name
forall (q :: * -> *). DsMonad q => DPat -> q Name
selectMatchVar DPat
pat
selectMatchVar (DVarP Name
var)   = String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName (Char
'_' Char -> String -> String
forall a. a -> [a] -> [a]
: Name -> String
nameBase Name
var)
selectMatchVar DPat
_             = String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"_pat"

-- like GHC's runs
runs :: (a -> a -> Bool) -> [a] -> [[a]]
runs :: (a -> a -> Bool) -> [a] -> [[a]]
runs a -> a -> Bool
_ [] = []
runs a -> a -> Bool
p (a
x:[a]
xs) = case (a -> Bool) -> [a] -> ([a], [a])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (a -> a -> Bool
p a
x) [a]
xs of
                  ([a]
first, [a]
rest) -> (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
first) [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: ((a -> a -> Bool) -> [a] -> [[a]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
runs a -> a -> Bool
p [a]
rest)