module Language.Hakaru.Syntax.AST.Eq where
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.TypeOf
import Control.Monad.Reader
import qualified Data.Foldable as F
import qualified Data.List.NonEmpty as L
import qualified Data.Sequence as S
import qualified Data.Traversable as T
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
import Data.Traversable
#endif
import Data.Maybe
import Unsafe.Coerce
jmEq_S
:: (ABT Term abt, JmEq2 abt)
=> SCon args a -> SArgs abt args
-> SCon args' a' -> SArgs abt args'
-> Maybe (TypeEq a a', TypeEq args args')
jmEq_S Lam_ es Lam_ es' =
jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S App_ es App_ es' =
jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S Let_ es Let_ es' =
jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S (CoerceTo_ c) (es :* End) (CoerceTo_ c') (es' :* End) = do
(Refl, Refl) <- jmEq2 es es'
let t1 = coerceTo c (typeOf es)
let t2 = coerceTo c' (typeOf es')
Refl <- jmEq1 t1 t2
return (Refl, Refl)
jmEq_S (UnsafeFrom_ c) (es :* End) (UnsafeFrom_ c') (es' :* End) = do
(Refl, Refl) <- jmEq2 es es'
let t1 = coerceFrom c (typeOf es)
let t2 = coerceFrom c' (typeOf es')
Refl <- jmEq1 t1 t2
return (Refl, Refl)
jmEq_S (PrimOp_ op) es (PrimOp_ op') es' = do
Refl <- jmEq1 es es'
(Refl, Refl) <- jmEq2 op op'
return (Refl, Refl)
jmEq_S (ArrayOp_ op) es (ArrayOp_ op') es' = do
Refl <- jmEq1 es es'
(Refl, Refl) <- jmEq2 op op'
return (Refl, Refl)
jmEq_S (MeasureOp_ op) es (MeasureOp_ op') es' = do
Refl <- jmEq1 es es'
(Refl, Refl) <- jmEq2 op op'
return (Refl, Refl)
jmEq_S Dirac es Dirac es' =
jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S MBind es MBind es' =
jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S Integrate es Integrate es' =
jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S (Summate h1 h2) es (Summate h1' h2') es' = do
Refl <- jmEq1 (sing_HDiscrete h1) (sing_HDiscrete h1')
Refl <- jmEq1 (sing_HSemiring h2) (sing_HSemiring h2')
Refl <- jmEq1 es es'
Just (Refl, Refl)
jmEq_S Expect es Expect es' =
jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S _ _ _ _ = Nothing
jmEq_Branch
:: (ABT Term abt, JmEq2 abt)
=> [(Branch a abt b, Branch a abt b')]
-> Maybe (TypeEq b b')
jmEq_Branch [] = Nothing
jmEq_Branch [(Branch pat e, Branch pat' e')] = do
(Refl, Refl) <- jmEq2 e e'
return Refl
jmEq_Branch ((Branch pat e, Branch pat' e'):es) = do
(Refl, Refl) <- jmEq2 e e'
jmEq_Branch es
instance JmEq2 abt => JmEq1 (SArgs abt) where
jmEq1 End End = Just Refl
jmEq1 (x :* xs) (y :* ys) =
jmEq2 x y >>= \(Refl, Refl) ->
jmEq1 xs ys >>= \Refl ->
Just Refl
jmEq1 _ _ = Nothing
instance (ABT Term abt, JmEq2 abt) => JmEq1 (Term abt) where
jmEq1 (o :$ es) (o' :$ es') = do
(Refl, Refl) <- jmEq_S o es o' es'
return Refl
jmEq1 (NaryOp_ o es) (NaryOp_ o' es') = do
Refl <- jmEq1 o o'
() <- all_jmEq2 es es'
return Refl
jmEq1 (Literal_ v) (Literal_ w) = jmEq1 v w
jmEq1 (Empty_ a) (Empty_ b) = jmEq1 a b
jmEq1 (Array_ i f) (Array_ j g) = do
(Refl, Refl) <- jmEq2 i j
(Refl, Refl) <- jmEq2 f g
Just Refl
jmEq1 (Datum_ (Datum hint _ _)) (Datum_ (Datum hint' _ _))
| hint == hint' = unsafeCoerce (Just Refl)
| otherwise = Nothing
jmEq1 (Case_ a bs) (Case_ a' bs') = do
(Refl, Refl) <- jmEq2 a a'
jmEq_Branch (zip bs bs')
jmEq1 (Superpose_ pms) (Superpose_ pms') = do
(Refl,Refl) L.:| _ <- T.sequence $ fmap jmEq_Tuple (L.zip pms pms')
return Refl
jmEq1 _ _ = Nothing
all_jmEq2
:: (ABT Term abt, JmEq2 abt)
=> S.Seq (abt '[] a)
-> S.Seq (abt '[] a)
-> Maybe ()
all_jmEq2 xs ys =
let eq x y = isJust (jmEq2 x y)
in if F.and (S.zipWith eq xs ys) then Just () else Nothing
jmEq_Tuple :: (ABT Term abt, JmEq2 abt)
=> ((abt '[] a , abt '[] b),
(abt '[] a', abt '[] b'))
-> Maybe (TypeEq a a', TypeEq b b')
jmEq_Tuple ((a,b), (a',b')) = do
a'' <- jmEq2 a a' >>= (\(Refl, Refl) -> Just Refl)
b'' <- jmEq2 b b' >>= (\(Refl, Refl) -> Just Refl)
return (a'', b'')
instance (ABT Term abt, JmEq2 abt) => Eq1 (Term abt) where
eq1 x y = isJust (jmEq1 x y)
instance (ABT Term abt, JmEq2 abt) => Eq (Term abt a) where
(==) = eq1
instance ( Show1 (Sing :: k -> *)
, JmEq1 (Sing :: k -> *)
, JmEq1 (syn (TrivialABT syn))
, Foldable21 syn
) => JmEq2 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *))
where
jmEq2 x y =
case (viewABT x, viewABT y) of
(Syn t1, Syn t2) ->
jmEq1 t1 t2 >>= \Refl -> Just (Refl, Refl)
(Var (Variable _ _ t1), Var (Variable _ _ t2)) ->
jmEq1 t1 t2 >>= \Refl -> Just (Refl, Refl)
(Bind (Variable _ _ x1) v1, Bind (Variable _ _ x2) v2) -> do
Refl <- jmEq1 x1 x2
(Refl,Refl) <- jmEq2 (unviewABT v1) (unviewABT v2)
return (Refl, Refl)
_ -> Nothing
instance ( Show1 (Sing :: k -> *)
, JmEq1 (Sing :: k -> *)
, JmEq1 (syn (TrivialABT syn))
, Foldable21 syn
) => JmEq1 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *) xs)
where
jmEq1 x y = jmEq2 x y >>= \(Refl, Refl) -> Just Refl
instance ( Show1 (Sing :: k -> *)
, JmEq1 (Sing :: k -> *)
, Foldable21 syn
, JmEq1 (syn (TrivialABT syn))
) => Eq2 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *))
where
eq2 x y = isJust (jmEq2 x y)
instance ( Show1 (Sing :: k -> *)
, JmEq1 (Sing :: k -> *)
, Foldable21 syn
, JmEq1 (syn (TrivialABT syn))
) => Eq1 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *) xs)
where
eq1 = eq2
instance ( Show1 (Sing :: k -> *)
, JmEq1 (Sing :: k -> *)
, Foldable21 syn
, JmEq1 (syn (TrivialABT syn))
) => Eq (TrivialABT (syn :: ([k] -> k -> *) -> k -> *) xs a)
where
(==) = eq1
type Varmap = Assocs (Variable :: Hakaru -> *)
void_jmEq1
:: Sing (a :: Hakaru)
-> Sing (b :: Hakaru)
-> ReaderT Varmap Maybe ()
void_jmEq1 x y = lift (jmEq1 x y) >> return ()
void_varEq
:: Variable (a :: Hakaru)
-> Variable (b :: Hakaru)
-> ReaderT Varmap Maybe ()
void_varEq x y = lift (varEq x y) >> return ()
try_bool :: Bool -> ReaderT Varmap Maybe ()
try_bool b = lift $ if b then Just () else Nothing
alphaEq
:: forall abt a
. (ABT Term abt)
=> abt '[] a
-> abt '[] a
-> Bool
alphaEq e1 e2 =
maybe False (const True)
$ runReaderT (go (viewABT e1) (viewABT e2)) emptyAssocs
where
go :: forall xs1 xs2 a
. View (Term abt) xs1 a
-> View (Term abt) xs2 a
-> ReaderT Varmap Maybe ()
go (Var x) (Var y) = do
s <- ask
case lookupAssoc x s of
Nothing -> void_varEq x y
Just y' -> void_varEq y' y
go (Bind x e1) (Bind y e2) = do
Refl <- lift $ jmEq1 (varType x) (varType y)
local (insertAssoc (Assoc x y)) (go e1 e2)
go (Syn t1) (Syn t2) = termEq t1 t2
go _ _ = lift Nothing
termEq :: forall a
. Term abt a
-> Term abt a
-> ReaderT Varmap Maybe ()
termEq e1 e2 =
case (e1, e2) of
(o1 :$ es1, o2 :$ es2) -> sConEq o1 es1 o2 es2
(NaryOp_ op1 es1, NaryOp_ op2 es2) -> do
try_bool (op1 == op2)
F.sequence_ $ S.zipWith go (viewABT <$> es1) (viewABT <$> es2)
(Literal_ x, Literal_ y) -> try_bool (x == y)
(Empty_ x, Empty_ y) -> void_jmEq1 x y
(Datum_ d1, Datum_ d2) -> datumEq d1 d2
(Array_ n1 e1, Array_ n2 e2) -> do
go (viewABT n1) (viewABT n2)
go (viewABT e1) (viewABT e2)
(Case_ e1 bs1, Case_ e2 bs2) -> do
Refl <- lift $ jmEq1 (typeOf e1) (typeOf e2)
go (viewABT e1) (viewABT e2)
zipWithM_ sBranch bs1 bs2
(Superpose_ pms1, Superpose_ pms2) ->
F.sequence_ $ L.zipWith pairEq pms1 pms2
(Reject_ x, Reject_ y) -> void_jmEq1 x y
(_, _) -> lift Nothing
sArgsEq
:: forall args
. SArgs abt args
-> SArgs abt args
-> ReaderT Varmap Maybe ()
sArgsEq End End = return ()
sArgsEq (e1 :* es1) (e2 :* es2) = do
go (viewABT e1) (viewABT e2)
sArgsEq es1 es2
sArgsEq _ _ = lift Nothing
sConEq
:: forall a args1 args2
. SCon args1 a
-> SArgs abt args1
-> SCon args2 a
-> SArgs abt args2
-> ReaderT Varmap Maybe ()
sConEq Lam_ e1
Lam_ e2 = sArgsEq e1 e2
sConEq App_ (e1 :* e2 :* End)
App_ (e1' :* e2' :* End) = do
Refl <- lift $ jmEq1 (typeOf e2) (typeOf e2')
go (viewABT e1) (viewABT e1')
go (viewABT e2) (viewABT e2')
sConEq Let_ (e1 :* e2 :* End)
Let_ (e1' :* e2' :* End) = do
Refl <- lift $ jmEq1 (typeOf e1) (typeOf e1')
go (viewABT e1) (viewABT e1')
go (viewABT e2) (viewABT e2')
sConEq (CoerceTo_ _) (e1 :* End)
(CoerceTo_ _) (e2 :* End) =
void_jmEq1 (typeOf e1) (typeOf e2)
sConEq (UnsafeFrom_ _) (e1 :* End)
(UnsafeFrom_ _) (e2 :* End) =
void_jmEq1 (typeOf e1) (typeOf e2)
sConEq (PrimOp_ o1) es1
(PrimOp_ o2) es2 = primOpEq o1 es1 o2 es2
sConEq (ArrayOp_ o1) es1
(ArrayOp_ o2) es2 = arrayOpEq o1 es1 o2 es2
sConEq (MeasureOp_ o1) es1
(MeasureOp_ o2) es2 = measureOpEq o1 es1 o2 es2
sConEq Dirac e1
Dirac e2 = sArgsEq e1 e2
sConEq MBind (e1 :* e2 :* End)
MBind (e1' :* e2' :* End) = do
Refl <- lift $ jmEq1 (typeOf e1) (typeOf e1')
go (viewABT e1) (viewABT e1')
go (viewABT e2) (viewABT e2')
sConEq Plate e1 Plate e2 = sArgsEq e1 e2
sConEq Chain e1 Chain e2 = sArgsEq e1 e2
sConEq Integrate e1 Integrate e2 = sArgsEq e1 e2
sConEq (Summate h1 h2) e1 (Summate h1' h2') e2 = do
Refl <- lift $ jmEq1 (sing_HDiscrete h1) (sing_HDiscrete h1')
Refl <- lift $ jmEq1 (sing_HSemiring h2) (sing_HSemiring h2')
sArgsEq e1 e2
sConEq Expect (e1 :* e2 :* End)
Expect (e1' :* e2' :* End) = do
Refl <- lift $ jmEq1 (typeOf e1) (typeOf e1')
go (viewABT e1) (viewABT e1')
go (viewABT e2) (viewABT e2')
sConEq _ _ _ _ = lift Nothing
primOpEq
:: forall a typs1 typs2 args1 args2
. (typs1 ~ UnLCs args1, args1 ~ LCs typs1,
typs2 ~ UnLCs args2, args2 ~ LCs typs2)
=> PrimOp typs1 a -> SArgs abt args1
-> PrimOp typs2 a -> SArgs abt args2
-> ReaderT Varmap Maybe ()
primOpEq p1 e1 p2 e2 = do
(Refl, Refl) <- lift $ jmEq2 p1 p2
sArgsEq e1 e2
arrayOpEq
:: forall a typs1 typs2 args1 args2
. (typs1 ~ UnLCs args1, args1 ~ LCs typs1,
typs2 ~ UnLCs args2, args2 ~ LCs typs2)
=> ArrayOp typs1 a -> SArgs abt args1
-> ArrayOp typs2 a -> SArgs abt args2
-> ReaderT Varmap Maybe ()
arrayOpEq p1 e1 p2 e2 = do
(Refl, Refl) <- lift $ jmEq2 p1 p2
sArgsEq e1 e2
measureOpEq
:: forall a typs1 typs2 args1 args2
. (typs1 ~ UnLCs args1, args1 ~ LCs typs1,
typs2 ~ UnLCs args2, args2 ~ LCs typs2)
=> MeasureOp typs1 a -> SArgs abt args1
-> MeasureOp typs2 a -> SArgs abt args2
-> ReaderT Varmap Maybe ()
measureOpEq m1 e1 m2 e2 = do
(Refl,Refl) <- lift $ jmEq2 m1 m2
sArgsEq e1 e2
datumEq :: forall a
. Datum (abt '[]) a
-> Datum (abt '[]) a
-> ReaderT Varmap Maybe ()
datumEq (Datum _ _ d1) (Datum _ _ d2) = datumCodeEq d1 d2
datumCodeEq
:: forall xss a
. DatumCode xss (abt '[]) a
-> DatumCode xss (abt '[]) a
-> ReaderT Varmap Maybe ()
datumCodeEq (Inr c) (Inr d) = datumCodeEq c d
datumCodeEq (Inl c) (Inl d) = datumStructEq c d
datumCodeEq _ _ = lift Nothing
datumStructEq
:: forall xs a
. DatumStruct xs (abt '[]) a
-> DatumStruct xs (abt '[]) a
-> ReaderT Varmap Maybe ()
datumStructEq (Et c1 c2) (Et d1 d2) = do
datumFunEq c1 d1
datumStructEq c2 d2
datumStructEq Done Done = return ()
datumStructEq _ _ = lift Nothing
datumFunEq
:: forall x a
. DatumFun x (abt '[]) a
-> DatumFun x (abt '[]) a
-> ReaderT Varmap Maybe ()
datumFunEq (Konst e) (Konst f) = go (viewABT e) (viewABT f)
datumFunEq (Ident e) (Ident f) = go (viewABT e) (viewABT f)
datumFunEq _ _ = lift Nothing
pairEq
:: forall a b
. (abt '[] a, abt '[] b)
-> (abt '[] a, abt '[] b)
-> ReaderT Varmap Maybe ()
pairEq (x1, y1) (x2, y2) = do
go (viewABT x1) (viewABT x2)
go (viewABT y1) (viewABT y2)
sBranch
:: forall a b
. Branch a abt b
-> Branch a abt b
-> ReaderT Varmap Maybe ()
sBranch (Branch _ e1) (Branch _ e2) = go (viewABT e1) (viewABT e2)