{-# Language CPP, DeriveGeneric, DeriveDataTypeable #-}
module Language.Haskell.TH.Datatype
(
DatatypeInfo(..)
, ConstructorInfo(..)
, DatatypeVariant(..)
, ConstructorVariant(..)
, reifyDatatype
, normalizeInfo
, normalizeDec
, normalizeCon
, TypeSubstitution(..)
, quantifyType
, freshenFreeVariables
, equalPred
, classPred
, dataDCompat
, arrowKCompat
, resolveTypeSynonyms
, unifyTypes
, tvName
, datatypeType
) where
import Data.Data (Typeable, Data)
import Data.Foldable (foldMap, foldl')
import Data.List (find, union, (\\))
import Data.Map (Map)
import qualified Data.Map as Map
import Control.Monad (foldM)
import GHC.Generics (Generic)
import Language.Haskell.TH
import Language.Haskell.TH.Lib (arrowK)
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative (Applicative(..), (<$>))
import Data.Traversable (traverse, sequenceA)
#endif
data DatatypeInfo = DatatypeInfo
{ datatypeContext :: Cxt
, datatypeName :: Name
, datatypeVars :: [Type]
, datatypeVariant :: DatatypeVariant
, datatypeCons :: [ConstructorInfo]
}
deriving (Show, Eq, Typeable, Data, Generic)
data DatatypeVariant
= Datatype
| Newtype
| DataInstance
| NewtypeInstance
deriving (Show, Read, Eq, Ord, Typeable, Data, Generic)
data ConstructorInfo = ConstructorInfo
{ constructorName :: Name
, constructorVars :: [TyVarBndr]
, constructorContext :: Cxt
, constructorFields :: [Type]
, constructorVariant :: ConstructorVariant
}
deriving (Show, Eq, Typeable, Data, Generic)
data ConstructorVariant
= NormalConstructor
| RecordConstructor [Name]
deriving (Show, Eq, Ord, Typeable, Data, Generic)
datatypeType :: DatatypeInfo -> Type
datatypeType di
= foldl AppT (ConT (datatypeName di))
$ datatypeVars di
reifyDatatype ::
Name ->
Q DatatypeInfo
reifyDatatype n = normalizeInfo =<< reify n
normalizeInfo :: Info -> Q DatatypeInfo
normalizeInfo (TyConI dec) = normalizeDec dec
# if MIN_VERSION_template_haskell(2,11,0)
normalizeInfo (DataConI name ty parent) = reifyParent name ty parent
# else
normalizeInfo (DataConI name ty parent _) = reifyParent name ty parent
# endif
normalizeInfo _ = fail "reifyDatatype: Expected a type constructor"
reifyParent :: Name -> Type -> Name -> Q DatatypeInfo
reifyParent con ty parent =
do info <- reify parent
case info of
TyConI dec -> normalizeDec dec
FamilyI dec instances ->
do let instances1 = map (repairInstance dec ty) instances
instances2 <- traverse normalizeDec instances1
case find p instances2 of
Just inst -> return inst
Nothing -> fail "PANIC: reifyParent lost the instance"
_ -> fail "PANIC: reifyParent unexpected parent"
where
p info = con `elem` map constructorName (datatypeCons info)
#if (!MIN_VERSION_template_haskell(2,10,0)) && MIN_VERSION_template_haskell(2,9,0)
kindPart (KindedTV _ k) = [k]
kindPart (PlainTV _ ) = []
countKindVars = length . freeVariables . map kindPart
repairInstance
(FamilyD _ _ dvars _)
(ForallT tvars _ _)
(NewtypeInstD cx n ts con deriv) =
NewtypeInstD cx n ts' con deriv
where
nparams = length dvars
kparams = countKindVars dvars
ts' = take nparams (drop kparams (ts ++ bndrParams tvars))
repairInstance
(FamilyD _ _ dvars _)
(ForallT tvars _ _)
(DataInstD cx n ts cons deriv) =
DataInstD cx n ts' cons deriv
where
nparams = length dvars
kparams = countKindVars dvars
ts' = take nparams (drop kparams (ts ++ bndrParams tvars))
#endif
repairInstance _ _ x = x
normalizeDec :: Dec -> Q DatatypeInfo
#if MIN_VERSION_template_haskell(2,12,0)
normalizeDec (NewtypeD context name tyvars _kind con _derives) =
normalizeDec' context name (bndrParams tyvars) [con] Newtype
normalizeDec (DataD context name tyvars _kind cons _derives) =
normalizeDec' context name (bndrParams tyvars) cons Datatype
normalizeDec (NewtypeInstD context name params _kind con _derives) =
repair13618 =<<
normalizeDec' context name params [con] NewtypeInstance
normalizeDec (DataInstD context name params _kind cons _derives) =
repair13618 =<<
normalizeDec' context name params cons DataInstance
#elif MIN_VERSION_template_haskell(2,11,0)
normalizeDec (NewtypeD context name tyvars _kind con _derives) =
normalizeDec' context name (bndrParams tyvars) [con] Newtype
normalizeDec (DataD context name tyvars _kind cons _derives) =
normalizeDec' context name (bndrParams tyvars) cons Datatype
normalizeDec (NewtypeInstD context name params _kind con _derives) =
repair13618 =<<
normalizeDec' context name params [con] NewtypeInstance
normalizeDec (DataInstD context name params _kind cons _derives) =
repair13618 =<<
normalizeDec' context name params cons DataInstance
#else
normalizeDec (NewtypeD context name tyvars con _derives) =
normalizeDec' context name (bndrParams tyvars) [con] Newtype
normalizeDec (DataD context name tyvars cons _derives) =
normalizeDec' context name (bndrParams tyvars) cons Datatype
normalizeDec (NewtypeInstD context name params con _derives) =
repair13618 =<<
normalizeDec' context name params [con] NewtypeInstance
normalizeDec (DataInstD context name params cons _derives) =
repair13618 =<<
normalizeDec' context name params cons DataInstance
#endif
normalizeDec _ = fail "reifyDatatype: DataD or NewtypeD required"
bndrParams :: [TyVarBndr] -> [Type]
bndrParams = map $ \bndr ->
case bndr of
KindedTV t k -> SigT (VarT t) k
PlainTV t -> VarT t
normalizeDec' ::
Cxt ->
Name ->
[Type] ->
[Con] ->
DatatypeVariant ->
Q DatatypeInfo
normalizeDec' context name params cons variant =
do let vs = freeVariables params
cons' <- concat <$> traverse (normalizeCon name vs) cons
pure DatatypeInfo
{ datatypeContext = context
, datatypeName = name
, datatypeVars = params
, datatypeCons = cons'
, datatypeVariant = variant
}
normalizeCon ::
Name ->
[Name] ->
Con ->
Q [ConstructorInfo]
normalizeCon typename vars = go [] []
where
go tyvars context c =
case c of
NormalC n xs ->
pure [ConstructorInfo n tyvars context (map snd xs) NormalConstructor]
InfixC l n r ->
pure [ConstructorInfo n tyvars context [snd l,snd r] NormalConstructor]
RecC n xs ->
let fns = takeFieldNames xs in
pure [ConstructorInfo n tyvars context
(takeFieldTypes xs) (RecordConstructor fns)]
ForallC tyvars' context' c' ->
go (tyvars'++tyvars) (context'++context) c'
#if MIN_VERSION_template_haskell(2,11,0)
GadtC ns xs innerType ->
gadtCase ns innerType (map snd xs) NormalConstructor
RecGadtC ns xs innerType ->
let fns = takeFieldNames xs in
gadtCase ns innerType (takeFieldTypes xs) (RecordConstructor fns)
where
gadtCase = normalizeGadtC typename vars tyvars context
normalizeGadtC ::
Name ->
[Name] ->
[TyVarBndr] ->
Cxt ->
[Name] ->
Type ->
[Type] ->
ConstructorVariant ->
Q [ConstructorInfo]
normalizeGadtC typename vars tyvars context names innerType fields variant =
do innerType' <- resolveTypeSynonyms innerType
case decomposeType innerType' of
ConT innerTyCon :| ts | typename == innerTyCon ->
let (substName, context1) = mergeArguments vars ts
subst = VarT <$> substName
tyvars' = [ tv | tv <- tyvars, Map.notMember (tvName tv) subst ]
context2 = applySubstitution subst (context1 ++ context)
fields' = applySubstitution subst fields
in pure [ConstructorInfo name tyvars' context2 fields' variant
| name <- names]
_ -> fail "normalizeGadtC: Expected type constructor application"
mergeArguments :: [Name] -> [Type] -> (Map Name Name, Cxt)
mergeArguments ns ts = foldr aux (Map.empty, []) (zip ns ts)
where
aux (n,p) (subst, context) =
case p of
VarT m | Map.notMember m subst -> (Map.insert m n subst, context)
_ -> (subst, EqualityT `AppT` VarT n `AppT` p : context)
#endif
resolveTypeSynonyms :: Type -> Q Type
resolveTypeSynonyms t =
let f :| xs = decomposeType t
notTypeSynCase = foldl AppT f <$> traverse resolveTypeSynonyms xs
in case f of
ConT n ->
do info <- reify n
case info of
TyConI (TySynD _ synvars def) ->
let argNames = map tvName synvars
(args,rest) = splitAt (length argNames) xs
subst = Map.fromList (zip argNames args)
t' = foldl AppT (applySubstitution subst def) rest
in resolveTypeSynonyms t'
_ -> notTypeSynCase
_ -> notTypeSynCase
decomposeType :: Type -> NonEmpty Type
decomposeType = reverseNonEmpty . go
where
go (AppT f x ) = x <| go f
#if MIN_VERSION_template_haskell(2,11,0)
go (InfixT l f r) = ConT f :| [l,r]
go (UInfixT l f r) = ConT f :| [l,r]
go (ParensT t ) = decomposeType t
#endif
go t = t :| []
tvName :: TyVarBndr -> Name
tvName (PlainTV name ) = name
tvName (KindedTV name _) = name
takeFieldNames :: [(Name,a,b)] -> [Name]
takeFieldNames xs = [a | (a,_,_) <- xs]
takeFieldTypes :: [(a,b,Type)] -> [Type]
takeFieldTypes xs = [a | (_,_,a) <- xs]
quantifyType :: Type -> Type
quantifyType t
| null vs = t
| otherwise = ForallT (PlainTV <$> vs) [] t
where
vs = freeVariables t
freshenFreeVariables :: Type -> Q Type
freshenFreeVariables t =
do let xs = [ (n, VarT <$> newName (nameBase n)) | n <- freeVariables t]
subst <- sequenceA (Map.fromList xs)
return (applySubstitution subst t)
class TypeSubstitution a where
applySubstitution :: Map Name Type -> a -> a
freeVariables :: a -> [Name]
instance TypeSubstitution a => TypeSubstitution [a] where
freeVariables = foldMap freeVariables
applySubstitution = fmap . applySubstitution
instance TypeSubstitution Type where
applySubstitution subst = go
where
go (ForallT tvs context t) =
let subst' = foldl' (flip Map.delete) subst (map tvName tvs) in
ForallT tvs (applySubstitution subst' context)
(applySubstitution subst' t)
go (AppT f x) = AppT (go f) (go x)
go (SigT t k) = SigT (go t) (applySubstitution subst k)
go (VarT v) = Map.findWithDefault (VarT v) v subst
#if MIN_VERSION_template_haskell(2,11,0)
go (InfixT l c r) = InfixT (go l) c (go r)
go (UInfixT l c r) = UInfixT (go l) c (go r)
go (ParensT t) = ParensT (go t)
#endif
go t = t
freeVariables t =
case t of
ForallT tvs context t' ->
(freeVariables context `union` freeVariables t')
\\ map tvName tvs
AppT f x -> freeVariables f `union` freeVariables x
SigT t' k -> freeVariables t' `union` freeVariables k
VarT v -> [v]
#if MIN_VERSION_template_haskell(2,11,0)
InfixT l _ r -> freeVariables l `union` freeVariables r
UInfixT l _ r -> freeVariables l `union` freeVariables r
ParensT t' -> freeVariables t'
#endif
_ -> []
instance TypeSubstitution ConstructorInfo where
freeVariables ci =
(freeVariables (constructorContext ci) `union`
freeVariables (constructorFields ci))
\\ (tvName <$> constructorVars ci)
applySubstitution subst ci =
let subst' = foldl' (flip Map.delete) subst (map tvName (constructorVars ci)) in
ci { constructorContext = applySubstitution subst' (constructorContext ci)
, constructorFields = applySubstitution subst' (constructorFields ci)
}
#if !MIN_VERSION_template_haskell(2,10,0)
instance TypeSubstitution Pred where
freeVariables (ClassP _ xs) = freeVariables xs
freeVariables (EqualP x y) = freeVariables x `union` freeVariables y
applySubstitution p (ClassP n xs) = ClassP n (applySubstitution p xs)
applySubstitution p (EqualP x y) = EqualP (applySubstitution p x)
(applySubstitution p y)
#endif
#if !MIN_VERSION_template_haskell(2,8,0)
instance TypeSubstitution Kind where
freeVariables _ = []
applySubstitution _ k = k
#endif
combineSubstitutions :: Map Name Type -> Map Name Type -> Map Name Type
combineSubstitutions x y = Map.union (fmap (applySubstitution y) x) y
unifyTypes :: [Type] -> Q (Map Name Type)
unifyTypes [] = pure Map.empty
unifyTypes (t:ts) =
do t':ts' <- traverse resolveTypeSynonyms (t:ts)
let aux sub u =
do sub' <- unify' (applySubstitution sub t')
(applySubstitution sub u)
return (combineSubstitutions sub sub')
case foldM aux Map.empty ts' of
Right m -> return m
Left (x,y) ->
fail $ showString "Unable to unify types "
. showsPrec 11 x
. showString " and "
. showsPrec 11 y
$ ""
unify' :: Type -> Type -> Either (Type,Type) (Map Name Type)
unify' (VarT n) (VarT m) | n == m = pure Map.empty
unify' (VarT n) t | n `elem` freeVariables t = Left (VarT n, t)
| otherwise = pure (Map.singleton n t)
unify' t (VarT n) | n `elem` freeVariables t = Left (VarT n, t)
| otherwise = pure (Map.singleton n t)
unify' (ConT n) (ConT m) | n == m = pure Map.empty
unify' (AppT f1 x1) (AppT f2 x2) =
do sub1 <- unify' f1 f2
sub2 <- unify' (applySubstitution sub1 x1) (applySubstitution sub1 x2)
return (combineSubstitutions sub1 sub2)
unify' (TupleT n) (TupleT m) | n == m = pure Map.empty
unify' t u = Left (t,u)
equalPred :: Type -> Type -> Pred
equalPred x y =
#if MIN_VERSION_template_haskell(2,10,0)
AppT (AppT EqualityT x) y
#else
EqualP x y
#endif
classPred :: Name -> [Type] -> Pred
classPred =
#if MIN_VERSION_template_haskell(2,10,0)
foldl AppT . ConT
#else
ClassP
#endif
data NonEmpty a = a :| [a]
(<|) :: a -> NonEmpty a -> NonEmpty a
x <| (y :| ys) = x :| (y : ys)
reverseNonEmpty :: NonEmpty a -> NonEmpty a
reverseNonEmpty (x :| xs) = y :| ys
where y:ys = reverse (x:xs)
repair13618 :: DatatypeInfo -> Q DatatypeInfo
repair13618 info =
do s <- sequenceA (Map.fromList substList)
return info { datatypeCons = applySubstitution s (datatypeCons info) }
where
used = freeVariables (datatypeCons info)
bound = freeVariables (datatypeVars info)
free = used \\ bound
substList =
[ (u, substEntry u vs)
| u <- free
, let vs = [v | v <- bound, nameBase v == nameBase u]
]
substEntry _ [v] = varT v
substEntry u [] = fail ("Impossible free variable: " ++ show u)
substEntry u _ = fail ("Ambiguous free variable: " ++ show u)
dataDCompat ::
CxtQ ->
Name ->
[TyVarBndr] ->
[ConQ] ->
[Name] ->
DecQ
#if MIN_VERSION_template_haskell(2,12,0)
dataDCompat c n ts cs ds =
dataD c n ts Nothing cs
(if null ds then [] else [derivClause Nothing (map conT ds)])
#elif MIN_VERSION_template_haskell(2,11,0)
dataDCompat c n ts cs ds =
dataD c n ts Nothing cs
(pure (map ConT ds))
#else
dataDCompat = dataD
#endif
arrowKCompat :: Kind -> Kind -> Kind
#if MIN_VERSION_template_haskell(2,8,0)
arrowKCompat x y = arrowK `appK` x `appK` y
#else
arrowKCompat = arrowK
#endif