{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Data.DiGraph.FloydWarshall
(
DenseAdjMatrix
, AdjacencySets
, fromAdjacencySets
, toAdjacencySets
, ShortestPathMatrix(..)
, floydWarshall
, shortestPath
, distance
, diameter
, distMatrix_
, floydWarshall_
, diameter_
, shortestPaths_
) where
import Control.DeepSeq
import Data.Foldable
import qualified Data.HashMap.Strict as HM
import qualified Data.HashSet as HS
import Data.Massiv.Array as M
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup
#endif
import GHC.Generics
import Numeric.Natural
type DenseAdjMatrix = Array U Ix2 Int
type AdjacencySets = HM.HashMap Int (HS.HashSet Int)
fromAdjacencySets :: AdjacencySets -> DenseAdjMatrix
fromAdjacencySets g = makeArray Seq (Sz (n :. n)) go
where
n = HM.size g
go (i :. j)
| isEdge (i, j) = 1
| isEdge (j, i) = 1
| otherwise = 0
isEdge (a, b) = maybe False (HS.member b) $ HM.lookup a g
toAdjacencySets :: DenseAdjMatrix -> AdjacencySets
toAdjacencySets = ifoldlS f mempty
where
f a (i :. j) x
| x == 0 = a
| otherwise = HM.insertWith (<>) i (HS.singleton j) a
newtype ShortestPathMatrix = ShortestPathMatrix (Array U Ix2 (Double, Int))
deriving (Show, Eq, Ord, Generic)
deriving newtype (NFData)
floydWarshall :: Unbox a => Real a => Array U Ix2 a -> ShortestPathMatrix
floydWarshall = ShortestPathMatrix
. floydWarshallInternal
. computeAs U
. intDistMatrix
shortestPath
:: ShortestPathMatrix
-> Int
-> Int
-> Maybe [Int]
shortestPath (ShortestPathMatrix m) src trg
| M.isEmpty mat = Nothing
| not (M.isSafeIndex (size m) (src :. trg)) = Nothing
| (mat M.! (src :. trg)) == (-1) = Nothing
| otherwise = go src trg
where
mat = M.computeAs U $ M.map snd m
go a b
| a == b = return []
| otherwise = do
n <- M.index mat (a :. b)
(:) n <$> go n b
distance :: ShortestPathMatrix -> Int -> Int -> Maybe Double
distance (ShortestPathMatrix m) src trg
| M.isEmpty m = Nothing
| otherwise = toDistance . fst =<< M.index m (src :. trg)
diameter :: ShortestPathMatrix -> Maybe Double
diameter (ShortestPathMatrix m)
| M.isEmpty m = Just 0
| otherwise = toDistance $ maximum' $ M.map fst m
toDistance :: RealFrac a => a -> Maybe a
toDistance x
| x == 1/0 = Nothing
| otherwise = Just x
intDistMatrix
:: Real a
=> Source r Ix2 a
=> Array r Ix2 a
-> Array M.D Ix2 (Double, Int)
intDistMatrix = M.imap go
where
go (x :. y) e
| x == y = (0, y)
| e > 0 = (realToFrac e, y)
| otherwise = (1/0, -1)
floydWarshallInternal
:: Array U Ix2 (Double, Int)
-> Array U Ix2 (Double,Int)
floydWarshallInternal a = foldl' go a [0..n-1]
where
Sz (n :. _) = size a
go :: Array U Ix2 (Double, Int) -> Int -> Array U Ix2 (Double,Int)
go c k = makeArray Seq (Sz (n :. n)) $ \(x :. y) ->
let
!xy = fst $! c M.! (x :. y)
!xk = fst $! c M.! (x :. k)
!ky = fst $! c M.! (k :. y)
!nxy = snd $! c M.! (x :. y)
!nxk = snd $! c M.! (x :. k)
in if xy > xk + ky then (xk + ky, nxk) else (xy, nxy)
distMatrix_
:: Source r Ix2 Int
=> Array r Ix2 Int
-> Array M.D Ix2 Double
distMatrix_ = M.imap go
where
go (x :. y) e
| x == y = 0
| e > 0 = realToFrac e
| otherwise = 1/0
floydWarshall_
:: Array U Ix2 Double
-> Array U Ix2 Double
floydWarshall_ a = foldl' go a [0..n-1]
where
Sz (n :. _) = size a
go :: Array U Ix2 Double -> Int -> Array U Ix2 Double
go c k = makeArray Seq (Sz (n :. n)) $ \(x :. y) ->
let
!xy = c M.! (x :. y)
!xk = c M.! (x :. k)
!ky = c M.! (k :. y)
in if xy > xk + ky then xk + ky else xy
shortestPaths_ :: Array U Ix2 Int -> Array U Ix2 Double
shortestPaths_ = floydWarshall_ . computeAs U . distMatrix_
diameter_ :: Array U Ix2 Int -> Maybe Natural
diameter_ g
| M.isEmpty g = Just 0
| otherwise = let x = round $ maximum' $ shortestPaths_ g
in if x == round (1/0 :: Double) then Nothing else Just x