module Numeric.MCMC.Util
where
import Control.Monad.Primitive (PrimMonad, PrimState)
import System.Random.MWC
import Data.Function (fix)
import qualified Data.Sequence as Seq
import Data.Foldable (toList)
import Control.Monad
import Control.Monad.ST
import Data.STRef
genDiffInt :: PrimMonad m => Int -> (Int, Int) -> Gen (PrimState m) -> m Int
genDiffInt a bounds gen = fix $ \loopB ->
do b <- uniformR bounds gen
if a == b then loopB else return b
mean :: [Double] -> Double
mean = go 0.0 0
where go :: Double -> Int -> [Double] -> Double
go s l [] = s / fromIntegral l
go s l (x:xs) = go (s + x) (l + 1) xs
mapPair :: (a -> b) -> (a, a) -> (b, b)
mapPair f (a, b) = (f a, f b)
shortListToPair :: [a] -> (a, a)
shortListToPair [x0, x1] = (x0, x1)
shortListToPair _ = error "shortListToPair - list must have length 2."
shuffle :: [a] -> Seed -> [a]
shuffle xs seed = runST $ do
xsref <- newSTRef $ Seq.fromList xs
gen <- restore seed
let n = length xs
forM_ [n1,n2..1] $ \i -> do
j <- uniformR (0, i) gen
xst <- readSTRef xsref
let tmp = Seq.index xst j
modifySTRef xsref $ Seq.update j (Seq.index xst i)
modifySTRef xsref $ Seq.update i tmp
result <- readSTRef xsref
return $ toList result
sample :: Int -> [a] -> Seed -> [a]
sample k xs seed = take k $ shuffle xs seed