{-# LANGUAGE NoMonomorphismRestriction, TypeFamilies, FlexibleInstances , MultiParamTypeClasses, FunctionalDependencies #-} {-| Module : Language.Haskell.TH.Alpha Description : Alpha equivalence in TH Copyright : (c) Julian K. Arni, 2014 License : BSD3 Maintainer : jkarni@gmail.com Stability : experimental Compare TH expressions (or clauses, patterns, etc.) for alpha equivalence. That is, compare for equality modulo the renaming of bound variables. >>> areExpEq [| \x -> x |] [| \y -> y |] True This can be useful when for instance testing libraries that use Template Haskell: usually correctness is only defined up to alpha equivalence. For most cases, 'areExpEq' is the only function you'll need. The 'AlphaEq' class is only exported to make it easier to expand alpha-equality comparison to other instances, or to be used (in combination with the package __th-desugar__) for alpha equality comparisons of non-expression types (e.g. TyVarBndr). /N.B.:/ This package doesn't yet handle type annotations correctly! -} module Language.Haskell.TH.Alpha ( areExpEq, exp_equal, AlphaEq(..) ) where import Language.Haskell.TH import Language.Haskell.TH.Syntax (Quasi, returnQ) import Language.Haskell.TH.Desugar import Data.Function (on) import Control.Monad (liftM, liftM2, liftM3, join, foldM) import Data.Data (toConstr, Data) import Data.Maybe (isJust) -- A poor man's bound variable lookup table. type Lookup = ([(Name,Int)], [(Name,Int)], Int) -- | The main Alpha Equivalence class. '@=' is by default defined in terms -- of 'lkEq'. 'lkEq' is exposed for composability: it is easy to -- recursively build 'AlphaEq' instances from other 'AlphaEq' instances by -- delegating the lookup update to the subinstances. class AlphaEq a where -- | Compares its arguments for alpha equivalence. (@=) :: a -> a -> Bool -- | Given a variable binding lookup compares arguments for alpha -- equivalence, returning Just of updated lookup in case of -- equivalence, Nothing otherwise. lkEq :: a -> a -> Lookup -> Maybe Lookup x @= y = isJust $ lkEq x y ([], [], 0) --------------------------------------------------------------------------- -- Exp --------------------------------------------------------------------------- -- | Convenience function that uses 'runQ' on 'exp_equal'. -- -- >>> areExpEq [| let x = 5 in x |] [| let y = 5 in y |] -- True areExpEq :: Quasi m => ExpQ -- ^ Quoted expression -> ExpQ -- ^ Quoted expression -> m Bool areExpEq e1 e2 = let expM = (join .) . liftM2 exp_equal in expM (runQ e1) (runQ e2) -- | Compare two expressions for alpha-equivalence. Since this uses -- th-desugar to desugar the expressions, returns a Bool in the Quasi -- context. exp_equal :: Quasi m => Exp -> Exp -> m Bool exp_equal t1 t2 = (liftM3 exp_equal') (dsExp t1) (dsExp t2) (return ([], [], 0)) instance AlphaEq DExp where lkEq a b lk = if exp_equal' a b lk then Just lk else Nothing exp_equal' :: DExp -> DExp -> Lookup -> Bool exp_equal' (DVarE a) (DVarE b) (m1,m2,_) = lookup a m1 == lookup b m2 exp_equal' (DConE a) (DConE b) (m1,m2,_) = lookup a m1 == lookup b m2 && (isJust $ lookup a m1) exp_equal' (DLitE l1) (DLitE l2) _ = l1 == l2 exp_equal' (DAppE a1 a2) (DAppE b1 b2) c = (exp_equal' a1 b1 c) && (exp_equal' a2 b2 c) exp_equal' (DLamE a1 a2) (DLamE b1 b2) (m1,m2,cnt) = if ((/=) `on` length) a1 b1 then False else exp_equal' a2 b2 ((ato a1 ++ m1),(ato b1 ++ m2), l) where ato x = zip x [cnt..] l = cnt + length a1 exp_equal' (DCaseE a1 a2) (DCaseE b1 b2) c = if length a2 == length b2 then exp_equal' a1 b1 c && (any id $ zipWith mec a2 b2) else False where mec x y = match_equal x y c exp_equal' (DLetE a1 a2) (DLetE b1 b2) c = isJust (foldM lkEqC c (zip a1 b1) >>= lkEq a2 b2) where lkEqC l (a,b) = lkEq a b l exp_equal' (DSigE a1 a2) (DSigE b1 b2) c@(m1,m2,_) = lkEqB a1 b1 c && lkEqB a2 b2 c exp_equal' _ _ _ = False --------------------------------------------------------------------------- -- Match --------------------------------------------------------------------------- match_equal :: DMatch -> DMatch -> Lookup -> Bool match_equal (DMatch pat1 exp1) (DMatch pat2 exp2) c = case lkEq pat1 pat2 c of Just d -> exp_equal' exp1 exp2 d Nothing -> False --------------------------------------------------------------------------- -- LetDec --------------------------------------------------------------------------- instance AlphaEq DLetDec where lkEq = letDec_equal letDec_equal :: DLetDec -> DLetDec -> Lookup -> Maybe Lookup letDec_equal (DFunD n1 cls1) (DFunD n2 cls2) c = if n1 == n2 then foldM lkEqC c (zip cls1 cls2) else Nothing where lkEqC l (a,b) = lkEq a b l letDec_equal (DValD pat1 exp1) (DValD pat2 exp2) c = lkEq exp1 exp2 c >>= lkEq pat1 pat2 letDec_equal (DSigD name1 typ1) (DSigD name2 typ2) c@(m1,m2,_) = -- Hard to tell how the name will be bound, so just check types lkEq typ1 typ2 c letDec_equal (DInfixD fx1 name1) (DInfixD fx2 name2) c = if fx1 == fx2 && name1 == name2 then Just c else Nothing letDec_equal _ _ _ = Nothing --------------------------------------------------------------------------- -- LetDec --------------------------------------------------------------------------- instance AlphaEq DType where lkEq = type_equal -- TODO: type_equal :: DType -> DType -> Lookup -> Maybe Lookup type_equal (DForallT tybs1 ctx1 typ1) (DForallT tybs2 ctx2 typ2) c = do nlk <- type_equal typ1 typ2 c if all (\y -> cmpTYvar y nlk) (zip tybs1 tybs2) then Just nlk else Nothing where cmpTYvar ((DPlainTV n1),(DPlainTV n2)) c' = cmpLk n1 n2 c' cmpTYvar ((DKindedTV n1 k1),(DKindedTV n2 k2)) c' = cmpLk n1 n2 c' && lkEqB k1 k2 c' cmpTYvar _ _ = False type_equal (DAppT ty1 arg1) (DAppT ty2 arg2) c = undefined type_equal (DSigT ty1 knd1) (DAppT ty2 knd2) c = undefined type_equal (DVarT n1) (DVarT n2) c = undefined --------------------------------------------------------------------------- -- Kind --------------------------------------------------------------------------- -- TODO: instance AlphaEq DKind where lkEq = undefined --------------------------------------------------------------------------- -- Clause --------------------------------------------------------------------------- instance AlphaEq DClause where lkEq = clause_equal clause_equal :: DClause -> DClause -> Lookup -> Maybe Lookup clause_equal (DClause pats1 exp1) (DClause pats2 exp2) lk = pat_res >>= lkEq exp1 exp2 where lkEqC l (a,b) = lkEq a b l pat_res = foldM lkEqC lk (zip pats1 pats2) --------------------------------------------------------------------------- -- Pat --------------------------------------------------------------------------- instance AlphaEq DPat where lkEq = pat_equal pat_equal :: DPat -> DPat -> Lookup -> Maybe Lookup pat_equal (DLitPa lit1) (DLitPa lit2) c = if lit1 == lit2 then Just c else Nothing pat_equal (DVarPa n1) (DVarPa n2) c = Just (addn n1 n2 c) where addn x y (m1,m2,i) = ((x,i):m1,(y,i):m2,i+1) pat_equal (DConPa n1 p1) (DConPa n2 p2) c@(m1,m2,i) = if (lookup n1 m1 == lookup n2 m2 && length p1 == length p2) then foldM cmbn c (zip p1 p2) -- Does this allow bindings across patterns? else Nothing where cmbn cn (x,y) = pat_equal x y c pat_equal (DTildePa pat1) (DTildePa pat2) c = pat_equal pat1 pat2 c pat_equal (DBangPa pat1) (DBangPa pat2) c = pat_equal pat1 pat2 c pat_equal DWildPa DWildPa c = Just c pat_equal _ _ _ = Nothing --------------------------------------------------------------------------- -- Utils --------------------------------------------------------------------------- fst3 (a,_,_) = a snd3 (_,b,_) = b thrd3 (_,_,c) = c cmpLk a b (m1,m2,_) = lookup a m1 == lookup b m2 cmpLkC (a,b) c = cmpLk a b c lkEqB a b c = isJust $ lkEq a b c