{-# LANGUAGE BangPatterns, GeneralizedNewtypeDeriving #-} module Data.SearchEngine.TermBag ( TermId(..), TermCount, TermBag, size, fromList, toList, elems, termCount, denseTable, invariant ) where import qualified Data.Vector.Unboxed as Vec import qualified Data.Vector.Unboxed.Mutable as MVec import qualified Data.Vector.Generic.Base as VecGen import qualified Data.Vector.Unboxed.Base as VecBase import qualified Data.Vector.Generic.Mutable as VecMut import Control.Monad.ST import qualified Data.Map as Map import Data.Word (Word32, Word8) import Data.Bits import Data.List (sortBy, foldl') import Data.Function (on) newtype TermId = TermId Word32 deriving (Eq, Ord, Show, Enum, Vec.Unbox, VecGen.Vector VecBase.Vector, VecMut.MVector VecBase.MVector) instance Bounded TermId where minBound = TermId 0 maxBound = TermId 0x00FFFFFF data TermBag = TermBag !Int !(Vec.Vector TermIdAndCount) deriving Show -- We sneakily stuff both the TermId and the bag count into one 32bit word type TermIdAndCount = Word32 type TermCount = Word8 -- Bottom 24 bits is the TermId, top 8 bits is the bag count termIdAndCount :: TermId -> Int -> TermIdAndCount termIdAndCount (TermId termid) freq = (min (fromIntegral freq) 255 `shiftL` 24) .|. (termid .&. 0x00FFFFFF) getTermId :: TermIdAndCount -> TermId getTermId word = TermId (word .&. 0x00FFFFFF) getTermCount :: TermIdAndCount -> TermCount getTermCount word = fromIntegral (word `shiftR` 24) invariant :: TermBag -> Bool invariant (TermBag _ vec) = strictlyAscending (Vec.toList vec) where strictlyAscending (a:xs@(b:_)) = getTermId a < getTermId b && strictlyAscending xs strictlyAscending _ = True size :: TermBag -> Int size (TermBag sz _) = sz elems :: TermBag -> [TermId] elems (TermBag _ vec) = map getTermId (Vec.toList vec) toList :: TermBag -> [(TermId, TermCount)] toList (TermBag _ vec) = [ (getTermId x, getTermCount x) | x <- Vec.toList vec ] termCount :: TermBag -> TermId -> TermCount termCount (TermBag _ vec) = binarySearch 0 (Vec.length vec - 1) where binarySearch :: Int -> Int -> TermId -> TermCount binarySearch !a !b !key | a > b = 0 | otherwise = let mid = (a + b) `div` 2 tidAndCount = vec Vec.! mid in case compare key (getTermId tidAndCount) of LT -> binarySearch a (mid-1) key EQ -> getTermCount tidAndCount GT -> binarySearch (mid+1) b key fromList :: [TermId] -> TermBag fromList termids = let bag = Map.fromListWith (+) [ (t, 1) | t <- termids ] sz = Map.foldl' (+) 0 bag vec = Vec.fromListN (Map.size bag) [ termIdAndCount termid freq | (termid, freq) <- Map.toAscList bag ] in TermBag sz vec -- | Given a bunch of term bags, merge them into a table for easier subsequent -- processing. This is bascially a sparse to dense conversion. Missing entries -- are filled in with 0. We represent the table as one vector for the -- term ids and a 2d array for the counts. -- -- Unfortunately vector does not directly support 2d arrays and array does -- not make it easy to trim arrays. -- denseTable :: [TermBag] -> (Vec.Vector TermId, Vec.Vector TermCount) denseTable termbags = (tids, tcts) where -- First merge the TermIds into one array -- then make a linear pass to create the counts array -- filling in 0s or the counts as we find them !numBags = length termbags !tids = unionsTermId termbags !numTerms = Vec.length tids !numCounts = numTerms * numBags !tcts = Vec.create (do out <- MVec.new numCounts sequence_ [ writeMergedTermCounts tids bag out i | (n, TermBag _ bag) <- zip [0..] termbags , let i = n * numTerms ] return out ) writeMergedTermCounts :: Vec.Vector TermId -> Vec.Vector TermIdAndCount -> MVec.MVector s TermCount -> Int -> ST s () writeMergedTermCounts xs0 ys0 !out i0 = -- assume xs & ys are sorted, and ys contains a subset of xs go xs0 ys0 i0 where go !xs !ys !i | Vec.null ys = MVec.set (MVec.slice i (Vec.length xs) out) 0 | Vec.null xs = return () | otherwise = let x = Vec.head xs ytc = Vec.head ys y = getTermId ytc c = getTermCount ytc in case x == y of True -> do MVec.write out i c go (Vec.tail xs) (Vec.tail ys) (i+1) False -> do MVec.write out i 0 go (Vec.tail xs) ys (i+1) -- | Given a set of term bags, form the set of TermIds -- unionsTermId :: [TermBag] -> Vec.Vector TermId unionsTermId tbs = case sortBy (compare `on` bagVecLength) tbs of [] -> Vec.empty [TermBag _ xs] -> (Vec.map getTermId xs) (x0:x1:xs) -> foldl' union3 (union2 x0 x1) xs where bagVecLength (TermBag _ vec) = Vec.length vec union2 :: TermBag -> TermBag -> Vec.Vector TermId union2 (TermBag _ xs) (TermBag _ ys) = Vec.create (MVec.new sizeBound >>= writeMergedUnion2 xs ys) where sizeBound = Vec.length xs + Vec.length ys writeMergedUnion2 :: Vec.Vector TermIdAndCount -> Vec.Vector TermIdAndCount -> MVec.MVector s TermId -> ST s (MVec.MVector s TermId) writeMergedUnion2 xs0 ys0 !out = do i <- go xs0 ys0 0 return $! MVec.take i out where go !xs !ys !i | Vec.null xs = do Vec.copy (MVec.slice i (Vec.length ys) out) (Vec.map getTermId ys) return (i + Vec.length ys) | Vec.null ys = do Vec.copy (MVec.slice i (Vec.length xs) out) (Vec.map getTermId xs) return (i + Vec.length xs) | otherwise = let x = getTermId (Vec.head xs) y = getTermId (Vec.head ys) in case compare x y of GT -> do MVec.write out i y go xs (Vec.tail ys) (i+1) EQ -> do MVec.write out i x go (Vec.tail xs) (Vec.tail ys) (i+1) LT -> do MVec.write out i x go (Vec.tail xs) ys (i+1) union3 :: Vec.Vector TermId -> TermBag -> Vec.Vector TermId union3 xs (TermBag _ ys) = Vec.create (MVec.new sizeBound >>= writeMergedUnion3 xs ys) where sizeBound = Vec.length xs + Vec.length ys writeMergedUnion3 :: Vec.Vector TermId -> Vec.Vector TermIdAndCount -> MVec.MVector s TermId -> ST s (MVec.MVector s TermId) writeMergedUnion3 xs0 ys0 !out = do i <- go xs0 ys0 0 return $! MVec.take i out where go !xs !ys !i | Vec.null xs = do Vec.copy (MVec.slice i (Vec.length ys) out) (Vec.map getTermId ys) return (i + Vec.length ys) | Vec.null ys = do Vec.copy (MVec.slice i (Vec.length xs) out) xs return (i + Vec.length xs) | otherwise = let x = Vec.head xs y = getTermId (Vec.head ys) in case compare x y of GT -> do MVec.write out i y go xs (Vec.tail ys) (i+1) EQ -> do MVec.write out i x go (Vec.tail xs) (Vec.tail ys) (i+1) LT -> do MVec.write out i x go (Vec.tail xs) ys (i+1)