module Jikka.Common.Graph where

import Control.Monad
import Control.Monad.ST
import Data.List (nub)
import Data.STRef
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV

type Graph = V.Vector [Int]

makeReversedDigraph :: Graph -> Graph
makeReversedDigraph :: Graph -> Graph
makeReversedDigraph Graph
g = (forall s. ST s Graph) -> Graph
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Graph) -> Graph)
-> (forall s. ST s Graph) -> Graph
forall a b. (a -> b) -> a -> b
$ do
  let n :: Int
n = Graph -> Int
forall a. Vector a -> Int
V.length Graph
g
  MVector s [Int]
h <- Int -> [Int] -> ST s (MVector (PrimState (ST s)) [Int])
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
n []
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Graph
g Graph -> Int -> [Int]
forall a. Vector a -> Int -> a
V.! Int
x) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
y -> do
      MVector (PrimState (ST s)) [Int]
-> ([Int] -> [Int]) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s [Int]
MVector (PrimState (ST s)) [Int]
h (Int
x Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) Int
y
  MVector (PrimState (ST s)) [Int] -> ST s Graph
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector s [Int]
MVector (PrimState (ST s)) [Int]
h

makeInducedDigraph :: Graph -> V.Vector Int -> Graph
makeInducedDigraph :: Graph -> Vector Int -> Graph
makeInducedDigraph Graph
g Vector Int
f = (forall s. ST s Graph) -> Graph
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Graph) -> Graph)
-> (forall s. ST s Graph) -> Graph
forall a b. (a -> b) -> a -> b
$ do
  let n :: Int
n = Graph -> Int
forall a. Vector a -> Int
V.length Graph
g
  let k :: Int
k = if Vector Int -> Bool
forall a. Vector a -> Bool
V.null Vector Int
f then Int
0 else Vector Int -> Int
forall a. Ord a => Vector a -> a
V.maximum Vector Int
f Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  MVector s [Int]
h <- Int -> [Int] -> ST s (MVector (PrimState (ST s)) [Int])
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
k []
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Graph
g Graph -> Int -> [Int]
forall a. Vector a -> Int -> a
V.! Int
x) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
y -> do
      MVector (PrimState (ST s)) [Int]
-> ([Int] -> [Int]) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s [Int]
MVector (PrimState (ST s)) [Int]
h ((Vector Int
f Vector Int -> Int -> Int
forall a. Vector a -> Int -> a
V.! Int
y) Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) (Vector Int
f Vector Int -> Int -> Int
forall a. Vector a -> Int -> a
V.! Int
x)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
a -> do
    MVector (PrimState (ST s)) [Int]
-> ([Int] -> [Int]) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s [Int]
MVector (PrimState (ST s)) [Int]
h [Int] -> [Int]
forall a. Eq a => [a] -> [a]
nub Int
a
  MVector (PrimState (ST s)) [Int] -> ST s Graph
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector s [Int]
MVector (PrimState (ST s)) [Int]
h

-- | `decomposeToStronglyConnectedComponents` does SCC in \(O(V + E)\) using Kosaraju's algorithm.
-- It takes a digraph \(G = (V, E)\) as an adjacent list \(g : V \to V^{\lt \omega}\), and returns an mapping \(f : V \to V'\) for the SCC DAG \(G' = (V', E')\).
-- The indices of vertices of the SCC DAG are topologically sorted.
decomposeToStronglyConnectedComponents :: Graph -> V.Vector Int
decomposeToStronglyConnectedComponents :: Graph -> Vector Int
decomposeToStronglyConnectedComponents Graph
g = (forall s. ST s (Vector Int)) -> Vector Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Int)) -> Vector Int)
-> (forall s. ST s (Vector Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
  let n :: Int
n = Graph -> Int
forall a. Vector a -> Int
V.length Graph
g
  let unless' :: MVector (PrimState m) Bool -> Int -> m () -> m ()
unless' MVector (PrimState m) Bool
used Int
x m ()
f = do
        Bool
usedX <- MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector (PrimState m) Bool
used Int
x
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
usedX (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          m ()
f
  -- The first DFS
  let order :: Vector Int
order = Graph -> Vector Int
topologicalSort Graph
g
  -- DFS on the reversed graph
  let gRev :: Graph
gRev = Graph -> Graph
makeReversedDigraph Graph
g
  MVector s Int
componentOf <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
n (-Int
1)
  STRef s Integer
size <- Integer -> ST s (STRef s Integer)
forall a s. a -> ST s (STRef s a)
newSTRef Integer
0
  MVector s Bool
used <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
n Bool
False
  let go :: Int -> ST s ()
go Int
x = do
        MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
x Bool
True
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Graph
gRev Graph -> Int -> [Int]
forall a. Vector a -> Int -> a
V.! Int
x) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
y -> do
          MVector (PrimState (ST s)) Bool -> Int -> ST s () -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bool -> Int -> m () -> m ()
unless' MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
y (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            Int -> ST s ()
go Int
y
  Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
V.forM_ Vector Int
order ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
    MVector (PrimState (ST s)) Bool -> Int -> ST s () -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bool -> Int -> m () -> m ()
unless' MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
x (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
      Int -> ST s ()
go Int
x
      STRef s Integer -> (Integer -> Integer) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Integer
size Integer -> Integer
forall a. Enum a => a -> a
succ
  MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector s Int
MVector (PrimState (ST s)) Int
componentOf

-- | `topologicalSort` does topological sort in \(O(V + E)\) using Tarjan's algorithm.
-- The input is an adjacent list of a DAG.
topologicalSort :: Graph -> V.Vector Int
topologicalSort :: Graph -> Vector Int
topologicalSort Graph
g = (forall s. ST s (Vector Int)) -> Vector Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Int)) -> Vector Int)
-> (forall s. ST s (Vector Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
  let n :: Int
n = Graph -> Int
forall a. Vector a -> Int
V.length Graph
g
  let unless' :: MVector (PrimState m) Bool -> Int -> m () -> m ()
unless' MVector (PrimState m) Bool
used Int
x m ()
f = do
        Bool
usedX <- MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector (PrimState m) Bool
used Int
x
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
usedX (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          m ()
f
  STRef s [Int]
order <- [Int] -> ST s (STRef s [Int])
forall a s. a -> ST s (STRef s a)
newSTRef []
  MVector s Bool
used <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
n Bool
False
  let go :: Int -> ST s ()
go Int
x = do
        MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
x Bool
True
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Graph
g Graph -> Int -> [Int]
forall a. Vector a -> Int -> a
V.! Int
x) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
y -> do
          MVector (PrimState (ST s)) Bool -> Int -> ST s () -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bool -> Int -> m () -> m ()
unless' MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
y (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            Int -> ST s ()
go Int
y
        STRef s [Int] -> ([Int] -> [Int]) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s [Int]
order (Int
x Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
    MVector (PrimState (ST s)) Bool -> Int -> ST s () -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bool -> Int -> m () -> m ()
unless' MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
x (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
      Int -> ST s ()
go Int
x
  [Int] -> Vector Int
forall a. [a] -> Vector a
V.fromList ([Int] -> Vector Int) -> ST s [Int] -> ST s (Vector Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STRef s [Int] -> ST s [Int]
forall s a. STRef s a -> ST s a
readSTRef STRef s [Int]
order