{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
module Data.GADT.Compare.TH
    ( DeriveGEQ(..)
    , DeriveGCompare(..)
    , module Data.GADT.Compare.Monad
    ) where

import Control.Monad
import Control.Monad.Writer
import Data.GADT.TH.Internal
import Data.Functor.Identity
import Data.GADT.Compare
import Data.GADT.Compare.Monad
import Data.Type.Equality ((:~:) (..))
import qualified Data.Set as Set
import Data.Set (Set)
import qualified Data.Map as Map
import qualified Data.Map.Merge.Lazy as Map
import Data.Map (Map)
import Language.Haskell.TH
import Language.Haskell.TH.Datatype

-- A type class purely for overloading purposes
class DeriveGEQ t where
    deriveGEq :: t -> Q [Dec]

instance DeriveGEQ Name where
  deriveGEq :: Name -> Q [Dec]
deriveGEq Name
typeName = do
    DatatypeInfo
typeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
typeName
    let instTypes :: [Type]
instTypes = DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
typeInfo
        paramVars :: Set Name
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- [Type]
instTypes]
        instTypes' :: [Type]
instTypes' = case [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
instTypes of
          [] -> String -> [Type]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
          (Type
_:[Type]
xs) -> [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
xs
        instanceHead :: Type
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT ''GEq) ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
typeName) [Type]
instTypes')
    ([Clause]
clauses, [Type]
cxt) <- WriterT [Type] Q [Clause] -> Q ([Clause], [Type])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT ((ConstructorInfo -> WriterT [Type] Q Clause)
-> [ConstructorInfo] -> WriterT [Type] Q [Clause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT [Type] Q Clause
geqClause Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo))

    [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
cxt Type
instanceHead [[Clause] -> Dec
geqFunction [Clause]
clauses]]

instance DeriveGEQ Dec where
    deriveGEq :: Dec -> Q [Dec]
deriveGEq = Name -> (DatatypeInfo -> WriterT [Type] Q Dec) -> Dec -> Q [Dec]
deriveForDec ''GEq ((DatatypeInfo -> WriterT [Type] Q Dec) -> Dec -> Q [Dec])
-> (DatatypeInfo -> WriterT [Type] Q Dec) -> Dec -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ \DatatypeInfo
typeInfo -> do
      let
        instTypes :: [Type]
instTypes = DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
typeInfo
        paramVars :: Set Name
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- [Type]
instTypes]
      [Clause]
clauses <- (ConstructorInfo -> WriterT [Type] Q Clause)
-> [ConstructorInfo] -> WriterT [Type] Q [Clause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT [Type] Q Clause
geqClause Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo)
      Dec -> WriterT [Type] Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> WriterT [Type] Q Dec) -> Dec -> WriterT [Type] Q Dec
forall a b. (a -> b) -> a -> b
$ [Clause] -> Dec
geqFunction [Clause]
clauses

instance DeriveGEQ t => DeriveGEQ [t] where
  deriveGEq :: [t] -> Q [Dec]
deriveGEq [t
it] = t -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq t
it
  deriveGEq [t]
_ = String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: [] instance only applies to single-element lists"

instance DeriveGEQ t => DeriveGEQ (Q t) where
  deriveGEq :: Q t -> Q [Dec]
deriveGEq = (Q t -> (t -> Q [Dec]) -> Q [Dec]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq)

