-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2019 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An efficient implementation of the Boyer-Moore string search algorithm.
-- http://www-igm.univ-mlv.fr/~lecroq/string/node14.html#SECTION00140
-- https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string-search_algorithm
--
-- This is case insensitive variant of the algorithm which, unlike the case
-- sensitive variant, has to be aware of the unicode code points that the bytes
-- represent.
--
module Data.Text.BoyerMooreCI.Automaton
    ( Automaton
    , CaseSensitivity (..)
    , CodeUnitIndex (..)
    , Next (..)
    , buildAutomaton
    , patternLength
    , patternText
    , runText

      -- Exposed for testing
    , minimumSkipForCodePoint
    ) where

import Control.DeepSeq (NFData)
import Control.Monad.ST (runST)
import Data.Hashable (Hashable (..))
import Data.Text.Internal (Text (..))
import GHC.Generics (Generic)

#if defined(HAS_AESON)
import qualified Data.Aeson as AE
#endif

import Data.Text.CaseSensitivity (CaseSensitivity (..))
import Data.Text.Utf8 (BackwardsIter (..), CodePoint, CodeUnitIndex (..))
import Data.TypedByteArray (Prim, TypedByteArray)

import qualified Data.Char as Char
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Text as Text
import qualified Data.Text.Utf8 as Utf8
import qualified Data.TypedByteArray as TBA

data Next a
  = Done !a
  | Step !a


-- | A Boyer-Moore automaton is based on lookup-tables that allow skipping through the haystack.
-- This allows for sub-linear matching in some cases, as we do not have to look at every input
-- character.
--
-- NOTE: Unlike the AcMachine, a Boyer-Moore automaton only returns non-overlapping matches.
-- This means that a Boyer-Moore automaton is not a 100% drop-in replacement for Aho-Corasick.
--
-- Returning overlapping matches would degrade the performance to /O(nm)/ in pathological cases like
-- finding @aaaa@ in @aaaaa....aaaaaa@ as for each match it would scan back the whole /m/ characters
-- of the pattern.
data Automaton = Automaton
  { Automaton -> TypedByteArray CodePoint
automatonPattern :: !(TypedByteArray CodePoint)
  , Automaton -> Int
automatonPatternHash :: !Int
  , Automaton -> SuffixTable
automatonSuffixTable :: !SuffixTable
  , Automaton -> BadCharLookup
automatonBadCharLookup :: !BadCharLookup
  , Automaton -> CodeUnitIndex
automatonMinPatternSkip :: !CodeUnitIndex
  }
  deriving stock (forall x. Rep Automaton x -> Automaton
forall x. Automaton -> Rep Automaton x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Automaton x -> Automaton
$cfrom :: forall x. Automaton -> Rep Automaton x
Generic, Int -> Automaton -> ShowS
[Automaton] -> ShowS
Automaton -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Automaton] -> ShowS
$cshowList :: [Automaton] -> ShowS
show :: Automaton -> String
$cshow :: Automaton -> String
showsPrec :: Int -> Automaton -> ShowS
$cshowsPrec :: Int -> Automaton -> ShowS
Show)
  deriving anyclass (Automaton -> ()
forall a. (a -> ()) -> NFData a
rnf :: Automaton -> ()
$crnf :: Automaton -> ()
NFData)

instance Hashable Automaton where
  hashWithSalt :: Int -> Automaton -> Int
hashWithSalt Int
salt = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Int
automatonPatternHash

instance Eq Automaton where
  Automaton
x == :: Automaton -> Automaton -> Bool
== Automaton
y = Automaton -> TypedByteArray CodePoint
automatonPattern Automaton
x forall a. Eq a => a -> a -> Bool
== Automaton -> TypedByteArray CodePoint
automatonPattern Automaton
y

#if defined(HAS_AESON)
instance AE.FromJSON Automaton where
  parseJSON :: Value -> Parser Automaton
parseJSON Value
v = Text -> Automaton
buildAutomaton forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. FromJSON a => Value -> Parser a
AE.parseJSON Value
v

instance AE.ToJSON Automaton where
  toJSON :: Automaton -> Value
toJSON = forall a. ToJSON a => a -> Value
AE.toJSON forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Text
patternText
#endif

