-- This file is part of the 'union-find-array' library. It is licensed
-- under an MIT license. See the accompanying 'LICENSE' file for details.
--
-- Authors: Bertram Felgenhauer

{-# LANGUAGE RankNTypes, FlexibleContexts, CPP #-}
-- |
-- Low-level interface for managing a disjoint set data structure, based on
-- 'Control.Monad.ST'. For a higher level convenience interface, look at
-- 'Control.Monad.Union'.
module Data.Union.ST (
    UnionST,
    runUnionST,
    new,
    grow,
    copy,
    lookup,
    annotate,
    merge,
    flatten,
    size,
    unsafeFreeze,
) where

import qualified Data.Union.Type as U

import Prelude hiding (lookup)
import Control.Monad.ST
import Control.Monad
import Control.Applicative
import Data.Array.Base hiding (unsafeFreeze)
import Data.Array.ST hiding (unsafeFreeze)
import qualified Data.Array.Base as A (unsafeFreeze)

-- | A disjoint set forest, with nodes numbered from 0, which can carry labels.
data UnionST s l = UnionST {
    UnionST s l -> STUArray s Int Int
up :: STUArray s Int Int,
    UnionST s l -> STUArray s Int Int
rank :: STUArray s Int Int,
    UnionST s l -> STArray s Int l
label :: STArray s Int l,
    UnionST s l -> Int
size :: !Int,
    UnionST s l -> l
def :: l
}

#if __GLASGOW_HASKELL__ < 702
instance Applicative (ST s) where
    (<*>) = ap
    pure = return
#endif

-- Use http://www.haskell.org/pipermail/libraries/2008-March/009465.html ?

-- | Analogous to 'Data.Array.ST.runSTArray'.
runUnionST :: (forall s. ST s (UnionST s l)) -> U.Union l
runUnionST :: (forall s. ST s (UnionST s l)) -> Union l
runUnionST forall s. ST s (UnionST s l)
a = (forall s. ST s (Union l)) -> Union l
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Union l)) -> Union l)
-> (forall s. ST s (Union l)) -> Union l
forall a b. (a -> b) -> a -> b
$ ST s (UnionST s l)
forall s. ST s (UnionST s l)
a ST s (UnionST s l)
-> (UnionST s l -> ST s (Union l)) -> ST s (Union l)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= UnionST s l -> ST s (Union l)
forall s l. UnionST s l -> ST s (Union l)
unsafeFreeze

-- | Analogous to 'Data.Array.Base.unsafeFreeze'
unsafeFreeze :: UnionST s l -> ST s (U.Union l)
unsafeFreeze :: UnionST s l -> ST s (Union l)
unsafeFreeze UnionST s l
u =
    Int -> UArray Int Int -> Array Int l -> Union l