geqFunction :: [Clause] -> Dec
geqFunction :: [Clause] -> Dec
geqFunction [Clause]
clauses = Name -> [Clause] -> Dec
FunD 'geq ([Clause] -> Dec) -> [Clause] -> Dec
forall a b. (a -> b) -> a -> b
$ [Clause]
clauses [Clause] -> [Clause] -> [Clause]
forall a. [a] -> [a] -> [a]
++ [ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP, Pat
WildP] (Exp -> Body
NormalB (Name -> Exp
ConE 'Nothing)) [] ]
 -- TODO: only include last clause if there's more than one constructor?

geqClause :: Set Name -> ConstructorInfo -> WriterT Cxt Q Clause
geqClause :: Set Name -> ConstructorInfo -> WriterT [Type] Q Clause
geqClause Set Name
paramVars ConstructorInfo
con = do
  let conName :: Name
conName = ConstructorInfo -> Name
constructorName ConstructorInfo
con
      argTypes :: [Type]
argTypes = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
con
      conTyVars :: Set Name
conTyVars = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList ((TyVarBndr_ Any -> Name) -> [TyVarBndr_ Any] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr_ Any -> Name
forall flag. TyVarBndr_ Any -> Name
tvName (ConstructorInfo -> [TyVarBndr_ Any]
constructorVars ConstructorInfo
con))
      needsGEq :: Type -> Bool
needsGEq Type
argType = Bool -> Bool
not (Bool -> Bool) -> (Set Name -> Bool) -> Set Name -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set Name -> Bool
forall a. Set a -> Bool
Set.null (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$
        Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.intersection (Type -> Set Name
freeTypeVariables Type
argType) (Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Name
paramVars Set Name
conTyVars)
  [Name]
lArgNames <- [Type]
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
argTypes ((Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name])
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall a b. (a -> b) -> a -> b
$ \Type
_ -> Q Name -> WriterT [Type] Q Name
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Name -> WriterT [Type] Q Name)
-> Q Name -> WriterT [Type] Q Name
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"x"
  [Name]
rArgNames <- [Type]
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
argTypes ((Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name])
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall a b. (a -> b) -> a -> b
$ \Type
_ -> Q Name -> WriterT [Type] Q Name
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Name -> WriterT [Type] Q Name)
-> Q Name -> WriterT [Type] Q Name
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"y"

  [Stmt]
stmts <- [(Name, Name, Type)]
-> ((Name, Name, Type) -> WriterT [Type] Q Stmt)
-> WriterT [Type] Q [Stmt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Name] -> [Name] -> [Type] -> [(Name, Name, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Name]
lArgNames [Name]
rArgNames [Type]
argTypes) (((Name, Name, Type) -> WriterT [Type] Q Stmt)
 -> WriterT [Type] Q [Stmt])
-> ((Name, Name, Type) -> WriterT [Type] Q Stmt)
-> WriterT [Type] Q [Stmt]
forall a b. (a -> b) -> a -> b
$ \(Name
l, Name
r, Type
t) -> do
    case Type
t of
      AppT Type
tyFun Type
tyArg | Type -> Bool
needsGEq Type
t -> do
        [Dec]
u <- Q [Dec] -> WriterT [Type] Q [Dec]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q [Dec] -> WriterT [Type] Q [Dec])
-> Q [Dec] -> WriterT [Type] Q [Dec]
forall a b. (a -> b) -> a -> b
$ Set Name -> Name -> [Type] -> Q [Dec]
reifyInstancesWithRigids Set Name
paramVars ''GEq [Type
tyFun]
        case [Dec]
u of
          [] -> [Type] -> WriterT [Type] Q ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Type -> Type -> Type
AppT (Name -> Type
ConT ''GEq) Type
tyFun]
          [(InstanceD Maybe Overlap
_ [Type]
cxt Type
_ [Dec]
_)] -> [Type] -> WriterT [Type] Q ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Type]
cxt
          [Dec]
_ -> String -> WriterT [Type] Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> WriterT [Type] Q ()) -> String -> WriterT [Type] Q ()
forall a b. (a -> b) -> a -> b
$ String
"More than one instance found for GEq (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall a. Ppr a => a -> Doc
ppr Type
tyFun) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"), and unsure what to do. Please report this."
        Q Stmt -> WriterT [Type] Q Stmt
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Stmt -> WriterT [Type] Q Stmt)
-> Q Stmt -> WriterT [Type] Q Stmt
forall a b. (a -> b) -> a -> b
$ PatQ -> ExpQ -> Q Stmt
bindS (Name -> [PatQ] -> PatQ
conP 'Refl []) [| geq $(varE l) $(varE r) |]
      Type
_ -> Q Stmt -> WriterT [Type] Q Stmt
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Stmt -> WriterT [Type] Q Stmt)
-> Q Stmt -> WriterT [Type] Q Stmt
forall a b. (a -> b) -> a -> b
$ ExpQ -> Q Stmt
noBindS [| guard ($(varE l) == $(varE r)) |]
  Stmt
