{-# LANGUAGE TypeApplications #-} module Streamly.External.LMDB.Tests (tests) where import Control.Concurrent.Async (asyncBound, wait) import Control.Exception (SomeException, bracket, onException, try) import Control.Monad (forM_) import Data.ByteString (ByteString, pack, unpack) import qualified Data.ByteString as B import Data.ByteString.Unsafe (unsafeUseAsCStringLen) import Data.List (find, foldl', nubBy, sort) import Data.Word (Word8) import Foreign (castPtr, nullPtr, with) import Streamly.Data.Stream.Prelude (fromList, toList, unfold) import qualified Streamly.Data.Stream.Prelude as S import Streamly.External.LMDB import Streamly.External.LMDB.Internal (Database (..)) import Streamly.External.LMDB.Internal.Foreign import Test.QuickCheck (Gen, NonEmptyList (..), choose, elements, frequency) import Test.QuickCheck.Monadic (PropertyM, monadicIO, pick, run) import Test.Tasty (TestTree) import Test.Tasty.QuickCheck (arbitrary, testProperty) tests :: IO (Database ReadWrite, Environment ReadWrite) -> [TestTree] tests dbenv = [ testReadLMDB dbenv, testUnsafeReadLMDB dbenv, testWriteLMDB dbenv, testWriteLMDB_2 dbenv, testWriteLMDB_3 dbenv, testBetween ] withReadOnlyTxnAndCurs :: (Mode mode) => Environment mode -> Database mode -> ((ReadOnlyTxn, Cursor) -> IO r) -> IO r withReadOnlyTxnAndCurs env db = bracket (beginReadOnlyTxn env >>= \txn -> openCursor txn db >>= \curs -> return (txn, curs)) (\(txn, curs) -> closeCursor curs >> abortReadOnlyTxn txn) -- | Clear the database, write key-value pairs to it in a normal manner, read -- them back using our library, and make sure the result is what we wrote. testReadLMDB :: (Mode mode) => IO (Database mode, Environment mode) -> TestTree testReadLMDB res = testProperty "readLMDB" . monadicIO $ do (db, env) <- run res keyValuePairs <- arbitraryKeyValuePairs'' run $ clearDatabase db run $ writeChunk db False keyValuePairs let keyValuePairsInDb = sort . removeDuplicateKeys $ keyValuePairs (readOpts, expectedResults) <- pick $ readOptionsAndResults keyValuePairsInDb let unf txn = toList $ unfold (readLMDB db txn readOpts) undefined results <- run $ unf Nothing resultsTxn <- run $ withReadOnlyTxnAndCurs env db (unf . Just) return $ results == expectedResults && resultsTxn == expectedResults -- | Similar to 'testReadLMDB', except that it tests the unsafe function in a different manner. testUnsafeReadLMDB :: (Mode mode) => IO (Database mode, Environment mode) -> TestTree testUnsafeReadLMDB res = testProperty "unsafeReadLMDB" . monadicIO $ do (db, env) <- run res keyValuePairs <- arbitraryKeyValuePairs'' run $ clearDatabase db run $ writeChunk db False keyValuePairs let keyValuePairsInDb = sort . removeDuplicateKeys $ keyValuePairs (readOpts, expectedResults) <- pick $ readOptionsAndResults keyValuePairsInDb let expectedLengths = map (\(k, v) -> (B.length k, B.length v)) expectedResults let unf txn = toList $ unfold (unsafeReadLMDB db txn readOpts (return . snd) (return . snd)) undefined lengths <- run $ unf Nothing lengthsTxn <- run $ withReadOnlyTxnAndCurs env db (unf . Just) return $ lengths == expectedLengths && lengthsTxn == expectedLengths -- | Clear the database, write key-value pairs to it using our library with key overwriting allowed, -- read them back using our library (already covered by 'testReadLMDB'), and make sure the result is -- what we wrote. testWriteLMDB :: IO (Database ReadWrite, Environment ReadWrite) -> TestTree testWriteLMDB res = testProperty "writeLMDB" . monadicIO $ do (db, _) <- run res keyValuePairs <- arbitraryKeyValuePairs run $ clearDatabase db chunkSz <- pick arbitrary unsafeFFI <- pick arbitrary let fol' = writeLMDB db $ defaultWriteOptions { writeTransactionSize = chunkSz, writeOverwriteOptions = OverwriteAllow, writeUnsafeFFI = unsafeFFI } -- TODO: Run with new "bound" functionality in streamly. run $ asyncBound (S.fold fol' (fromList keyValuePairs)) >>= wait let keyValuePairsInDb = sort . removeDuplicateKeys $ keyValuePairs readPairsAll <- run . toList $ unfold (readLMDB db Nothing defaultReadOptions) undefined return $ keyValuePairsInDb == readPairsAll -- | Clear the database, write key-value pairs to it using our library with key overwriting -- disallowed, and make sure an exception occurs iff we had a duplicate key in our pairs. -- Furthermore make sure that key-value pairs prior to a duplicate key are actually in the database. testWriteLMDB_2 :: IO (Database ReadWrite, Environment ReadWrite) -> TestTree testWriteLMDB_2 res = testProperty "writeLMDB_2" . monadicIO $ do (db, _) <- run res keyValuePairs <- arbitraryKeyValuePairs' run $ clearDatabase db chunkSz <- pick arbitrary unsafeFFI <- pick arbitrary -- TODO: Run with new "bound" functionality in streamly. let fol' = writeLMDB db $ defaultWriteOptions { writeTransactionSize = chunkSz, writeOverwriteOptions = OverwriteDisallow, writeUnsafeFFI = unsafeFFI } e <- run $ try @SomeException $ (asyncBound (S.fold fol' (fromList keyValuePairs)) >>= wait) exceptionAsExpected <- case e of Left _ -> return $ hasDuplicateKeys keyValuePairs Right _ -> return . not $ hasDuplicateKeys keyValuePairs let keyValuePairsInDb = sort . prefixBeforeDuplicate $ keyValuePairs readPairsAll <- run . toList $ unfold (readLMDB db Nothing defaultReadOptions) undefined let pairsAsExpected = keyValuePairsInDb == readPairsAll return $ exceptionAsExpected && pairsAsExpected -- | Clear the database, write key-value pairs to it using our library with key overwriting -- disallowed except when attempting to replace an existing key-value pair, and make sure an -- exception occurs iff we had a duplicate key with different values in our pairs. Furthermore make -- sure that key-value pairs prior to a such a duplicate key are actually in the database. testWriteLMDB_3 :: IO (Database ReadWrite, Environment ReadWrite) -> TestTree testWriteLMDB_3 res = testProperty "writeLMDB_3" . monadicIO $ do (db, _) <- run res keyValuePairs <- arbitraryKeyValuePairs' run $ clearDatabase db chunkSz <- pick arbitrary unsafeFFI <- pick arbitrary -- TODO: Run with new "bound" functionality in streamly. let fol' = writeLMDB db $ defaultWriteOptions { writeTransactionSize = chunkSz, writeOverwriteOptions = OverwriteAllowSame, writeUnsafeFFI = unsafeFFI } e <- run $ try @SomeException $ (asyncBound (S.fold fol' (fromList keyValuePairs)) >>= wait) exceptionAsExpected <- case e of Left _ -> return $ hasDuplicateKeysWithDiffVals keyValuePairs Right _ -> return . not $ hasDuplicateKeysWithDiffVals keyValuePairs let keyValuePairsInDb = sort . removeDuplicateKeys . prefixBeforeDuplicateWithDiffVal $ keyValuePairs readPairsAll <- run . toList $ unfold (readLMDB db Nothing defaultReadOptions) undefined let pairsAsExpected = keyValuePairsInDb == readPairsAll return $ exceptionAsExpected && pairsAsExpected arbitraryKeyValuePairs :: PropertyM IO [(ByteString, ByteString)] arbitraryKeyValuePairs = map (\(ws1, ws2) -> (pack ws1, pack ws2)) . filter (\(ws1, _) -> not (null ws1)) -- LMDB does not allow empty keys. <$> pick arbitrary -- A variation that makes duplicate keys more likely. arbitraryKeyValuePairs' :: PropertyM IO [(ByteString, ByteString)] arbitraryKeyValuePairs' = do arb <- arbitraryKeyValuePairs b <- pick arbitrary if not (null arb) && b then do let (k, v) = head arb b' <- pick arbitrary v' <- if b' then return v else pack <$> pick arbitrary i <- pick $ choose (negate $ length arb, 2 * length arb) let (arb1, arb2) = splitAt i arb let arb' = arb1 ++ [(k, v')] ++ arb2 return arb' else return arb -- A variation that makes more likely keys with same the prefix and a difference of trailing zero -- bytes. arbitraryKeyValuePairs'' :: PropertyM IO [(ByteString, ByteString)] arbitraryKeyValuePairs'' = do arb <- arbitraryKeyValuePairs if null arb then return arb else pick $ frequency [ (1, return arb), ( 3, do let (k, v) = head arb b' <- arbitrary v' <- if b' then return v else pack <$> arbitrary i <- choose (0, length arb - 1) let (arb1, arb2) = splitAt i arb let arb3 = map (\i' -> (k `B.append` B.replicate i' 0, v')) [1 .. (i + 1)] let arb' = arb1 ++ arb3 ++ arb2 return arb' ) ] -- | Note that this function retains the last value for each key. removeDuplicateKeys :: (Eq a) => [(a, b)] -> [(a, b)] removeDuplicateKeys = foldl' (\acc (a, b) -> if any ((== a) . fst) acc then acc else (a, b) : acc) [] . reverse hasDuplicateKeys :: (Eq a) => [(a, b)] -> Bool hasDuplicateKeys l = let l2 = nubBy (\(a1, _) (a2, _) -> a1 == a2) l in length l /= length l2 hasDuplicateKeysWithDiffVals :: (Eq a, Eq b) => [(a, b)] -> Bool hasDuplicateKeysWithDiffVals l = let l2 = nubBy (\(a1, b1) (a2, b2) -> a1 == a2 && b1 /= b2) l in length l /= length l2 prefixBeforeDuplicate :: (Eq a) => [(a, b)] -> [(a, b)] prefixBeforeDuplicate xs = let fstDup = snd <$> find (\((a, _), i) -> a `elem` map fst (take i xs)) (zip xs [0 ..]) in case fstDup of Nothing -> xs Just i -> take i xs prefixBeforeDuplicateWithDiffVal :: (Eq a, Eq b) => [(a, b)] -> [(a, b)] prefixBeforeDuplicateWithDiffVal xs = let fstDup = snd <$> find ( \((a, b), i) -> any (\(a', b') -> a == a' && b /= b') (take i xs) ) (zip xs [0 ..]) in case fstDup of Nothing -> xs Just i -> take i xs -- Assumes first < second. between :: [Word8] -> [Word8] -> [Word8] -> Maybe [Word8] between [] [] _ = error "first = second" between _ [] _ = error "first > second" between [] (w : ws) commonPrefixRev | w == 0 && null ws = Nothing | w == 0 = between [] ws (w : commonPrefixRev) | otherwise = Just $ reverse (0 : commonPrefixRev) between (w1 : ws1) (w2 : ws2) commonPrefixRev | w1 == w2 = between ws1 ws2 (w1 : commonPrefixRev) | w1 > w2 = error "first > second" | otherwise = Just $ reverse commonPrefixRev ++ [w1] ++ ws1 ++ [0] testBetween :: TestTree testBetween = testProperty "testBetween" $ \ws1 ws2 -> (ws1 == ws2) || let (smaller, bigger) = if ws1 < ws2 then (ws1, ws2) else (ws2, ws1) in case between smaller bigger [] of Nothing -> drop (length ws1) ws2 == replicate (length ws2 - length ws1) 0 Just betw -> smaller < betw && betw < bigger betweenBs :: ByteString -> ByteString -> Maybe ByteString betweenBs bs1 bs2 = between (unpack bs1) (unpack bs2) [] >>= (return . pack) type PairsInDatabase = [(ByteString, ByteString)] type ExpectedReadResult = [(ByteString, ByteString)] -- | Given database pairs, randomly generates read options and corresponding expected results. readOptionsAndResults :: PairsInDatabase -> Gen (ReadOptions, ExpectedReadResult) readOptionsAndResults pairsInDb = do forw <- arbitrary let dir = if forw then Forward else Backward unsafeFFI <- arbitrary let len = length pairsInDb readAll <- frequency [(1, return True), (3, return False)] let ropts = defaultReadOptions {readDirection = dir, readUnsafeFFI = unsafeFFI} if readAll then return (ropts {readStart = Nothing}, (if forw then id else reverse) pairsInDb) else if len == 0 then do bs <- arbitrary >>= \(NonEmpty ws) -> return $ pack ws return (ropts {readStart = Just bs}, []) else do idx <- if len < 3 then choose (0, len - 1) else frequency [(1, choose (1, len - 2)), (3, elements [0, len - 1])] let keyAt i = fst $ pairsInDb !! i let nextKey | idx + 1 <= len - 1 = betweenBs (keyAt idx) (keyAt $ idx + 1) | otherwise = Just $ keyAt (len - 1) `B.append` B.singleton 0 let prevKey -- Keys are known to be non-empty. | idx == 0 && keyAt idx /= B.singleton 0 = Just $ B.singleton 0 | idx == 0 = Nothing | otherwise = betweenBs (keyAt $ idx - 1) (keyAt idx) let forwEq = (ropts {readStart = Just $ keyAt idx}, drop idx pairsInDb) let backwEq = (ropts {readStart = Just $ keyAt idx}, reverse $ take (idx + 1) pairsInDb) ord <- arbitrary @Ordering -- Proximity to the key at idx (if possible). return $ case (ord, dir) of (EQ, Forward) -> forwEq (EQ, Backward) -> backwEq (GT, Forward) -> case nextKey of Nothing -> forwEq Just nextKey' -> (ropts {readStart = Just nextKey'}, drop (idx + 1) pairsInDb) (GT, Backward) -> case nextKey of Nothing -> backwEq Just nextKey' -> (ropts {readStart = Just nextKey'}, reverse $ take (idx + 1) pairsInDb) (LT, Forward) -> case prevKey of Nothing -> forwEq Just prevKey' -> (ropts {readStart = Just prevKey'}, drop idx pairsInDb) (LT, Backward) -> case prevKey of Nothing -> backwEq Just prevKey' -> (ropts {readStart = Just prevKey'}, reverse $ take idx pairsInDb) -- Writes the given key-value pairs to the given database. writeChunk :: (Foldable t, Mode mode) => Database mode -> Bool -> t (ByteString, ByteString) -> IO () writeChunk (Database penv dbi) noOverwrite' keyValuePairs = let flags = combineOptions $ [mdb_nooverwrite | noOverwrite'] in asyncBound ( do ptxn <- mdb_txn_begin penv nullPtr 0 onException ( forM_ keyValuePairs $ \(k, v) -> marshalOut k $ \k' -> marshalOut v $ \v' -> with k' $ \k'' -> with v' $ \v'' -> mdb_put ptxn dbi k'' v'' flags ) (mdb_txn_commit ptxn) -- Make sure the key-value pairs we have so far are committed. mdb_txn_commit ptxn ) >>= wait {-# INLINE marshalOut #-} marshalOut :: ByteString -> (MDB_val -> IO ()) -> IO () marshalOut bs f = unsafeUseAsCStringLen bs $ \(ptr, len) -> f $ MDB_val (fromIntegral len) (castPtr ptr)