module Mcmc.Internal.Shuffle
( shuffle,
grabble,
)
where
import Control.Monad
import Control.Monad.ST
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as M
import System.Random.Stateful
shuffle :: StatefulGen g m => [a] -> g -> m [a]
shuffle :: [a] -> g -> m [a]
shuffle [a]
xs = [a] -> Int -> g -> m [a]
forall g (m :: * -> *) a.
StatefulGen g m =>
[a] -> Int -> g -> m [a]
grabble [a]
xs ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs)
grabble :: StatefulGen g m => [a] -> Int -> g -> m [a]
grabble :: [a] -> Int -> g -> m [a]
grabble [a]
xs Int
m g
gen = do
[(Int, Int)]
swaps <- [Int] -> (Int -> m (Int, Int)) -> m [(Int, Int)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
m] ((Int -> m (Int, Int)) -> m [(Int, Int)])
-> (Int -> m (Int, Int)) -> m [(Int, Int)]
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Int
j <- (Int, Int) -> g -> m Int
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
i, Int
l) g
gen
(Int, Int) -> m (Int, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i, Int
j)
[a] -> m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> m [a]) -> [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ (Vector a -> [a]
forall a. Vector a -> [a]
V.toList (Vector a -> [a])
-> ([(Int, Int)] -> Vector a) -> [(Int, Int)] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.take Int
m (Vector a -> Vector a)
-> ([(Int, Int)] -> Vector a) -> [(Int, Int)] -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> [(Int, Int)] -> Vector a
forall a. Vector a -> [(Int, Int)] -> Vector a
swapElems ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
xs)) [(Int, Int)]
swaps
where
l :: Int
l = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
swapElems :: V.Vector a -> [(Int, Int)] -> V.Vector a
swapElems :: Vector a -> [(Int, Int)] -> Vector a
swapElems Vector a
xs [(Int, Int)]
swaps = (forall s. ST s (Vector a)) -> Vector a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a)) -> Vector a)
-> (forall s. ST s (Vector a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
MVector s a
mxs <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.unsafeThaw Vector a
xs
((Int, Int) -> ST s ()) -> [(Int, Int)] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Int -> Int -> ST s ()) -> (Int, Int) -> ST s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Int -> Int -> ST s ()) -> (Int, Int) -> ST s ())
-> (Int -> Int -> ST s ()) -> (Int, Int) -> ST s ()
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST s)) a -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
M.unsafeSwap MVector s a
MVector (PrimState (ST s)) a
mxs) [(Int, Int)]
swaps
MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
mxs