{-# LANGUAGE CPP #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
module Language.Haskell.TH.ExpandSyns(
expandSyns
,expandSynsWith
,SynonymExpansionSettings
,noWarnTypeFamilies
,substInType
,substInCon
,evades,evade) where
import Language.Haskell.TH.Datatype
import Language.Haskell.TH.Datatype.TyVarBndr
import Language.Haskell.TH.ExpandSyns.SemigroupCompat as Sem
import Language.Haskell.TH hiding(cxt)
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Generics
import Control.Monad
import Prelude
#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative
#endif
#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(X,Y,Z) 1
#endif
packagename :: String
packagename :: String
packagename = String
"th-expand-syns"
tyVarBndrSetName :: Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName :: forall flag. Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName Name
n = forall flag. (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVName (forall a b. a -> b -> a
const Name
n)
data SynonymExpansionSettings =
SynonymExpansionSettings {
SynonymExpansionSettings -> Bool
sesWarnTypeFamilies :: Bool
}
instance Semigroup SynonymExpansionSettings where
SynonymExpansionSettings Bool
w1 <> :: SynonymExpansionSettings
-> SynonymExpansionSettings -> SynonymExpansionSettings
<> SynonymExpansionSettings Bool
w2 =
Bool -> SynonymExpansionSettings
SynonymExpansionSettings (Bool
w1 Bool -> Bool -> Bool
&& Bool
w2)
instance Monoid SynonymExpansionSettings where
mempty :: SynonymExpansionSettings
mempty =
SynonymExpansionSettings {
sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
True
}
#if !MIN_VERSION_base(4,11,0)
mappend = (Sem.<>)
#endif
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies = forall a. Monoid a => a
mempty { sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
False }
warn :: String -> Q ()
warn :: String -> Q ()
warn String
msg =
#if MIN_VERSION_template_haskell(2,8,0)
String -> Q ()
reportWarning
#else
report False
#endif
(String
packagename forall a. [a] -> [a] -> [a]
++String
": WARNING: "forall a. [a] -> [a] -> [a]
++String
msg)
warnIfNameIsTypeFamily :: Name -> Q ()
warnIfNameIsTypeFamily :: Name -> Q ()
warnIfNameIsTypeFamily Name
n = do
Info
i <- Name -> Q Info
reify Name
n
case Info
i of
ClassI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
ClassOpI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
TyConI Dec
d -> Dec -> Q ()
warnIfDecIsTypeFamily Dec
d
#if MIN_VERSION_template_haskell(2,7,0)
FamilyI Dec
d [Dec]
_ -> Dec -> Q ()
warnIfDecIsTypeFamily Dec
d
#endif
PrimTyConI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
DataConI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
VarI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
TyVarI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_template_haskell(2,12,0)
PatSynI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
warnIfDecIsTypeFamily :: Dec -> Q ()
warnIfDecIsTypeFamily :: Dec -> Q ()
warnIfDecIsTypeFamily = Dec -> Q ()
go
where
go :: Dec -> Q ()
go (TySynD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_template_haskell(2,11,0)
go (OpenTypeFamilyD (TypeFamilyHead Name
name [TyVarBndr ()]
_ FamilyResultSig
_ Maybe InjectivityAnn
_)) = Name -> Q ()
maybeWarnTypeFamily Name
name
go (ClosedTypeFamilyD (TypeFamilyHead Name
name [TyVarBndr ()]
_ FamilyResultSig
_ Maybe InjectivityAnn
_) [TySynEqn]
_) = Name -> Q ()
maybeWarnTypeFamily Name
name
#else
#if MIN_VERSION_template_haskell(2,9,0)
go (ClosedTypeFamilyD name _ _ _) = maybeWarnTypeFamily name
#endif
go (FamilyD TypeFam name _ _) = maybeWarnTypeFamily name
#endif
go (FunD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (ValD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (DataD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (NewtypeD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (ClassD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (InstanceD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (SigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (ForeignD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_template_haskell(2,8,0)
go (InfixD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
go (PragmaD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_template_haskell(2,11,0)
go (DataFamilyD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
go (FamilyD DataFam _ _ _) = return ()
#endif
go (DataInstD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (NewtypeInstD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (TySynInstD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_template_haskell(2,9,0)
go (RoleAnnotD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,10,0)
go (StandaloneDerivD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (DefaultSigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,12,0)
go (PatSynD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (PatSynSigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,15,0)
go (ImplicitParamBindD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,16,0)
go (KiSigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,19,0)
go (DefaultD {}) = return ()
#endif
#if MIN_VERSION_template_haskell(2,20,0)
go (TypeDataD {}) = return ()
#endif
warnTypeFamiliesInType :: Type -> Q ()
warnTypeFamiliesInType :: Type -> Q ()
warnTypeFamiliesInType = Type -> Q ()
go
where
go :: Type -> Q ()
go :: Type -> Q ()
go (ConT Name
n) = Name -> Q ()
warnIfNameIsTypeFamily Name
n
go (AppT Type
t1 Type
t2) = Type -> Q ()
go Type
t1 forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go Type
t2
go (SigT Type
t Type
k) = Type -> Q ()
go Type
t forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go_kind Type
k
go ListT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go ArrowT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go VarT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go TupleT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go (ForallT [TyVarBndr Specificity]
tvbs Cxt
ctxt Type
body) = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Q ()
go_kind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall flag. TyVarBndr_ flag -> Type
tvKind) [TyVarBndr Specificity]
tvbs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> Q ()
go_pred Cxt
ctxt
Type -> Q ()
go Type
body
#if MIN_VERSION_template_haskell(2,6,0)
go UnboxedTupleT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,8,0)
go PromotedT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go PromotedTupleT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go PromotedConsT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go PromotedNilT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go StarT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go ConstraintT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
go LitT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,10,0)
go EqualityT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,11,0)
go (InfixT Type
t1 Name
n Type
t2) = do
Name -> Q ()
warnIfNameIsTypeFamily Name
n
Type -> Q ()
go Type
t1
Type -> Q ()
go Type
t2
go (UInfixT Type
t1 Name
n Type
t2) = do
Name -> Q ()
warnIfNameIsTypeFamily Name
n
Type -> Q ()
go Type
t1
Type -> Q ()
go Type
t2
go (ParensT Type
t) = Type -> Q ()
go Type
t
go WildCardT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,12,0)
go UnboxedSumT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,15,0)
go (AppKindT Type
t Type
k) = Type -> Q ()
go Type
t forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go_kind Type
k
go (ImplicitParamT String
_ Type
t) = Type -> Q ()
go Type
t
#endif
#if MIN_VERSION_template_haskell(2,16,0)
go (ForallVisT [TyVarBndr ()]
tvbs Type
body) = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Q ()
go_kind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall flag. TyVarBndr_ flag -> Type
tvKind) [TyVarBndr ()]
tvbs
Type -> Q ()
go Type
body
#endif
#if MIN_VERSION_template_haskell(2,17,0)
go MulArrowT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,19,0)
go (PromotedInfixT t1 n t2) = do
warnIfNameIsTypeFamily n
go t1
go t2
go (PromotedUInfixT t1 n t2) = do
warnIfNameIsTypeFamily n
go t1
go t2
#endif
go_kind :: Kind -> Q ()
#if MIN_VERSION_template_haskell(2,8,0)
go_kind :: Type -> Q ()
go_kind = Type -> Q ()
go
#else
go_kind _ = return ()
#endif
go_pred :: Pred -> Q ()
#if MIN_VERSION_template_haskell(2,10,0)
go_pred :: Type -> Q ()
go_pred = Type -> Q ()
go
#else
go_pred (ClassP _ ts) = mapM_ go ts
go_pred (EqualP t1 t2) = go t1 >> go t2
#endif
maybeWarnTypeFamily :: Name -> Q ()
maybeWarnTypeFamily :: Name -> Q ()
maybeWarnTypeFamily Name
name =
String -> Q ()
warn (String
"Type synonym families (and associated type synonyms) are currently not supported (they won't be expanded). Name of unsupported family: "forall a. [a] -> [a] -> [a]
++forall a. Show a => a -> String
show Name
name)
expandSyns :: Type -> Q Type
expandSyns :: Type -> Q Type
expandSyns = SynonymExpansionSettings -> Type -> Q Type
expandSynsWith forall a. Monoid a => a
mempty
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith SynonymExpansionSettings
settings = Type -> Q Type
expandSyns'
where
expandSyns' :: Type -> Q Type
expandSyns' Type
x = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SynonymExpansionSettings -> Bool
sesWarnTypeFamilies SynonymExpansionSettings
settings) forall a b. (a -> b) -> a -> b
$
Type -> Q ()
warnTypeFamiliesInType Type
x
Type -> Q Type
resolveTypeSynonyms Type
x
evade :: Data d => Name -> d -> Name
evade :: forall d. Data d => Name -> d -> Name
evade Name
n d
t =
let
vars :: Set.Set Name
vars :: Set Name
vars = forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything forall a. Ord a => Set a -> Set a -> Set a
Set.union (forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ forall a. Set a
Set.empty forall a. a -> Set a
Set.singleton) d
t
go :: Name -> Name
go Name
n1 = if Name
n1 forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Name
vars
then Name -> Name
go (Name -> Name
bump Name
n1)
else Name
n1
bump :: Name -> Name
bump = String -> Name
mkName forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
'f'forall a. a -> [a] -> [a]
:) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase
in
Name -> Name
go Name
n
evades :: (Data t) => [Name] -> t -> [Name]
evades :: forall t. Data t => [Name] -> t -> [Name]
evades [Name]
ns t
t = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Name -> [Name] -> [Name]
c [] [Name]
ns
where
c :: Name -> [Name] -> [Name]
c Name
n [Name]
rec = forall d. Data d => Name -> d -> Name
evade Name
n ([Name]
rec,t
t) forall a. a -> [a] -> [a]
: [Name]
rec
substInType :: (Name,Type) -> Type -> Type
substInType :: (Name, Type) -> Type -> Type
substInType (Name, Type)
vt = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt])
substInCon :: (Name,Type) -> Con -> Con
substInCon :: (Name, Type) -> Con -> Con
substInCon (Name, Type)
vt = Con -> Con
go
where
vtSubst :: Map Name Type
vtSubst = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt]
st :: a -> a
st = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vtSubst
go :: Con -> Con
go (NormalC Name
n [BangType]
ts) = Name -> [BangType] -> Con
NormalC Name
n [(Bang
x, forall {a}. TypeSubstitution a => a -> a
st Type
y) | (Bang
x,Type
y) <- [BangType]
ts]
go (RecC Name
n [VarBangType]
ts) = Name -> [VarBangType] -> Con
RecC Name
n [(Name
x, Bang
y, forall {a}. TypeSubstitution a => a -> a
st Type
z) | (Name
x,Bang
y,Type
z) <- [VarBangType]
ts]
go (InfixC (Bang
y1,Type
t1) Name
op (Bang
y2,Type
t2)) = BangType -> Name -> BangType -> Con
InfixC (Bang
y1,forall {a}. TypeSubstitution a => a -> a
st Type
t1) Name
op (Bang
y2,forall {a}. TypeSubstitution a => a -> a
st Type
t2)
go (ForallC [TyVarBndr Specificity]
vars Cxt
cxt Con
body) =
forall flag a.
(Name, Type)
-> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase (Name, Type)
vt [TyVarBndr Specificity]
vars forall a b. (a -> b) -> a -> b
$ \Map Name Type
vts' [TyVarBndr Specificity]
vars' ->
[TyVarBndr Specificity] -> Cxt -> Con -> Con
ForallC (forall a b. (a -> b) -> [a] -> [b]
map (forall flag. (Type -> Type) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVKind (forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vts')) [TyVarBndr Specificity]
vars')
(forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vts' Cxt
cxt)
(forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey (\Name
v Type
t -> (Name, Type) -> Con -> Con
substInCon (Name
v, Type
t)) Con
body Map Name Type
vts')
#if MIN_VERSION_template_haskell(2,11,0)
go c :: Con
c@GadtC{} = forall {a} {a}. Ppr a => a -> a
errGadt Con
c
go c :: Con
c@RecGadtC{} = forall {a} {a}. Ppr a => a -> a
errGadt Con
c
errGadt :: a -> a
errGadt a
c = forall a. HasCallStack => String -> a
error (String
packagenameforall a. [a] -> [a] -> [a]
++String
": substInCon currently doesn't support GADT constructors with GHC >= 8 ("forall a. [a] -> [a] -> [a]
++forall a. Ppr a => a -> String
pprint a
cforall a. [a] -> [a] -> [a]
++String
")")
#endif
commonForallCase :: (Name, Type) -> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase :: forall flag a.
(Name, Type)
-> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase vt :: (Name, Type)
vt@(Name
v,Type
t) [TyVarBndr_ flag]
bndrs Map Name Type -> [TyVarBndr_ flag] -> a
k
| Name
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (forall flag. TyVarBndr_ flag -> Name
tvName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs) = Map Name Type -> [TyVarBndr_ flag] -> a
k (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt]) [TyVarBndr_ flag]
bndrs
| Bool
otherwise =
let
vars :: [Name]
vars = forall flag. TyVarBndr_ flag -> Name
tvName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs
freshes :: [Name]
freshes = forall t. Data t => [Name] -> t -> [Name]
evades [Name]
vars Type
t
freshTyVarBndrs :: [TyVarBndr_ flag]
freshTyVarBndrs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall flag. Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName [Name]
freshes [TyVarBndr_ flag]
bndrs
substs :: [(Name, Type)]
substs = forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
vars (Name -> Type
VarT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
freshes)
in
Map Name Type -> [TyVarBndr_ flag] -> a
k (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ((Name, Type)
vtforall a. a -> [a] -> [a]
:[(Name, Type)]
substs)) [TyVarBndr_ flag]
freshTyVarBndrs