-- |
-- Module      :  Mcmc.Internal.Shuffle
-- Description :  Shuffle a list
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Wed May 20 14:37:09 2020.
--
-- From https://wiki.haskell.org/Random_shuffle.
module Mcmc.Internal.Shuffle
  ( shuffle,
  )
where

import Control.Monad
import Control.Monad.ST
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as M
import System.Random.Stateful

-- TODO: Remove shuffle, use System.Random.MWC.Distributions.uniformShuffle and
-- vectors.

-- | Shuffle a vector.
shuffle :: StatefulGen g m => [a] -> g -> m [a]
shuffle :: forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [a]
xs = forall g (m :: * -> *) a.
StatefulGen g m =>
[a] -> Int -> g -> m [a]
grabble [a]
xs (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs)

-- @grabble xs m n@ is /O(m*n')/, where @n' = min n (length xs)@. Choose @n'@
-- elements from @xs@, without replacement, and that @m@ times.
grabble :: StatefulGen g m => [a] -> Int -> g -> m [a]
grabble :: forall g (m :: * -> *) a.
StatefulGen g m =>
[a] -> Int -> g -> m [a]
grabble [a]
xs Int
m g
gen = do
  [(Int, Int)]
swaps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. forall a. Ord a => a -> a -> a
min (Int
l forall a. Num a => a -> a -> a
- Int
1) Int
m] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Int
j <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
i, Int
l) g
gen
    forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i, Int
j)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ (forall a. Vector a -> [a]
V.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> Vector a -> Vector a
V.take Int
m forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Vector a -> [(Int, Int)] -> Vector a
swapElems (forall a. [a] -> Vector a
V.fromList [a]
xs)) [(Int, Int)]
swaps
  where
    l :: Int
l = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs forall a. Num a => a -> a -> a
- Int
1

swapElems :: V.Vector a -> [(Int, Int)] -> V.Vector a
swapElems :: forall a. Vector a -> [(Int, Int)] -> Vector a
swapElems Vector a
xs [(Int, Int)]
swaps = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  MVector s a
mxs <- forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.unsafeThaw Vector a
xs
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
M.unsafeSwap MVector s a
mxs) [(Int, Int)]
swaps
  forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector s a
mxs