module GHC.Data.UnionFind where
import GHC.Prelude
import Data.STRef
import Control.Monad.ST
import Control.Monad
newtype Point s a = Point (STRef s (Link s a))
deriving (Point s a -> Point s a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall s a. Point s a -> Point s a -> Bool
/= :: Point s a -> Point s a -> Bool
$c/= :: forall s a. Point s a -> Point s a -> Bool
== :: Point s a -> Point s a -> Bool
$c== :: forall s a. Point s a -> Point s a -> Bool
Eq)
writePoint :: Point s a -> Link s a -> ST s ()
writePoint :: forall s a. Point s a -> Link s a -> ST s ()
writePoint (Point STRef s (Link s a)
v) = forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s (Link s a)
v
readPoint :: Point s a -> ST s (Link s a)
readPoint :: forall s a. Point s a -> ST s (Link s a)
readPoint (Point STRef s (Link s a)
v) = forall s a. STRef s a -> ST s a
readSTRef STRef s (Link s a)
v
data Link s a
= Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a)
| Link {-# UNPACK #-} !(Point s a)
fresh :: a -> ST s (Point s a)
fresh :: forall a s. a -> ST s (Point s a)
fresh a
desc = do
STRef s Int
weight <- forall a s. a -> ST s (STRef s a)
newSTRef Int
1
STRef s a
descriptor <- forall a s. a -> ST s (STRef s a)
newSTRef a
desc
forall s a. STRef s (Link s a) -> Point s a
Point forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` forall a s. a -> ST s (STRef s a)
newSTRef (forall s a. STRef s Int -> STRef s a -> Link s a
Info STRef s Int
weight STRef s a
descriptor)
repr :: Point s a -> ST s (Point s a)
repr :: forall s a. Point s a -> ST s (Point s a)
repr Point s a
point = forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r ->
case Link s a
r of
Link Point s a
point' -> do
Point s a
point'' <- forall s a. Point s a -> ST s (Point s a)
repr Point s a
point'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Point s a
point'' forall a. Eq a => a -> a -> Bool
/= Point s a
point') forall a b. (a -> b) -> a -> b
$ do
forall s a. Point s a -> Link s a -> ST s ()
writePoint Point s a
point forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point'
forall (m :: * -> *) a. Monad m => a -> m a
return Point s a
point''
Info STRef s Int
_ STRef s a
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Point s a
point
find :: Point s a -> ST s a
find :: forall s a. Point s a -> ST s a
find Point s a
point =
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r ->
case Link s a
r of
Info STRef s Int
_ STRef s a
d_ref -> forall s a. STRef s a -> ST s a
readSTRef STRef s a
d_ref
Link Point s a
point' -> forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point' forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r' ->
case Link s a
r' of
Info STRef s Int
_ STRef s a
d_ref -> forall s a. STRef s a -> ST s a
readSTRef STRef s a
d_ref
Link Point s a
_ -> forall s a. Point s a -> ST s (Point s a)
repr Point s a
point forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall s a. Point s a -> ST s a
find
union :: Point s a -> Point s a -> ST s ()
union :: forall s a. Point s a -> Point s a -> ST s ()
union Point s a
refpoint1 Point s a
refpoint2 = do
Point s a
point1 <- forall s a. Point s a -> ST s (Point s a)
repr Point s a
refpoint1
Point s a
point2 <- forall s a. Point s a -> ST s (Point s a)
repr Point s a
refpoint2
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Point s a
point1 forall a. Eq a => a -> a -> Bool
/= Point s a
point2) forall a b. (a -> b) -> a -> b
$ do
Link s a
l1 <- forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point1
Link s a
l2 <- forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point2
case (Link s a
l1, Link s a
l2) of
(Info STRef s Int
wref1 STRef s a
dref1, Info STRef s Int
wref2 STRef s a
dref2) -> do
Int
weight1 <- forall s a. STRef s a -> ST s a
readSTRef STRef s Int
wref1
Int
weight2 <- forall s a. STRef s a -> ST s a
readSTRef STRef s Int
wref2
if Int
weight1 forall a. Ord a => a -> a -> Bool
>= Int
weight2
then do
forall s a. Point s a -> Link s a -> ST s ()
writePoint Point s a
point2 (forall s a. Point s a -> Link s a
Link Point s a
point1)
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Int
wref1 (Int
weight1 forall a. Num a => a -> a -> a
+ Int
weight2)
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s a
dref1 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s a. STRef s a -> ST s a
readSTRef STRef s a
dref2
else do
forall s a. Point s a -> Link s a -> ST s ()
writePoint Point s a
point1 (forall s a. Point s a -> Link s a
Link Point s a
point2)
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Int
wref2 (Int
weight1 forall a. Num a => a -> a -> a
+ Int
weight2)
(Link s a, Link s a)
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"UnionFind.union: repr invariant broken"
equivalent :: Point s a -> Point s a -> ST s Bool
equivalent :: forall s a. Point s a -> Point s a -> ST s Bool
equivalent Point s a
point1 Point s a
point2 = forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 forall a. Eq a => a -> a -> Bool
(==) (forall s a. Point s a -> ST s (Point s a)
repr Point s a
point1) (forall s a. Point s a -> ST s (Point s a)
repr Point s a
point2)