buildAutomaton :: Text -> Automaton
buildAutomaton :: Text -> Automaton
buildAutomaton Text
pattern_ =
  Automaton
    { automatonPattern :: TypedByteArray CodePoint
automatonPattern = TypedByteArray CodePoint
patternVec
    , automatonPatternHash :: Int
automatonPatternHash = forall a. Hashable a => a -> Int
hash Text
pattern_
    , automatonSuffixTable :: SuffixTable
automatonSuffixTable = TypedByteArray CodePoint -> SuffixTable
buildSuffixTable TypedByteArray CodePoint
patternVec
    , automatonBadCharLookup :: BadCharLookup
automatonBadCharLookup = TypedByteArray CodePoint -> BadCharLookup
buildBadCharLookup TypedByteArray CodePoint
patternVec
    , automatonMinPatternSkip :: CodeUnitIndex
automatonMinPatternSkip = TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector TypedByteArray CodePoint
patternVec
    }
  where
    patternVec :: TypedByteArray CodePoint
patternVec = forall a. Prim a => [a] -> TypedByteArray a
TBA.fromList (Text -> String
Text.unpack Text
pattern_)

-- | Finds all matches in the text, calling the match callback with the first and last byte index of
-- each match of the pattern.
runText  :: forall a
  . a
  -> (a -> CodeUnitIndex -> CodeUnitIndex -> Next a)
  -> Automaton
  -> Text
  -> a
{-# INLINE runText #-}
runText :: forall a.
a
-> (a -> CodeUnitIndex -> CodeUnitIndex -> Next a)
-> Automaton
-> Text
-> a
runText a
seed a -> CodeUnitIndex -> CodeUnitIndex -> Next a
f Automaton
automaton !Text
text
  | forall a. TypedByteArray a -> Bool
TBA.null TypedByteArray CodePoint
pattern_ = a
seed
  | Bool
otherwise = a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
seed CodeUnitIndex
initialHaystackMin (CodeUnitIndex
initialHaystackMin forall a. Num a => a -> a -> a
+ CodeUnitIndex
minPatternSkip forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
  where
    Automaton TypedByteArray CodePoint
pattern_ Int
_ SuffixTable
suffixTable BadCharLookup
badCharTable CodeUnitIndex
minPatternSkip = Automaton
automaton

    -- In the pattern we always count codepoints,
    -- in the haystack we always count code units

    -- Highest index that we're allowed to use in the text
    haystackMax :: CodeUnitIndex
haystackMax = case Text
text of Text Array
_ Int
offset Int
len -> Int -> CodeUnitIndex
CodeUnitIndex (Int
offset forall a. Num a => a -> a -> a
+ Int
len forall a. Num a => a -> a -> a
- Int
1)

    -- How far we can look back in the text data is first limited by the text
    -- offset, and later by what we matched before.
    initialHaystackMin :: CodeUnitIndex
initialHaystackMin = case Text
text of Text Array
_ Int
offset Int
_ -> Int -> CodeUnitIndex
CodeUnitIndex Int
offset

    -- This is our _outer_ loop, called when the pattern is moved
    alignPattern
      :: a
      -> CodeUnitIndex  -- Don't read before this point in the haystack
      -> CodeUnitIndex  -- End of pattern is aligned at this point in the haystack
      -> a
    {-# INLINE alignPattern #-}
    alignPattern :: a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern !a
result !CodeUnitIndex
haystackMin !CodeUnitIndex
alignmentEnd
      | CodeUnitIndex
alignmentEnd forall a. Ord a => a -> a -> Bool
> CodeUnitIndex
haystackMax = a
result
      | Bool
otherwise =
          let
            !iter :: BackwardsIter
iter = Array -> CodeUnitIndex -> BackwardsIter
Utf8.unsafeIndexAnywhereInCodePoint' (case Text
text of Text Array
d Int
_ Int
_ -> Array
d) CodeUnitIndex
alignmentEnd
            !patternIndex :: Int
patternIndex = forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_ forall a. Num a => a -> a -> a
- Int
1
            -- End of char may be somewhere different than where we started looking
            !alignmentEnd' :: CodeUnitIndex
alignmentEnd' = BackwardsIter -> CodeUnitIndex
backwardsIterEndOfChar BackwardsIter
iter
          in
            a -> CodeUnitIndex -> CodeUnitIndex -> BackwardsIter -> Int -> a
matchLoop a
result CodeUnitIndex
haystackMin CodeUnitIndex
alignmentEnd' BackwardsIter
iter Int
patternIndex

    -- The _inner_ loop, called for every pattern character back to front within a pattern alignment.
    matchLoop
      :: a
      -> CodeUnitIndex  -- haystackMin, don't read before this point in the haystack
      -> CodeUnitIndex  -- (adjusted) alignmentEnd, end of pattern is aligned at this point in the haystack
      -> BackwardsIter
      -> Int            -- index in the pattern
      -> a
    matchLoop :: a -> CodeUnitIndex -> CodeUnitIndex -> BackwardsIter -> Int -> a
matchLoop !a
result !CodeUnitIndex
haystackMin !CodeUnitIndex
alignmentEnd !BackwardsIter
iter !Int
patternIndex =
      let
        !haystackCodePointLower :: CodePoint
haystackCodePointLower = CodePoint -> CodePoint
Utf8.lowerCodePoint (BackwardsIter -> CodePoint
backwardsIterChar BackwardsIter
iter)
      in
        case CodePoint
haystackCodePointLower forall a. Eq a => a -> a -> Bool
== forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ Int
patternIndex of

          Bool
True | Int
patternIndex forall a. Eq a => a -> a -> Bool
== Int
0 ->
            -- We found a complete match (all pattern characters matched)
            let !from :: CodeUnitIndex
from = BackwardsIter -> CodeUnitIndex
backwardsIterNext BackwardsIter
iter forall a. Num a => a -> a -> a
+ CodeUnitIndex
1 forall a. Num a => a -> a -> a
- CodeUnitIndex
initialHaystackMin
                !to :: CodeUnitIndex
to = CodeUnitIndex
alignmentEnd forall a. Num a => a -> a -> a
- CodeUnitIndex
initialHaystackMin
            in
              case a -> CodeUnitIndex -> CodeUnitIndex -> Next a
f a
result CodeUnitIndex
from CodeUnitIndex
to of
                Done a
final -> a
final
                Step a
intermediate ->
                  let haystackMin' :: CodeUnitIndex
haystackMin' = CodeUnitIndex
alignmentEnd forall a. Num a => a -> a -> a
+ CodeUnitIndex
1  -- Disallow overlapping matches
                      alignmentEnd' :: CodeUnitIndex
alignmentEnd' = CodeUnitIndex
alignmentEnd forall a. Num a => a -> a -> a
+ CodeUnitIndex
minPatternSkip
                  in a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
intermediate CodeUnitIndex
haystackMin' CodeUnitIndex
alignmentEnd'

          -- The pattern may be aligned in such a way that the start is before the start of the
          -- haystack. This _only_ happens when ⱥ and ⱦ characters occur (due to how minPatternSkip
          -- is calculated).
          Bool
True | BackwardsIter -> CodeUnitIndex
backwardsIterNext BackwardsIter
iter forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
haystackMin ->
            a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
result CodeUnitIndex
haystackMin (CodeUnitIndex
alignmentEnd forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)

          -- We continue by comparing the next character
          Bool
True ->
            let
              next :: CodeUnitIndex
next = BackwardsIter -> CodeUnitIndex
backwardsIterNext BackwardsIter
iter
              !iter' :: BackwardsIter
iter' = Array -> CodeUnitIndex -> BackwardsIter
Utf8.unsafeIndexEndOfCodePoint' (case Text
text of Text Array
d Int
_ Int
_ -> Array
d) CodeUnitIndex
next
            in
              a -> CodeUnitIndex -> CodeUnitIndex -> BackwardsIter -> Int -> a
matchLoop a
result CodeUnitIndex
haystackMin CodeUnitIndex
alignmentEnd BackwardsIter
iter' (Int
patternIndex forall a. Num a => a -> a -> a
- Int
1)

          -- Character did _not_ match at current position. Check how far the pattern has to move.
          Bool
