{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE BangPatterns #-}
module Data.Vector.Sample
( sampleNoReplace
, sampleNoReplaceReservoir
, sampleNoReplaceIntSet
, sampleReplace
) where
import Control.Monad (when)
import Control.Monad.ST
import Control.Monad.State.Strict
import qualified Data.IntSet as IS
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
import System.Random
import System.Random.Mersenne.Pure64
sampleNoReplace :: (RandomGen g, G.Mutable v1 ~ G.Mutable v1, G.Vector v1 Int, G.Vector v1 a)
=> g
-> Int
-> v1 a
-> (v1 a, g)
sampleNoReplace !seed !n !v
| n * 1024 <= l = sampleNoReplaceIntSet seed n v
| otherwise = sampleNoReplaceReservoir seed n v
where l = G.length v
sampleNoReplaceReservoir :: (RandomGen g, G.Mutable v1 ~ G.Mutable v1, G.Vector v1 a)
=> g
-> Int
-> v1 a
-> (v1 a, g)
sampleNoReplaceReservoir !seed !n !v = runST $ do
let (initVals, rest) = G.splitAt n v
randVals <- G.thaw initVals
seed' <- execStateT (G.imapM_ (resReplace randVals n) rest) seed
flip (,) seed' <$> G.unsafeFreeze randVals
{-# INLINE sampleNoReplaceReservoir #-}
resReplace !randVals !n !i !x = do
r <- state . runState $ getUpTo (n + i)
when (r < n) $
MG.unsafeWrite randVals r x
{-# INLINE resReplace #-}
sampleNoReplaceIntSet :: (RandomGen g, G.Vector v Int, G.Vector v a)
=> g
-> Int
-> v a
-> (v a, g)
sampleNoReplaceIntSet !seed !n !v = (G.backpermute v rvs, seed')
where
l = G.length v
(rvs, (seed', _)) = runState (G.replicateM n (getUpToUq l)) (seed, IS.empty)
{-# INLINE sampleNoReplaceIntSet #-}
sampleReplace :: (RandomGen g, G.Vector v Int, G.Vector v a)
=> g
-> Int
-> v a
-> (v a, g)
sampleReplace !seed !n !v = (G.backpermute v rvs, seed')
where
l = G.length v
(rvs, seed') = runState (G.replicateM n (getUpTo l)) seed
getUpTo :: (RandomGen g) => Int -> State g Int
getUpTo !l = state $ randomR (0, l - 1)
{-# NOINLINE getUpTo #-}
{-# RULES "getUpTo/PureMT" getUpTo = getUpToMT #-}
getUpToMT :: Int -> State PureMT Int
getUpToMT !l = state randomInt
getUpToUq :: (RandomGen g) => Int -> State (g, IS.IntSet) Int
getUpToUq !l = do
(g, seen) <- get
let (r, g') = randomR (0, l - 1) g
seen' = IS.insert r seen
put (g', seen')
if IS.member r seen
then getUpToUq l
else return r
{-# NOINLINE getUpToUq #-}
{-# RULES "getUpToUq/PureMT" getUpToUq = getUpToUqMT #-}
getUpToUqMT :: Int -> State (PureMT, IS.IntSet) Int
getUpToUqMT !l = do
(g, seen) <- get
let (r_, g') = randomInt g
r = r `mod` l
seen' = IS.insert r seen
put (g', seen')
if IS.member r seen
then getUpToUqMT l
else return r