-- | Lightweight union-find implementation suitable for use with nondeterminism

-- Mutable union-find, as in Data.Equivalence.Monad, should be faster overall,
-- but this persistent implementation is suitable for use in nondeterministic search
-- (e.g.: in the list monad)

module Data.Persistent.UnionFind (
    UVarGen
  , initUVarGen
  , nextUVar

  , UVar
  , uvarToInt
  , intToUVar

  , UnionFind
  , empty
  , withInitialValues
  , union
  , find
  ) where

import Control.Monad.State.Strict ( State, runState, execState, get, put, modify' )
import Data.Coerce ( coerce )
import Data.IntMap.Strict ( IntMap )
import qualified Data.IntMap.Strict as IntMap


----------------------------------------------------------

---------------------------
-------- UVarGen
---------------------------

newtype UVarGen = UVarGen Int
  deriving ( UVarGen -> UVarGen -> Bool
(UVarGen -> UVarGen -> Bool)
-> (UVarGen -> UVarGen -> Bool) -> Eq UVarGen
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UVarGen -> UVarGen -> Bool
$c/= :: UVarGen -> UVarGen -> Bool
== :: UVarGen -> UVarGen -> Bool
$c== :: UVarGen -> UVarGen -> Bool
Eq, Eq UVarGen
Eq UVarGen
-> (UVarGen -> UVarGen -> Ordering)
-> (UVarGen -> UVarGen -> Bool)
-> (UVarGen -> UVarGen -> Bool)
-> (UVarGen -> UVarGen -> Bool)
-> (UVarGen -> UVarGen -> Bool)
-> (UVarGen -> UVarGen -> UVarGen)
-> (UVarGen -> UVarGen -> UVarGen)
-> Ord UVarGen
UVarGen -> UVarGen -> Bool
UVarGen -> UVarGen -> Ordering
UVarGen -> UVarGen -> UVarGen
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: UVarGen -> UVarGen -> UVarGen
$cmin :: UVarGen -> UVarGen -> UVarGen
max :: UVarGen -> UVarGen -> UVarGen
$cmax :: UVarGen -> UVarGen -> UVarGen
>= :: UVarGen -> UVarGen -> Bool
$c>= :: UVarGen -> UVarGen -> Bool
> :: UVarGen -> UVarGen -> Bool
$c> :: UVarGen -> UVarGen -> Bool
<= :: UVarGen -> UVarGen -> Bool
$c<= :: UVarGen -> UVarGen -> Bool
< :: UVarGen -> UVarGen -> Bool
$c< :: UVarGen -> UVarGen -> Bool
compare :: UVarGen -> UVarGen -> Ordering
$ccompare :: UVarGen -> UVarGen -> Ordering
$cp1Ord :: Eq UVarGen
Ord, Int -> UVarGen -> ShowS
[UVarGen] -> ShowS
UVarGen -> String
(Int -> UVarGen -> ShowS)
-> (UVarGen -> String) -> ([UVarGen] -> ShowS) -> Show UVarGen
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UVarGen] -> ShowS
$cshowList :: [UVarGen] -> ShowS
show :: UVarGen -> String
$cshow :: UVarGen -> String
showsPrec :: Int -> UVarGen -> ShowS
$cshowsPrec :: Int -> UVarGen -> ShowS
Show )

initUVarGen :: UVarGen
initUVarGen :: UVarGen
initUVarGen = Int -> UVarGen
UVarGen Int
0

nextUVar :: UVarGen -> (UVarGen, UVar)
nextUVar :: UVarGen -> (UVarGen, UVar)
nextUVar (UVarGen Int
n) = (Int -> UVarGen
UVarGen (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1), Int -> UVar
UVar Int
n)


---------------------------
-------- UVar
---------------------------

newtype UVar = UVar Int
  deriving ( UVar -> UVar -> Bool
(UVar -> UVar -> Bool) -> (UVar -> UVar -> Bool) -> Eq UVar
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UVar -> UVar -> Bool
$c/= :: UVar -> UVar -> Bool
== :: UVar -> UVar -> Bool
$c== :: UVar -> UVar -> Bool
Eq, Eq UVar
Eq UVar
-> (UVar -> UVar -> Ordering)
-> (UVar -> UVar -> Bool)
-> (UVar -> UVar -> Bool)
-> (UVar -> UVar -> Bool)
-> (UVar -> UVar -> Bool)
-> (UVar -> UVar -> UVar)
-> (UVar -> UVar -> UVar)
-> Ord UVar
UVar -> UVar -> Bool
UVar -> UVar -> Ordering
UVar -> UVar -> UVar
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: UVar -> UVar -> UVar
$cmin :: UVar -> UVar -> UVar
max :: UVar -> UVar -> UVar
$cmax :: UVar -> UVar -> UVar
>= :: UVar -> UVar -> Bool
$c>= :: UVar -> UVar -> Bool
> :: UVar -> UVar -> Bool
$c> :: UVar -> UVar -> Bool
<= :: UVar -> UVar -> Bool
$c<= :: UVar -> UVar -> Bool
< :: UVar -> UVar -> Bool
$c< :: UVar -> UVar -> Bool
compare :: UVar -> UVar -> Ordering
$ccompare :: UVar -> UVar -> Ordering
$cp1Ord :: Eq UVar
Ord, Int -> UVar -> ShowS
[UVar] -> ShowS
UVar -> String
(Int -> UVar -> ShowS)
-> (UVar -> String) -> ([UVar] -> ShowS) -> Show UVar
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UVar] -> ShowS
$cshowList :: [UVar] -> ShowS
show :: UVar -> String
$cshow :: UVar -> String
showsPrec :: Int -> UVar -> ShowS
$cshowsPrec :: Int -> UVar -> ShowS
Show )

