{-# LANGUAGE CPP #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
module Language.Haskell.TH.ExpandSyns(-- * Expand synonyms
                                      expandSyns
                                     ,expandSynsWith
                                     ,SynonymExpansionSettings
                                     ,noWarnTypeFamilies

                                      -- * Misc utilities
                                     ,substInType
                                     ,substInCon
                                     ,evades,evade) where

import Language.Haskell.TH.Datatype
import Language.Haskell.TH.Datatype.TyVarBndr
import Language.Haskell.TH.ExpandSyns.SemigroupCompat as Sem
import Language.Haskell.TH hiding(cxt)
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Generics
import Control.Monad
import Prelude

#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative
#endif

-- For ghci
#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(X,Y,Z) 1
#endif

packagename :: String
packagename :: String
packagename = String
"th-expand-syns"

tyVarBndrSetName :: Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName :: forall flag. Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName Name
n = forall flag. (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVName (forall a b. a -> b -> a
const Name
n)

data SynonymExpansionSettings =
  SynonymExpansionSettings {
    SynonymExpansionSettings -> Bool
sesWarnTypeFamilies :: Bool
  }

instance Semigroup SynonymExpansionSettings where
  SynonymExpansionSettings Bool
w1 <> :: SynonymExpansionSettings
-> SynonymExpansionSettings -> SynonymExpansionSettings
<> SynonymExpansionSettings Bool
w2 =
    Bool -> SynonymExpansionSettings
SynonymExpansionSettings (Bool
w1 Bool -> Bool -> Bool
&& Bool
w2)

-- | Default settings ('mempty'):
--
-- * Warn if type families are encountered.
--
-- (The 'mappend' is currently rather useless; the monoid instance is intended for additional settings in the future).
instance Monoid SynonymExpansionSettings where
  mempty :: SynonymExpansionSettings
mempty =
    SynonymExpansionSettings {
      sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
True
    }

#if !MIN_VERSION_base(4,11,0)
-- starting with base-4.11, mappend definitions are redundant;
-- at some point `mappend` will be removed from `Monoid`
  mappend = (Sem.<>)
#endif

-- | Suppresses the warning that type families are unsupported.
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies = forall a. Monoid a => a
mempty { sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
False }

warn ::  String -> Q ()
warn :: String -> Q ()
warn String
msg =
#if MIN_VERSION_template_haskell(2,8,0)
    String -> Q ()
reportWarning
#else
    report False
#endif
      (String
packagename forall a. [a] -> [a] -> [a]
++String
": WARNING: "forall a. [a] -> [a] -> [a]
++String
msg)

warnIfNameIsTypeFamily :: Name -> Q ()
warnIfNameIsTypeFamily :: Name -> Q ()
warnIfNameIsTypeFamily Name
n = do
  Info
i <- Name -> Q Info
reify Name
n
  case Info
i of
    ClassI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    ClassOpI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    TyConI Dec
d -> Dec -> Q ()
warnIfDecIsTypeFamily Dec
d
#if MIN_VERSION_template_haskell(2,7,0)
    FamilyI Dec
d [Dec]
_ -> Dec -> Q ()
warnIfDecIsTypeFamily Dec
d -- Called for warnings
#endif
    PrimTyConI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    DataConI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    VarI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    TyVarI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_template_haskell(2,12,0)
    PatSynI {} -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

warnIfDecIsTypeFamily :: Dec -> Q ()
warnIfDecIsTypeFamily :: Dec -> Q ()
warnIfDecIsTypeFamily = Dec -> Q ()
go
  where
    go :: Dec -> Q ()
go (TySynD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()

#if MIN_VERSION_template_haskell(2,11,0)
    go (OpenTypeFamilyD (TypeFamilyHead Name
name [TyVarBndr ()]
_ FamilyResultSig
_ Maybe InjectivityAnn
_)) = Name -> Q ()
maybeWarnTypeFamily Name
name
    go (ClosedTypeFamilyD (TypeFamilyHead Name
name [TyVarBndr ()]
_ FamilyResultSig
_ Maybe InjectivityAnn
_) [TySynEqn]
_) = Name -> Q ()
maybeWarnTypeFamily Name
name
#else

#if MIN_VERSION_template_haskell(2,9,0)
    go (ClosedTypeFamilyD name _ _ _) = maybeWarnTypeFamily name
#endif

    go (FamilyD TypeFam name _ _) = maybeWarnTypeFamily name
#endif

    go (FunD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ValD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (DataD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (NewtypeD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ClassD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (InstanceD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (SigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ForeignD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()

#if MIN_VERSION_template_haskell(2,8,0)
    go (InfixD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

    go (PragmaD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()

    -- Nothing to expand for data families, so no warning
#if MIN_VERSION_template_haskell(2,11,0)
    go (DataFamilyD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
    go (FamilyD DataFam _ _ _) = return ()
#endif

    go (DataInstD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (NewtypeInstD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (TySynInstD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()

#if MIN_VERSION_template_haskell(2,9,0)
    go (RoleAnnotD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,10,0)
    go (StandaloneDerivD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (DefaultSigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,12,0)
    go (PatSynD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (PatSynSigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,15,0)
    go (ImplicitParamBindD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,16,0)
    go (KiSigD {}) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,19,0)
    go (DefaultD {}) = return ()
#endif

#if MIN_VERSION_template_haskell(2,20,0)
    go (TypeDataD {}) = return ()
#endif

warnTypeFamiliesInType :: Type -> Q ()
warnTypeFamiliesInType :: Type -> Q ()
warnTypeFamiliesInType = Type -> Q ()
go
  where
    go :: Type -> Q ()
    go :: Type -> Q ()
go (ConT Name
n)     = Name -> Q ()
warnIfNameIsTypeFamily Name
n
    go (AppT Type
t1 Type
t2) = Type -> Q ()
go Type
t1 forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go Type
t2
    go (SigT Type
t Type
k)   = Type -> Q ()
go Type
t  forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go_kind Type
k
    go ListT{}      = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go ArrowT{}     = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go VarT{}       = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go TupleT{}     = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ForallT [TyVarBndr Specificity]
tvbs Cxt
ctxt Type
body) = do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Q ()
go_kind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall flag. TyVarBndr_ flag -> Type
tvKind) [TyVarBndr Specificity]
tvbs
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> Q ()
go_pred Cxt
ctxt
      Type -> Q ()
go Type
body
#if MIN_VERSION_template_haskell(2,6,0)
    go UnboxedTupleT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,8,0)
    go PromotedT{}      = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go PromotedTupleT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go PromotedConsT{}  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go PromotedNilT{}   = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go StarT{}          = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go ConstraintT{}    = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go LitT{}           = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,10,0)
    go EqualityT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,11,0)
    go (InfixT Type
t1 Name
n Type
t2) = do
      Name -> Q ()
warnIfNameIsTypeFamily Name
n
      Type -> Q ()
go Type
t1
      Type -> Q ()
go Type
t2
    go (UInfixT Type
t1 Name
n Type
t2) = do
      Name -> Q ()
warnIfNameIsTypeFamily Name
n
      Type -> Q ()
go Type
t1
      Type -> Q ()
go Type
t2
    go (ParensT Type
t) = Type -> Q ()
go Type
t
    go WildCardT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,12,0)
    go UnboxedSumT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,15,0)
    go (AppKindT Type
t Type
k)       = Type -> Q ()
go Type
t forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go_kind Type
k
    go (ImplicitParamT String
_ Type
t) = Type -> Q ()
go Type
t
#endif
#if MIN_VERSION_template_haskell(2,16,0)
    go (ForallVisT [TyVarBndr ()]
tvbs Type
body) = do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Q ()
go_kind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall flag. TyVarBndr_ flag -> Type
tvKind) [TyVarBndr ()]
tvbs
      Type -> Q ()
go Type
body
#endif
#if MIN_VERSION_template_haskell(2,17,0)
    go MulArrowT{} = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,19,0)
    go (PromotedInfixT t1 n t2) = do
      warnIfNameIsTypeFamily n
      go t1
      go t2
    go (PromotedUInfixT t1 n t2) = do
      warnIfNameIsTypeFamily n
      go t1
      go t2
#endif

    go_kind :: Kind -> Q ()
#if MIN_VERSION_template_haskell(2,8,0)
    go_kind :: Type -> Q ()
go_kind = Type -> Q ()
go
#else
    go_kind _ = return ()
#endif

    go_pred :: Pred -> Q ()
#if MIN_VERSION_template_haskell(2,10,0)
    go_pred :: Type -> Q ()
go_pred = Type -> Q ()
go
#else
    go_pred (ClassP _ ts)  = mapM_ go ts
    go_pred (EqualP t1 t2) = go t1 >> go t2
#endif

maybeWarnTypeFamily :: Name -> Q ()
maybeWarnTypeFamily :: Name -> Q ()
maybeWarnTypeFamily Name
name =
  String -> Q ()
warn (String
"Type synonym families (and associated type synonyms) are currently not supported (they won't be expanded). Name of unsupported family: "forall a. [a] -> [a] -> [a]
++forall a. Show a => a -> String
show Name
name)

-- | Calls 'expandSynsWith' with the default settings.
expandSyns :: Type -> Q Type
expandSyns :: Type -> Q Type
expandSyns = SynonymExpansionSettings -> Type -> Q Type
expandSynsWith forall a. Monoid a => a
mempty

-- | Expands all type synonyms in the given type. Type families currently won't be expanded (but will be passed through).
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith SynonymExpansionSettings
settings = Type -> Q Type
expandSyns'
    where
      expandSyns' :: Type -> Q Type
expandSyns' Type
x = do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SynonymExpansionSettings -> Bool
sesWarnTypeFamilies SynonymExpansionSettings
settings) forall a b. (a -> b) -> a -> b
$
          Type -> Q ()
warnTypeFamiliesInType Type
x
        Type -> Q Type
resolveTypeSynonyms Type
x

-- | Make a name (based on the first arg) that's distinct from every name in the second arg
--
-- Example why this is necessary:
--
-- > type E x = forall y. Either x y
-- >
-- > ... expandSyns [t| forall y. y -> E y |]
--
-- The example as given may actually work correctly without any special capture-avoidance depending
-- on how GHC handles the @y@s, but in any case, the input type to expandSyns may be an explicit
-- AST using 'mkName' to ensure a collision.
--
evade :: Data d => Name -> d -> Name
evade :: forall d. Data d => Name -> d -> Name
evade Name
n d
t =
    let
        vars :: Set.Set Name
        vars :: Set Name
vars = forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything forall a. Ord a => Set a -> Set a -> Set a
Set.union (forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ forall a. Set a
Set.empty forall a. a -> Set a
Set.singleton) d
t

        go :: Name -> Name
go Name
n1 = if Name
n1 forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Name
vars
                then Name -> Name
go (Name -> Name
bump Name
n1)
                else Name
n1

        bump :: Name -> Name
bump = String -> Name
mkName forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
'f'forall a. a -> [a] -> [a]
:) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase
    in
      Name -> Name
go Name
n

-- | Make a list of names (based on the first arg) such that every name in the result
-- is distinct from every name in the second arg, and from the other results
evades :: (Data t) => [Name] -> t -> [Name]
evades :: forall t. Data t => [Name] -> t -> [Name]
evades [Name]
ns t
t = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Name -> [Name] -> [Name]
c [] [Name]
ns
    where
      c :: Name -> [Name] -> [Name]
c Name
n [Name]
rec = forall d. Data d => Name -> d -> Name
evade Name
n ([Name]
rec,t
t) forall a. a -> [a] -> [a]
: [Name]
rec

-- evadeTest = let v = mkName "x"
--             in
--               evade v (AppT (VarT v) (VarT (mkName "fx")))

-- | Capture-free substitution
substInType :: (Name,Type) -> Type -> Type
substInType :: (Name, Type) -> Type -> Type
substInType (Name, Type)
vt = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt])

-- | Capture-free substitution
substInCon :: (Name,Type) -> Con -> Con
substInCon :: (Name, Type) -> Con -> Con
substInCon (Name, Type)
vt = Con -> Con
go
    where
      vtSubst :: Map Name Type
vtSubst = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt]
      st :: a -> a
st = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vtSubst

      go :: Con -> Con
go (NormalC Name
n [BangType]
ts) = Name -> [BangType] -> Con
NormalC Name
n [(Bang
x, forall {a}. TypeSubstitution a => a -> a
st Type
y) | (Bang
x,Type
y) <- [BangType]
ts]
      go (RecC Name
n [VarBangType]
ts) = Name -> [VarBangType] -> Con
RecC Name
n [(Name
x, Bang
y, forall {a}. TypeSubstitution a => a -> a
st Type
z) | (Name
x,Bang
y,Type
z) <- [VarBangType]
ts]
      go (InfixC (Bang
y1,Type
t1) Name
op (Bang
y2,Type
t2)) = BangType -> Name -> BangType -> Con
InfixC (Bang
y1,forall {a}. TypeSubstitution a => a -> a
st Type
t1) Name
op (Bang
y2,forall {a}. TypeSubstitution a => a -> a
st Type
t2)
      go (ForallC [TyVarBndr Specificity]
vars Cxt
cxt Con
body) =
          forall flag a.
(Name, Type)
-> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase (Name, Type)
vt [TyVarBndr Specificity]
vars forall a b. (a -> b) -> a -> b
$ \Map Name Type
vts' [TyVarBndr Specificity]
vars' ->
          [TyVarBndr Specificity] -> Cxt -> Con -> Con
ForallC (forall a b. (a -> b) -> [a] -> [b]
map (forall flag. (Type -> Type) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVKind (forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vts')) [TyVarBndr Specificity]
vars')
                  (forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vts' Cxt
cxt)
                  (forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey (\Name
v Type
t -> (Name, Type) -> Con -> Con
substInCon (Name
v, Type
t)) Con
body Map Name Type
vts')
#if MIN_VERSION_template_haskell(2,11,0)
      go c :: Con
c@GadtC{} = forall {a} {a}. Ppr a => a -> a
errGadt Con
c
      go c :: Con
c@RecGadtC{} = forall {a} {a}. Ppr a => a -> a
errGadt Con
c

      errGadt :: a -> a
errGadt a
c = forall a. HasCallStack => String -> a
error (String
packagenameforall a. [a] -> [a] -> [a]
++String
": substInCon currently doesn't support GADT constructors with GHC >= 8 ("forall a. [a] -> [a] -> [a]
++forall a. Ppr a => a -> String
pprint a
cforall a. [a] -> [a] -> [a]
++String
")")
#endif

-- Apply a substitution to something underneath a @forall@. The continuation
-- argument provides new substitutions and fresh type variable binders to avoid
-- the outer substitution from capturing the thing underneath the @forall@.
commonForallCase :: (Name, Type) -> [TyVarBndr_ flag]
                 -> (Map Name Type -> [TyVarBndr_ flag] -> a)
                 -> a
commonForallCase :: forall flag a.
(Name, Type)
-> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase vt :: (Name, Type)
vt@(Name
v,Type
t) [TyVarBndr_ flag]
bndrs Map Name Type -> [TyVarBndr_ flag] -> a
k
            -- If a variable with the same name as the one to be replaced is bound by the forall,
            -- the variable to be replaced is shadowed in the body, so we leave the whole thing alone (no recursion)
          | Name
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (forall flag. TyVarBndr_ flag -> Name
tvName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs) = Map Name Type -> [TyVarBndr_ flag] -> a
k (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt]) [TyVarBndr_ flag]
bndrs

          | Bool
otherwise =
              let
                  -- prevent capture
                  vars :: [Name]
vars = forall flag. TyVarBndr_ flag -> Name
tvName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs
                  freshes :: [Name]
freshes = forall t. Data t => [Name] -> t -> [Name]
evades [Name]
vars Type
t
                  freshTyVarBndrs :: [TyVarBndr_ flag]
freshTyVarBndrs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall flag. Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName [Name]
freshes [TyVarBndr_ flag]
bndrs
                  substs :: [(Name, Type)]
substs = forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
vars (Name -> Type
VarT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
freshes)
              in
                Map Name Type -> [TyVarBndr_ flag] -> a
k (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ((Name, Type)
vtforall a. a -> [a] -> [a]
:[(Name, Type)]
substs)) [TyVarBndr_ flag]
freshTyVarBndrs