--------------------------------------------------------------------------------
-- |
-- Module      :  Algorithms.Geometry.FrechetDistance.Discrete
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--------------------------------------------------------------------------------
module Algorithms.Geometry.FrechetDistance.Discrete( discreteFrechetDistance
                                                   , discreteFrechetDistanceWith
                                                   ) where

import           Control.Lens ((^.))
import           Control.Monad.ST (ST,runST)
import           Data.Ext
import           Data.Geometry.Point
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import qualified VectorBuilder.Builder as Builder
import qualified VectorBuilder.Vector as Builder

--------------------------------------------------------------------------------


-- | Returns the discrete frechet distance between two point sequences
-- using the squared Euclidean distance. In other words, returns the
-- square of the (Euclidean) frechet distance.
--
-- running time: \(O((nm))\), where \(n\) and \(m\) are the lengths of
-- the sequences.
discreteFrechetDistance :: (Foldable f, Foldable g,  Functor f, Functor g, Ord r, Num r)
                        => f (Point 2 r :+ p) -> g (Point 2 r :+ q) -> r
discreteFrechetDistance :: f (Point 2 r :+ p) -> g (Point 2 r :+ q) -> r
discreteFrechetDistance = (Point 2 r -> Point 2 r -> r)
-> f (Point 2 r :+ p) -> g (Point 2 r :+ q) -> r
forall (f :: * -> *) (g :: * -> *) r p q.
(Foldable f, Functor f, Functor g, Foldable g, Ord r) =>
(Point 2 r -> Point 2 r -> r)
-> f (Point 2 r :+ p) -> g (Point 2 r :+ q) -> r
discreteFrechetDistanceWith Point 2 r -> Point 2 r -> r
forall r (d :: Nat).
(Num r, Arity d) =>
Point d r -> Point d r -> r
squaredEuclideanDist

-- | Returns the discrete frechet distance between two point sequences
-- using the given distance measure.
--
-- running time: \(O((nm))\), where \(n\) and \(m\) are the lengths of
-- the sequences (and assuming that a distance calculation takes
-- constant time).
discreteFrechetDistanceWith         :: ( Foldable f, Functor f, Functor g, Foldable g, Ord r)
                                    => (Point 2 r -> Point 2 r -> r) -- ^ distance function
                                    -> f (Point 2 r :+ p)
                                    -> g (Point 2 r :+ q) -> r
discreteFrechetDistanceWith :: (Point 2 r -> Point 2 r -> r)
-> f (Point 2 r :+ p) -> g (Point 2 r :+ q) -> r
discreteFrechetDistanceWith Point 2 r -> Point 2 r -> r
d f (Point 2 r :+ p)
ta g (Point 2 r :+ q)
tb = (forall s. ST s r) -> r
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s r) -> r) -> (forall s. ST s r) -> r
forall a b. (a -> b) -> a -> b
$ do
                                                 MVector s (Maybe r)
v <- Int -> Maybe r -> ST s (MVector (PrimState (ST s)) (Maybe r))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m) Maybe r
forall a. Maybe a
Nothing
                                                 let dpTable :: DPTable s r
dpTable = Int -> MVector s (Maybe r) -> DPTable s r
forall s r. Int -> MVector s (Maybe r) -> DPTable s r
DPTable Int
m MVector s (Maybe r)
v
                                                     z :: Loc
z       = Int -> Int -> Loc
Loc Int
0 Int
0
                                                 -- initializes (0,0) with the appropriate distance
                                                 DPTable s r -> Loc -> r -> ST s ()
forall s r. DPTable s r -> Loc -> r -> ST s ()
storeT DPTable s r
dpTable Loc
z (Loc -> r
dist Loc
z)
                                                 (Loc -> r) -> DPTable s r -> Loc -> ST s r