uvarToInt :: UVar -> Int
uvarToInt :: UVar -> Int
uvarToInt (UVar Int
i) = Int
i

intToUVar :: Int -> UVar
intToUVar :: Int -> UVar
intToUVar = Int -> UVar
UVar

---------------------------
-------- Union-find data structure
---------------------------

newtype UnionFind = UnionFind { UnionFind -> IntMap Int
getUnionFindMap :: IntMap Int }
  deriving ( UnionFind -> UnionFind -> Bool
(UnionFind -> UnionFind -> Bool)
-> (UnionFind -> UnionFind -> Bool) -> Eq UnionFind
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UnionFind -> UnionFind -> Bool
$c/= :: UnionFind -> UnionFind -> Bool
== :: UnionFind -> UnionFind -> Bool
$c== :: UnionFind -> UnionFind -> Bool
Eq, Eq UnionFind
Eq UnionFind
-> (UnionFind -> UnionFind -> Ordering)
-> (UnionFind -> UnionFind -> Bool)
-> (UnionFind -> UnionFind -> Bool)
-> (UnionFind -> UnionFind -> Bool)
-> (UnionFind -> UnionFind -> Bool)
-> (UnionFind -> UnionFind -> UnionFind)
-> (UnionFind -> UnionFind -> UnionFind)
-> Ord UnionFind
UnionFind -> UnionFind -> Bool
UnionFind -> UnionFind -> Ordering
UnionFind -> UnionFind -> UnionFind
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: UnionFind -> UnionFind -> UnionFind
$cmin :: UnionFind -> UnionFind -> UnionFind
max :: UnionFind -> UnionFind -> UnionFind
$cmax :: UnionFind -> UnionFind -> UnionFind
>= :: UnionFind -> UnionFind -> Bool
$c>= :: UnionFind -> UnionFind -> Bool
> :: UnionFind -> UnionFind -> Bool
$c> :: UnionFind -> UnionFind -> Bool
<= :: UnionFind -> UnionFind -> Bool
$c<= :: UnionFind -> UnionFind -> Bool
< :: UnionFind -> UnionFind -> Bool
$c< :: UnionFind -> UnionFind -> Bool
compare :: UnionFind -> UnionFind -> Ordering
$ccompare :: UnionFind -> UnionFind -> Ordering
$cp1Ord :: Eq UnionFind
Ord, Int -> UnionFind -> ShowS
[UnionFind] -> ShowS
UnionFind -> String
(Int -> UnionFind -> ShowS)
-> (UnionFind -> String)
-> ([UnionFind] -> ShowS)
-> Show UnionFind
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnionFind] -> ShowS
$cshowList :: [UnionFind] -> ShowS
show :: UnionFind -> String
$cshow :: UnionFind -> String
showsPrec :: Int -> UnionFind -> ShowS
$cshowsPrec :: Int -> UnionFind -> ShowS
Show )

empty :: UnionFind
empty :: UnionFind
empty = IntMap Int -> UnionFind
UnionFind IntMap Int
forall a. IntMap a
IntMap.empty

withInitialValues :: [UVar] -> UnionFind
withInitialValues :: [UVar] -> UnionFind
withInitialValues [UVar]
uvs = IntMap Int -> UnionFind
UnionFind (IntMap Int -> UnionFind) -> IntMap Int -> UnionFind
forall a b. (a -> b) -> a -> b
$ [(Int, Int)] -> IntMap Int
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Int)] -> IntMap Int) -> [(Int, Int)] -> IntMap Int
forall a b. (a -> b) -> a -> b
$ (Int -> (Int, Int)) -> [Int] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (,-Int
1) ([Int] -> [(Int, Int)]) -> [Int] -> [(Int, Int)]
forall a b. (a -> b) -> a -> b
$ [UVar] -> [Int]
coerce [UVar]
uvs

---------------------------
-------- Union-find operations
---------------------------

