{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Text.BoyerMoore.Automaton
( Automaton
, CaseSensitivity (..)
, CodeUnitIndex (..)
, Next (..)
, buildAutomaton
, patternLength
, patternText
, runText
) where
import Prelude hiding (length)
import Control.DeepSeq (NFData)
import Control.Monad (when)
import Control.Monad.ST (runST)
import Data.Hashable (Hashable (..))
import GHC.Generics (Generic)
#if defined(HAS_AESON)
import qualified Data.Aeson as AE
#endif
import Data.Text.CaseSensitivity (CaseSensitivity (..))
import Data.Text.Utf8 (CodeUnit, CodeUnitIndex (..), Text)
import Data.TypedByteArray (Prim, TypedByteArray)
import qualified Data.Text.Utf8 as Utf8
import qualified Data.TypedByteArray as TBA
data Next a
= Done !a
| Step !a
data Automaton = Automaton
{ Automaton -> Text
automatonPattern :: !Text
, Automaton -> Int
automatonPatternHash :: !Int
, Automaton -> SuffixTable
automatonSuffixTable :: !SuffixTable
, Automaton -> BadCharTable
automatonBadCharTable :: !BadCharTable
}
deriving stock (forall x. Rep Automaton x -> Automaton
forall x. Automaton -> Rep Automaton x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Automaton x -> Automaton
$cfrom :: forall x. Automaton -> Rep Automaton x
Generic, Int -> Automaton -> ShowS
[Automaton] -> ShowS
Automaton -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Automaton] -> ShowS
$cshowList :: [Automaton] -> ShowS
show :: Automaton -> String
$cshow :: Automaton -> String
showsPrec :: Int -> Automaton -> ShowS
$cshowsPrec :: Int -> Automaton -> ShowS
Show)
deriving anyclass (Automaton -> ()
forall a. (a -> ()) -> NFData a
rnf :: Automaton -> ()
$crnf :: Automaton -> ()
NFData)
instance Hashable Automaton where
hashWithSalt :: Int -> Automaton -> Int
hashWithSalt Int
salt (Automaton Text
_ Int
patternHash SuffixTable
_ BadCharTable
_) = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt Int
patternHash
instance Eq Automaton where
(Automaton Text
pat1 Int
patHash1 SuffixTable
_ BadCharTable
_) == :: Automaton -> Automaton -> Bool
== (Automaton Text
pat2 Int
patHash2 SuffixTable
_ BadCharTable
_) =
Int
patHash1 forall a. Eq a => a -> a -> Bool
== Int
patHash2 Bool -> Bool -> Bool
&& Text
pat1 forall a. Eq a => a -> a -> Bool
== Text
pat2
#if defined(HAS_AESON)
instance AE.FromJSON Automaton where
parseJSON :: Value -> Parser Automaton
parseJSON Value
v = Text -> Automaton
buildAutomaton forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. FromJSON a => Value -> Parser a
AE.parseJSON Value
v
instance AE.ToJSON Automaton where
toJSON :: Automaton -> Value
toJSON = forall a. ToJSON a => a -> Value
AE.toJSON forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Text
automatonPattern
#endif
buildAutomaton :: Text -> Automaton
buildAutomaton :: Text -> Automaton
buildAutomaton Text
pattern =
Text -> Int -> SuffixTable -> BadCharTable -> Automaton
Automaton Text
pattern (forall a. Hashable a => a -> Int
hash Text
pattern) (Text -> SuffixTable
buildSuffixTable Text
pattern) (Text -> BadCharTable
buildBadCharTable Text
pattern)
runText :: forall a
. a
-> (a -> CodeUnitIndex -> Next a)
-> Automaton
-> Text
-> a
{-# INLINE runText #-}
runText :: forall a.
a -> (a -> CodeUnitIndex -> Next a) -> Automaton -> Text -> a
runText a
seed a -> CodeUnitIndex -> Next a
f Automaton
automaton Text
text
| CodeUnitIndex
patLen forall a. Eq a => a -> a -> Bool
== CodeUnitIndex
0 = a
seed
| Bool
otherwise = a -> CodeUnitIndex -> a
go a
seed (CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
where
Automaton Text
needle Int
_ SuffixTable
suffixTable BadCharTable
badCharTable = Automaton
automaton
patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
needle
stringLen :: CodeUnitIndex
stringLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
text
codeUnitAt :: CodeUnitIndex -> CodeUnit
codeUnitAt = Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
text
{-# INLINE go #-}
go :: a -> CodeUnitIndex -> a
go a
result CodeUnitIndex
haystackIndex
| CodeUnitIndex
haystackIndex forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
stringLen = a -> CodeUnitIndex -> CodeUnitIndex -> a
matchLoop a
result CodeUnitIndex
haystackIndex (CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
| Bool
otherwise = a
result
matchLoop :: a -> CodeUnitIndex -> CodeUnitIndex -> a
matchLoop a
result CodeUnitIndex
haystackIndex CodeUnitIndex
needleIndex
| CodeUnitIndex
needleIndex forall a. Ord a => a -> a -> Bool
>= CodeUnitIndex
0 Bool -> Bool -> Bool
&& CodeUnitIndex -> CodeUnit
codeUnitAt CodeUnitIndex
haystackIndex forall a. Eq a => a -> a -> Bool
== Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
needle CodeUnitIndex
needleIndex =
a -> CodeUnitIndex -> CodeUnitIndex -> a
matchLoop a
result (CodeUnitIndex
haystackIndex forall a. Num a => a -> a -> a
- CodeUnitIndex
1) (CodeUnitIndex
needleIndex forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
| CodeUnitIndex
needleIndex forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
0 =
case a -> CodeUnitIndex -> Next a
f a
result (CodeUnitIndex
haystackIndex forall a. Num a => a -> a -> a
+ CodeUnitIndex
1) of
Done a
final -> a
final
Step a
intermediate -> a -> CodeUnitIndex -> a
go a
intermediate (CodeUnitIndex
haystackIndex forall a. Num a => a -> a -> a
+ CodeUnitIndex
2 forall a. Num a => a -> a -> a
* CodeUnitIndex
patLen)
| Bool
otherwise =
let
badCharSkip :: CodeUnitIndex
badCharSkip = BadCharTable -> CodeUnit -> CodeUnitIndex
badCharLookup BadCharTable
badCharTable (CodeUnitIndex -> CodeUnit
codeUnitAt CodeUnitIndex
haystackIndex)
suffixSkip :: CodeUnitIndex
suffixSkip = SuffixTable -> CodeUnitIndex -> CodeUnitIndex
suffixLookup SuffixTable
suffixTable CodeUnitIndex
needleIndex
skip :: CodeUnitIndex
skip = forall a. Ord a => a -> a -> a
max CodeUnitIndex
badCharSkip CodeUnitIndex
suffixSkip
in
a -> CodeUnitIndex -> a
go a
result (CodeUnitIndex
haystackIndex forall a. Num a => a -> a -> a
+ CodeUnitIndex
skip)
patternLength :: Automaton -> CodeUnitIndex
patternLength :: Automaton -> CodeUnitIndex
patternLength = Text -> CodeUnitIndex
Utf8.lengthUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Text
patternText
patternText :: Automaton -> Text
patternText :: Automaton -> Text
patternText = Automaton -> Text
automatonPattern
newtype SuffixTable = SuffixTable (TypedByteArray Int)
deriving stock (forall x. Rep SuffixTable x -> SuffixTable
forall x. SuffixTable -> Rep SuffixTable x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep SuffixTable x -> SuffixTable
$cfrom :: forall x. SuffixTable -> Rep SuffixTable x
Generic, Int -> SuffixTable -> ShowS
[SuffixTable] -> ShowS
SuffixTable -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SuffixTable] -> ShowS
$cshowList :: [SuffixTable] -> ShowS
show :: SuffixTable -> String
$cshow :: SuffixTable -> String
showsPrec :: Int -> SuffixTable -> ShowS
$cshowsPrec :: Int -> SuffixTable -> ShowS
Show)
deriving anyclass (SuffixTable -> ()
forall a. (a -> ()) -> NFData a
rnf :: SuffixTable -> ()
$crnf :: SuffixTable -> ()
NFData)
suffixLookup :: SuffixTable -> CodeUnitIndex -> CodeUnitIndex
{-# INLINE suffixLookup #-}
suffixLookup :: SuffixTable -> CodeUnitIndex -> CodeUnitIndex
suffixLookup (SuffixTable TypedByteArray Int
table) = Int -> CodeUnitIndex
CodeUnitIndex forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray Int
table forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodeUnitIndex -> Int
codeUnitIndex
buildSuffixTable :: Text -> SuffixTable
buildSuffixTable :: Text -> SuffixTable
buildSuffixTable Text
pattern = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
let patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
pattern
MutableTypedByteArray Int (PrimState (ST s))
table <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
TBA.newTypedByteArray forall a b. (a -> b) -> a -> b
$ CodeUnitIndex -> Int
codeUnitIndex CodeUnitIndex
patLen
let
init1 :: CodeUnitIndex -> CodeUnitIndex -> ST s ()
init1 CodeUnitIndex
lastPrefixIndex CodeUnitIndex
p
| CodeUnitIndex
p forall a. Ord a => a -> a -> Bool
>= CodeUnitIndex
0 = do
let
prefixIndex :: CodeUnitIndex
prefixIndex
| Text -> CodeUnitIndex -> Bool
isPrefix Text
pattern (CodeUnitIndex
p forall a. Num a => a -> a -> a
+ CodeUnitIndex
1) = CodeUnitIndex
p forall a. Num a => a -> a -> a
+ CodeUnitIndex
1
| Bool
otherwise = CodeUnitIndex
lastPrefixIndex
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray Int (PrimState (ST s))
table (CodeUnitIndex -> Int
codeUnitIndex CodeUnitIndex
p) (CodeUnitIndex -> Int
codeUnitIndex forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
prefixIndex forall a. Num a => a -> a -> a
+ CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 forall a. Num a => a -> a -> a
- CodeUnitIndex
p)
CodeUnitIndex -> CodeUnitIndex -> ST s ()
init1 CodeUnitIndex
prefixIndex (CodeUnitIndex
p forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
| Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
init2 :: CodeUnitIndex -> ST s ()
init2 CodeUnitIndex
p
| CodeUnitIndex
p forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 = do
let
suffixLen :: CodeUnitIndex
suffixLen = Text -> CodeUnitIndex -> CodeUnitIndex
suffixLength Text
pattern CodeUnitIndex
p
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
p forall a. Num a => a -> a -> a
- CodeUnitIndex
suffixLen) forall a. Eq a => a -> a -> Bool
/= Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 forall a. Num a => a -> a -> a
- CodeUnitIndex
suffixLen)) forall a b. (a -> b) -> a -> b
$
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray Int (PrimState (ST s))
table (CodeUnitIndex -> Int
codeUnitIndex forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 forall a. Num a => a -> a -> a
- CodeUnitIndex
suffixLen) (CodeUnitIndex -> Int
codeUnitIndex forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 forall a. Num a => a -> a -> a
- CodeUnitIndex
p forall a. Num a => a -> a -> a
+ CodeUnitIndex
suffixLen)
CodeUnitIndex -> ST s ()
init2 (CodeUnitIndex
p forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
| Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
CodeUnitIndex -> CodeUnitIndex -> ST s ()
init1 (CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1) (CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
CodeUnitIndex -> ST s ()
init2 CodeUnitIndex
0
TypedByteArray Int -> SuffixTable
SuffixTable forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray Int (PrimState (ST s))
table
data BadCharTable = BadCharTable
{ BadCharTable -> TypedByteArray Int
badCharTableEntries :: {-# UNPACK #-} !(TypedByteArray Int)
, BadCharTable -> CodeUnitIndex
badCharTablePatternLen :: CodeUnitIndex
}
deriving stock (forall x. Rep BadCharTable x -> BadCharTable
forall x. BadCharTable -> Rep BadCharTable x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep BadCharTable x -> BadCharTable
$cfrom :: forall x. BadCharTable -> Rep BadCharTable x
Generic, Int -> BadCharTable -> ShowS
[BadCharTable] -> ShowS
BadCharTable -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BadCharTable] -> ShowS
$cshowList :: [BadCharTable] -> ShowS
show :: BadCharTable -> String
$cshow :: BadCharTable -> String
showsPrec :: Int -> BadCharTable -> ShowS
$cshowsPrec :: Int -> BadCharTable -> ShowS
Show)
deriving anyclass (BadCharTable -> ()
forall a. (a -> ()) -> NFData a
rnf :: BadCharTable -> ()
$crnf :: BadCharTable -> ()
NFData)
badcharTableSize :: Int
{-# INLINE badcharTableSize #-}
badcharTableSize :: Int
badcharTableSize = Int
256
badCharLookup :: BadCharTable -> CodeUnit -> CodeUnitIndex
{-# INLINE badCharLookup #-}
badCharLookup :: BadCharTable -> CodeUnit -> CodeUnitIndex
badCharLookup (BadCharTable TypedByteArray Int
asciiTable CodeUnitIndex
_patLen) CodeUnit
char = Int -> CodeUnitIndex
CodeUnitIndex forall a b. (a -> b) -> a -> b
$ forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray Int
asciiTable Int
intChar
where
intChar :: Int
intChar = forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
char
isPrefix :: Text -> CodeUnitIndex -> Bool
isPrefix :: Text -> CodeUnitIndex -> Bool
isPrefix Text
needle CodeUnitIndex
pos = CodeUnitIndex -> Bool
go CodeUnitIndex
0
where
suffixLen :: CodeUnitIndex
suffixLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
needle forall a. Num a => a -> a -> a
- CodeUnitIndex
pos
go :: CodeUnitIndex -> Bool
go CodeUnitIndex
i
| CodeUnitIndex
i forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
suffixLen =
if Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
needle CodeUnitIndex
i forall a. Eq a => a -> a -> Bool
== Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
needle (CodeUnitIndex
pos forall a. Num a => a -> a -> a
+ CodeUnitIndex
i)
then CodeUnitIndex -> Bool
go (CodeUnitIndex
i forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
else Bool
False
| Bool
otherwise = Bool
True
suffixLength :: Text -> CodeUnitIndex -> CodeUnitIndex
suffixLength :: Text -> CodeUnitIndex -> CodeUnitIndex
suffixLength Text
pattern CodeUnitIndex
pos = CodeUnitIndex -> CodeUnitIndex
go CodeUnitIndex
0
where
patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
pattern
go :: CodeUnitIndex -> CodeUnitIndex
go CodeUnitIndex
i
| Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
pos forall a. Num a => a -> a -> a
- CodeUnitIndex
i) forall a. Eq a => a -> a -> Bool
== Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 forall a. Num a => a -> a -> a
- CodeUnitIndex
i) Bool -> Bool -> Bool
&& CodeUnitIndex
i forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
pos = CodeUnitIndex -> CodeUnitIndex
go (CodeUnitIndex
i forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
| Bool
otherwise = CodeUnitIndex
i
buildBadCharTable :: Text -> BadCharTable
buildBadCharTable :: Text -> BadCharTable
buildBadCharTable Text
pattern = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
let patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
pattern
MutableTypedByteArray Int (PrimState (ST s))
asciiTable <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> a -> m (MutableTypedByteArray a (PrimState m))
TBA.replicate Int
badcharTableSize forall a b. (a -> b) -> a -> b
$ CodeUnitIndex -> Int
codeUnitIndex CodeUnitIndex
patLen
let
fillTable :: CodeUnitIndex -> ST s ()
fillTable !CodeUnitIndex
i
| CodeUnitIndex
i forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 = do
let patChar :: CodeUnit
patChar = Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern CodeUnitIndex
i
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray Int (PrimState (ST s))
asciiTable (forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
patChar) (CodeUnitIndex -> Int
codeUnitIndex forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
patLen forall a. Num a => a -> a -> a
- CodeUnitIndex
1 forall a. Num a => a -> a -> a
- CodeUnitIndex
i)
CodeUnitIndex -> ST s ()
fillTable (CodeUnitIndex
i forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
| Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
CodeUnitIndex -> ST s ()
fillTable CodeUnitIndex
0
TypedByteArray Int
asciiTableFrozen <- forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray Int (PrimState (ST s))
asciiTable
forall (f :: * -> *) a. Applicative f => a -> f a
pure BadCharTable
{ badCharTableEntries :: TypedByteArray Int
badCharTableEntries = TypedByteArray Int
asciiTableFrozen
, badCharTablePatternLen :: CodeUnitIndex
badCharTablePatternLen = CodeUnitIndex
patLen
}
indexTable :: Prim a => TypedByteArray a -> Int -> a
{-# INLINE indexTable #-}
indexTable :: forall a. Prim a => TypedByteArray a -> Int -> a
indexTable = forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex