-- This file is part of the 'union-find-array' library. It is licensed
-- under an MIT license. See the accompanying 'LICENSE' file for details.
--
-- Authors: Bertram Felgenhauer

{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances, UndecidableInstances #-}
module Control.Monad.Union.Class (
    MonadUnion (..),
) where

import Data.Union.Type (Node (..), Union (..))
import Control.Monad.Trans (MonadTrans (..))
import Prelude hiding (lookup)

class Monad m => MonadUnion l m | m -> l where
    -- | Add a new node, with a given label.
    new :: l -> m Node

    -- | Find the node representing a given node, and its label.
    lookup :: Node -> m (Node, l)

    -- | Merge two sets. The first argument is a function that takes the labels
    -- of the corresponding sets' representatives and computes a new label for
    -- the joined set. Returns Nothing if the given nodes are in the same set
    -- already.
    merge :: (l -> l -> (l, a)) -> Node -> Node -> m (Maybe a)

    -- | Re-label a node.
    annotate :: Node -> l -> m ()

    -- | Flatten the disjoint set forest for faster lookups.
    flatten :: m ()

instance (MonadUnion l m, MonadTrans t, Monad (t m)) => MonadUnion l (t m) where
    new :: l -> t m Node
new l
a = m Node -> t m Node
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Node -> t m Node) -> m Node -> t m Node
forall a b. (a -> b) -> a -> b
$ l -> m Node
forall l (m :: * -> *). MonadUnion l m => l -> m Node
new l
a
    lookup :: Node -> t m (Node, l)
lookup Node
a = m (Node, l) -> t m (Node, l)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Node, l) -> t m (Node, l)) -> m (Node, l) -> t m (Node, l)
forall a b. (a -> b) -> a -> b
$ Node -> m (Node, l)
forall l (m :: * -> *). MonadUnion l m => Node -> m (Node, l)
lookup Node
a
    merge :: (l -> l -> (l, a)) -> Node -> Node -> t m (Maybe a)
merge l -> l -> (l, a)
a Node
b Node
c = m (Maybe a) -> t m (Maybe a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Maybe a) -> t m (Maybe a)) -> m (Maybe a) -> t m (Maybe a)
forall a b. (a -> b) -> a -> b
$ (l -> l -> (l, a)) -> Node -> Node -> m (Maybe a)
forall l (m :: * -> *) a.
MonadUnion l m =>
(l -> l -> (l, a)) -> Node -> Node -> m (Maybe a)
merge l -> l -> (l, a)
a Node
b Node
c
    annotate :: Node -> l -> t m ()
annotate Node
a l
b = m () -> t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> m () -> t m ()
forall a b. (a -> b) -> a -> b
$ Node -> l -> m ()
forall l (m :: * -> *). MonadUnion l m => Node -> l -> m ()
annotate Node
a l
b
    flatten :: t m ()
flatten = m () -> t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> m () -> t m ()
forall a b. (a -> b) -> a -> b
$ m ()
forall l (m :: * -> *). MonadUnion l m => m ()
flatten