ret <- Q Stmt -> WriterT [Type] Q Stmt
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Stmt -> WriterT [Type] Q Stmt)
-> Q Stmt -> WriterT [Type] Q Stmt
forall a b. (a -> b) -> a -> b
$ ExpQ -> Q Stmt
noBindS [| return Refl |]

  [Pat]
pats <- Q [Pat] -> WriterT [Type] Q [Pat]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q [Pat] -> WriterT [Type] Q [Pat])
-> Q [Pat] -> WriterT [Type] Q [Pat]
forall a b. (a -> b) -> a -> b
$ [PatQ] -> Q [Pat]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
    [ Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
lArgNames)
    , Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
rArgNames)
    ]
  Clause -> WriterT [Type] Q Clause
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Clause -> WriterT [Type] Q Clause)
-> Clause -> WriterT [Type] Q Clause
forall a b. (a -> b) -> a -> b
$ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
pats
    (Exp -> Body
NormalB ([Stmt] -> Exp
doUnqualifiedE ([Stmt]
stmts [Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++ [Stmt
ret])))
    []

class DeriveGCompare t where
    deriveGCompare :: t -> Q [Dec]

instance DeriveGCompare Name where
    deriveGCompare :: Name -> Q [Dec]
deriveGCompare Name
typeName = do
      DatatypeInfo
typeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
typeName
      let instTypes :: [Type]
instTypes = DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
typeInfo
          paramVars :: Set Name
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- [Type]
instTypes]
          instTypes' :: [Type]
instTypes' = case [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
instTypes of
            [] -> String -> [Type]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGCompare: Not enough type parameters"
            (Type
_:[Type]
xs) -> [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
xs
          instanceHead :: Type
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT ''GCompare) ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
typeName) [Type]
instTypes')
      ([Clause]
clauses, [Type]
cxt) <- WriterT [Type] Q [Clause] -> Q ([Clause], [Type])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (([[Clause]] -> [Clause])
-> WriterT [Type] Q [[Clause]] -> WriterT [Type] Q [Clause]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Clause]] -> [Clause]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (WriterT [Type] Q [[Clause]] -> WriterT [Type] Q [Clause])
-> WriterT [Type] Q [[Clause]] -> WriterT [Type] Q [Clause]
forall a b. (a -> b) -> a -> b
$ (ConstructorInfo -> WriterT [Type] Q [Clause])
-> [ConstructorInfo] -> WriterT [Type] Q [[Clause]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT [Type] Q [Clause]
gcompareClauses Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo))
      Dec
dec <- [Clause] -> Q Dec
gcompareFunction [Clause]
clauses
      [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
cxt Type
instanceHead [Dec
dec]]

instance DeriveGCompare Dec where
    deriveGCompare :: Dec -> Q [Dec]
deriveGCompare = Name -> (DatatypeInfo -> WriterT [Type] Q Dec) -> Dec -> Q [Dec]
deriveForDec ''GCompare ((DatatypeInfo -> WriterT [Type] Q Dec) -> Dec -> Q [Dec])
-> (DatatypeInfo -> WriterT [Type] Q Dec) -> Dec -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ \DatatypeInfo
typeInfo -> do
      let
        instTypes :: [Type]
instTypes = DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
typeInfo
        paramVars :: Set Name
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- [Type]
instTypes]
      [[Clause]]