False ->
            let
              -- The bad character table tells us how far we can advance to the right so that the
              -- character at the current position in the input string, where matching failed,
              -- is lined up with it's rightmost occurrence in the pattern.
              !fromBadChar :: CodeUnitIndex
fromBadChar =
                BackwardsIter -> CodeUnitIndex
backwardsIterEndOfChar BackwardsIter
iter forall a. Num a => a -> a -> a
+ BadCharLookup -> CodePoint -> CodeUnitIndex
badCharLookup BadCharLookup
badCharTable CodePoint
haystackCodePointLower

              -- This is always at least 1, ensuring that we make progress
              -- Suffixlookup tells us how far we can move the pattern
              !fromSuffixLookup :: CodeUnitIndex
fromSuffixLookup =
                CodeUnitIndex
alignmentEnd forall a. Num a => a -> a -> a
+ SuffixTable -> Int -> CodeUnitIndex
suffixLookup SuffixTable
suffixTable Int
patternIndex

              !alignmentEnd' :: CodeUnitIndex
alignmentEnd' = forall a. Ord a => a -> a -> a
max CodeUnitIndex
fromBadChar CodeUnitIndex
fromSuffixLookup

            in
              -- Minimum stays the same
              a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
result CodeUnitIndex
haystackMin CodeUnitIndex
alignmentEnd'

