module DDC.Type.Equiv
( equivT
, equivWithBindsT)
where
import DDC.Type.Transform.Crush
import DDC.Type.Compounds
import DDC.Type.Bind
import DDC.Type.Exp
import qualified DDC.Type.Sum as Sum
equivT :: Ord n => Type n -> Type n -> Bool
equivT t1 t2
= equivWithBindsT [] [] t1 t2
equivWithBindsT
:: Ord n
=> [Bind n]
-> [Bind n]
-> Type n
-> Type n
-> Bool
equivWithBindsT stack1 stack2 t1 t2
= let t1' = unpackSumT $ crushSomeT t1
t2' = unpackSumT $ crushSomeT t2
in case (t1', t2') of
(TVar u1, TVar u2)
| Nothing <- getBindType stack1 u1
, Nothing <- getBindType stack2 u2
, u1 == u2 -> checkBounds u1 u2 True
| Just (ix1, t1a) <- getBindType stack1 u1
, Just (ix2, t2a) <- getBindType stack2 u2
, ix1 == ix2
-> checkBounds u1 u2
$ equivWithBindsT stack1 stack2 t1a t2a
| otherwise
-> checkBounds u1 u2
$ False
(TCon tc1, TCon tc2)
-> tc1 == tc2
(TForall b11 t12, TForall b21 t22)
| equivT (typeOfBind b11) (typeOfBind b21)
-> equivWithBindsT
(b11 : stack1)
(b21 : stack2)
t12 t22
(TApp t11 t12, TApp t21 t22)
-> equivWithBindsT stack1 stack2 t11 t21
&& equivWithBindsT stack1 stack2 t12 t22
(TSum ts1, TSum ts2)
-> let ts1' = Sum.toList ts1
ts2' = Sum.toList ts2
checkFast = and $ zipWith (equivWithBindsT stack1 stack2) ts1' ts2'
checkSlow = and [ or (map (equivWithBindsT stack1 stack2 t1c) ts2')
| t1c <- ts1' ]
&& and [ or (map (equivWithBindsT stack2 stack1 t2c) ts1')
| t2c <- ts2' ]
in (length ts1' == length ts2')
&& (checkFast || checkSlow)
(_, _) -> False
checkBounds :: Eq n => Bound n -> Bound n -> a -> a
checkBounds u1 u2 x
= case (u1, u2) of
(UName n2, UPrim n1 _)
| n1 == n2 -> die
(UPrim n1 _, UName n2)
| n1 == n2 -> die
_ -> x
where
die = error $ unlines
[ "DDC.Type.Equiv"
, " Found a primitive and non-primitive bound variable with the same name."]
unpackSumT :: Type n -> Type n
unpackSumT (TSum ts)
| [t] <- Sum.toList ts = t
unpackSumT tt = tt