-- This file is part of the 'union-find-array' library. It is licensed
-- under an MIT license. See the accompanying 'LICENSE' file for details.
--
-- Authors: Bertram Felgenhauer

-- |
-- Immutable disjoint set forests.
module Data.Union (
    Union,
    Node (..),
    size,
    lookup,
    lookupFlattened,
) where

import Prelude hiding (lookup)
import Data.Union.Type (Union, Node (..))
import qualified Data.Union.Type as T
import Data.Array.Base ((!))

-- | Get the number of nodes in the forest.
size :: Union l -> Int
size :: Union l -> Int
size = Union l -> Int
forall a. Union a -> Int
T.size

-- | Look up the representative of a node, and its label.
lookup :: Union l -> Node -> (Node, l)
lookup :: Union l -> Node -> (Node, l)
lookup Union l
u (Node Int
n) = Int -> (Node, l)
go Int
n where
    go :: Int -> (Node, l)
go Int
n | Int
n' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n   = (Int -> Node
Node Int
n, Union l -> Array Int l
forall a. Union a -> Array Int a
T.label Union l
u Array Int l -> Int -> l
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! Int
n)
         | Bool
otherwise = Int -> (Node, l)
go Int
n'
      where
        n' :: Int
n' = Union l -> UArray Int Int
forall a. Union a -> UArray Int Int
T.up Union l
u UArray Int Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! Int
n

-- | Version of 'lookup' that assumes the forest to be flattened.
-- (cf. 'Control.Union.ST.flatten'.)
--
-- Do not use otherwise: It will give wrong results!
lookupFlattened :: Union a -> Node -> (Node, a)
lookupFlattened :: Union a -> Node -> (Node, a)
lookupFlattened Union a
u (Node Int
n) = (Int -> Node
Node (Union a -> UArray Int Int
forall a. Union a -> UArray Int Int
T.up Union a
u UArray Int Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! Int
n), Union a -> Array Int a
forall a. Union a -> Array Int a
T.label Union a
u Array Int a -> Int -> a
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! Int
n)