-- | Length of the matched pattern measured in UTF-8 code units (bytes).
patternLength :: Automaton -> CodeUnitIndex
patternLength :: Automaton -> CodeUnitIndex
patternLength = Text -> CodeUnitIndex
Utf8.lengthUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Text
patternText

-- | Return the pattern that was used to construct the automaton, O(n).
patternText :: Automaton -> Text
patternText :: Automaton -> Text
patternText = String -> Text
Text.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Prim a => TypedByteArray a -> [a]
TBA.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> TypedByteArray CodePoint
automatonPattern


-- | Number of bytes that we can skip in the haystack if we want to skip no more
-- than 1 pattern codepoint.
--
-- It must always be a low (safe) estimate, otherwise the algorithm can miss
-- matches. It must account for any variation of upper/lower case characters
-- that may occur in the haystack. In most cases, this is the same number of
-- bytes as for the given codepoint
--
--     minimumSkipForCodePoint 'a' == 1
--     minimumSkipForCodePoint 'д' == 2
--     minimumSkipForCodePoint 'ⓟ' == 3
--     minimumSkipForCodePoint '🎄' == 4
--
minimumSkipForCodePoint :: CodePoint -> CodeUnitIndex
minimumSkipForCodePoint :: CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
cp =
  case CodePoint -> Int
Char.ord CodePoint
cp of
    Int
c | Int
c forall a. Ord a => a -> a -> Bool
< Int
0x80     -> CodeUnitIndex
1
    Int
c | Int
c forall a. Ord a => a -> a -> Bool
< Int
0x800    -> CodeUnitIndex
2
    -- The letters ⱥ and ⱦ are 3 UTF8 bytes, but have unlowerings Ⱥ and Ⱦ of 2 bytes
    Int
0x2C65           -> CodeUnitIndex
2  -- ⱥ
    Int
0x2C66           -> CodeUnitIndex
2  -- ⱦ
    Int
c | Int
c forall a. Ord a => a -> a -> Bool
< Int
0x10000  -> CodeUnitIndex
3
    Int
_                -> CodeUnitIndex
4


-- | Number of bytes of the shortest case variation of the given needle. Needles
-- are assumed to be lower case.
--
--     minimumSkipForVector (TBA.fromList "ab..cd") == 6
--     minimumSkipForVector (TBA.fromList "aⱥ💩") == 7
--
minimumSkipForVector :: TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector :: TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector = forall a b. Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
TBA.foldr (\CodePoint
cp CodeUnitIndex
s -> CodeUnitIndex
s forall a. Num a => a -> a -> a
+ CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
cp) CodeUnitIndex
0


-- | The suffix table tells us for each codepoint (not byte!) of the pattern how many bytes (not
-- codepoints!) we can jump ahead if the match fails at that point.
newtype SuffixTable = SuffixTable (TypedByteArray CodeUnitIndex)
  deriving stock (forall x. Rep SuffixTable x -> SuffixTable
forall x. SuffixTable -> Rep SuffixTable x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep SuffixTable x -> SuffixTable
$cfrom :: forall x. SuffixTable -> Rep SuffixTable x
Generic)
  deriving anyclass (SuffixTable -> ()
forall a. (a -> ()) -> NFData a
rnf :: SuffixTable -> ()
$crnf :: SuffixTable -> ()
NFData)

