-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Singletons.Deriving.Ord
-- Copyright   :  (C) 2015 Richard Eisenberg
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Ryan Scott
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Implements deriving of Ord instances
--
----------------------------------------------------------------------------

module Data.Singletons.Deriving.Ord ( mkOrdInstance ) where

import Language.Haskell.TH.Desugar
import Data.Singletons.Names
import Data.Singletons.Util
import Language.Haskell.TH.Syntax
import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Data.Singletons.Syntax

-- | Make a *non-singleton* Ord instance
mkOrdInstance :: DsMonad q => DerivDesc q
mkOrdInstance :: DerivDesc q
mkOrdInstance mb_ctxt :: Maybe DCxt
mb_ctxt ty :: DType
ty (DataDecl _ _ cons :: [DCon]
cons) = do
  DCxt
constraints <- Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
forall (q :: * -> *).
DsMonad q =>
Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
inferConstraintsDef Maybe DCxt
mb_ctxt (Name -> DType
DConT Name
ordName) DType
ty [DCon]
cons
  [DClause]
compare_eq_clauses <- (DCon -> q DClause) -> [DCon] -> q [DClause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DCon -> q DClause
forall (q :: * -> *). Quasi q => DCon -> q DClause
mk_equal_clause [DCon]
cons
  let compare_noneq_clauses :: [DClause]
compare_noneq_clauses = (((DCon, Int), (DCon, Int)) -> DClause)
-> [((DCon, Int), (DCon, Int))] -> [DClause]
forall a b. (a -> b) -> [a] -> [b]
map (((DCon, Int) -> (DCon, Int) -> DClause)
-> ((DCon, Int), (DCon, Int)) -> DClause
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause)
                                  [ ((DCon, Int)
con1, (DCon, Int)
con2)
                                  | (DCon, Int)
con1 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [1..]
                                  , (DCon, Int)
con2 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [1..]
                                  , DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con1) Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/=
                                    DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con2) ]
      clauses :: [DClause]
clauses | [DCon] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DCon]
cons = [DClause
mk_empty_clause]
              | Bool
otherwise = [DClause]
compare_eq_clauses [DClause] -> [DClause] -> [DClause]
forall a. [a] -> [a] -> [a]
++ [DClause]
compare_noneq_clauses
  InstDecl Unannotated -> q (InstDecl Unannotated)
forall (m :: * -> *) a. Monad m => a -> m a
return (InstDecl :: forall (ann :: AnnotationFlag).
DCxt
-> Name
-> DCxt
-> OMap Name DType
-> [(Name, LetDecRHS ann)]
-> InstDecl ann
InstDecl { id_cxt :: DCxt
id_cxt = DCxt
constraints
                   , id_name :: Name
id_name = Name
ordName
                   , id_arg_tys :: DCxt
id_arg_tys = [DType
ty]
                   , id_sigs :: OMap Name DType
id_sigs  = OMap Name DType
forall a. Monoid a => a
mempty
                   , id_meths :: [(Name, LetDecRHS Unannotated)]
id_meths = [(Name
compareName, [DClause] -> LetDecRHS Unannotated
UFunction [DClause]
clauses)] })

mk_equal_clause :: Quasi q => DCon -> q DClause
mk_equal_clause :: DCon -> q DClause
mk_equal_clause (DCon _tvbs :: [DTyVarBndr]
_tvbs _cxt :: DCxt
_cxt name :: Name
name fields :: DConFields
fields _rty :: DType
_rty) = do
  let tys :: DCxt
tys = DConFields -> DCxt
tysOfConFields DConFields
fields
  [Name]
a_names <- (DType -> q Name) -> DCxt -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (q Name -> DType -> q Name
forall a b. a -> b -> a
const (q Name -> DType -> q Name) -> q Name -> DType -> q Name
forall a b. (a -> b) -> a -> b
$ String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName "a") DCxt
tys
  [Name]
b_names <- (DType -> q Name) -> DCxt -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (q Name -> DType -> q Name
forall a b. a -> b -> a
const (q Name -> DType -> q Name) -> q Name -> DType -> q Name
forall a b. (a -> b) -> a -> b
$ String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName "b") DCxt
tys
  let pat1 :: DPat
pat1 = Name -> [DPat] -> DPat
DConP Name
name ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
a_names)
      pat2 :: DPat
pat2 = Name -> [DPat] -> DPat
DConP Name
name ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
b_names)
  DClause -> q DClause
forall (m :: * -> *) a. Monad m => a -> m a
return (DClause -> q DClause) -> DClause -> q DClause
forall a b. (a -> b) -> a -> b
$ [DPat] -> DExp -> DClause
DClause [DPat
pat1, DPat
pat2] (Name -> DExp
DVarE Name
foldlName DExp -> DExp -> DExp
`DAppE`
                                 Name -> DExp
DVarE Name
thenCmpName DExp -> DExp -> DExp
`DAppE`
                                 Name -> DExp
DConE Name
cmpEQName DExp -> DExp -> DExp
`DAppE`
                                 [DExp] -> DExp
mkListE ((Name -> Name -> DExp) -> [Name] -> [Name] -> [DExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                                          (\a :: Name
a b :: Name
b -> Name -> DExp
DVarE Name
compareName DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
a
                                                                     DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
b)
                                          [Name]
a_names [Name]
b_names))

mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause (DCon _tvbs1 :: [DTyVarBndr]
_tvbs1 _cxt1 :: DCxt
_cxt1 name1 :: Name
name1 fields1 :: DConFields
fields1 _rty1 :: DType
_rty1, n1 :: Int
n1)
                   (DCon _tvbs2 :: [DTyVarBndr]
_tvbs2 _cxt2 :: DCxt
_cxt2 name2 :: Name
name2 fields2 :: DConFields
fields2 _rty2 :: DType
_rty2, n2 :: Int
n2) =
  [DPat] -> DExp -> DClause
DClause [DPat
pat1, DPat
pat2] (case Int
n1 Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
n2 of
                          LT -> Name -> DExp
DConE Name
cmpLTName
                          EQ -> Name -> DExp
DConE Name
cmpEQName
                          GT -> Name -> DExp
DConE Name
cmpGTName)
  where
    pat1 :: DPat
pat1 = Name -> [DPat] -> DPat
DConP Name
name1 ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields1))
    pat2 :: DPat
pat2 = Name -> [DPat] -> DPat
DConP Name
name2 ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields2))

-- A variant of mk_equal_clause tailored to empty datatypes
mk_empty_clause :: DClause
mk_empty_clause :: DClause
mk_empty_clause = [DPat] -> DExp -> DClause
DClause [DPat
DWildP, DPat
DWildP] (Name -> DExp
DConE Name
cmpEQName)