{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.Bytes.Metrics
  ( levenshteinWithTolerance
  , isWithinLevenshtein
  ) where

import Control.Monad.ST (runST)
import Data.Bytes (Bytes)

import qualified Data.Bytes as Bytes
import qualified Data.Primitive.Contiguous as Arr
import qualified Data.Primitive.PrimArray as Prim

{- | Determine if two 'Bytes' are within a given Levenshtein distance of each other (inclusive).
Computes in O(t*min(n,m)) time and O(min(t,n,m)) space,
where @n,m@ are lengths of the input strings and @t@ is the tolerance.
-}
isWithinLevenshtein :: Int -> Bytes -> Bytes -> Bool
isWithinLevenshtein :: Int -> Bytes -> Bytes -> Bool
isWithinLevenshtein Int
t Bytes
a Bytes
b = Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
t) (Maybe Int -> Bool) -> Maybe Int -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> Bytes -> Bytes -> Maybe Int
levenshteinWithTolerance Int
t Bytes
a Bytes
b

{- | Determine Levenshtein distance between two strings, as long as their
distance is within (inclusive) the given tolerance.
Computes in O(t*min(n,m)) time and O(min(t,n,m)) space,
where @n,m@ are lengths of the input strings and @t@ is the tolerance.
-}
levenshteinWithTolerance :: Int -> Bytes -> Bytes -> Maybe Int
levenshteinWithTolerance :: Int -> Bytes -> Bytes -> Maybe Int
levenshteinWithTolerance !Int
t !Bytes
a !Bytes
b
  -- ensure that the first string (which will create columns) is longer
  -- this minimizes the space needed for intermediate results
  | Int
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = if Bytes
a Bytes -> Bytes -> Bool
forall a. Eq a => a -> a -> Bool
== Bytes
b then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0 else Maybe Int
forall a. Maybe a
Nothing
  | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n = Int -> Bytes -> Bytes -> Maybe Int
levenshteinWithWorker Int
t Bytes
b Bytes
a
  | Bool
otherwise = Int -> Bytes -> Bytes -> Maybe Int
levenshteinWithWorker Int
t Bytes
a Bytes
b
 where
  m :: Int
m = Bytes -> Int
Bytes.length Bytes
a
  n :: Int
n = Bytes -> Int
Bytes.length Bytes
b

-- Precondition: Length of A is less than or equal to length of B.
levenshteinWithWorker :: Int -> Bytes -> Bytes -> Maybe Int
levenshteinWithWorker :: Int -> Bytes -> Bytes -> Maybe Int
levenshteinWithWorker !Int
t !Bytes
a !Bytes
b
  | Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
deltaN = Maybe Int
forall a. Maybe a
Nothing
  | Bool
otherwise = (forall s. ST s (Maybe Int)) -> Maybe Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Maybe Int)) -> Maybe Int)
-> (forall s. ST s (Maybe Int)) -> Maybe Int
forall a b. (a -> b) -> a -> b
$ do
      -- during table creation, some column indices will be negative:
      -- the contents of such oob cells must not impact the contents of in-bounds cells
      -- using maxBound to initialize could provoke overflow on increment
      -- using n+m will definitely be larger than any entry in the table, but likely small enough to avoid wrapping arithmetic
      MutablePrimArray (PrimState (ST s)) Int
row :: Prim.MutablePrimArray s Int <- Int -> Int -> ST s (Mutable PrimArray (PrimState (ST s)) Int)
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Int -> b -> m (Mutable arr (PrimState m) b)
forall (m :: * -> *) b.
(PrimMonad m, Element PrimArray b) =>
Int -> b -> m (Mutable PrimArray (PrimState m) b)
Arr.replicateMut Int
rowLen (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m)
      let outerLoop :: Int -> ST s ()
outerLoop !Int
rowIx
            | Int
rowIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m = do
                let innerLoop :: Int -> ST s ()
innerLoop !Int
bandIx
                      | Int
bandIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
rowLen = do
                          let colIx :: Int
colIx = Int
rowIx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
bandIx
                          let initCost :: Int
initCost = if Int
rowIx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
colIx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else Int
forall a. Bounded a => a
maxBound
                          let !byteA :: Word8