instance Show SuffixTable where
  show :: SuffixTable -> String
show (SuffixTable TypedByteArray CodeUnitIndex
table) = String
"SuffixTable (TBA.toList " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall a. Prim a => TypedByteArray a -> [a]
TBA.toList TypedByteArray CodeUnitIndex
table) forall a. Semigroup a => a -> a -> a
<> String
")"

-- | Lookup an entry in the suffix table.
suffixLookup :: SuffixTable -> Int -> CodeUnitIndex
{-# INLINE suffixLookup #-}
suffixLookup :: SuffixTable -> Int -> CodeUnitIndex
suffixLookup (SuffixTable TypedByteArray CodeUnitIndex
table) = forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray CodeUnitIndex
table

buildSuffixTable :: TypedByteArray CodePoint -> SuffixTable
buildSuffixTable :: TypedByteArray CodePoint -> SuffixTable
buildSuffixTable TypedByteArray CodePoint
pattern_ = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  let
    patLen :: Int
patLen = forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_
    wholePatternSkip :: CodeUnitIndex
wholePatternSkip = TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector TypedByteArray CodePoint
pattern_

  MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
TBA.newTypedByteArray Int
patLen

  let
    -- Case 1: For each position of the pattern we record the shift that would align the pattern so
    -- that it starts at the longest suffix that is at the same time a prefix, if a mismatch would
    -- happen at that position.
    --
    -- Suppose the length of the pattern is n, a mismatch occurs at position i in the pattern and j
    -- in the haystack, then we know that pattern[i+1..n] == haystack[j+1..j+n-i]. That is, we know
    -- that the part of the haystack that we already matched is a suffix of the pattern.
    -- If the pattern happens to have a prefix that is equal to or a shorter suffix of that matched
    -- suffix, we can shift the pattern to the right so that the pattern starts at the longest
    -- suffix that we have seen that conincides with a prefix of the pattern.
    --
    -- Consider the pattern `ababa`. Then we get
    --
    -- p:              0  1  2  3  4
    -- Pattern:        a  b  a  b  a
    -- lastSkipBytes:              5   not touched by init1
    -- lastSkipBytes:           4  5   "a" == "a" so if we get a mismatch here we can skip
    --                                            everything but the length of "a"
    -- lastSkipBytes:        4  4  5   "ab" /= "ba" so keep skip value
    -- lastSkipBytes:     2  4  4  5   "aba" == "aba"
    -- lastSkipBytes:  2  2  4  4  5   "abab" /= "baba"
    init1 :: CodeUnitIndex -> Int -> ST s ()
init1 CodeUnitIndex
lastSkipBytes Int
p
      | Int
p forall a. Ord a => a -> a -> Bool
>= Int
0 = do
        let
          skipBytes :: CodeUnitIndex
skipBytes = case TypedByteArray CodePoint -> Int -> Maybe CodeUnitIndex
suffixIsPrefix TypedByteArray CodePoint
pattern_ (Int
p forall a. Num a => a -> a -> a
+ Int
1) of
                        Maybe CodeUnitIndex
Nothing -> CodeUnitIndex
lastSkipBytes
                        -- Skip the whole pattern _except_ the bytes for the suffix(==prefix)
                        Just CodeUnitIndex
nonSkippableBytes -> CodeUnitIndex
wholePatternSkip forall a. Num a => a -> a -> a
- CodeUnitIndex
nonSkippableBytes
        forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table Int
p CodeUnitIndex
skipBytes
        CodeUnitIndex -> Int -> ST s ()
init1 CodeUnitIndex
skipBytes (Int
p forall a. Num a => a -> a -> a
- Int
1)
      | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    -- Case 2: We also have to account for the fact that the matching suffix of the pattern might
    -- occur again somewhere within the pattern. In that case, we may not shift as far as if it was
    -- a prefix. That is why the `init2` loop is run after `init1`, potentially overwriting some
    -- entries with smaller shifts.
    init2 :: Int -> CodeUnitIndex -> ST s ()
init2 Int
p CodeUnitIndex
skipBytes
      | Int
p forall a. Ord a => a -> a -> Bool
< Int
patLen forall a. Num a => a -> a -> a
- Int
1 = do
          -- If we find a suffix that ends at p, we can skip everything _after_ p.
          let skipBytes' :: CodeUnitIndex
skipBytes' = CodeUnitIndex
skipBytes forall a. Num a => a -> a -> a
- CodePoint -> CodeUnitIndex
minimumSkipForCodePoint (forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ Int
p)
          case TypedByteArray CodePoint -> Int -> Maybe Int
substringIsSuffix TypedByteArray CodePoint
pattern_ Int
p of
            Maybe Int
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            Just Int
suffixLen -> do
              forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table (Int
patLen forall a. Num a => a -> a -> a
- Int
1 forall a. Num a => a -> a -> a
- Int
suffixLen) CodeUnitIndex
skipBytes'
          Int -> CodeUnitIndex -> ST s ()
init2 (Int
p forall a. Num a => a -> a -> a
+ Int
1) CodeUnitIndex
skipBytes'
      | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  CodeUnitIndex -> Int -> ST s ()
init1 (CodeUnitIndex
wholePatternSkipforall a. Num a => a -> a -> a
-CodeUnitIndex
1) (Int
patLen forall a. Num a => a -> a -> a
- Int
1)
  Int -> CodeUnitIndex -> ST s ()
init2 Int
0 CodeUnitIndex
wholePatternSkip
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table (Int
patLen forall a. Num a => a -> a -> a
- Int
1) (Int -> CodeUnitIndex
CodeUnitIndex Int
1)

  TypedByteArray CodeUnitIndex -> SuffixTable
SuffixTable forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table

-- | True if the suffix of the @pattern@ starting from @pos@ is a prefix of the pattern
-- For example, @suffixIsPrefix \"aabbaa\" 4 == Just 2@.
suffixIsPrefix :: TypedByteArray CodePoint -> Int -> Maybe CodeUnitIndex
suffixIsPrefix :: TypedByteArray CodePoint -> Int -> Maybe CodeUnitIndex
suffixIsPrefix TypedByteArray CodePoint
pattern_ Int
pos = Int -> CodeUnitIndex -> Maybe CodeUnitIndex
go Int
0 (Int -> CodeUnitIndex
CodeUnitIndex Int
0)
  where
    suffixLen :: Int
suffixLen = forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_ forall a. Num a => a -> a -> a
- Int
pos
    go :: Int -> CodeUnitIndex -> Maybe CodeUnitIndex
go !Int
i !CodeUnitIndex
skipBytes
      | Int
i forall a. Ord a => a -> a -> Bool
< Int
suffixLen =
          let prefixChar :: CodePoint
prefixChar = forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ Int
i in
          if CodePoint
prefixChar forall a. Eq a => a -> a -> Bool
== forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ (Int
pos forall a. Num a => a -> a -> a
+ Int
i)
            then Int -> CodeUnitIndex -> Maybe CodeUnitIndex
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) (CodeUnitIndex
skipBytes forall a. Num a => a -> a -> a
+ CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
prefixChar)
            else forall a. Maybe a