forall r s. Ord r => (Loc -> r) -> DPTable s r -> Loc -> ST s r
evalTable Loc -> r
dist DPTable s r
dpTable (Int -> Int -> Loc
Loc (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
  where
    ta' :: Vector (Point 2 r)
ta' = Builder (Point 2 r) -> Vector (Point 2 r)
forall (vector :: * -> *) element.
Vector vector element =>
Builder element -> vector element
Builder.build (Builder (Point 2 r) -> Vector (Point 2 r))
-> (f (Point 2 r :+ p) -> Builder (Point 2 r))
-> f (Point 2 r :+ p)
-> Vector (Point 2 r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Point 2 r) -> Builder (Point 2 r)
forall (foldable :: * -> *) element.
Foldable foldable =>
foldable element -> Builder element
Builder.foldable (f (Point 2 r) -> Builder (Point 2 r))
-> (f (Point 2 r :+ p) -> f (Point 2 r))
-> f (Point 2 r :+ p)
-> Builder (Point 2 r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Point 2 r :+ p) -> Point 2 r)
-> f (Point 2 r :+ p) -> f (Point 2 r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Point 2 r :+ p)
-> Getting (Point 2 r) (Point 2 r :+ p) (Point 2 r) -> Point 2 r
forall s a. s -> Getting a s a -> a
^.Getting (Point 2 r) (Point 2 r :+ p) (Point 2 r)
forall core extra core'.
Lens (core :+ extra) (core' :+ extra) core core'
core) (f (Point 2 r :+ p) -> Vector (Point 2 r))
-> f (Point 2 r :+ p) -> Vector (Point 2 r)
forall a b. (a -> b) -> a -> b
$ f (Point 2 r :+ p)
ta
    tb' :: Vector (Point 2 r)
tb' = Builder (Point 2 r) -> Vector (Point 2 r)
forall (vector :: * -> *) element.
Vector vector element =>
Builder element -> vector element
Builder.build (Builder (Point 2 r) -> Vector (Point 2 r))
-> (g (Point 2 r :+ q) -> Builder (Point 2 r))
-> g (Point 2 r :+ q)
-> Vector (Point 2 r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g (Point 2 r) -> Builder (Point 2 r)
forall (foldable :: * -> *) element.
Foldable foldable =>
foldable element -> Builder element
Builder.foldable (g (Point 2 r) -> Builder (Point 2 r))
-> (g (Point 2 r :+ q) -> g (Point 2 r))
-> g (Point 2 r :+ q)
-> Builder (Point 2 r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Point 2 r :+ q) -> Point 2 r)
-> g (Point 2 r :+ q) -> g (Point 2 r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Point 2 r :+ q)
-> Getting (Point 2 r) (Point 2 r :+ q) (Point 2 r) -> Point 2 r
forall s a. s -> Getting a s a -> a
^.Getting (Point 2 r) (Point 2 r :+ q) (Point 2 r)
forall core extra core'.
Lens (core :+ extra) (core' :+ extra) core core'
core) (g (Point 2 r :+ q) -> Vector (Point 2 r))
-> g (Point 2 r :+ q) -> Vector (Point 2 r)
forall a b. (a -> b) -> a -> b
$ g (Point 2 r :+ q)
tb
    n :: Int
n = Vector (Point 2 r) -> Int
forall a. Vector a -> Int
V.length Vector (Point 2 r)
ta'
    m :: Int
m = Vector (Point 2 r) -> Int
forall a. Vector a -> Int
V.length Vector (Point 2 r)
tb'

    dist :: Loc -> r
dist (Loc Int
r Int
c) = Point 2 r -> Point 2 r -> r
d (Vector (Point 2 r)
ta' Vector (Point 2 r) -> Int -> Point 2 r
forall a. Vector a -> Int -> a
V.! Int
r) (Vector (Point 2 r)
tb' Vector (Point 2 r) -> Int -> Point 2 r
forall a. Vector a -> Int -> a
V.! Int
c)

data Loc = Loc !Int !Int deriving (Int -> Loc -> ShowS
[Loc] -> ShowS
Loc -> String
(Int -> Loc -> ShowS)
-> (Loc -> String) -> ([Loc] -> ShowS) -> Show Loc
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Loc] -> ShowS
$cshowList :: [Loc] -> ShowS
show :: Loc -> String
$cshow :: Loc -> String
showsPrec :: Int -> Loc -> ShowS
$cshowsPrec :: Int -> Loc -> ShowS
Show,Loc -> Loc -> Bool
(Loc -> Loc -> Bool) -> (Loc -> Loc -> Bool) -> Eq Loc
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Loc -> Loc -> Bool
$c/= :: Loc -> Loc -> Bool
== :: Loc -> Loc -> Bool
$c== :: Loc -> Loc -> Bool
Eq)