union :: UVar -> UVar -> UnionFind -> UnionFind
union :: UVar -> UVar -> UnionFind -> UnionFind
union UVar
uv1 UVar
uv2 UnionFind
uf
   | Bool
otherwise   = (State UnionFind () -> UnionFind -> UnionFind)
-> UnionFind -> State UnionFind () -> UnionFind
forall a b c. (a -> b -> c) -> b -> a -> c
flip State UnionFind () -> UnionFind -> UnionFind
forall s a. State s a -> s -> s
execState UnionFind
uf (State UnionFind () -> UnionFind)
-> State UnionFind () -> UnionFind
forall a b. (a -> b) -> a -> b
$ do
                     (UVar
uv1Rep, Int
negativeUv1Size) <- UVar -> State UnionFind (UVar, Int)
findWithNegSize UVar
uv1
                     (UVar
uv2Rep, Int
negativeUv2Size) <- UVar -> State UnionFind (UVar, Int)
findWithNegSize UVar
uv2
                     if UVar
uv1Rep UVar -> UVar -> Bool
forall a. Eq a => a -> a -> Bool
== UVar
uv2Rep then
                       () -> State UnionFind ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                      else if Int
negativeUv1Size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
negativeUv2Size then
                       do (UnionFind -> UnionFind) -> State UnionFind ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((Int -> Int -> IntMap Int -> IntMap Int)
-> UVar -> UVar -> UnionFind -> UnionFind
coerce (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert @Int) UVar
uv1Rep UVar
uv2Rep)
                          (UnionFind -> UnionFind) -> State UnionFind ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((Int -> Int -> IntMap Int -> IntMap Int)
-> UVar -> Int -> UnionFind -> UnionFind
coerce (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert @Int) UVar
uv2Rep (Int
negativeUv1Size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
negativeUv2Size))
                      else
                       do (UnionFind -> UnionFind) -> State UnionFind ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((Int -> Int -> IntMap Int -> IntMap Int)
-> UVar -> UVar -> UnionFind -> UnionFind
coerce (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert @Int) UVar
uv2Rep UVar
uv1Rep)
                          (UnionFind -> UnionFind) -> State UnionFind ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((Int -> Int -> IntMap Int -> IntMap Int)
-> UVar -> Int -> UnionFind -> UnionFind
coerce (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert @Int) UVar
uv1Rep (Int
negativeUv1Size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
negativeUv2Size))

findWithNegSize :: UVar -> State UnionFind (UVar, Int)
findWithNegSize :: UVar -> State UnionFind (UVar, Int)
findWithNegSize UVar
uv = do
  UnionFind
m <- StateT UnionFind Identity UnionFind
forall s (m :: * -> *). MonadState s m => m s
get
  case (Int -> IntMap Int -> Maybe Int) -> UVar -> UnionFind -> Maybe Int
coerce (Int -> IntMap Int -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup @Int) UVar
uv UnionFind
m of
    Maybe Int
Nothing -> UnionFind -> State UnionFind ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ((Int -> Int -> IntMap Int -> IntMap Int)
-> UVar -> Int -> UnionFind -> UnionFind
coerce (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert @Int) UVar
uv (-Int
1 :: Int) UnionFind
m) State UnionFind ()
-> State UnionFind (UVar, Int) -> State UnionFind (UVar, Int)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (UVar, Int) -> State UnionFind (UVar, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UVar
uv, -Int
1)
    Just Int
x
       | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0     -> (UVar, Int) -> State UnionFind (UVar, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UVar
uv, Int
x)
       | Bool
otherwise -> do (UVar
rep,Int
size) <- UVar -> State UnionFind (UVar, Int)
findWithNegSize (Int -> UVar
UVar Int
x)
                         UnionFind -> State UnionFind ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ((Int -> Int -> IntMap Int -> IntMap Int)
-> UVar -> UVar -> UnionFind -> UnionFind
coerce (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert @Int) UVar
uv UVar
rep UnionFind
m)
                         (UVar, Int) -> State UnionFind (UVar, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UVar
rep, Int
size)


find :: UVar -> UnionFind -> (UVar, UnionFind)
find :: UVar -> UnionFind -> (UVar, UnionFind)
find UVar
uv UnionFind
uf = (State UnionFind UVar -> UnionFind -> (UVar, UnionFind))
-> State UnionFind UVar -> UnionFind -> (UVar, UnionFind)
coerce State UnionFind UVar -> UnionFind -> (UVar, UnionFind)
forall s a. State s a -> s -> (a, s)
runState ((UVar, Int) -> UVar
forall a b. (a, b) -> a
fst ((UVar, Int) -> UVar)
-> State UnionFind (UVar, Int) -> State UnionFind UVar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UVar -> State UnionFind (UVar, Int)
findWithNegSize UVar
uv) UnionFind
uf