Nothing
      | Bool
otherwise = forall a. a -> Maybe a
Just CodeUnitIndex
skipBytes

-- | Length in bytes of the longest suffix of the pattern ending on @pos@.  For
-- example, @substringIsSuffix \"abaacbbaac\" 4 == Just 4@, because the
-- substring \"baac\" ends at position 4 and is at the same time the longest
-- suffix that does so, having a length of 4 characters.
--
-- For a string like "abaacaabcbaac", when we detect at pos=4 that baac==baac,
-- it means that if we get a mismatch before the "baac" suffix, we can skip the
-- "aabcbaac" characters _after_ the "baac" substring. So we can put
-- (minimumSkipForText "aabcbaac") at that point in the suffix table.
--
--   substringIsSuffix (Vector.fromList "ababa") 0 == Nothing  -- a == a, but not a proper substring
--   substringIsSuffix (Vector.fromList "ababa") 1 == Nothing  -- b /= a
--   substringIsSuffix (Vector.fromList "ababa") 2 == Nothing  -- aba == aba, but not a proper substring
--   substringIsSuffix (Vector.fromList "ababa") 3 == Nothing  -- b /= a
--   substringIsSuffix (Vector.fromList "ababa") 4 == Nothing  -- ababa == ababa, but not a proper substring
--   substringIsSuffix (Vector.fromList "baba") 0 == Nothing  -- b /= a
--   substringIsSuffix (Vector.fromList "baba") 1 == Nothing  -- ba == ba, but not a proper substring
--   substringIsSuffix (Vector.fromList "abaacaabcbaac") 4 == Just 4  -- baac == baac
--   substringIsSuffix (Vector.fromList "abaacaabcbaac") 8 == Just 1  -- c == c
--
substringIsSuffix :: TypedByteArray CodePoint -> Int -> Maybe Int
substringIsSuffix :: TypedByteArray CodePoint -> Int -> Maybe Int
substringIsSuffix TypedByteArray CodePoint
pattern_ Int
pos = Int -> Maybe Int
go Int
0
  where
    patLen :: Int
