{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE BangPatterns #-} -- | Vector samplers with and without replacement. For best results, use -- @System.Random.Mersenne.Pure64@ as the random generator. -- n indicates the size of the source, k the size of the desired sample. module Data.Vector.Sample ( sampleNoReplace , sampleNoReplaceReservoir , sampleNoReplaceIntSet , sampleReplace ) where import Control.Monad (when) import Control.Monad.ST import Control.Monad.State.Strict import qualified Data.IntSet as IS import qualified Data.Vector.Generic as G import qualified Data.Vector.Generic.Mutable as MG import System.Random import System.Random.Mersenne.Pure64 -- | Sampling without replacement. Uses either a generate-check-retry algorithm or reservoir sampling -- based on the input. -- Optimized for @System.Random.Mersenne.Pure64.PureMT@ as the randomness generator. -- Expected O(n). For deterministic O(n), use @sampleNoReplaceReservoir. sampleNoReplace :: (RandomGen g, G.Mutable v1 ~ G.Mutable v1, G.Vector v1 Int, G.Vector v1 a) => g -- ^ The random generator -> Int -- ^ The size of the desired sample -> v1 a -- ^ The vector to sample -> (v1 a, g) -- ^ The vector to sample and a new random seed sampleNoReplace !seed !n !v | n * 1024 <= l = sampleNoReplaceIntSet seed n v | otherwise = sampleNoReplaceReservoir seed n v where l = G.length v -- | Sampling without replacement, using reservoir sampling. -- Optimized for @System.Random.Mersenne.Pure64.PureMT@ as the randomness generator. -- O(n) sampleNoReplaceReservoir :: (RandomGen g, G.Mutable v1 ~ G.Mutable v1, G.Vector v1 a) => g -> Int -> v1 a -> (v1 a, g) sampleNoReplaceReservoir !seed !n !v = runST $ do let (initVals, rest) = G.splitAt n v randVals <- G.thaw initVals --take the first n as the initial sample seed' <- execStateT (G.imapM_ (resReplace randVals n) rest) seed --element replacement flip (,) seed' <$> G.unsafeFreeze randVals {-# INLINE sampleNoReplaceReservoir #-} resReplace !randVals !n !i !x = do --replace a value at random with prob. (n / (n + i)) r <- state . runState $ getUpTo (n + i) when (r < n) $ MG.unsafeWrite randVals r x {-# INLINE resReplace #-} -- | Sampling without replacement. Draws indicies at random, redrawing if the index is already seen. -- Liable to run an extremely long time when the sample size is over half the size of the source vector, -- but extremely fast when the chance of redrawing is low. sampleNoReplaceIntSet :: (RandomGen g, G.Vector v Int, G.Vector v a) => g -> Int -> v a -> (v a, g) sampleNoReplaceIntSet !seed !n !v = (G.backpermute v rvs, seed') where l = G.length v (rvs, (seed', _)) = runState (G.replicateM n (getUpToUq l)) (seed, IS.empty) {-# INLINE sampleNoReplaceIntSet #-} -- | Sampling with replacement, through index generation. -- O(k) sampleReplace :: (RandomGen g, G.Vector v Int, G.Vector v a) => g -> Int -> v a -> (v a, g) sampleReplace !seed !n !v = (G.backpermute v rvs, seed') where l = G.length v (rvs, seed') = runState (G.replicateM n (getUpTo l)) seed -- | Gets random numbers from @0..l - 1@, and updates the seed. NIH'd for rewrite rules getUpTo :: (RandomGen g) => Int -> State g Int getUpTo !l = state $ randomR (0, l - 1) {-# NOINLINE getUpTo #-} {-# RULES "getUpTo/PureMT" getUpTo = getUpToMT #-} -- | @getUpTo@ specialized to @PureMT@ getUpToMT :: Int -> State PureMT Int getUpToMT !l = state randomInt -- | Generates random numbers in @0..l@ until an unseen one is found. getUpToUq :: (RandomGen g) => Int -> State (g, IS.IntSet) Int getUpToUq !l = do (g, seen) <- get let (r, g') = randomR (0, l - 1) g seen' = IS.insert r seen put (g', seen') if IS.member r seen then getUpToUq l else return r {-# NOINLINE getUpToUq #-} {-# RULES "getUpToUq/PureMT" getUpToUq = getUpToUqMT #-} -- | @getUpToUq@, specialized to @PureMT@ getUpToUqMT :: Int -> State (PureMT, IS.IntSet) Int getUpToUqMT !l = do (g, seen) <- get let (r_, g') = randomInt g r = r `mod` l seen' = IS.insert r seen put (g', seen') if IS.member r seen then getUpToUqMT l else return r