forall a. Int -> UArray Int Int -> Array Int a -> Union a
U.Union (UnionST s l -> Int
forall s l. UnionST s l -> Int
size UnionST s l
u) (UArray Int Int -> Array Int l -> Union l)
-> ST s (UArray Int Int) -> ST s (Array Int l -> Union l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STUArray s Int Int -> ST s (UArray Int Int)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
A.unsafeFreeze (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
up UnionST s l
u) ST s (Array Int l -> Union l)
-> ST s (Array Int l) -> ST s (Union l)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> STArray s Int l -> ST s (Array Int l)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
A.unsafeFreeze (UnionST s l -> STArray s Int l
forall s l. UnionST s l -> STArray s Int l
label UnionST s l
u)

-- What about thawing?

-- | Create a new disjoint set forest, of given capacity.
new :: Int -> l -> ST s (UnionST s l)
new :: Int -> l -> ST s (UnionST s l)
new Int
size l
def = do
    STUArray s Int Int
up <- (Int, Int) -> [Int] -> ST s (STUArray s Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray (Int
0, Int
sizeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [Int
0..]
    STUArray s Int Int
rank <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
sizeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
0
    STArray s Int l
label <- (Int, Int) -> l -> ST s (STArray s Int l)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
sizeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) l
def
    UnionST s l -> ST s (UnionST s l)
forall (m :: * -> *) a. Monad m => a -> m a
return UnionST :: forall s l.
STUArray s Int Int
-> STUArray s Int Int -> STArray s Int l -> Int -> l -> UnionST s l
UnionST{ up :: STUArray s Int Int
up = STUArray s Int Int
up, rank :: STUArray s Int Int
rank = STUArray s Int Int
rank, label :: STArray s Int l
label = STArray s Int l
label, size :: Int
size = Int
size, def :: l
def = l
def }

-- | Grow the capacity of a disjoint set forest. Shrinking is not possible.
-- Trying to shrink a disjoint set forest will return the same forest
-- unmodified.
grow :: UnionST s l -> Int -> ST s (UnionST s l)
grow :: UnionST s l -> Int -> ST s (UnionST s l)
grow UnionST s l
u Int
size' | Int
size' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= UnionST s l -> Int
forall s l. UnionST s l -> Int
size UnionST s l
u = UnionST s l -> ST s (UnionST s l)
forall (m :: * -> *) a. Monad m => a -> m a
return UnionST s l
u
grow UnionST s l
u Int
size' = UnionST s l -> Int -> ST s (UnionST s l)
forall s l. UnionST s l -> Int -> ST s (UnionST s l)
grow' UnionST s l
u Int
size'

-- | Copy a disjoint set forest.
copy :: UnionST s l -> ST s (UnionST s l)
copy :: UnionST s l -> ST s (UnionST s l)
copy UnionST s l
u = UnionST s l -> Int -> ST s (UnionST s l)
forall s l. UnionST s l -> Int -> ST s (UnionST s l)
grow' UnionST s l
u (UnionST s l -> Int
forall s l. UnionST s l -> Int
size UnionST s l
u)

grow' :: UnionST s l -> Int -> ST s (UnionST s l)
grow' :: UnionST s l -> Int -> ST s (UnionST s l)
grow' UnionST s l
u Int
size' = do
    STUArray s Int Int
up' <- (Int, Int) -> [Int] -> ST s (STUArray s Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray (Int
0, Int
size'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [Int
0..]
    STUArray s Int Int
rank' <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
size'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
0
    STArray s Int l
label' <- (Int, Int) -> l -> ST s (STArray s Int l)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
size'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (UnionST s l -> l
forall s l. UnionST s l -> l
def UnionST s l
u)
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..UnionST s l -> Int
forall s l. UnionST s l -> Int
size UnionST s l
u Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
up UnionST s l
u) Int
i ST s Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
up' Int
i
        STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
rank UnionST s l
u) Int
i ST s Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
rank' Int
i
        STArray s Int l -> Int -> ST s l
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray (UnionST s l -> STArray s Int l
forall s l. UnionST s l -> STArray s Int l
label UnionST s l
u) Int
i ST s l -> (l -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STArray s Int l -> Int -> l -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int l
label' Int
i
    UnionST s l -> ST s (UnionST s l)
forall (m :: * -> *) a. Monad m => a -> m a
return UnionST s l
u{ up :: STUArray s Int Int
up = STUArray s Int Int
up', rank :: STUArray s Int Int
rank = STUArray s Int Int
rank', label :: STArray s Int l
label = STArray s Int l
label', size :: Int
size = Int
size' }

-- | Annotate a node with a new label.
annotate :: UnionST s l -> Int -> l -> ST s ()
annotate :: UnionST s l -> Int -> l -> ST s ()
annotate UnionST s l
u Int
i l
v = STArray s Int l -> Int -> l -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STArray s Int l
forall s l. UnionST s l -> STArray s Int l
label UnionST s l
u) Int
i l
v

-- | Look up the representative of a given node.
--
-- lookup' does path compression.
lookup' :: UnionST s l -> Int -> ST s Int
lookup' :: UnionST s l -> Int -> ST s Int
lookup' UnionST s l
u Int
i = do
    Int
i' <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
up UnionST s l
u) Int
i
    if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i' then Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i else do
        Int
i'' <- UnionST s l -> Int -> ST s Int
forall s l. UnionST s l -> Int -> ST s Int
lookup' UnionST s l
u Int
i'
        STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
up UnionST s l
u) Int
i Int
i''
        Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i''

-- | Look up the representative of a given node and its label.
lookup :: UnionST s l -> Int -> ST s (Int, l)
lookup :: UnionST s l -> Int -> ST s (Int, l)
lookup UnionST s l
u Int
i = do
    Int