patLen = forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_
    go :: Int -> Maybe Int
go Int
i | Int
i forall a. Ord a => a -> a -> Bool
> Int
pos = forall a. Maybe a
Nothing  -- prefix==suffix, so already covered by suffixIsPrefix
         | forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ (Int
pos forall a. Num a => a -> a -> a
- Int
i) forall a. Eq a => a -> a -> Bool
== forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ (Int
patLen forall a. Num a => a -> a -> a
- Int
1 forall a. Num a => a -> a -> a
- Int
i) =
             Int -> Maybe Int
go (Int
i forall a. Num a => a -> a -> a
+ Int
1)
         | Int
i forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. Maybe a
Nothing  -- Nothing matched
         | Bool
otherwise = forall a. a -> Maybe a
Just Int
i


-- | The bad char table tells us how many bytes we may skip ahead when encountering a certain
-- character in the input string. For example, if there's a character that is not contained in the
-- pattern at all, we can skip ahead until after that character.
data BadCharLookup = BadCharLookup
  { BadCharLookup -> TypedByteArray CodeUnitIndex
badCharLookupTable :: {-# UNPACK #-} !(TypedByteArray CodeUnitIndex)
  , BadCharLookup -> HashMap CodePoint CodeUnitIndex
badCharLookupMap :: !(HashMap.HashMap CodePoint CodeUnitIndex)
  , BadCharLookup -> CodeUnitIndex
badCharLookupDefault :: !CodeUnitIndex
  }
  deriving stock (forall x. Rep BadCharLookup x -> BadCharLookup
forall x. BadCharLookup -> Rep BadCharLookup x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep BadCharLookup x -> BadCharLookup
$cfrom :: forall x. BadCharLookup -> Rep BadCharLookup x
Generic, Int -> BadCharLookup -> ShowS
[BadCharLookup] -> ShowS
BadCharLookup -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BadCharLookup] -> ShowS
$cshowList :: [BadCharLookup] -> ShowS
show :: BadCharLookup -> String
$cshow :: BadCharLookup -> String
showsPrec :: Int -> BadCharLookup -> ShowS
$cshowsPrec :: Int -> BadCharLookup -> ShowS
Show)
  deriving anyclass (BadCharLookup -> ()
forall a. (a -> ()) -> NFData a
rnf :: BadCharLookup -> ()
$crnf :: BadCharLookup -> ()
NFData)