data DPTable s r = DPTable !Int (MV.MVector s (Maybe r))

-- | compute the discrete frechet distance between the subtrajectories
-- up to the given Loc using dpTable for memoization memoization
evalTable              :: Ord r => (Loc -> r) -> DPTable s r -> Loc -> ST s r
evalTable :: (Loc -> r) -> DPTable s r -> Loc -> ST s r
evalTable Loc -> r
dist DPTable s r
dpTable = Loc -> ST s r
go
  where
    go :: Loc -> ST s r
go Loc
p = DPTable s r -> Loc -> ST s (Maybe r)
forall s r. DPTable s r -> Loc -> ST s (Maybe r)
lookupT DPTable s r
dpTable Loc
p ST s (Maybe r) -> (Maybe r -> ST s r) -> ST s r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
             Just r
d  -> r -> ST s r
forall (f :: * -> *) a. Applicative f => a -> f a
pure r
d
             Maybe r
Nothing -> do
                          r
fd <- [r] -> r
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum ([r] -> r) -> ST s [r] -> ST s r
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Loc -> ST s r) -> [Loc] -> ST s [r]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Loc -> ST s r
go (Loc -> [Loc]
prevs Loc
p)
                          let d :: r
d = Loc -> r
dist Loc
p r -> r -> r
forall a. Ord a => a -> a -> a
`max` r
fd
                          DPTable s r -> Loc -> r -> ST s ()
forall s r. DPTable s r -> Loc -> r -> ST s ()
storeT DPTable s r
dpTable Loc
p r
d
                          r -> ST s r
forall (f :: * -> *) a. Applicative f => a -> f a
pure r
d

-- | Look up a value in the DP Table
lookupT                           :: DPTable s r -> Loc -> ST s (Maybe r)
lookupT :: DPTable s r -> Loc -> ST s (Maybe r)
lookupT (DPTable Int
m MVector s (Maybe r)
v) (Loc Int
r Int
c) = MVector (PrimState (ST s)) (Maybe r) -> Int -> ST s (Maybe r)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s (Maybe r)
MVector (PrimState (ST s)) (Maybe r)
v (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c)

-- | Stoer a value in the DP table
storeT                             :: DPTable s r -> Loc -> r -> ST s ()
storeT :: DPTable s r -> Loc -> r -> ST s ()
storeT (DPTable Int
m MVector s (Maybe r)
v) (Loc Int
r Int
c) r
d = MVector (PrimState (ST s)) (Maybe r) -> Int -> Maybe r -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (Maybe r)
MVector (PrimState (ST s)) (Maybe r)
v (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c) (r -> Maybe r
forall a. a -> Maybe a
Just r
d)

-- | Candidate previous locations
prevs           :: Loc -> [Loc]
prevs :: Loc -> [Loc]
prevs (Loc Int
r Int
c) = (Loc -> Bool) -> [Loc] -> [Loc]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Loc Int
x Int
y) -> Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)
                    [Int -> Int -> Loc
Loc (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
c, Int -> Int -> Loc
Loc (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), Int -> Int -> Loc
Loc Int
r (Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]