{-# language BangPatterns, RankNTypes, ScopedTypeVariables #-}
module Data.Vector.Algorithms where

import Prelude hiding (length)
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST (runST)

import Data.Vector.Generic.Mutable
import qualified Data.Vector.Generic as V
import qualified Data.Vector.Unboxed.Mutable as UMV
import qualified Data.Bit as Bit

import Data.Vector.Algorithms.Common (Comparison)
import Data.Vector.Algorithms.Intro (sortUniqBy)
import qualified Data.Vector.Algorithms.Search  as S

-- | The `nub` function which removes duplicate elements from a vector.
nub :: forall v e . (V.Vector v e, Ord e) => v e -> v e
nub :: v e -> v e
nub = Comparison e -> v e -> v e
forall (v :: * -> *) e. Vector v e => Comparison e -> v e -> v e
nubBy Comparison e
forall a. Ord a => a -> a -> Ordering
compare

-- | A version of `nub` with a custom comparison predicate.
--
-- /Note:/ This function makes use of `sortByUniq` using the intro
-- sort algorithm.
nubBy ::
  forall v e . (V.Vector v e) =>
  Comparison e -> v e -> v e
nubBy :: Comparison e -> v e -> v e
nubBy Comparison e
cmp v e
vec = (forall s. ST s (v e)) -> v e
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v e)) -> v e) -> (forall s. ST s (v e)) -> v e
forall a b. (a -> b) -> a -> b
$ do
  Mutable v s e
mv <- v e -> ST s (Mutable v (PrimState (ST s)) e)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
V.unsafeThaw v e
vec -- safe as the nubByMut algorithm copies the input
  Mutable v s e
destMV <- (Comparison e
 -> Mutable v (PrimState (ST s)) e
 -> ST s (Mutable v (PrimState (ST s)) e))
-> Comparison e
-> Mutable v (PrimState (ST s)) e
-> ST s (Mutable v (PrimState (ST s)) e)
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(Comparison e -> v (PrimState m) e -> m (v (PrimState m) e))
-> Comparison e -> v (PrimState m) e -> m (v (PrimState m) e)
nubByMut Comparison e
-> Mutable v (PrimState (ST s)) e
-> ST s (Mutable v (PrimState (ST s)) e)
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m (v (PrimState m) e)
sortUniqBy Comparison e
cmp Mutable v s e
Mutable v (PrimState (ST s)) e
mv
  v e
v <- Mutable v (PrimState (ST s)) e -> ST s (v e)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
V.unsafeFreeze Mutable v s e
Mutable v (PrimState (ST s)) e
destMV
  v e -> ST s (v e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (v e -> v e
forall (v :: * -> *) a. Vector v a => v a -> v a
V.force v e
v)

-- | The `nubByMut` function takes in an in-place sort algorithm
-- and uses it to do a de-deduplicated sort. It then uses this to
-- remove duplicate elements from the input.
--
-- /Note:/ Since this algorithm needs the original input and so
-- copies before sorting in-place. As such, it is safe to use on
-- immutable inputs.
nubByMut ::
  forall m v e . (PrimMonad m, MVector v e) =>
  (Comparison e -> v (PrimState m) e -> m (v (PrimState m) e))
  -> Comparison e -> v (PrimState m) e -> m (v (PrimState m) e)
nubByMut :: (Comparison e -> v (PrimState m) e -> m (v (PrimState m) e))
-> Comparison e -> v (PrimState m) e -> m (v (PrimState m) e)
nubByMut Comparison e -> v (PrimState m) e -> m (v (PrimState m) e)
alg Comparison e
cmp v (PrimState m) e
inp = do
  let len :: Int
len = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
inp
  v (PrimState m) e
inp' <- v (PrimState m) e -> m (v (PrimState m) e)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone v (PrimState m) e
inp
  v (PrimState m) e
sortUniqs <- Comparison e -> v (PrimState m) e -> m (v (PrimState m) e)
alg Comparison e
cmp v (PrimState m) e
inp'
  let uniqLen :: Int
uniqLen = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
sortUniqs
  MVector (PrimState m) Bit
bitmask <- Int -> Bit -> m (MVector (PrimState m) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UMV.replicate Int
uniqLen (Bool -> Bit
Bit.Bit Bool
False) -- bitmask to track which elements have
                                                   -- already been seen.
  dest ::  v (PrimState m) e <- Int -> m (v (PrimState m) e)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
unsafeNew Int
uniqLen  -- return vector
  let
    go :: Int -> Int -> m ()
    go :: Int -> Int -> m ()
go !Int
srcInd !Int
destInd
      | Int
srcInd Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len = () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      | Int
destInd Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
uniqLen = () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      | Bool
otherwise = do
          e
curr    <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
inp Int
srcInd                -- read current element
          Int
sortInd <- Comparison e -> v (PrimState m) e -> e -> m Int
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> e -> m Int
S.binarySearchBy Comparison e
cmp v (PrimState m) e
sortUniqs e
curr  -- find sorted index
          Bit
bit <- MVector (PrimState m) Bit -> Int -> m Bit
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UMV.unsafeRead MVector (PrimState m) Bit
bitmask Int
sortInd           -- check if we have already seen
                                                          -- this element in bitvector
          case Bit
bit of
            -- if we have seen it then iterate
            Bit.Bit Bool
True -> Int -> Int -> m ()
go (Int
srcInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
destInd
            -- if we haven't then write it into output
            -- and mark that it has been seen
            Bit.Bit Bool
False -> do
              MVector (PrimState m) Bit -> Int -> Bit -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.unsafeWrite MVector (PrimState m) Bit
bitmask Int
sortInd (Bool -> Bit
Bit.Bit Bool
True)
              v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
dest Int
destInd e
curr
              Int -> Int -> m ()
go (Int
srcInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
destInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  Int -> Int -> m ()
go Int
0 Int
0
  v (PrimState m) e -> m (v (PrimState m) e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) e
dest