{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Std.Data.Vector.Search (
findIndices, elemIndices
, find, findR
, findIndex, findIndexR
, filter, partition
, indicesOverlapping
, indices
, elemIndicesBytes, findByte, findByteR
, indicesOverlappingBytes, indicesBytes
, kmpNextTable
, sundayBloom
, elemSundayBloom
) where
import Control.Monad.ST
import Data.Bits
import GHC.Word
import Prelude hiding (filter, partition)
import Std.Data.Array
import Std.Data.PrimArray.BitTwiddle (c_memchr, memchrReverse)
import Std.Data.Vector.Base
elemIndices :: (Vec v a, Eq a) => a -> v a -> [Int]
{-# INLINE [1] elemIndices #-}
{-# RULES "elemIndices/Bytes" elemIndices = elemIndicesBytes #-}
elemIndices w (Vec arr s l) = go s
where
!end = s + l
go !i
| i >= end = []
| x == w = let !i' = i - s in i' : go (i+1)
| otherwise = go (i+1)
where (# x #) = indexArr' arr i
findIndices :: Vec v a => (a -> Bool) -> v a -> [Int]
{-# INLINE [1] findIndices #-}
{-# RULES "findIndices/Bytes1" forall w. findIndices (w `eqWord8`) = elemIndicesBytes w #-}
{-# RULES "findIndices/Bytes2" forall w. findIndices (`eqWord8` w) = elemIndicesBytes w #-}
findIndices f (Vec arr s l) = go s
where
!end = s + l
go !p | p >= end = []
| f x = p : go (p+1)
| otherwise = go (p+1)
where (# x #) = indexArr' arr p
elemIndicesBytes :: Word8 -> Bytes -> [Int]
{-# INLINE elemIndicesBytes #-}
elemIndicesBytes w (PrimVector (PrimArray ba#) s l) = go s
where
!end = s + l
go !i
| i >= end = []
| otherwise =
case c_memchr ba# i w (end - i) of
-1 -> []
r -> let !i' = (i+r) in i': go (i'+1)
findIndex :: Vec v a => (a -> Bool) -> v a -> Int
{-# INLINE findIndex #-}
findIndex f v = fst (find f v)
findIndexR :: Vec v a => (a -> Bool) -> v a -> Int
{-# INLINE findIndexR #-}
findIndexR f v = fst (findR f v)
find :: Vec v a => (a -> Bool) -> v a -> (Int, Maybe a)
{-# INLINE [1] find #-}
{-# RULES "find/Bytes1" forall w. find (w `eqWord8`) = findByte w #-}
{-# RULES "find/Bytes2" forall w. find (`eqWord8` w) = findByte w #-}
find f (Vec arr s l) = go s
where
!end = s + l
go !p | p >= end = (l, Nothing)
| f x = let !i = p-s in (i, Just x)
| otherwise = go (p+1)
where (# x #) = indexArr' arr p
findByte :: Word8 -> Bytes -> (Int, Maybe Word8)
{-# INLINE findByte #-}
findByte w (PrimVector (PrimArray ba#) s l) =
case c_memchr ba# s w l of
-1 -> (l, Nothing)
r -> (r, Just w)
findR :: Vec v a => (a -> Bool) -> v a -> (Int, Maybe a)
{-# INLINE [1] findR #-}
{-# RULES "findR/Bytes1" forall w. findR (w `eqWord8`) = findByteR w #-}
{-# RULES "findR/Bytes2" forall w. findR (`eqWord8` w) = findByteR w #-}
findR f (Vec arr s l) = go (s+l-1)
where
go !p | p < s = (-1, Nothing)
| f x = let !i = p-s in (i, Just x)
| otherwise = go (p-1)
where (# x #) = indexArr' arr p
findByteR :: Word8 -> Bytes -> (Int, Maybe Word8)
{-# INLINE findByteR #-}
findByteR w (PrimVector ba s l) =
case memchrReverse ba w (s+l-1) l of
-1 -> (-1, Nothing)
r -> (r, Just w)
filter :: forall v a. Vec v a => (a -> Bool) -> v a -> v a
{-# INLINE filter #-}
filter f (Vec arr s l)
| l == 0 = empty
| otherwise = createN l (go f 0 s)
where
!end = s + l
go :: (a -> Bool) -> Int -> Int -> MArray v s a -> ST s Int
go f !i !p !marr
| p >= end = return i
| f x = writeArr marr i x >> go f (i+1) (p+1) marr
| otherwise = go f i (p+1) marr
where (# x #) = indexArr' arr p
partition :: forall v a. Vec v a => (a -> Bool) -> v a -> (v a, v a)
{-# INLINE partition #-}
partition f (Vec arr s l)
| l == 0 = (empty, empty)
| otherwise = createN2 l l (go f 0 0 s)
where
!end = s + l
go :: (a -> Bool) -> Int -> Int -> Int -> MArray v s a -> MArray v s a -> ST s (Int, Int)
go f !i !j !p !mba0 !mba1
| p >= end = return (i, j)
| f x = writeArr mba0 i x >> go f (i+1) j (p+1) mba0 mba1
| otherwise = writeArr mba1 j x >> go f i (j+1) (p+1) mba0 mba1
where (# x #) = indexArr' arr p
indicesOverlapping :: (Vec v a, Eq a)
=> v a
-> v a
-> Bool
-> [Int]
{-# INLINABLE[1] indicesOverlapping #-}
{-# RULES "indicesOverlapping/Bytes" indicesOverlapping = indicesOverlappingBytes #-}
indicesOverlapping needle@(Vec narr noff nlen) = search
where
next = kmpNextTable needle
search haystack@(Vec harr hoff hlen) reportPartial
| nlen <= 0 = [0..hlen-1]
| nlen == 1 = case indexArr' narr 0 of
(# x #) -> elemIndices x haystack
| otherwise = kmp 0 0
where
kmp !i !j | i >= hlen = if reportPartial && j /= 0 then [-j] else []
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j
in case next `indexArr` j' of
-1 -> i' : kmp (i+1) 0
j'' -> i' : kmp (i+1) j''
else kmp (i+1) j'
| otherwise = case next `indexArr` j of
-1 -> kmp (i+1) 0
j' -> kmp i j'
indicesOverlappingBytes :: Bytes
-> Bytes
-> Bool
-> [Int]
{-# INLINABLE indicesOverlappingBytes #-}
indicesOverlappingBytes needle@(Vec narr noff nlen) | popCount bloom > 48 = search
| otherwise = search'
where
next = kmpNextTable needle
bloom = sundayBloom needle
search haystack@(Vec harr hoff hlen) reportPartial
| nlen <= 0 = [0..hlen-1]
| nlen == 1 = case indexArr' narr 0 of
(# x #) -> elemIndices x haystack
| otherwise = kmp 0 0
where
kmp !i !j | i >= hlen = if reportPartial && j /= 0 then [-j] else []
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j
in case next `indexArr` j' of
-1 -> i' : kmp (i+1) 0
j'' -> i' : kmp (i+1) j''
else kmp (i+1) j'
| otherwise = case next `indexArr` j of
-1 -> kmp (i+1) 0
j' -> kmp i j'
search' haystack@(Vec harr hoff hlen) reportPartial
| nlen <= 0 = [0..hlen-1]
| nlen == 1 = elemIndices (indexArr narr 0) haystack
| otherwise = sunday 0 0
where
kmp !i !j | i >= hlen = if reportPartial && j /= 0 then [-j] else []
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j
in case next `indexArr` j' of
-1 -> i' : kmp (i+1) 0
j'' -> i' : kmp (i+1) j''
else kmp (i+1) j'
| otherwise = case next `indexArr` j of
-1 -> kmp (i+1) 0
j' -> kmp i j'
!hlen' = hlen - nlen
sunday !i !j | i >= hlen' = kmp i j
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j
in case next `indexArr` j' of
-1 -> i' : sunday (i+1) 0
j'' -> i' : sunday (i+1) j''
else sunday (i+1) j'
| otherwise = let !k = i+nlen-j
!afterNeedle = indexArr harr (k+hoff)
in if elemSundayBloom bloom afterNeedle
then case next `indexArr` j of
-1 -> sunday (i+1) 0
j' -> sunday i j'
else sunday (k+1) 0
indices :: (Vec v a, Eq a) => v a -> v a -> Bool -> [Int]
{-# INLINABLE[1] indices #-}
{-# RULES "indices/Bytes" indices = indicesBytes #-}
indices needle@(Vec narr noff nlen) = search
where
next = kmpNextTable needle
search haystack@(Vec harr hoff hlen) reportPartial
| nlen <= 0 = [0..hlen-1]
| nlen == 1 = case indexArr' narr 0 of
(# x #) -> elemIndices x haystack
| otherwise = kmp 0 0
where
kmp !i !j | i >= hlen = if reportPartial && j /= 0 then [-j] else []
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j in i' : kmp (i+1) 0
else kmp (i+1) j'
| otherwise = case next `indexArr` j of
-1 -> kmp (i+1) 0
j' -> kmp i j'
indicesBytes :: Bytes
-> Bytes
-> Bool
-> [Int]
{-# INLINABLE indicesBytes #-}
indicesBytes needle@(Vec narr noff nlen) | popCount bloom > 48 = search
| otherwise = search'
where
next = kmpNextTable needle
bloom = sundayBloom needle
search haystack@(Vec harr hoff hlen) reportPartial
| nlen <= 0 = [0..hlen-1]
| nlen == 1 = case indexArr' narr 0 of
(# x #) -> elemIndices x haystack
| otherwise = kmp 0 0
where
kmp !i !j | i >= hlen = if reportPartial && j /= 0 then [-j] else []
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j in i' : kmp (i+1) 0
else kmp (i+1) j'
| otherwise = case next `indexArr` j of
-1 -> kmp (i+1) 0
j' -> kmp i j'
search' haystack@(Vec harr hoff hlen) reportPartial
| nlen <= 0 = [0..hlen-1]
| nlen == 1 = elemIndices (indexArr narr 0) haystack
| otherwise = sunday 0 0
where
kmp !i !j | i >= hlen = if reportPartial && j /= 0 then [-j] else []
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j in i' : kmp (i+1) 0
else kmp (i+1) j'
| otherwise = case next `indexArr` j of
-1 -> kmp (i+1) 0
j' -> kmp i j'
!hlen' = hlen - nlen
sunday !i !j | i >= hlen' = kmp i j
| narr `indexArr` (j+noff) == harr `indexArr` (i+hoff) =
let !j' = j+1
in if j' >= nlen
then let !i' = i-j in i' : sunday (i+1) 0
else sunday (i+1) j'
| otherwise = let !k = i+nlen-j
!afterNeedle = indexArr harr (k+hoff)
in if elemSundayBloom bloom afterNeedle
then case next `indexArr` j of
-1 -> sunday (i+1) 0
j' -> sunday i j'
else sunday (k+1) 0
kmpNextTable :: (Vec v a, Eq a) => v a -> PrimArray Int
{-# INLINE kmpNextTable #-}
kmpNextTable (Vec arr s l) = runST (do
ma <- newArr (l+1)
writeArr ma 0 (-1)
let dec !w !j
| j < 0 || w == indexArr arr (s+j) = return $! j+1
| otherwise = readArr ma j >>= dec w
go !i !j
| i > l = unsafeFreezeArr ma
| otherwise = do
let !w = indexArr arr (s+i-1)
j' <- dec w j
if i < l && indexArr arr (s+j') == indexArr arr (s+i)
then readArr ma j' >>= writeArr ma i
else writeArr ma i j'
go (i+1) j'
go 1 (-1))
sundayBloom :: Bytes -> Word64
{-# INLINE sundayBloom #-}
sundayBloom (Vec arr s l) = go 0x00000000 s
where
!end = s+l
go !b !i
| i >= end = b
| otherwise =
let !w = indexArr arr i
!b' = b .|. (0x00000001 `unsafeShiftL` (fromIntegral w .&. 0x3f))
in go b' (i+1)
elemSundayBloom :: Word64 -> Word8 -> Bool
{-# INLINE elemSundayBloom #-}
elemSundayBloom b w = b .&. (0x01 `unsafeShiftL` (fromIntegral w .&. 0x3f)) /= 0