{-# LANGUAGE NoImplicitPrelude #-}
module Data.DisjointSet.Int (
DisjointIntSet(DisjointIntSet),
create,
find,
count,
findAndCount,
numSets,
size,
nextInSet,
setToList
)
where
import Data.Vector.Unboxed (Vector, (!))
import Data.DisjointSet.Int.Monadic (DisjointIntSet(DisjointIntSet), runDisjointIntSet, newDisjointIntSetFixed, union)
import Data.DisjointSet.Int.Monadic.Impl (PointerOrCount(Pointer, Count), isPointer)
import Data.Foldable (Foldable, foldl', mapM_)
import Prelude (
Int,
Bool (True, False),
negate,
fst,
snd,
(/=),
Maybe (Just, Nothing),
(.), (+),
max,
($),
return
)
import Data.List (unfoldr)
import Data.Foldable (foldl')
type VectorT = Vector Int
read :: VectorT -> Int -> PointerOrCount
read v i =
let
r = v ! i
in
case (isPointer r) of
True -> Pointer r
False -> Count (negate r)
create :: Foldable t => t (Int, Int) -> DisjointIntSet
create l =
let
maxElem = foldl' (\r (x1,x2) -> max r (max x1 x2)) (-1) l
in
runDisjointIntSet $ do
v <- newDisjointIntSetFixed (maxElem + 1)
mapM_ (\(x1,x2) -> union v x1 x2) l
return v
findAndCount :: DisjointIntSet -> Int -> (Int, Int)
findAndCount (DisjointIntSet v _ _ _) i = go i where
go i = case (read v i) of
Pointer next_i -> go next_i
Count c -> (i, c)
find :: DisjointIntSet -> Int -> Int
find v i = fst (findAndCount v i)
count :: DisjointIntSet -> Int -> Int
count v i = snd (findAndCount v i)
numSets :: DisjointIntSet -> Int
numSets (DisjointIntSet _ _ numSets _) = numSets
size :: DisjointIntSet -> Int
size (DisjointIntSet _ _ _ size) = size
nextInSet :: DisjointIntSet -> Int -> Int
nextInSet (DisjointIntSet _ set_v _ _) i = set_v ! i
setToList :: DisjointIntSet -> Int -> [Int]
setToList ds i = i:(unfoldr f i) where
f curr_i = let next_i = nextInSet ds curr_i in if next_i /= i then Just (next_i, next_i) else Nothing