{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE EmptyCase #-}
module Data.Parameterized.TH.GADT
(
structuralEquality
, structuralTypeEquality
, structuralTypeOrd
, structuralTraversal
, structuralShowsPrec
, structuralHash
, PolyEq(..)
, DataD
, lookupDataType'
, asTypeCon
, conPat
, TypePat(..)
, dataParamTypes
, assocTypePats
) where
import Control.Monad
import Data.Hashable (hashWithSalt)
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Data.Parameterized.Classes
type DataD = DatatypeInfo
lookupDataType' :: Name -> Q DatatypeInfo
lookupDataType' = reifyDatatype
conPat ::
ConstructorInfo ->
String ->
Q (Pat, [Name])
conPat con pre = do
nms <- newNames pre (length (constructorFields con))
return (ConP (constructorName con) (VarP <$> nms), nms)
conExpr :: ConstructorInfo -> Exp
conExpr = ConE . constructorName
data TypePat
= TypeApp TypePat TypePat
| AnyType
| DataArg Int
| ConType TypeQ
matchTypePat :: [Type] -> TypePat -> Type -> Q Bool
matchTypePat d (TypeApp p q) (AppT x y) = do
r <- matchTypePat d p x
case r of
True -> matchTypePat d q y
False -> return False
matchTypePat _ AnyType _ = return True
matchTypePat tps (DataArg i) tp
| i < 0 || i > length tps = error $ "Illegal type pattern index " ++ show i
| otherwise = do
return $ stripSigT (tps !! i) == tp
where
stripSigT (SigT t _) = t
stripSigT t = t
matchTypePat _ (ConType tpq) tp = do
tp' <- tpq
return (tp' == tp)
matchTypePat _ _ _ = return False
dataParamTypes :: DatatypeInfo -> [Type]
dataParamTypes = datatypeVars
assocTypePats :: [Type] -> [(TypePat,v)] -> Type -> Q (Maybe v)
assocTypePats _ [] _ = return Nothing
assocTypePats dTypes ((p,v):pats) tp = do
r <- matchTypePat dTypes p tp
case r of
True -> return (Just v)
False -> assocTypePats dTypes pats tp
typeVars :: TypeSubstitution a => a -> Set Name
typeVars = Set.fromList . freeVariables
structuralEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
structuralEquality tpq pats =
[| \x y -> isJust ($(structuralTypeEquality tpq pats) x y) |]
joinEqMaybe :: Name -> Name -> ExpQ -> ExpQ
joinEqMaybe x y r = do
[| if $(varE x) == $(varE y) then $(r) else Nothing |]
joinTestEquality :: ExpQ -> Name -> Name -> ExpQ -> ExpQ
joinTestEquality f x y r =
[| case $(f) $(varE x) $(varE y) of
Nothing -> Nothing
Just Refl -> $(r)
|]
matchEqArguments :: [Type]
-> [(TypePat,ExpQ)]
-> Name
-> Set Name
-> [Type]
-> [Name]
-> [Name]
-> ExpQ
matchEqArguments dTypes pats cnm bnd (tp:tpl) (x:xl) (y:yl) = do
doesMatch <- assocTypePats dTypes pats tp
case doesMatch of
Just q -> do
let bnd' =
case tp of
AppT _ (VarT nm) -> Set.insert nm bnd
_ -> bnd
joinTestEquality q x y (matchEqArguments dTypes pats cnm bnd' tpl xl yl)
Nothing | typeVars tp `Set.isSubsetOf` bnd -> do
joinEqMaybe x y (matchEqArguments dTypes pats cnm bnd tpl xl yl)
Nothing -> do
fail $ "Unsupported argument type " ++ show tp
++ " in " ++ show (ppr cnm) ++ "."
matchEqArguments _ _ _ _ [] [] [] = [| Just Refl |]
matchEqArguments _ _ _ _ [] _ _ = error "Unexpected end of types."
matchEqArguments _ _ _ _ _ [] _ = error "Unexpected end of names."
matchEqArguments _ _ _ _ _ _ [] = error "Unexpected end of names."
mkSimpleEqF :: [Type]
-> Set Name
-> [(TypePat,ExpQ)]
-> ConstructorInfo
-> [Name]
-> ExpQ
-> Bool
-> ExpQ
mkSimpleEqF dTypes bnd pats con xv yQ multipleCases = do
let nm = constructorName con
(yp,yv) <- conPat con "y"
let rv = matchEqArguments dTypes pats nm bnd (constructorFields con) xv yv
caseE yQ $ match (pure yp) (normalB rv) []
: [ match wildP (normalB [| Nothing |]) [] | multipleCases ]
mkEqF :: DatatypeInfo
-> [(TypePat,ExpQ)]
-> ConstructorInfo
-> [Name]
-> ExpQ
-> Bool
-> ExpQ
mkEqF d pats con =
let dVars = datatypeVars d
bnd | null dVars = Set.empty
| otherwise = typeVars (init dVars)
in mkSimpleEqF dVars bnd pats con
structuralTypeEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
structuralTypeEquality tpq pats = do
d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq
let multipleCons = not (null (drop 1 (datatypeCons d)))
trueEqs yQ = [ do (xp,xv) <- conPat con "x"
match (pure xp) (normalB (mkEqF d pats con xv yQ multipleCons)) []
| con <- datatypeCons d
]
if null (datatypeCons d)
then [| \x -> case x of {} |]
else [| \x y -> $(caseE [| x |] (trueEqs [| y |])) |]
structuralTypeOrd ::
TypeQ ->
[(TypePat,ExpQ)] ->
ExpQ
structuralTypeOrd tpq l = do
d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq
let withNumber :: ExpQ -> (Maybe ExpQ -> ExpQ) -> ExpQ
withNumber yQ k
| null (drop 1 (datatypeCons d)) = k Nothing
| otherwise = [| let yn :: Int
yn = $(caseE yQ (constructorNumberMatches (datatypeCons d)))
in $(k (Just [| yn |])) |]
if null (datatypeCons d)
then [| \x -> case x of {} |]
else [| \x y -> $(withNumber [|y|] $ \mbYn -> caseE [| x |] (outerOrdMatches d [|y|] mbYn)) |]
where
constructorNumberMatches :: [ConstructorInfo] -> [MatchQ]
constructorNumberMatches cons =
[ match (recP (constructorName con) [])
(normalB (litE (integerL i)))
[]
| (i,con) <- zip [0..] cons ]
outerOrdMatches :: DatatypeInfo -> ExpQ -> Maybe ExpQ -> [MatchQ]
outerOrdMatches d yExp mbYn =
[ do (pat,xv) <- conPat con "x"
match (pure pat)
(normalB (do xs <- mkOrdF d l con i mbYn xv
caseE yExp xs))
[]
| (i,con) <- zip [0..] (datatypeCons d) ]
newNames ::
String ->
Int ->
Q [Name]
newNames base n = traverse (\i -> newName (base ++ show i)) [1..n]
joinCompareF :: ExpQ -> Name -> Name -> ExpQ -> ExpQ
joinCompareF f x y r = do
[| case $(f) $(varE x) $(varE y) of
LTF -> LTF
GTF -> GTF
EQF -> $(r)
|]
joinCompareToOrdF :: Name -> Name -> ExpQ -> ExpQ
joinCompareToOrdF x y r =
[| case compare $(varE x) $(varE y) of
LT -> LTF
GT -> GTF
EQ -> $(r)
|]
matchOrdArguments :: [Type]
-> [(TypePat,ExpQ)]
-> Name
-> Set Name
-> [Type]
-> [Name]
-> [Name]
-> ExpQ
matchOrdArguments dTypes pats cnm bnd (tp : tpl) (x:xl) (y:yl) = do
doesMatch <- assocTypePats dTypes pats tp
case doesMatch of
Just f -> do
let bnd' = case tp of
AppT _ (VarT nm) -> Set.insert nm bnd
_ -> bnd
joinCompareF f x y (matchOrdArguments dTypes pats cnm bnd' tpl xl yl)
Nothing | typeVars tp `Set.isSubsetOf` bnd -> do
joinCompareToOrdF x y (matchOrdArguments dTypes pats cnm bnd tpl xl yl)
Nothing ->
fail $ "Unsupported argument type " ++ show (ppr tp)
++ " in " ++ show (ppr cnm) ++ "."
matchOrdArguments _ _ _ _ [] [] [] = [| EQF |]
matchOrdArguments _ _ _ _ [] _ _ = error "Unexpected end of types."
matchOrdArguments _ _ _ _ _ [] _ = error "Unexpected end of names."
matchOrdArguments _ _ _ _ _ _ [] = error "Unexpected end of names."
mkSimpleOrdF :: [Type]
-> [(TypePat,ExpQ)]
-> ConstructorInfo
-> Integer
-> Maybe ExpQ
-> [Name]
-> Q [MatchQ]
mkSimpleOrdF dTypes pats con xnum mbYn xv = do
(yp,yv) <- conPat con "y"
let rv = matchOrdArguments dTypes pats (constructorName con) Set.empty (constructorFields con) xv yv
return $ match (pure yp) (normalB rv) []
: case mbYn of
Nothing -> []
Just yn -> [match wildP (normalB [| if xnum < $yn then LTF else GTF |]) []]
mkOrdF :: DatatypeInfo
-> [(TypePat,ExpQ)]
-> ConstructorInfo
-> Integer
-> Maybe ExpQ
-> [Name]
-> Q [MatchQ]
mkOrdF d pats = mkSimpleOrdF (datatypeVars d) pats
recurseArg :: (Type -> Q (Maybe ExpQ))
-> ExpQ
-> ExpQ
-> Type
-> Q (Maybe Exp)
recurseArg m f v tp = do
mr <- m tp
case mr of
Just g -> Just <$> [| $(g) $(f) $(v) |]
Nothing ->
case tp of
AppT (ConT _) (AppT (VarT _) _) -> Just <$> [| traverse $(f) $(v) |]
AppT (VarT _) _ -> Just <$> [| $(f) $(v) |]
_ -> return Nothing
traverseAppMatch :: (Type -> Q (Maybe ExpQ))
-> ExpQ
-> ConstructorInfo
-> MatchQ
traverseAppMatch pats fv c0 = do
(pat,patArgs) <- conPat c0 "p"
exprs <- zipWithM (recurseArg pats fv) (varE <$> patArgs) (constructorFields c0)
let mkRes :: ExpQ -> [(Name, Maybe Exp)] -> ExpQ
mkRes e [] = e
mkRes e ((v,Nothing):r) =
mkRes (appE e (varE v)) r
mkRes e ((_,Just{}):r) = do
v <- newName "r"
lamE [varP v] (mkRes (appE e (varE v)) r)
let applyRest :: ExpQ -> [Exp] -> ExpQ
applyRest e [] = e
applyRest e (a:r) = applyRest [| $(e) <*> $(pure a) |] r
let applyFirst :: ExpQ -> [Exp] -> ExpQ
applyFirst e [] = [| pure $(e) |]
applyFirst e (a:r) = applyRest [| $(e) <$> $(pure a) |] r
let pargs = patArgs `zip` exprs
let rhs = applyFirst (mkRes (pure (conExpr c0)) pargs) (catMaybes exprs)
match (pure pat) (normalB rhs) []
structuralTraversal :: TypeQ -> [(TypePat, ExpQ)] -> ExpQ
structuralTraversal tpq pats0 = do
d <- reifyDatatype =<< asTypeCon "structuralTraversal" =<< tpq
f <- newName "f"
a <- newName "a"
lamE [varP f, varP a] $
caseE (varE a)
(traverseAppMatch (assocTypePats (datatypeVars d) pats0) (varE f) <$> datatypeCons d)
asTypeCon :: Monad m => String -> Type -> m Name
asTypeCon _ (ConT nm) = return nm
asTypeCon fn _ = fail $ fn ++ " expected type constructor."
structuralHash :: TypeQ -> ExpQ
structuralHash tpq = do
d <- reifyDatatype =<< asTypeCon "structuralHash" =<< tpq
s <- newName "s"
a <- newName "a"
lamE [varP s, varP a] $
caseE (varE a) (zipWith (matchHashCtor (varE s)) [0..] (datatypeCons d))
matchHashCtor :: ExpQ -> Integer -> ConstructorInfo -> MatchQ
matchHashCtor s0 i c = do
(pat,vars) <- conPat c "x"
let args = [| $(litE (IntegerL i)) :: Int |] : (varE <$> vars)
let go s e = [| hashWithSalt $(s) $(e) |]
let rhs = foldl go s0 args
match (pure pat) (normalB rhs) []
structuralShowsPrec :: TypeQ -> ExpQ
structuralShowsPrec tpq = do
d <- reifyDatatype =<< asTypeCon "structuralShowPrec" =<< tpq
p <- newName "_p"
a <- newName "a"
lamE [varP p, varP a] $
caseE (varE a) (matchShowCtor (varE p) <$> datatypeCons d)
showCon :: ExpQ -> Name -> Int -> MatchQ
showCon p nm n = do
vars <- newNames "x" n
let pat = ConP nm (VarP <$> vars)
let go s e = [| $(s) . showChar ' ' . showsPrec 10 $(varE e) |]
let ctor = [| showString $(return (LitE (StringL (nameBase nm)))) |]
let rhs | null vars = ctor
| otherwise = [| showParen ($(p) >= 10) $(foldl go ctor vars) |]
match (pure pat) (normalB rhs) []
matchShowCtor :: ExpQ -> ConstructorInfo -> MatchQ
matchShowCtor p con = showCon p (constructorName con) (length (constructorFields con))