-- | List shuffling and sampling with optimal asymptotic time and space complexity using the imperative Fisher–Yates
-- algorithm.
module List.Shuffle
  ( -- * Shuffling
    shuffle,
    shuffle_,
    shuffleIO,

    -- * Sampling
    sample,
    sample_,
    sampleIO,

    -- * Adapting to other monads

    -- ** Reader monad
    -- $example-reader

    -- ** State monad
    -- $example-state
  )
where

import Control.Monad.IO.Class (MonadIO)
import Control.Monad.ST (runST)
import Control.Monad.ST.Strict (ST)
import Data.Foldable qualified as Foldable
import Data.Primitive.Array qualified as Array
import System.Random (RandomGen)
import System.Random qualified as Random

-- $example-reader
--
-- You are working in a reader monad, with access to a pseudo-random number generator somewhere in the environment,
-- in a mutable cell like an @IORef@ or @TVar@:
--
-- > import System.Random qualified as Random
-- > import System.Random.Stateful qualified as Random
-- >
-- > data MyMonad a
-- >
-- > instance MonadIO MyMonad
-- > instance MonadReader MyEnv MyMonad
-- >
-- > data MyEnv = MyEnv
-- >   { ...
-- >   , prng :: Random.AtomicGenM Random.StdGen
-- >   , ...
-- >   }
--
-- In this case, you can adapt 'shuffle' to work in your monad as follows:
--
-- > import List.Shuffle qualified as List
-- > import System.Random qualified as Random
-- >
-- > shuffleList :: [a] -> MyMonad [a]
-- > shuffleList list = do
-- >   MyEnv {prng} <- ask
-- >   Random.applyAtomicGen (List.shuffle list) prng

-- $example-state
--
-- You are working in a state monad with access to a pseudo-random number generator somewhere in the state type. You
-- also have a lens onto this field, which is commonly either provided by @generic-lens@/@optics@ or written manually:
--
-- > import System.Random qualified as Random
-- >
-- > data MyState = MyState
-- >   { ...
-- >   , prng :: Random.StdGen
-- >   , ...
-- >   }
-- >
-- > prngLens :: Lens' MyState Random.StdGen
--
-- In this case, you can adapt 'shuffle' to work in your monad as follows:
--
-- > import Control.Lens qualified as Lens
-- > import Control.Monad.Trans.State.Strict qualified as State
-- > import List.Shuffle qualified as List
-- >
-- > shuffleList :: Monad m => [a] -> StateT MyState m [a]
-- > shuffleList =
-- >   Lens.zoom prngLens . State.state . List.shuffle

-- | \(\mathcal{O}(n)\). Shuffle a list.
shuffle :: (RandomGen g) => [a] -> g -> ([a], g)
shuffle :: forall g a. RandomGen g => [a] -> g -> ([a], g)
shuffle [a]
list g
gen0 =
  forall a. (forall s. ST s a) -> a
runST do
    MutableArray s a
array <- forall a s. [a] -> ST s (MutableArray s a)
listToMutableArray [a]
list
    g
gen1 <- forall a g s. RandomGen g => Int -> MutableArray s a -> g -> ST s g
shuffleN (forall s a. MutableArray s a -> Int
Array.sizeofMutableArray MutableArray s a
array forall a. Num a => a -> a -> a
- Int
1) MutableArray s a
array g
gen0
    Array a
array1 <- forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> m (Array a)
Array.unsafeFreezeArray MutableArray s a
array
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList Array a
array1, g
gen1)
{-# SPECIALIZE shuffle :: [a] -> Random.StdGen -> ([a], Random.StdGen) #-}

-- `shuffleN n array g` shuffles the first `n` elements of `array`, i.e. it performs the Fisher-Yates algorithm, but
-- stopping after `n` elements, effectively leaving those `n` elements at the head of the array "shuffled" and the rest
-- in some random indeterminate order.
--
-- Call `len` the length of the array minus 1. When `n` is the len, the whole array gets shuffled, as shuffling `n-1` of
-- `n` elements is equivalent to shuffling all `n` elements.
--
-- It's fine to pass nonsense values for `n` - negative numbers are equivalent to 0, and numbers larger than `len` are
-- equivalent to `len`.
shuffleN :: forall a g s. (RandomGen g) => Int -> Array.MutableArray s a -> g -> ST s g
shuffleN :: forall a g s. RandomGen g => Int -> MutableArray s a -> g -> ST s g
shuffleN Int
n0 MutableArray s a
array =
  Int -> g -> ST s g
go Int
0
  where
    go :: Int -> g -> ST s g
    go :: Int -> g -> ST s g
go !Int
i g
gen0
      | Int
i forall a. Ord a => a -> a -> Bool
>= Int
n = forall (f :: * -> *) a. Applicative f => a -> f a
pure g
gen0
      | Bool
otherwise = do
          let (Int
j, g
gen1) = forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
Random.uniformR (Int
i, Int
m) g
gen0
          forall s a. Int -> Int -> MutableArray s a -> ST s ()
swapArrayElems Int
i Int
j MutableArray s a
array
          Int -> g -> ST s g
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) g
gen1

    n :: Int
n = forall a. Ord a => a -> a -> a
min Int
n0 Int
m
    m :: Int
