module Control.Monad.MC.Sample (
sample,
sampleWithWeights,
sampleSubset,
sampleSubsetWithWeights,
shuffle,
sampleInt,
sampleIntWithWeights,
sampleIntSubset,
sampleIntSubsetWithWeights,
shuffleInt,
) where
import Control.Monad( forM_, liftM )
import Control.Monad.Primitive( PrimMonad )
import Control.Monad.Trans.Class( lift )
import Data.List( foldl', sort )
import qualified Data.Vector as BV
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as MV
import Control.Monad.MC.GSLBase
import Control.Monad.MC.Walker
sample :: (PrimMonad m) => [a] -> MC m a
sample xs = let
n = length xs
in sampleHelp n xs $ sampleInt n
sampleWithWeights :: (PrimMonad m) => [(Double, a)] -> MC m a
sampleWithWeights wxs = let
(ws,xs) = unzip wxs
n = length xs
in sampleHelp n xs $ sampleIntWithWeights ws n
sampleHelp :: (PrimMonad m) => Int -> [a] -> MC m Int -> MC m a
sampleHelp _n xs f = let
arr = BV.fromList xs
in liftM (BV.unsafeIndex arr) f
sampleSubset :: (PrimMonad m) => [a] -> Int -> MC m [a]
sampleSubset xs k = let
n = length xs
in sampleSubsetHelp n xs $ sampleIntSubset n k
sampleSubsetWithWeights :: (PrimMonad m) => [(Double,a)] -> Int -> MC m [a]
sampleSubsetWithWeights wxs k = let
(ws,xs) = unzip wxs
n = length ws
in sampleSubsetHelp n xs $ sampleIntSubsetWithWeights ws n k
sampleSubsetHelp :: (Monad m) => Int -> [a] -> m [Int] -> m [a]
sampleSubsetHelp _n xs f = let
arr = BV.fromList xs
in liftM (map $ BV.unsafeIndex arr) f
sampleInt :: (PrimMonad m) => Int -> MC m Int
sampleInt n | n < 1 = fail "invalid argument"
| otherwise = uniformInt n
sampleIntWithWeights :: (PrimMonad m) => [Double] -> Int -> MC m Int
sampleIntWithWeights ws n =
let qjs = computeTable n ws
in liftM (indexTable qjs) (uniform 0 1)
sampleIntSubset :: (PrimMonad m) => Int -> Int -> MC m [Int]
sampleIntSubset n k | k < 0 = fail "negative subset size"
| k > n = fail "subset size is too big"
| otherwise = do
xs <- lift $ (V.thaw . V.fromList) [ 0..n1 ]
go xs [] n k
where
go xs ys n' k' | k' == 0 = return $ reverse ys
| otherwise = do
u <- uniformInt n'
y <- lift $ do
i <- MV.unsafeRead xs u
j <- MV.unsafeRead xs (n'1)
MV.unsafeWrite xs u j
return i
go xs (y:ys) (n'1) (k'1)
sampleIntSubsetWithWeights :: (PrimMonad m) => [Double] -> Int -> Int -> MC m [Int]
sampleIntSubsetWithWeights ws n k | k < 0 = fail "negative subset size"
| k > n = fail "subset size is too big"
| otherwise = let
wsum = foldl' (+) 0 $ take n ws
wjs = [ (w / wsum, j) | (w,j) <- reverse $ sort $ zip ws [ 0..n1 ] ]
in do
xs <- lift $ (V.thaw . V.fromList) wjs
go xs wsum [] n k
where
go xs wsum' ys n' k' | k' == 0 = return $ reverse ys
| otherwise = do
target <- uniform 0 wsum'
(w,y) <- lift $ do
(i,wj) <- findTarget xs n' target 0 0
shiftDown xs (i+1) (n'1)
return wj
let wsum'' = wsum' w
ys' = y:ys
n'' = n' 1
k'' = k' 1
go xs wsum'' ys' n'' k''
findTarget xs n' target i acc
| i == n' 1 = do
wj <- MV.unsafeRead xs i
return (i,wj)
| otherwise = do
(w,j) <- MV.unsafeRead xs i
let acc' = acc + w
if target <= acc'
then return (i,(w,j))
else findTarget xs n' target (i+1) acc'
shiftDown xs from to =
forM_ [ from..to ] $ \i -> do
wj <- MV.unsafeRead xs i
MV.unsafeWrite xs (i1) wj
shuffle :: (PrimMonad m) => [a] -> MC m [a]
shuffle xs = let
n = length xs
mis = liftM BV.fromList $ shuffleInt n
in liftM (BV.toList . BV.unsafeBackpermute (BV.fromList xs)) mis
shuffleInt :: (PrimMonad m) => Int -> MC m [Int]
shuffleInt n = sampleIntSubset n n