module Data.Singletons.Deriving.Ord ( mkOrdInstance ) where
import Language.Haskell.TH.Desugar
import Data.Singletons.Names
import Data.Singletons.Util
import Language.Haskell.TH.Syntax
import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Data.Singletons.Syntax
mkOrdInstance :: DsMonad q => DerivDesc q
mkOrdInstance mb_ctxt ty (DataDecl _ _ cons) = do
constraints <- inferConstraintsDef mb_ctxt (DConPr ordName) ty cons
compare_eq_clauses <- mapM mk_equal_clause cons
let compare_noneq_clauses = map (uncurry mk_nonequal_clause)
[ (con1, con2)
| con1 <- zip cons [1..]
, con2 <- zip cons [1..]
, extractName (fst con1) /=
extractName (fst con2) ]
clauses | null cons = [mk_empty_clause]
| otherwise = compare_eq_clauses ++ compare_noneq_clauses
return (InstDecl { id_cxt = constraints
, id_name = ordName
, id_arg_tys = [ty]
, id_sigs = mempty
, id_meths = [(compareName, UFunction clauses)] })
mk_equal_clause :: Quasi q => DCon -> q DClause
mk_equal_clause (DCon _tvbs _cxt name fields _rty) = do
let tys = tysOfConFields fields
a_names <- mapM (const $ newUniqueName "a") tys
b_names <- mapM (const $ newUniqueName "b") tys
let pat1 = DConPa name (map DVarPa a_names)
pat2 = DConPa name (map DVarPa b_names)
return $ DClause [pat1, pat2] (DVarE foldlName `DAppE`
DVarE thenCmpName `DAppE`
DConE cmpEQName `DAppE`
mkListE (zipWith
(\a b -> DVarE compareName `DAppE` DVarE a
`DAppE` DVarE b)
a_names b_names))
mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause (DCon _tvbs1 _cxt1 name1 fields1 _rty1, n1)
(DCon _tvbs2 _cxt2 name2 fields2 _rty2, n2) =
DClause [pat1, pat2] (case n1 `compare` n2 of
LT -> DConE cmpLTName
EQ -> DConE cmpEQName
GT -> DConE cmpGTName)
where
pat1 = DConPa name1 (map (const DWildPa) (tysOfConFields fields1))
pat2 = DConPa name2 (map (const DWildPa) (tysOfConFields fields2))
mk_empty_clause :: DClause
mk_empty_clause = DClause [DWildPa, DWildPa] (DConE cmpEQName)