-- | Number of entries in the fixed-size lookup-table of the bad char table.
badCharTableSize :: Int
{-# INLINE badCharTableSize #-}
badCharTableSize :: Int
badCharTableSize = Int
256

-- | Lookup an entry in the bad char table.
badCharLookup :: BadCharLookup -> CodePoint -> CodeUnitIndex
{-# INLINE badCharLookup #-}
badCharLookup :: BadCharLookup -> CodePoint -> CodeUnitIndex
badCharLookup (BadCharLookup TypedByteArray CodeUnitIndex
bclTable HashMap CodePoint CodeUnitIndex
bclMap CodeUnitIndex
bclDefault) CodePoint
char
  | Int
intChar forall a. Ord a => a -> a -> Bool
< Int
badCharTableSize = forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray CodeUnitIndex
bclTable Int
intChar
  | Bool
otherwise = forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
HashMap.lookupDefault CodeUnitIndex
bclDefault CodePoint
char HashMap CodePoint CodeUnitIndex
bclMap
  where
    intChar :: Int
intChar = forall a. Enum a => a -> Int
fromEnum CodePoint
char



buildBadCharLookup :: TypedByteArray CodePoint -> BadCharLookup
buildBadCharLookup :: TypedByteArray CodePoint -> BadCharLookup
buildBadCharLookup TypedByteArray CodePoint
pattern_ = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do

  let
    defaultSkip :: CodeUnitIndex
defaultSkip = TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector TypedByteArray CodePoint
pattern_

  -- Initialize table with the maximum skip distance, which is the length of the pattern.
  -- This applies to all characters that are not part of the pattern.
  MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table <- (forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> a -> m (MutableTypedByteArray a (PrimState m))
TBA.replicate Int
badCharTableSize CodeUnitIndex
defaultSkip)

  let
    -- Fill the bad character table based on the rightmost occurrence of a character in the pattern.
    -- Note that there is also a variant of Boyer-Moore that records all positions (see Wikipedia,
    -- but that requires even more storage space).
    -- Also note that we exclude the last character of the pattern when building the table.
    -- This is because
    --
    -- 1. If the last character does not occur anywhere else in the pattern and we encounter it
    --    during a mismatch, we can advance the pattern to just after that character:
    --
    --    Haystack: aaadcdabcdbb
    --    Pattern:    abcd
    --
    --    In the above example, we would match `d` and `c`, but then fail because `d` != `b`.
    --    Since `d` only occurs at the very last position of the pattern, we can shift to
    --
    --    Haystack: aaadcdabcdbb
    --    Pattern:      abcd
    --
    -- 2. If it does occur anywhere else in the pattern, we can only shift as far as it's necessary
    --    to align it with the haystack:
    --
    --    Haystack: aaadddabcdbb
    --    Pattern:    adcd
    --
    --    We match `d`, and then there is a mismatch `d` != `c`, which allows us to shift only up to:

    --    Haystack: aaadddabcdbb
    --    Pattern:     adcd



    fillTable :: HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable !HashMap CodePoint CodeUnitIndex
badCharMap !CodeUnitIndex
skipBytes = \case
      [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure HashMap CodePoint CodeUnitIndex
badCharMap
      [CodePoint
_] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure HashMap CodePoint CodeUnitIndex
badCharMap  -- The last pattern character doesn't count.
      (!CodePoint
patChar : !String
patChars) ->
        let skipBytes' :: CodeUnitIndex
skipBytes' = CodeUnitIndex
skipBytes forall a. Num a => a -> a -> a
- CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
patChar in
        if forall a. Enum a => a -> Int
fromEnum CodePoint
patChar forall a. Ord a => a -> a -> Bool
< Int
badCharTableSize
        then do
          forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table (forall a. Enum a => a -> Int
fromEnum CodePoint
patChar) CodeUnitIndex
skipBytes'
          HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable HashMap CodePoint CodeUnitIndex
badCharMap CodeUnitIndex
skipBytes' String
patChars
        else
          let badCharMap' :: HashMap CodePoint CodeUnitIndex
badCharMap' = forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HashMap.insert CodePoint
patChar CodeUnitIndex
skipBytes' HashMap CodePoint CodeUnitIndex
badCharMap
          in HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable HashMap CodePoint CodeUnitIndex
badCharMap' CodeUnitIndex
skipBytes' String
patChars

  HashMap CodePoint CodeUnitIndex
badCharMap <- HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable forall k v. HashMap k v
HashMap.empty CodeUnitIndex
defaultSkip (forall a. Prim a => TypedByteArray a -> [a]
TBA.toList TypedByteArray CodePoint
pattern_)

  TypedByteArray CodeUnitIndex
tableFrozen <- forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table

  forall (f :: * -> *) a. Applicative f => a -> f a
pure BadCharLookup
    { badCharLookupTable :: TypedByteArray CodeUnitIndex
badCharLookupTable = TypedByteArray CodeUnitIndex
tableFrozen
    , badCharLookupMap :: HashMap CodePoint CodeUnitIndex
badCharLookupMap = HashMap CodePoint CodeUnitIndex
badCharMap
    , badCharLookupDefault :: CodeUnitIndex
badCharLookupDefault = CodeUnitIndex
defaultSkip
    }


-- Helper functions for easily toggling the safety of this module

-- | Read from a lookup table at the specified index.
indexTable :: Prim a => TypedByteArray a -> Int -> a
{-# INLINE indexTable #-}
indexTable :: forall a. Prim a => TypedByteArray a -> Int -> a
indexTable = forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex