module Language.Haskell.TH.Alpha (
areExpAEq,
expEqual,
(@=),
AlphaEq(..)
) where
import Language.Haskell.TH
#if !MIN_VERSION_th_desugar(1,5,0)
import Language.Haskell.TH.Syntax (Quasi)
#endif
import Language.Haskell.TH.Desugar
import Data.Function (on)
import Control.Monad.State
import Control.Monad.Identity
import Control.Monad.Trans.Maybe
import Control.Monad.Morph
import Data.Maybe (isJust)
import qualified Data.Map as Map
import Control.Applicative
type Lookup = (Map.Map Name Int, Map.Map Name Int, Int)
emptyLookup :: Lookup
emptyLookup = (Map.empty, Map.empty, 0)
data LookupTbl = LookupTbl
{ insertLR :: Name -> Name -> LookupTbl
, eqInTbl :: Name -> Name -> Bool
, isInL :: Name -> Bool
, isInR :: Name -> Bool
}
mapLookup :: Lookup -> LookupTbl
mapLookup !(ls,rs,cnt) = LookupTbl
{ insertLR = \a b -> mapLookup (Map.insert a cnt ls,
Map.insert b cnt rs,
cnt + 1)
, eqInTbl = \a b -> Map.lookup a ls == Map.lookup b rs
, isInL = \a -> isJust $ Map.lookup a ls
, isInR = \b -> isJust $ Map.lookup b rs
}
newtype LookupSTM m b = LookupST {
unLookupST :: StateT LookupTbl (MaybeT m) b
} deriving (Functor, Applicative, Monad, MonadState LookupTbl
, MonadPlus, Alternative)
instance MonadTrans (LookupSTM) where
lift m = LookupST $ StateT (\tbl -> MaybeT $ m >>= \x -> return $ Just (x, tbl))
hoist' :: (Monad m) => (forall a . m a -> n a) -> LookupSTM m b -> LookupSTM n b
hoist' nat lkstm = LookupST $ StateT (\tbl -> MaybeT . nat . runMaybeT $ runStateT (unLookupST lkstm) tbl)
instance MFunctor LookupSTM where
hoist = hoist'
toQ :: LookupST b -> LookupSTQ b
toQ = hoist generalize
type LookupST b = LookupSTM Identity b
type LookupSTQ b = LookupSTM Q b
runLookupST :: Monad m => LookupSTM m a -> LookupTbl -> m (Maybe (a, LookupTbl))
runLookupST st tbl = runMaybeT $ runStateT (unLookupST st) tbl
runLookupST' :: LookupST a -> LookupTbl -> Maybe (a, LookupTbl)
runLookupST' = (runIdentity .) . runLookupST
class AlphaEq a m | a -> m where
lkEq :: a -> a -> LookupSTM m ()
(@=) :: (Monad m, AlphaEq a m) => a -> a -> m Bool
x @= y = liftM isJust $ runLookupST (lkEq x y) (mapLookup emptyLookup)
infix 4 @=
areExpAEq ::
#if MIN_VERSION_th_desugar(1,5,0)
DsMonad m
#else
Quasi m
#endif
=> ExpQ
-> ExpQ
-> m Bool
areExpAEq e1 e2 = let expM = (join .) . liftM2 expEqual
in expM (runQ e1) (runQ e2)
instance AlphaEq Exp Q where
lkEq e1 e2 = do
e1' <- lift $ dsExp e1
e2' <- lift $ dsExp e2
toQ $ expEqual' e1' e2'
expEqual ::
#if MIN_VERSION_th_desugar(1,5,0)
DsMonad m
#else
Quasi m
#endif
=> Exp
-> Exp
-> m Bool
expEqual t1 t2 = do
t1' <- dsExp t1
t2' <- dsExp t2
let lkt = mapLookup emptyLookup
return $ isJust $ runLookupST' (lkEq t1' t2') lkt
instance AlphaEq DExp Identity where
lkEq = expEqual'
expEqual' :: DExp -> DExp -> LookupST ()
expEqual' (DVarE a1 ) (DVarE a2 ) = a1 ~=~ a2
expEqual' (DConE a1 ) (DConE a2 ) = a1 ~=~ a2
expEqual' (DLitE l1 ) (DLitE l2 ) = guard $ l1 == l2
expEqual' (DAppE a1 b1 ) (DAppE a2 b2 ) = lkEq a1 a2 >> lkEq b1 b2
expEqual' (DLamE a1 b1 ) (DLamE a2 b2 ) = do
guard $ ((==) `on` length) a1 a2
zipWithM_ insertLRLST a1 a2
lkEq b1 b2
return ()
expEqual' (DCaseE a1 b1) (DCaseE a2 b2) = do
guard $ length b1 == length b2
lkEq a1 a2
zipWithM_ lkEq b1 b2
return ()
expEqual' (DLetE a1 b1 ) (DLetE a2 b2 ) = zipWithM_ lkEq a1 a2 >> lkEq b1 b2
expEqual' (DSigE a1 b1 ) (DSigE a2 b2 ) = lkEq a1 a2 >> lkEq b1 b2
expEqual' _ _ = mzero
instance AlphaEq DMatch Identity where
lkEq = matchEqual
matchEqual :: DMatch -> DMatch -> LookupST ()
matchEqual (DMatch pat1 exp1) (DMatch pat2 exp2) = lkEq pat1 pat2
>> lkEq exp1 exp2
instance AlphaEq DLetDec Identity where
lkEq = letDecEqual
letDecEqual :: DLetDec -> DLetDec -> LookupST ()
letDecEqual (DFunD n1 cls1 ) (DFunD n2 cls2 ) = do
guard $ n1 == n2
zipWithM_ lkEq cls1 cls2
letDecEqual (DValD pat1 exp1 ) (DValD pat2 exp2 ) =
lkEq exp1 exp2 >> lkEq pat1 pat2
letDecEqual (DSigD _name1 typ1) (DSigD _name2 typ2) =
lkEq typ1 typ2
letDecEqual (DInfixD fx1 name1) (DInfixD fx2 name2) = guard $ fx1 == fx2
&& name1 == name2
letDecEqual _ _ = mzero
instance AlphaEq DType Identity where
lkEq = typeEqual
typeEqual :: DType -> DType -> LookupST ()
typeEqual (DForallT tybs1 ctx1 typ1) (DForallT tybs2 ctx2 typ2) = do
zipWithM_ insertLRLSTty tybs1 tybs2
zipWithM_ lkEq ctx1 ctx2
lkEq typ1 typ2
typeEqual (DAppT ty1 arg1 ) (DAppT ty2 arg2 ) =
lkEq ty1 ty2 >> lkEq arg1 arg2
typeEqual (DSigT ty1 knd1 ) (DSigT ty2 knd2 ) = do
guard $ show knd1 == show knd2
lkEq ty1 ty2
typeEqual (DConT n1 ) (DConT n2 ) =
guard $ show n1 == show n2
typeEqual (DVarT n1 ) (DVarT n2 ) =
n1 ~=~ n2
typeEqual _ _ = mzero
#if !MIN_VERSION_th_desugar(1,6,0)
instance AlphaEq DKind Identity where
lkEq = kindEqual
kindEqual :: DKind -> DKind -> LookupST ()
kindEqual (DForallK ns1 typ1 ) (DForallK ns2 typ2 ) = do
zipWithM_ insertLRLST ns1 ns2
lkEq typ1 typ2
kindEqual (DVarK n1 ) (DVarK n2 ) = n1 ~=~ n2
kindEqual (DArrowK knda1 kndb1) (DArrowK knda2 kndb2) = lkEq knda1 knda2
>> lkEq kndb1 kndb2
kindEqual DStarK DStarK = return ()
kindEqual _ _ = mzero
#endif
instance AlphaEq DClause Identity where
lkEq = clauseEqual
clauseEqual :: DClause -> DClause -> LookupST ()
clauseEqual (DClause pats1 exp1) (DClause pats2 exp2) =
zipWithM_ lkEq pats1 pats2 >> lkEq exp1 exp2
instance AlphaEq DPred Identity where
lkEq = predEqual
predEqual :: DPred -> DPred -> LookupST ()
predEqual (DAppPr pred1 typ1 ) (DAppPr pred2 typ2 ) = lkEq pred1 pred2
>> lkEq typ1 typ2
predEqual (DSigPr pred1 kind1) (DSigPr pred2 kind2) = lkEq pred1 pred2
>> lkEq kind1 kind2
predEqual (DVarPr n1 ) (DVarPr n2 ) = n1 ~=~ n2
predEqual (DConPr n1 ) (DConPr n2 ) = n1 ~=~ n2
predEqual _ _ = mzero
instance AlphaEq DPat Identity where
lkEq = patEqual
patEqual :: DPat -> DPat -> LookupST ()
patEqual (DLitPa lit1 ) (DLitPa lit2 ) = guard $ lit1 == lit2
patEqual (DVarPa n1 ) (DVarPa n2 ) = insertLRLST n1 n2
patEqual (DConPa n1 p1 ) (DConPa n2 p2 ) = do
n1 ~=~ n2
guard $ length p1 == length p2
zipWithM_ lkEq p1 p2
patEqual (DTildePa pat1) (DTildePa pat2) = lkEq pat1 pat2
patEqual (DBangPa pat1 ) (DBangPa pat2 ) = lkEq pat1 pat2
patEqual DWildPa DWildPa = return ()
patEqual _ _ = mzero
(~=~) :: Name -> Name -> LookupST ()
a ~=~ b = do
tbl <- get
guard $ eqInTbl tbl a b
bol <- isInL' a
unless bol $ guard $ show a == show b
isInL' :: Name -> LookupST Bool
isInL' n = do
tbl <- get
return $ isInL tbl n
insertLRLST :: Name -> Name -> LookupST ()
insertLRLST a b = modify $ \tbl -> insertLR tbl a b
insertLRLSTty :: DTyVarBndr -> DTyVarBndr -> LookupST ()
insertLRLSTty (DPlainTV n1 ) (DPlainTV n2 ) = insertLRLST n1 n2
insertLRLSTty (DKindedTV n1 k1) (DKindedTV n2 k2) = do
guard $ show k1 == show k2
insertLRLST n1 n2
insertLRLSTty _ _ = mzero