-- Alfred-Margaret: Fast Aho-Corasick string searching -- Copyright 2019 Channable -- -- Licensed under the 3-clause BSD license, see the LICENSE file in the -- repository root. {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE ScopedTypeVariables #-} -- | An efficient implementation of the Aho-Corasick string matching algorithm. -- See http://web.stanford.edu/class/archive/cs/cs166/cs166.1166/lectures/02/Small02.pdf -- for a good explanation of the algorithm. -- -- The memory layout of the automaton, and the function that steps it, were -- optimized to the point where string matching compiles roughly to a loop over -- the code units in the input text, that keeps track of the current state. -- Lookup of the next state is either just an array index (for the root state), -- or a linear scan through a small array (for non-root states). The pointer -- chases that are common for traversing Haskell data structures have been -- eliminated. -- -- The construction of the automaton has not been optimized that much, because -- construction time is usually negligible in comparison to matching time. -- Therefore construction is a two-step process, where first we build the -- automaton as int maps, which are convenient for incremental construction. -- Afterwards we pack the automaton into unboxed vectors. module Data.Text.AhoCorasick.Automaton ( AcMachine (..) , build , runText , runLower , debugBuildDot , CaseSensitivity (..) , CodeUnitIndex (..) , Match (..) , Next (..) ) where import Prelude hiding (length) import Control.DeepSeq (NFData) import Data.Bits ((.&.), (.|.), shiftL, shiftR) import Data.Foldable (foldl') import Data.Hashable (Hashable) import Data.IntMap.Strict (IntMap) import Data.Text.Internal (Text (..)) import Data.Word (Word64) import GHC.Generics (Generic) #if defined(HAS_AESON) import Data.Aeson (FromJSON, ToJSON) #endif import qualified Data.IntMap.Strict as IntMap import qualified Data.List as List import qualified Data.Vector as Vector import qualified Data.Vector.Unboxed as UVector import Data.Text.Utf16 (CodeUnit, CodeUnitIndex (..), indexTextArray, lowerCodeUnit) data CaseSensitivity = CaseSensitive | IgnoreCase deriving stock (Eq, Generic, Show) #if defined(HAS_AESON) deriving anyclass (Hashable, NFData, FromJSON, ToJSON) #else deriving anyclass (Hashable, NFData) #endif -- | A numbered state in the Aho-Corasick automaton. type State = Int -- | A transition is a pair of (code unit, next state). The code unit is 16 bits, -- and the state index is 32 bits. We pack these together as a manually unlifted -- tuple, because an unboxed Vector of tuples is a tuple of vectors, but we want -- the elements of the tuple to be adjacent in memory. (The Word64 still needs -- to be unpacked in the places where it is used.) The code unit is stored in -- the least significant 32 bits, with the special value 2^16 indicating a -- wildcard; the "failure" transition. Bit 17 through 31 (starting from zero, -- both bounds inclusive) are always 0. -- -- Bit 63 (most significant) Bit 0 (least significant) -- | | -- v v -- |<-- goto state -->|<-- zeros -->| |<-- input -->| -- |SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS|000000000000000|W|IIIIIIIIIIIIIIII| -- | -- Wildcard bit (bit 16) -- type Transition = Word64 data Match v = Match { matchPos :: {-# UNPACK #-} !CodeUnitIndex -- ^ The code unit index past the last code unit of the match. Note that this -- is not a code *point* (Haskell `Char`) index; a code point might be encoded -- as two code units. , matchValue :: v -- ^ The payload associated with the matched needle. } deriving (Show, Eq) -- | An Aho-Corasick automaton. data AcMachine v = AcMachine { machineValues :: !(Vector.Vector [v]) -- ^ For every state, the values associated with its needles. If the state is -- not a match state, the list is empty. , machineTransitions :: !(UVector.Vector Transition) -- ^ A packed vector of transitions. For every state, there is a slice of this -- vector that starts at the offset given by `machineOffsets`, and ends at the -- first wildcard transition. , machineOffsets :: !(UVector.Vector Int) -- ^ For every state, the index into `machineTransitions` where the transition -- list for that state starts. , machineRootAsciiTransitions :: !(UVector.Vector Transition) -- ^ A lookup table for transitions from the root state, an optimization to -- avoid having to walk all transitions, at the cost of using a bit of -- additional memory. } deriving (Generic) instance NFData v => NFData (AcMachine v) -- | The wildcard value is 2^16, one more than the maximal 16-bit code unit. wildcard :: Integral a => a wildcard = 0x10000 -- | Extract the code unit from a transition. The special wildcard transition -- will return 0. transitionCodeUnit :: Transition -> CodeUnit transitionCodeUnit t = fromIntegral (t .&. 0xffff) -- | Extract the goto state from a transition. transitionState :: Transition -> State transitionState t = fromIntegral (t `shiftR` 32) -- | Test if the transition is not for a specific code unit, but the wildcard -- transition to take if nothing else matches. transitionIsWildcard :: Transition -> Bool transitionIsWildcard t = (t .&. wildcard) == wildcard newTransition :: CodeUnit -> State -> Transition newTransition input state = let input64 = fromIntegral input :: Word64 state64 = fromIntegral state :: Word64 in (state64 `shiftL` 32) .|. input64 newWildcardTransition :: State -> Transition newWildcardTransition state = let state64 = fromIntegral state :: Word64 in (state64 `shiftL` 32) .|. wildcard -- | Pack transitions for each state into one contiguous array. In order to find -- the transitions for a specific state, we also produce a vector of start -- indices. All transition lists are terminated by a wildcard transition, so -- there is no need to record the length. packTransitions :: [[Transition]] -> (UVector.Vector Transition, UVector.Vector Int) packTransitions transitions = let packed = UVector.fromList $ concat transitions offsets = UVector.fromList $ scanl (+) 0 $ fmap List.length transitions in (packed, offsets) -- | Construct an Aho-Corasick automaton for the given needles. -- Takes a list of code units rather than `Text`, to allow mapping the code -- units before construction, for example to lowercase individual code points, -- rather than doing proper case folding (which might change the number of code -- units). build :: [([CodeUnit], v)] -> AcMachine v build needlesWithValues = let -- Construct the Aho-Corasick automaton using IntMaps, which are a suitable -- representation when building the automaton. We use int maps rather than -- hash maps to ensure that the iteration order is the same as that of a -- vector. (numStates, transitionMap, initialValueMap) = buildTransitionMap needlesWithValues fallbackMap = buildFallbackMap transitionMap valueMap = buildValueMap transitionMap fallbackMap initialValueMap -- Convert the map of transitions, and the map of fallback states, into a -- list of transition lists, where every transition list is terminated by -- a wildcard transition to the fallback state. prependTransition ts input state = newTransition (fromIntegral input) state : ts makeTransitions fallback ts = IntMap.foldlWithKey' prependTransition [newWildcardTransition fallback] ts transitionsList = zipWith makeTransitions (IntMap.elems fallbackMap) (IntMap.elems transitionMap) -- Pack the transition lists into one contiguous array, and build the lookup -- table for the transitions from the root state. (transitions, offsets) = packTransitions transitionsList rootTransitions = buildAsciiTransitionLookupTable $ transitionMap IntMap.! 0 values = Vector.generate numStates (valueMap IntMap.!) in AcMachine values transitions offsets rootTransitions -- | Build the automaton, and format it as Graphviz Dot, for visual debugging. debugBuildDot :: [[CodeUnit]] -> String debugBuildDot needles = let (_numStates, transitionMap, initialValueMap) = buildTransitionMap $ zip needles ([0..] :: [Int]) fallbackMap = buildFallbackMap transitionMap valueMap = buildValueMap transitionMap fallbackMap initialValueMap dotEdge extra state nextState = " " ++ (show state) ++ " -> " ++ (show nextState) ++ " [" ++ extra ++ "];" dotFallbackEdge :: [String] -> State -> State -> [String] dotFallbackEdge edges state nextState = (dotEdge "style = dashed" state nextState) : edges dotTransitionEdge :: State -> [String] -> Int -> State -> [String] dotTransitionEdge state edges input nextState = (dotEdge ("label = \"" ++ show input ++ "\"") state nextState) : edges prependTransitionEdges edges state = IntMap.foldlWithKey' (dotTransitionEdge state) edges (transitionMap IntMap.! state) dotMatchState :: [String] -> State -> [Int] -> [String] dotMatchState edges _ [] = edges dotMatchState edges state _ = (" " ++ show state ++ " [shape = doublecircle];") : edges dot0 = foldBreadthFirst prependTransitionEdges [] transitionMap dot1 = IntMap.foldlWithKey' dotFallbackEdge dot0 fallbackMap dot2 = IntMap.foldlWithKey' dotMatchState dot1 valueMap in -- Set rankdir = "LR" to prefer a left-to-right graph, rather than top to -- bottom. I have dual widescreen monitors and I don't use them in portrait -- mode. Reverse the instructions because order affects node lay-out, and by -- prepending we built up a reversed list. unlines $ ["digraph {", " rankdir = \"LR\";"] ++ (reverse dot2) ++ ["}"] -- Different int maps that are used during constuction of the automaton. The -- transition map represents the trie of states, the fallback map contains the -- fallback (or "failure" or "suffix") edge for every state. type TransitionMap = IntMap (IntMap State) type FallbackMap = IntMap State type ValuesMap v = IntMap [v] -- | Build the trie of the Aho-Corasick state machine for all input needles. buildTransitionMap :: forall v. [([CodeUnit], v)] -> (Int, TransitionMap, ValuesMap v) buildTransitionMap = let go :: State -> (Int, TransitionMap, ValuesMap v) -> ([CodeUnit], v) -> (Int, TransitionMap, ValuesMap v) -- End of the current needle, insert the associated payload value. -- If a needle occurs multiple times, then at this point we will merge -- their payload values, so the needle is reported twice, possibly with -- different payload values. go !state (!numStates, transitions, values) ([], v) = (numStates, transitions, IntMap.insertWith (++) state [v] values) -- Follow the edge for the given input from the current state, creating it -- if it does not exist. go !state (!numStates, transitions, values) (!input : needleTail, vs) = let transitionsFromState = transitions IntMap.! state in case IntMap.lookup (fromIntegral input) transitionsFromState of Just nextState -> go nextState (numStates, transitions, values) (needleTail, vs) Nothing -> let -- Allocate a new state, and insert a transition to it. -- Also insert an empty transition map for it. nextState = numStates transitionsFromState' = IntMap.insert (fromIntegral input) nextState transitionsFromState transitions' = IntMap.insert state transitionsFromState' $ IntMap.insert nextState IntMap.empty $ transitions in go nextState (numStates + 1, transitions', values) (needleTail, vs) -- Initially, the root state (state 0) exists, and it has no transitions -- to anywhere. stateInitial = 0 initialTransitions = IntMap.singleton stateInitial IntMap.empty initialValues = IntMap.empty insertNeedle = go stateInitial in foldl' insertNeedle (1, initialTransitions, initialValues) -- Size of the ascii transition lookup table. asciiCount :: Integral a => a asciiCount = 128 -- | Build a lookup table for the first 128 code units, that can be used for -- O(1) lookup of a transition, rather than doing a linear scan over all -- transitions. The fallback goes back to the initial state, state 0. buildAsciiTransitionLookupTable :: IntMap State -> UVector.Vector Transition buildAsciiTransitionLookupTable transitions = UVector.generate asciiCount $ \i -> case IntMap.lookup i transitions of Just state -> newTransition (fromIntegral i) state Nothing -> newWildcardTransition 0 -- | Traverse the state trie in breadth-first order. foldBreadthFirst :: (a -> State -> a) -> a -> TransitionMap -> a foldBreadthFirst f seed transitions = go [0] [] seed where -- For the traversal, we keep a queue of states to vitit. Every iteration we -- take one off the front, and all states reachable from there get added to -- the back. Rather than using a list for this, we use the functional -- amortized queue to avoid O(n²) append. This makes a measurable difference -- when the backlog can grow large. In one of our benchmark inputs for -- example, we have roughly 160 needles that are 10 characters each (but -- with some shared prefixes), and the backlog size grows to 148 during -- construction. Construction time goes down from ~0.80 ms to ~0.35 ms by -- using the amortized queue. -- See also section 3.1.1 of Purely Functional Data Structures by Okasaki -- https://www.cs.cmu.edu/~rwh/theses/okasaki.pdf. go [] [] !acc = acc go [] revBacklog !acc = go (reverse revBacklog) [] acc go (state : backlog) revBacklog !acc = let -- Note that the backlog never contains duplicates, because we traverse -- a trie that only branches out. For every state, there is only one -- path from the root that leads to it. extra = IntMap.elems $ transitions IntMap.! state in go backlog (extra ++ revBacklog) (f acc state) -- | Determine the fallback transition for every state, by traversing the -- transition trie breadth-first. buildFallbackMap :: TransitionMap -> FallbackMap buildFallbackMap transitions = let -- Suppose that in state `state`, there is a transition for input `input` -- to state `nextState`, and we already know the fallback for `state`. Then -- this function returns the fallback state for `nextState`. getFallback :: FallbackMap -> State -> Int -> State -- All the states after the root state (state 0) fall back to the root state. getFallback _ 0 _ = 0 getFallback fallbacks !state !input = let fallback = fallbacks IntMap.! state transitionsFromFallback = transitions IntMap.! fallback in case IntMap.lookup input transitionsFromFallback of Just st -> st Nothing -> getFallback fallbacks fallback input insertFallback :: State -> FallbackMap -> Int -> State -> FallbackMap insertFallback !state fallbacks !input !nextState = IntMap.insert nextState (getFallback fallbacks state input) fallbacks insertFallbacks :: FallbackMap -> State -> FallbackMap insertFallbacks fallbacks !state = IntMap.foldlWithKey' (insertFallback state) fallbacks (transitions IntMap.! state) in foldBreadthFirst insertFallbacks (IntMap.singleton 0 0) transitions -- | Determine which matches to report at every state, by traversing the -- transition trie breadth-first, and appending all the matches from a fallback -- state to the matches for the current state. buildValueMap :: forall v. TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v buildValueMap transitions fallbacks valuesInitial = let insertValues :: ValuesMap v -> State -> ValuesMap v insertValues values !state = let fallbackValues = values IntMap.! (fallbacks IntMap.! state) valuesForState = case IntMap.lookup state valuesInitial of Just vs -> vs ++ fallbackValues Nothing -> fallbackValues in IntMap.insert state valuesForState values in foldBreadthFirst insertValues (IntMap.singleton 0 []) transitions -- Define aliases for array indexing so we can turn bounds checks on and off -- in one place. We ran this code with `Vector.!` (bounds-checked indexing) in -- production for two months without failing the bounds check, so we have turned -- the check off for performance now. at :: forall a. Vector.Vector a -> Int -> a at = Vector.unsafeIndex uAt :: forall a. UVector.Unbox a => UVector.Vector a -> Int -> a uAt = UVector.unsafeIndex -- | Result of handling a match: stepping the automaton can exit early by -- returning a `Done`, or it can continue with a new accumulator with `Step`. data Next a = Done !a | Step !a -- | Run the automaton, possibly lowercasing the input text on the fly if case -- insensitivity is desired. See also `lowerCodeUnit` and `runLower`. -- WARNING: Run benchmarks when modifying this function; its performance is -- fragile. It took many days to discover the current formulation which compiles -- to fast code; removing the wrong bang pattern could cause a 10% performance -- regression. {-# INLINE runWithCase #-} runWithCase :: forall a v . CaseSensitivity -> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a runWithCase caseSensitivity seed f machine text = let Text u16data !initialOffset !initialRemaining = text !values = machineValues machine !transitions = machineTransitions machine !offsets = machineOffsets machine !rootAsciiTransitions = machineRootAsciiTransitions machine !stateInitial = 0 -- NOTE: All of the arguments are strict here, because we want to compile -- them down to unpacked variables on the stack, or even registers. -- The INLINE / NOINLINE annotations here were added to fix a regression we -- observed when going from GHC 8.2 to GHC 8.6, and this particular -- combination of INLINE and NOINLINE is the fastest one. Removing increases -- the benchmark running time by about 9%. {-# NOINLINE consumeInput #-} consumeInput :: Int -> Int -> a -> State -> a consumeInput !offset !remaining !acc !state = let inputCodeUnit = fromIntegral $ indexTextArray u16data offset -- NOTE: Although doing this match here entangles the automaton a bit -- with case sensitivity, doing so is faster than passing in a function -- that transforms each code unit. casedCodeUnit = case caseSensitivity of IgnoreCase -> lowerCodeUnit inputCodeUnit CaseSensitive -> inputCodeUnit in case remaining of 0 -> acc _ -> followEdge (offset + 1) (remaining - 1) acc state casedCodeUnit {-# INLINE followEdge #-} followEdge :: Int -> Int -> a -> State -> CodeUnit -> a followEdge !offset !remaining !acc !state !input = let !tssOffset = offsets `uAt` state in -- When we follow an edge, we look in the transition table and do a -- linear scan over all transitions until we find the right one, or -- until we hit the wildcard transition at the end. For 0 or 1 or 2 -- transitions that is fine, but the initial state often has more -- transitions, so we have a dedicated lookup table for it, that takes -- up a bit more space, but provides O(1) lookup of the next state. We -- only do this for the first 128 code units (all of ascii). if state == stateInitial && input < asciiCount then lookupRootAsciiTransition offset remaining acc input else lookupTransition offset remaining acc state input tssOffset {-# NOINLINE collectMatches #-} collectMatches :: Int -> Int -> a -> State -> a collectMatches !offset !remaining !acc !state = let matchedValues = values `at` state -- Fold over the matched values. If at any point the user-supplied fold -- function returns `Done`, then we early out. Otherwise continue. handleMatch !acc' vs = case vs of [] -> consumeInput offset remaining acc' state v:more -> case f acc' (Match (CodeUnitIndex $ offset - initialOffset) v) of Step newAcc -> handleMatch newAcc more Done finalAcc -> finalAcc in handleMatch acc matchedValues -- NOTE: there is no `state` argument here, because this case applies only -- to the root state `stateInitial`. {-# INLINE lookupRootAsciiTransition #-} lookupRootAsciiTransition :: Int -> Int -> a -> CodeUnit -> a lookupRootAsciiTransition !offset !remaining !acc !input = case rootAsciiTransitions `uAt` fromIntegral input of t | transitionIsWildcard t -> consumeInput offset remaining acc stateInitial | otherwise -> collectMatches offset remaining acc (transitionState t) {-# INLINE lookupTransition #-} lookupTransition :: Int -> Int -> a -> State -> CodeUnit -> Int -> a lookupTransition !offset !remaining !acc !state !input !i = case transitions `uAt` i of -- There is no transition for the given input. Follow the fallback edge, -- and try again from that state, etc. If we are in the base state -- already, then nothing matched, so move on to the next input. t | transitionIsWildcard t -> if state == stateInitial then consumeInput offset remaining acc state else followEdge offset remaining acc (transitionState t) input -- We found the transition, switch to that new state, collecting matches. -- NOTE: This comes after wildcard checking, because the code unit of -- the wildcard transition is 0, which is a valid input. t | transitionCodeUnit t == input -> collectMatches offset remaining acc (transitionState t) -- The transition we inspected is not for the current input, and it is not -- a wildcard either; look at the next transition then. _ -> lookupTransition offset remaining acc state input (i + 1) in consumeInput initialOffset initialRemaining seed stateInitial -- NOTE: To get full advantage of inlining this function, you probably want to -- compile the compiling module with -fllvm and the same optimization flags as -- this module. {-# INLINE runText #-} runText :: forall a v. a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a runText = runWithCase CaseSensitive -- Finds all matches in the lowercased text. This function lowercases the text -- on the fly to avoid allocating a second lowercased text array. Lowercasing is -- applied to individual code units, so the indexes into the lowercased text can -- be used to index into the original text. It is still the responsibility of -- the caller to lowercase the needles. Needles that contain uppercase code -- points will not match. -- -- NOTE: To get full advantage of inlining this function, you probably want to -- compile the compiling module with -fllvm and the same optimization flags as -- this module. {-# INLINE runLower #-} runLower :: forall a v. a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a runLower = runWithCase IgnoreCase