clauses <- (ConstructorInfo -> WriterT [Type] Q [Clause])
-> [ConstructorInfo] -> WriterT [Type] Q [[Clause]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT [Type] Q [Clause]
gcompareClauses Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo)
      Q Dec -> WriterT [Type] Q Dec
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Dec -> WriterT [Type] Q Dec) -> Q Dec -> WriterT [Type] Q Dec
forall a b. (a -> b) -> a -> b
$ [Clause] -> Q Dec
gcompareFunction ([[Clause]] -> [Clause]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Clause]]
clauses)

instance DeriveGCompare t => DeriveGCompare [t] where
    deriveGCompare :: [t] -> Q [Dec]
deriveGCompare [t
it] = t -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare t
it
    deriveGCompare [t]
_ = String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGCompare: [] instance only applies to single-element lists"

instance DeriveGCompare t => DeriveGCompare (Q t) where
    deriveGCompare :: Q t -> Q [Dec]
deriveGCompare = (Q t -> (t -> Q [Dec]) -> Q [Dec]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare)

gcompareFunction :: [Clause] -> Q Dec
gcompareFunction :: [Clause] -> Q Dec
gcompareFunction [] = Name -> [ClauseQ] -> Q Dec
funD 'gcompare [[PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [] (ExpQ -> BodyQ
normalB [| \x y -> seq x (seq y undefined) |]) []]
gcompareFunction [Clause]
clauses = Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$ Name -> [Clause] -> Dec
FunD 'gcompare [Clause]
clauses

gcompareClauses :: Set Name -> ConstructorInfo -> WriterT Cxt Q [Clause]
gcompareClauses :: Set Name -> ConstructorInfo -> WriterT [Type] Q [Clause]
gcompareClauses Set Name
paramVars ConstructorInfo
con = do
  let conName :: Name
conName = ConstructorInfo -> Name
constructorName ConstructorInfo
con
      argTypes :: [Type]
argTypes = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
con
      conTyVars :: Set Name
conTyVars = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList ((TyVarBndr_ Any -> Name) -> [TyVarBndr_ Any] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr_ Any -> Name
forall flag. TyVarBndr_ Any -> Name
tvName (ConstructorInfo -> [TyVarBndr_ Any]
constructorVars ConstructorInfo
con))
      needsGCompare :: Type -> Bool
needsGCompare Type
argType = Bool -> Bool
not (Bool -> Bool) -> (Set Name -> Bool) -> Set Name -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set Name -> Bool
forall a. Set a -> Bool
Set.null (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$
        Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.intersection (Type -> Set Name
freeTypeVariables Type
argType) (Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Name
paramVars Set Name
conTyVars)

  [Name]
lArgNames <- [Type]
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
argTypes ((Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name])
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall a b. (a -> b) -> a -> b
$ \Type
_ -> Q Name -> WriterT [Type] Q Name
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Name -> WriterT [Type] Q Name)
-> Q Name -> WriterT [Type] Q Name
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"x"
  [Name]
rArgNames <- [Type]
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
argTypes ((Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name])
-> (Type -> WriterT [Type] Q Name) -> WriterT [Type] Q [Name]
forall a b. (a -> b) -> a -> b
$ \Type
_ -> Q Name -> WriterT [Type] Q Name
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Name -> WriterT [Type] Q Name)
-> Q Name -> WriterT [Type] Q Name
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"y"

  [Stmt]
stmts <- [(Name, Name, Type)]
-> ((Name, Name, Type) -> WriterT [Type] Q Stmt)
-> WriterT [Type] Q [Stmt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Name] -> [Name] -> [Type] -> [(Name, Name, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Name]
lArgNames [Name]
rArgNames [Type]
argTypes) (((Name, Name, Type) -> WriterT [Type] Q Stmt)
 -> WriterT [Type] Q [Stmt])
-> ((Name, Name, Type) -> WriterT [Type] Q Stmt)
-> WriterT [Type] Q [Stmt]
forall a b. (a -> b) -> a -> b
$ \(Name
lArg, Name
rArg, Type
argType) ->
    case Type