i' <- UnionST s l -> Int -> ST s Int
forall s l. UnionST s l -> Int -> ST s Int
lookup' UnionST s l
u Int
i
    l
l' <- STArray s Int l -> Int -> ST s l
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray (UnionST s l -> STArray s Int l
forall s l. UnionST s l -> STArray s Int l
label UnionST s l
u) Int
i'
    (Int, l) -> ST s (Int, l)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i', l
l')

-- | Check whether two nodes are in the same set.
equals :: UnionST s l -> Int -> Int -> ST s Bool
equals :: UnionST s l -> Int -> Int -> ST s Bool
equals UnionST s l
u Int
a Int
b = do
    Int
a' <- UnionST s l -> Int -> ST s Int
forall s l. UnionST s l -> Int -> ST s Int
lookup' UnionST s l
u Int
a
    Int
b' <- UnionST s l -> Int -> ST s Int
forall s l. UnionST s l -> Int -> ST s Int
lookup' UnionST s l
u Int
b
    Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
a' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b')

-- | Merge two nodes if they are in distinct equivalence classes. The
-- passed function is used to combine labels, if a merge happens.
merge :: UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a)
merge :: UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a)
merge UnionST s l
u l -> l -> (l, a)
f Int
a Int
b = do
    (Int
a', l
va) <- UnionST s l -> Int -> ST s (Int, l)
forall s l. UnionST s l -> Int -> ST s (Int, l)
lookup UnionST s l
u Int
a
    (Int
b', l
vb) <- UnionST s l -> Int -> ST s (Int, l)
forall s l. UnionST s l -> Int -> ST s (Int, l)
lookup UnionST s l
u Int
b
    if Int
a' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b' then Maybe a -> ST s (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing else do
        Int
ra <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
rank UnionST s l
u) Int
a'
        Int
rb <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
rank UnionST s l
u) Int
b'
        let cont :: Int -> l -> Int -> l -> m (Maybe a)
cont Int
x l
vx Int
y l
vy = do
                STArray s Int l -> Int -> l -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STArray s Int l
forall s l. UnionST s l -> STArray s Int l
label UnionST s l
u) Int
y ([Char] -> l
forall a. HasCallStack => [Char] -> a
error [Char]
"invalid entry")
                let (l
v, a
w) = l -> l -> (l, a)
f l
vx l
vy
                STArray s Int l -> Int -> l -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STArray s Int l
forall s l. UnionST s l -> STArray s Int l
label UnionST s l
u) Int
x l
v
                Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
w)
        case Int
ra Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
rb of
            Ordering
LT -> do
                STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
up UnionST s l
u) Int
a' Int
b'
                Int -> l -> Int -> l -> ST s (Maybe a)
forall (m :: * -> *).
MArray (STArray s) l m =>
Int -> l -> Int -> l -> m (Maybe a)
cont Int
b' l
vb Int
a' l
va
            Ordering
GT -> do
                STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
up UnionST s l
u) Int
b' Int
a'
                Int -> l -> Int -> l -> ST s (Maybe a)
forall (m :: * -> *).
MArray (STArray s) l m =>
Int -> l -> Int -> l -> m (Maybe a)
cont Int
a' l
va Int
b' l
vb
            Ordering
EQ -> do
                STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
up UnionST s l
u) Int
a' Int
b'
                STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray (UnionST s l -> STUArray s Int Int
forall s l. UnionST s l -> STUArray s Int Int
rank UnionST s l
u) Int
b' (Int
ra Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                Int -> l -> Int -> l -> ST s (Maybe a)
forall (m :: * -> *).
MArray (STArray s) l m =>
Int -> l -> Int -> l -> m (Maybe a)
cont Int
b' l
vb Int
a' l
va

-- | Flatten a disjoint set forest, for faster lookups.
flatten :: UnionST s l -> ST s ()
flatten :: UnionST s l -> ST s ()
flatten UnionST s l
u = [Int] -> (Int -> ST s Int) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..UnionST s l -> Int
forall s l. UnionST s l -> Int
size UnionST s l
u Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s Int) -> ST s ()) -> (Int -> ST s Int) -> ST s ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> ST s Int
forall s l. UnionST s l -> Int -> ST s Int
lookup' UnionST s l
u