{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.DisjointSet.Int.Monadic.Impl where
import Prelude (
Int, (+), (-), (*), negate,
($),
return,
Monad, (>>),
Bool(True, False),
(<), (>=), (==), (>), (/=), (<=),
mapM_,
max,
pred, succ,
undefined,
minBound,
div,
Maybe,
Foldable,
(.)
)
import Data.Vector.Unboxed.Mutable (
MVector,
new,
unsafeRead, unsafeWrite, unsafeSwap,
unsafeGrow
)
import qualified Data.Vector.Unboxed.Mutable as Vector
import Data.Vector.Unboxed (
Vector,
unsafeFreeze
)
import Control.Monad.Primitive (
PrimState, PrimMonad
)
import Control.Monad.ST (
ST, runST
)
import Control.Monad.Ref (
MonadRef, Ref, newRef,
readRef, writeRef, modifyRef'
)
import Control.Monad (
when
)
type MVectorT m = MVector (PrimState m) Int
initNewElem :: (PrimMonad m, ?v :: MVectorT m, ?set_v :: MVectorT m) => Int -> m ()
initNewElem i = do
init_count i
init_set i
resize :: (
MonadRef m,
PrimMonad m,
PrimMonad m,
?v_ref :: Ref m (MVectorT m),
?set_v_ref :: Ref m (MVectorT m),
?numElems_ref :: Ref m Int,
?numSets_ref :: Ref m Int
) => Int -> m ()
resize i = do
numElems <- readRef ?numElems_ref
let new_numElems = i+1
let addedElems = new_numElems - numElems
when (addedElems > 0) $ do
v <- readRef ?v_ref
set_v <- readRef ?set_v_ref
modifyRef' ?numSets_ref (+addedElems)
writeRef ?numElems_ref new_numElems
if (new_numElems > Vector.length v)
then
do
new_v <- unsafeGrow v new_numElems
writeRef ?v_ref new_v
new_set_v <- unsafeGrow set_v new_numElems
writeRef ?set_v_ref new_set_v
let
?v = new_v
?set_v = new_set_v
in
init_range numElems new_numElems
else
do
let
?v = v
?set_v = set_v
in
init_range numElems new_numElems
isCount = (<0)
isPointer = (>=0)
collapse :: (PrimMonad m, ?v :: MVectorT m) => Int -> Int -> m ()
collapse target_i current_i = when (target_i /= current_i) $ collapse' current_i where
collapse' current_i = do
next_i <- read_absolute current_i
when (target_i /= next_i) $ do
write current_i target_i
collapse' next_i
data PointerOrCount = Pointer Int | Count Int
read :: (PrimMonad m, ?v :: MVectorT m) => Int -> m PointerOrCount
read i =
do
r <- read_absolute i
return (if (isPointer r) then (Pointer r) else (Count (negate r)))
read_absolute :: (PrimMonad m, ?v :: MVectorT m) => Int -> m Int
read_absolute i = unsafeRead ?v i
write :: (PrimMonad m, ?v :: MVectorT m) => Int -> Int -> m ()
write = unsafeWrite ?v
write_count :: (PrimMonad m, ?v :: MVectorT m) => Int -> Int -> m ()
write_count i x = write i (-x)
read_set :: (PrimMonad m, ?set_v :: MVectorT m) => Int -> m Int
read_set = unsafeRead ?set_v
swap_set :: (PrimMonad m, ?set_v :: MVectorT m) => Int -> Int -> m ()
swap_set = unsafeSwap ?set_v
mapM_zeroToN :: (Monad m) => Int -> (Int -> m ()) -> m ()
mapM_zeroToN = mapM_nToN 0
mapM_nToN :: (Monad m) => Int -> Int -> (Int -> m ()) -> m ()
mapM_nToN start_n end_n f = mapM_ f [start_n..(end_n-1)]
new_count :: (PrimMonad m, ?size :: Int, ?array_size :: Int) => m (MVectorT m)
new_count = do
v <- new ?array_size
let ?v = v in mapM_zeroToN ?size init_count
return v
init_count :: (PrimMonad m, ?v :: MVectorT m) => Int -> m ()
init_count i = write_count i 1
init_count_range :: (PrimMonad m, ?v :: MVectorT m) => Int -> Int -> m ()
init_count_range start_i end_i = mapM_nToN start_i end_i init_count
init_set :: (PrimMonad m, ?set_v :: MVectorT m) => Int -> m ()
init_set i = unsafeWrite ?set_v i i
init_set_range :: (PrimMonad m, ?set_v :: MVectorT m) => Int -> Int -> m ()
init_set_range start_i end_i = mapM_nToN start_i end_i init_set
init_range :: (PrimMonad m, ?v :: MVectorT m, ?set_v :: MVectorT m) => Int -> Int -> m ()
init_range start_i end_i = do
init_count_range start_i end_i
init_set_range start_i end_i
new_set :: (PrimMonad m, ?size :: Int, ?array_size :: Int) => m (MVectorT m)
new_set = do
set_v <- new ?array_size
let ?set_v = set_v in mapM_zeroToN ?size init_set
return set_v
union :: forall m. (MonadRef m, PrimMonad m, ?v :: MVectorT m, ?set_v :: MVectorT m, ?numSets_ref :: Ref m Int) => (Int,Int) -> m Bool
union (x_i, y_i) = go x_i y_i where
go :: (?v :: MVectorT m) => Int -> Int -> m Bool
go x_i y_i = do
(x_count_i, x_count) <- findAndCountNoCollapse x_i
(y_count_i, y_count) <- findAndCountNoCollapse y_i
let new_count = x_count + y_count
(target_i, result) <-
case (x_count_i /= y_count_i) of
True -> do
final_i <-
case (x_count > y_count) of
True ->
do
write_count x_count_i new_count
write y_count_i x_count_i
return x_count_i
False ->
do
write_count y_count_i new_count
write x_count_i y_count_i
return y_count_i
swap_set x_count_i y_count_i
modifyRef' ?numSets_ref pred
return (final_i, True)
False -> return (x_count_i, False)
collapse target_i x_i
collapse target_i y_i
return result
equivalent :: (PrimMonad m, ?set_v :: MVectorT m) => (a -> Int -> m a) -> a -> Int -> m a
equivalent f init i = go init i where
go acc i' = do
p <- read_set i'
r <- f acc i'
if (p /= i) then go r p else return r
find :: (PrimMonad m, ?v :: MVectorT m) => Int -> m Int
find i = do
(final_i, _) <- findAndCount i
return final_i
count :: (PrimMonad m, ?v :: MVectorT m) => Int -> m Int
count i = do
(_, count) <- findAndCount i
return count
findAndCountNoCollapse :: (PrimMonad m, ?v :: MVectorT m) => Int -> m (Int,Int)
findAndCountNoCollapse i = do
r <- read i
case r of
Pointer p -> findAndCountNoCollapse p
Count c -> return (i, c)
findAndCount :: (PrimMonad m, ?v :: MVectorT m) => Int -> m (Int,Int)
findAndCount i = do
r@(final_i, _) <- findAndCountNoCollapse i
collapse final_i i
return r
nextInSet :: (PrimMonad m, ?set_v :: MVectorT m) => Int -> m Int
nextInSet = read_set