{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE ScopedTypeVariables, BangPatterns #-}
module ToySolver.Combinatorial.BipartiteMatching
(
maximumCardinalityMatching
, maximumWeightMatching
, maximumWeightMatchingComplete
, maximumWeightPerfectMatching
, minimumWeightPerfectMatching
, maximumWeightPerfectMatchingComplete
, minimumWeightPerfectMatchingComplete
, minimumCardinalityEdgeCover
, minimumWeightEdgeCover
, minimumWeightEdgeCoverComplete
) where
import Control.Monad
import qualified Data.Foldable as F
import Data.IntMap.Strict (IntMap, (!))
import qualified Data.IntMap.Strict as IntMap
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Maybe
maximumCardinalityMatching
:: IntSet
-> IntSet
-> [(Int,Int)]
-> IntMap Int
maximumCardinalityMatching _as bs es =
case maximumCardinalityMatching' bs (\b -> IntMap.findWithDefault IntSet.empty b e_b2a) IntMap.empty of
(m, _, _) -> m
where
e_b2a :: IntMap IntSet
e_b2a = IntMap.fromListWith IntSet.union [(b, IntSet.singleton a) | (a,b) <- es]
type AlternatingPath = (Int, [(Int,Int)], Int)
type AugmentingPath = ([(Int,Int)], Int)
maximumCardinalityMatching'
:: IntSet
-> (Int -> IntSet)
-> IntMap Int
-> (IntMap Int, IntSet, IntSet)
maximumCardinalityMatching' bs e_b2a m0 = loop m0 m0_b_exposed
where
m0_b_exposed = bs `IntSet.difference` IntSet.fromList (IntMap.elems m0)
loop :: IntMap Int -> IntSet -> (IntMap Int, IntSet, IntSet)
loop m m_b_exposed =
case search m_b_exposed of
(l_a, l_b, []) -> (m, l_a, l_b)
(_, _, ds) ->
let
ds2 = [IntMap.fromList d2 | (d2,_) <- ds]
m' = IntMap.unions ds2 `IntMap.union` m
m_b_exposed' = m_b_exposed `IntSet.difference` IntSet.fromList [b0 | (_, b0) <- ds]
in loop m' m_b_exposed'
where
search :: IntSet -> (IntSet, IntSet, [AugmentingPath])
search b_exposed = loopB IntSet.empty b_exposed [(b, [], b) | b <- IntSet.toList b_exposed] [] []
where
loopB :: IntSet -> IntSet -> [AlternatingPath] -> [AlternatingPath] -> [AugmentingPath] -> (IntSet, IntSet, [AugmentingPath])
loopB !visitedA !visitedB [] [] result = (visitedA, visitedB, result)
loopB !visitedA !visitedB [] nextB result = loopB visitedA visitedB nextB [] result
loopB !visitedA !visitedB ((b, d2, b0) : currB) nextB result = loopA visitedA visitedB (IntSet.toList (e_b2a b)) currB nextB result
where
loopA !visitedA !visitedB [] currB nextB result = loopB visitedA visitedB currB nextB result
loopA !visitedA !visitedB (a:as) currB nextB result
| a `IntSet.member` visitedA = loopA visitedA visitedB as currB nextB result
| otherwise =
case IntMap.lookup a m of
Nothing ->
loopB (IntSet.insert a visitedA) visitedB (filter (\(_,_,b0') -> b0/=b0') currB) (filter (\(_,_,b0') -> b0/=b0') nextB) (((a,b) : d2, b0) : result)
Just b2
| b==b2 -> loopA visitedA visitedB as currB nextB result
| b2 `IntSet.member` visitedB -> loopA (IntSet.insert a visitedA) visitedB as currB nextB result
| otherwise -> loopA (IntSet.insert a visitedA) (IntSet.insert b2 visitedB) as currB ((b2, (a,b):d2, b0) : nextB) result
maximumWeightMatchingComplete
:: forall w. (Real w)
=> IntSet
-> IntSet
-> (Int -> Int -> w)
-> (w, IntMap Int)
maximumWeightMatchingComplete as bs w =
case maximumWeightMaximumMatchingComplete as bs (\a b -> max 0 (w a b)) of
(_, m) ->
let m' = IntMap.filterWithKey (\a b -> w a b > 0) m
in (sum [w a b | (a,b) <- IntMap.toList m'], m')
maximumWeightMatching
:: forall w. (Real w)
=> IntSet
-> IntSet
-> [(Int, Int, w)]
-> (w, IntMap Int)
maximumWeightMatching as bs w =
case maximumWeightMaximumMatchingComplete as bs g of
(_, m) ->
let m' = IntMap.filterWithKey (\a b -> isJust (f a b)) m
in (sum [g a b | (a,b) <- IntMap.toList m'], m')
where
tbl :: IntMap (IntMap w)
tbl = IntMap.fromListWith IntMap.union [(a, (IntMap.singleton b v)) | (a,b,v) <- w]
f a b = do
t <- IntMap.lookup a tbl
v <- IntMap.lookup b t
guard $ v >= 0
return v
g a b = fromMaybe 0 (f a b)
maximumWeightMaximumMatchingComplete
:: forall w. (Real w)
=> IntSet
-> IntSet
-> (Int -> Int -> w)
-> (w, IntMap Int)
maximumWeightMaximumMatchingComplete as bs w =
case as_size `compare` bs_size of
EQ ->
case maximumWeightPerfectMatchingComplete as bs w of
(obj, sol, _) -> (obj, sol)
GT ->
let bs' = bs `IntSet.union` IntSet.fromAscList (take (as_size-bs_size) $ filter (`IntSet.notMember` bs) [0..])
w' a b
| b `IntSet.member` bs = w a b
| otherwise = 0
in case maximumWeightPerfectMatchingComplete as bs' w' of
(obj, sol, _) ->
( obj
, IntMap.filterWithKey (\_ b -> b `IntSet.member` bs) sol
)
LT ->
let as' = as `IntSet.union` IntSet.fromAscList (take (bs_size-as_size) $ filter (`IntSet.notMember` as) [0..])
w' a b
| a `IntSet.member` as = w a b
| otherwise = 0
in case maximumWeightPerfectMatchingComplete as' bs w' of
(obj, sol, _) ->
( obj
, IntMap.filterWithKey (\a _ -> a `IntSet.member` as) sol
)
where
as_size = IntSet.size as
bs_size = IntSet.size bs
maximumWeightPerfectMatchingComplete
:: forall w. (Real w)
=> IntSet
-> IntSet
-> (Int -> Int -> w)
-> (w, IntMap Int, (IntMap w, IntMap w))
maximumWeightPerfectMatchingComplete as bs w =
case minimumWeightPerfectMatchingComplete as bs (\a b -> - w a b) of
(obj, m, (ysA,ysB)) -> (- obj, m, (IntMap.map negate ysA, IntMap.map negate ysB))
minimumWeightPerfectMatchingComplete
:: forall w. (Real w)
=> IntSet
-> IntSet
-> (Int -> Int -> w)
-> (w, IntMap Int, (IntMap w, IntMap w))
minimumWeightPerfectMatchingComplete as bs w
| n /= IntSet.size bs = error "minimumWeightPerfectMatchingComplete: two sets must be same size"
| otherwise = loop m0 ys0 (equalityGraph ys0)
where
n = IntSet.size as
ys0 :: (IntMap w, IntMap w)
ys0 = ( IntMap.fromSet (\a -> minimum [w a b | b <- IntSet.toList bs]) as
, IntMap.fromSet (\_ -> 0) bs
)
m0 = IntMap.empty
loop
:: IntMap Int -> (IntMap w, IntMap w) -> IntMap IntSet
-> (w, IntMap Int, (IntMap w, IntMap w))
loop m_pre ys@(ysA,ysB) g_eq
| IntMap.size m == n = (F.sum ysA + F.sum ysB, m, ys)
| otherwise = loop m ys' g_eq'
where
(m, l_a, l_b) = maximumCardinalityMatching' bs (g_eq !) m_pre
l_a' = as `IntSet.difference` l_a
slack :: w
slack = minimum
[ w a b - (ysA!a + ysB!b)
| a <- IntSet.toList l_a'
, b <- IntSet.toList l_b
]
ys' :: (IntMap w, IntMap w)
ys' = (IntMap.mapWithKey f ysA, IntMap.mapWithKey g ysB)
where
f a val
| a `IntSet.notMember` l_a = val + slack
| otherwise = val
g b val
| b `IntSet.notMember` l_b = val - slack
| otherwise = val
g_eq' :: IntMap IntSet
g_eq' = IntMap.mapWithKey f g_eq
where
f b as3
| b `IntSet.member` l_b =
as3 `IntSet.union` IntSet.filter (\a -> w a b == (fst ys' ! a + snd ys' ! b)) l_a'
| otherwise = as3 `IntSet.difference` l_a
equalityGraph :: (IntMap w, IntMap w) -> IntMap IntSet
equalityGraph (ysA,ysB) =
IntMap.fromSet (\b -> IntSet.filter (\a -> w a b == ysA!a + ysB!b) as) bs
maximumWeightPerfectMatching
:: forall w. (Real w)
=> IntSet
-> IntSet
-> [(Int,Int,w)]
-> Maybe (w, IntMap Int, (IntMap w, IntMap w))
maximumWeightPerfectMatching as bs es = do
(obj, m, (ysA,ysB)) <- minimumWeightPerfectMatching as bs [(a,b,-w) |(a,b,w) <- es]
return (- obj, m, (IntMap.map negate ysA, IntMap.map negate ysB))
minimumWeightPerfectMatching
:: forall w. (Real w)
=> IntSet
-> IntSet
-> [(Int,Int,w)]
-> Maybe (w, IntMap Int, (IntMap w, IntMap w))
minimumWeightPerfectMatching as bs es
| n /= IntSet.size bs = Nothing
| F.any IntMap.null e_b2a = Nothing
| otherwise = loop m0 ys0 (equalityGraph ys0)
where
n = IntSet.size as
e_b2a :: IntMap (IntMap w)
e_b2a = fmap IntMap.fromList $ IntMap.fromListWith (++) [(b,[(a,w)]) | (a,b,w) <- es]
`IntMap.union` IntMap.fromSet (\_ -> []) bs
ys0 :: (IntMap w, IntMap w)
ys0 = ( IntMap.fromSet (\_ -> 0) as
, fmap F.minimum e_b2a
)
m0 = IntMap.empty
loop
:: IntMap Int -> (IntMap w, IntMap w) -> IntMap IntSet
-> Maybe (w, IntMap Int, (IntMap w, IntMap w))
loop m_pre ys@(ysA,ysB) g_eq
| IntMap.size m == n = Just (F.sum ysA + F.sum ysB, m, ys)
| null slacks = Nothing
| otherwise = loop m ys' g_eq'
where
(m, l_a, l_b) = maximumCardinalityMatching' bs (g_eq !) m_pre
slacks :: [w]
slacks = [w - (ysA!a + ysB!b) | b <- IntSet.toList l_b, (a,w) <- IntMap.toList (e_b2a ! b), a `IntSet.notMember` l_a]
slack :: w
slack = minimum slacks
ys' :: (IntMap w, IntMap w)
ys' = (IntMap.mapWithKey f ysA, IntMap.mapWithKey g ysB)
where
f a val
| a `IntSet.notMember` l_a = val + slack
| otherwise = val
g b val
| b `IntSet.notMember` l_b = val - slack
| otherwise = val
g_eq' :: IntMap IntSet
g_eq' = IntMap.mapWithKey f g_eq
where
f b as3
| b `IntSet.member` l_b =
as3 `IntSet.union` IntSet.fromAscList [a | (a,w) <- IntMap.toAscList (e_b2a ! b), a `IntSet.notMember` l_a, w == fst ys' ! a + snd ys' ! b]
| otherwise = as3 `IntSet.difference` l_a
equalityGraph :: (IntMap w, IntMap w) -> IntMap IntSet
equalityGraph (ysA,ysB) = IntMap.mapWithKey f e_b2a
where
f b xs = IntSet.fromAscList [a | (a,w) <- IntMap.toAscList xs, w == ysA!a + ysB!b]
minimumCardinalityEdgeCover
:: IntSet
-> IntSet
-> [(Int,Int)]
-> Maybe (Set (Int,Int))
minimumCardinalityEdgeCover as bs es
| IntMap.size ca /= IntSet.size as = Nothing
| IntMap.size cb /= IntSet.size bs = Nothing
| otherwise =
case maximumCardinalityMatching as bs es of
m ->
let ma = IntMap.keysSet m
mb = IntSet.fromList $ IntMap.elems m
m2 = Set.unions
[ Set.fromList (IntMap.toList m)
, Set.fromList [(a,b) | a <- IntSet.toList (as `IntSet.difference` ma), let b = ca IntMap.! a]
, Set.fromList [(a,b) | b <- IntSet.toList (bs `IntSet.difference` mb), let a = cb IntMap.! b]
]
in Just $ m2
where
ca = IntMap.fromList es
cb = IntMap.fromList [(b,a) | (a,b) <- es]
minimumWeightEdgeCover
:: forall w. (Real w)
=> IntSet
-> IntSet
-> [(Int,Int,w)]
-> Maybe (Set (Int,Int))
minimumWeightEdgeCover as bs es
| IntMap.size ca /= IntSet.size as = Nothing
| IntMap.size cb /= IntSet.size bs = Nothing
| otherwise =
case maximumWeightMatching as' bs' es' of
(_, m) ->
let ma = IntMap.keysSet m
mb = IntSet.fromList $ IntMap.elems m
m2 = Set.unions
[ Set.fromList (IntMap.toList m)
, Set.fromList [(a,b) | a <- IntSet.toList (as `IntSet.difference` ma), let (b,_) = ca IntMap.! a]
, Set.fromList [(a,b) | b <- IntSet.toList (bs `IntSet.difference` mb), let (a,_) = cb IntMap.! b]
, Set.fromList [(a,b) | (a,b,w) <- es, w < 0]
]
in Just m2
where
minOnSnd xw1@(_,w1) xw2@(_,w2) = if w1 <= w2 then xw1 else xw2
ca = IntMap.fromListWith minOnSnd [(a,(b,w)) | (a,b,w) <- es]
cb = IntMap.fromListWith minOnSnd [(b,(a,w)) | (a,b,w) <- es]
as' = IntSet.fromAscList [a | (a,(_,w)) <- IntMap.toAscList ca, w >= 0]
bs' = IntSet.fromAscList [b | (b,(_,w)) <- IntMap.toAscList cb, w >= 0]
es' = [(a, b, (snd (ca IntMap.! a) + snd (cb IntMap.! b)) - w) | (a,b,w) <- es, w >= 0]
minimumWeightEdgeCoverComplete
:: forall w. (Real w)
=> IntSet
-> IntSet
-> (Int -> Int -> w)
-> Maybe (Set (Int,Int))
minimumWeightEdgeCoverComplete as bs w
| IntMap.size ca /= IntSet.size as = Nothing
| IntMap.size cb /= IntSet.size bs = Nothing
| otherwise =
case maximumWeightMatching as' bs' es' of
(_, m) ->
let ma = IntMap.keysSet m
mb = IntSet.fromList $ IntMap.elems m
m2 = Set.unions
[ Set.fromList (IntMap.toList m)
, Set.fromList [(a,b) | a <- IntSet.toList (as `IntSet.difference` ma), let (b,_) = ca IntMap.! a]
, Set.fromList [(a,b) | b <- IntSet.toList (bs `IntSet.difference` mb), let (a,_) = cb IntMap.! b]
, Set.fromList [(a,b) | a <- IntSet.toList as, b <- IntSet.toList bs, let w' = w a b, w' < 0]
]
in Just m2
where
minOnSnd xw1@(_,w1) xw2@(_,w2) = if w1 <= w2 then xw1 else xw2
ca = IntMap.fromListWith minOnSnd [(a, (b, w a b)) | a <- IntSet.toList as, b <- IntSet.toList bs]
cb = IntMap.fromListWith minOnSnd [(b, (a, w a b)) | a <- IntSet.toList as, b <- IntSet.toList bs]
as' = IntSet.fromAscList [a | (a,(_,w)) <- IntMap.toAscList ca, w >= 0]
bs' = IntSet.fromAscList [b | (b,(_,w)) <- IntMap.toAscList cb, w >= 0]
es' = [ (a, b, (snd (ca IntMap.! a) + snd (cb IntMap.! b)) - w')
| a <- IntSet.toList as, b <- IntSet.toList bs, let w' = w a b, w' >= 0 ]