argType of
      AppT Type
tyFun Type
tyArg | Type -> Bool
needsGCompare Type
argType -> do
        [Dec]
u <- Q [Dec] -> WriterT [Type] Q [Dec]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q [Dec] -> WriterT [Type] Q [Dec])
-> Q [Dec] -> WriterT [Type] Q [Dec]
forall a b. (a -> b) -> a -> b
$ Set Name -> Name -> [Type] -> Q [Dec]
reifyInstancesWithRigids Set Name
paramVars ''GCompare [Type
tyFun]
        case [Dec]
u of
          [] -> [Type] -> WriterT [Type] Q ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Type -> Type -> Type
AppT (Name -> Type
ConT ''GCompare) Type
tyFun]
          [(InstanceD Maybe Overlap
_ [Type]
cxt Type
_ [Dec]
_)] -> [Type] -> WriterT [Type] Q ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Type]
cxt -- this might not be enough, may want to do full instance resolution.
          [Dec]
_ -> String -> WriterT [Type] Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> WriterT [Type] Q ()) -> String -> WriterT [Type] Q ()
forall a b. (a -> b) -> a -> b
$ String
"More than one instance of GCompare (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall a. Ppr a => a -> Doc
ppr Type
tyFun) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
") found, and unsure what to do. Please report this."
        Q Stmt -> WriterT [Type] Q Stmt
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Stmt -> WriterT [Type] Q Stmt)
-> Q Stmt -> WriterT [Type] Q Stmt
forall a b. (a -> b) -> a -> b
$ PatQ -> ExpQ -> Q Stmt
bindS (Name -> [PatQ] -> PatQ
conP 'Refl []) [| geq' $(varE lArg) $(varE rArg) |]
      Type
_ -> Q Stmt -> WriterT [Type] Q Stmt
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Stmt -> WriterT [Type] Q Stmt)
-> Q Stmt -> WriterT [Type] Q Stmt
forall a b. (a -> b) -> a -> b
$ ExpQ -> Q Stmt
noBindS [| compare' $(varE lArg) $(varE rArg) |]

  Stmt
ret <- Q Stmt -> WriterT [Type] Q Stmt
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Stmt -> WriterT [Type] Q Stmt)
-> Q Stmt -> WriterT [Type] Q Stmt
forall a b. (a -> b) -> a -> b
$ ExpQ -> Q Stmt
noBindS [| return GEQ |]


  [Pat]
pats <- Q [Pat] -> WriterT [Type] Q [Pat]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q [Pat] -> WriterT [Type] Q [Pat])
-> Q [Pat] -> WriterT [Type] Q [Pat]
forall a b. (a -> b) -> a -> b
$ [PatQ] -> Q [Pat]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
        [ Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
lArgNames)
        , Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
rArgNames)
        ]
  let main :: Clause
main = [Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
pats
        (Exp -> Body
NormalB (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'runGComparing) ([Stmt] -> Exp
doUnqualifiedE ([Stmt]
stmts [Stmt] -> [Stmt] -> [Stmt]
forall a. [a] -> [a] -> [a]
++ [Stmt
ret]))))
        []
      lt :: Clause
lt = [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> [FieldPat] -> Pat
RecP Name
conName [], Pat
WildP] (Exp -> Body
NormalB (Name -> Exp
ConE 'GLT)) []
      gt :: Clause
gt = [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP, Name -> [FieldPat] -> Pat
RecP Name
conName []] (Exp -> Body
NormalB (Name -> Exp
ConE 'GGT)) []
  [Clause] -> WriterT [Type] Q [Clause]
forall (m :: * -> *) a. Monad m => a -> m a
return [Clause
main, Clause
lt, Clause
gt]

#if MIN_VERSION_template_haskell(2,17,0)
doUnqualifiedE = DoE Nothing
#else
doUnqualifiedE :: [Stmt] -> Exp
doUnqualifiedE = [Stmt] -> Exp
DoE
#endif