m = forall s a. MutableArray s a -> Int
Array.sizeofMutableArray MutableArray s a
array forall a. Num a => a -> a -> a
- Int
1
{-# SPECIALIZE shuffleN :: Int -> Array.MutableArray s a -> Random.StdGen -> ST s Random.StdGen #-}

-- | \(\mathcal{O}(n)\). Like 'shuffle', but discards the final generator.
shuffle_ :: (RandomGen g) => [a] -> g -> [a]
shuffle_ :: forall g a. RandomGen g => [a] -> g -> [a]
shuffle_ [a]
list g
g =
  forall a b. (a, b) -> a
fst (forall g a. RandomGen g => [a] -> g -> ([a], g)
shuffle [a]
list g
g)
{-# SPECIALIZE shuffle_ :: [a] -> Random.StdGen -> [a] #-}

-- | \(\mathcal{O}(n)\). Like 'shuffle', but uses the global random number generator.
shuffleIO :: (MonadIO m) => [a] -> m [a]
shuffleIO :: forall (m :: * -> *) a. MonadIO m => [a] -> m [a]
shuffleIO [a]
list =
  forall g a. RandomGen g => [a] -> g -> [a]
shuffle_ [a]
list forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => m StdGen
Random.newStdGen
{-# SPECIALIZE shuffleIO :: [a] -> IO [a] #-}

-- | \(\mathcal{O}(n)\). Sample elements of a list, without replacement.
--
-- @sample_ c xs@ is equivalent to @take c . shuffle_ xs@, but with a constant factor that is proportional to @c@, not
-- the length of @xs@.
sample :: (RandomGen g) => Int -> [a] -> g -> ([a], g)
sample :: forall g a. RandomGen g => Int -> [a] -> g -> ([a], g)
sample Int
n [a]
list g
gen0 =
  forall a. (forall s. ST s a) -> a
runST do
    MutableArray s a
array <- forall a s. [a] -> ST s (MutableArray s a)
listToMutableArray [a]
list
    g
gen1 <- forall a g s. RandomGen g => Int -> MutableArray s a -> g -> ST s g
shuffleN Int
n MutableArray s a
array g
gen0
    Array a
array1 <- forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> m (Array a)
Array.unsafeFreezeArray MutableArray s a
array
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Int -> [a] -> [a]
take Int
n (forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList Array a
array1), g
gen1)
{-# SPECIALIZE sample :: Int -> [a] -> Random.StdGen -> ([a], Random.StdGen) #-}

-- | \(\mathcal{O}(n)\). Like 'sample', but discards the final generator.
sample_ :: (RandomGen g) => Int -> [a] -> g -> [a]
sample_ :: forall g a. RandomGen g => Int -> [a] -> g -> [a]
sample_ Int
n [a]
list g
g =
  forall a b. (a, b) -> a
fst (forall g a. RandomGen g => Int -> [a] -> g -> ([a], g)
sample Int
n [a]
list g
g)
{-# SPECIALIZE sample_ :: Int -> [a] -> Random.StdGen -> [a] #-}

-- | \(\mathcal{O}(n)\). Like 'sample', but uses the global random number generator.
sampleIO :: (MonadIO m) => Int -> [a] -> m [a]
sampleIO :: forall (m :: * -> *) a. MonadIO m => Int -> [a] -> m [a]
sampleIO Int
n [a]
list =
  forall g a. RandomGen g => Int -> [a] -> g -> [a]
sample_ Int
n [a]
list forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => m StdGen
Random.newStdGen
{-# SPECIALIZE sampleIO :: Int -> [a] -> IO [a] #-}

-- Swap two elements in a mutable array.
swapArrayElems :: Int -> Int -> Array.MutableArray s a -> ST s ()
swapArrayElems :: forall s a. Int -> Int -> MutableArray s a -> ST s ()
swapArrayElems Int
i Int
j MutableArray s a
array = do
  a
x <- forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> m a
Array.readArray MutableArray s a
array Int
i
  a
y <- forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> m a
Array.readArray MutableArray s a
array Int
j
  forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
Array.writeArray MutableArray s a
array Int
i a
y
  forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
Array.writeArray MutableArray s a
array Int
j a
x
{-# INLINE swapArrayElems #-}

-- Construct a mutable array from a list.
listToMutableArray :: [a] -> ST s (Array.MutableArray s a)
listToMutableArray :: forall a s. [a] -> ST s (MutableArray s a)
listToMutableArray [a]
list = do
  MutableArray (PrimState (ST s)) a
array <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MutableArray (PrimState m) a)
Array.newArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
list) forall a. HasCallStack => a
undefined
  let writeElems :: Int -> [a] -> ST s ()
writeElems !Int
i = \case
        [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        a
x : [a]
xs -> do
          forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
Array.writeArray MutableArray (PrimState (ST s)) a
array Int
i a
x
          Int -> [a] -> ST s ()
writeElems (Int
i forall a. Num a => a -> a -> a
+ Int
1) [a]
xs
  Int -> [a] -> ST s ()
writeElems Int
0 [a]
list
  forall (f :: * -> *) a. Applicative f => a -> f a
pure MutableArray (PrimState (ST s)) a
array
{-# INLINE listToMutableArray #-}