byteA = Bytes -> Int -> Word8
Bytes.unsafeIndex Bytes
a (Int
rowIx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                          let !byteB :: Word8
byteB = Bytes -> Int -> Word8
Bytes.unsafeIndex Bytes
b (Int
colIx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                          !Int
editCost <-
                            if
                              | Bool -> Bool
not (Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
colIx Bool -> Bool -> Bool
&& Int
colIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) -> Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
forall a. Bounded a => a
maxBound
                              | Word8
byteA Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
byteB -> Mutable PrimArray (PrimState (ST s)) Int -> Int -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
forall (m :: * -> *) b.
(PrimMonad m, Element PrimArray b) =>
Mutable PrimArray (PrimState m) b -> Int -> m b
Arr.read MutablePrimArray (PrimState (ST s)) Int
Mutable PrimArray (PrimState (ST s)) Int
row Int
bandIx
                              | Bool
otherwise -> (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int) -> ST s Int -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mutable PrimArray (PrimState (ST s)) Int -> Int -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
forall (m :: * -> *) b.
(PrimMonad m, Element PrimArray b) =>
Mutable PrimArray (PrimState m) b -> Int -> m b
Arr.read MutablePrimArray (PrimState (ST s)) Int
Mutable PrimArray (PrimState (ST s)) Int
row Int
bandIx
                          !Int
insCost <-
                            if Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
bandIx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
                              then (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int) -> ST s Int -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mutable PrimArray (PrimState (ST s)) Int -> Int -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
forall (m :: * -> *) b.
(PrimMonad m, Element PrimArray b) =>
Mutable PrimArray (PrimState m) b -> Int -> m b
Arr.read MutablePrimArray (PrimState (ST s)) Int
Mutable PrimArray (PrimState (ST s)) Int
row (Int
bandIx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                              else Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
forall a. Bounded a => a
maxBound
                          !Int
delCost <-
                            if Int
bandIx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
rowLen
                              then (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int) -> ST s Int -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mutable PrimArray (PrimState (ST s)) Int -> Int -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
forall (m :: * -> *) b.
(PrimMonad m, Element PrimArray b) =>
Mutable PrimArray (PrimState m) b -> Int -> m b
Arr.read MutablePrimArray (PrimState (ST s)) Int
Mutable PrimArray (PrimState (ST s)) Int
row (Int
bandIx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                              else Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
forall a. Bounded a => a
maxBound
                          let cost :: Int
cost = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
initCost Int
editCost) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
insCost Int
delCost)
                          Mutable PrimArray (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> b -> m ()
forall (m :: * -> *) b.
(PrimMonad m, Element PrimArray b) =>
Mutable PrimArray (PrimState m) b -> Int -> b -> m ()
Arr.write MutablePrimArray (PrimState (ST s)) Int
Mutable PrimArray (PrimState (ST s)) Int
row Int
bandIx Int
cost
                          Int -> ST s ()
innerLoop (Int
bandIx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                      | Bool
otherwise = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                Int -> ST s ()
innerLoop Int
0
                Int -> ST s ()
outerLoop (Int
rowIx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            | Bool
otherwise = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Int -> ST s ()
outerLoop Int
0
      Int
d <- Mutable PrimArray (PrimState (ST s)) Int -> Int -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
forall (m :: * -> *) b.
(PrimMonad m, Element PrimArray b) =>
Mutable PrimArray (PrimState m) b -> Int -> m b
Arr.read MutablePrimArray (PrimState (ST s)) Int
Mutable PrimArray (PrimState (ST s)) Int
row (Int
deltaN Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)
      Maybe Int -> ST s (Maybe Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> ST s (Maybe Int)) -> Maybe Int -> ST s (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
d
 where
  m :: Int
m = Bytes -> Int
Bytes.length Bytes
a
  n :: Int
n = Bytes -> Int
Bytes.length Bytes
b
  deltaN :: Int
deltaN = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m
  -- FIXME what a gross name, what even is p really supposed to be? a one-sided external tolerance for the diagonal band?
  p :: Int
p = (Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
deltaN) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2
  -- \| the other way to think of this length is `t - deltaN + (1 - t `mod` 2)`
  -- the floor operation to compute `p` is what gives it that awful last term, and why I'm sticking with the paper's presentation
  rowLen :: Int
rowLen = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
deltaN Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p