{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides functions to perform shuffles on immutable vectors.
-- The shuffling is uniform amongst all permuations and performs the minimal
-- number of transpositions.

module Immutable.Shuffle where

import           Control.Monad.Primitive
import           Control.Monad.Random    (MonadRandom (..))
import           Control.Monad.ST        (runST)
import           Data.Vector.Generic
import qualified Mutable.Shuffle         as MS
import           Prelude                 hiding (length, take)
import           System.Random           (RandomGen (..))


-- |
-- Perform a shuffle on an immutable vector with a given random generator returning a shuffled vector and a new generator.
--
-- This uses the Fisher--Yates--Knuth algorithm.
shuffle :: forall a g v. (RandomGen g, Vector v a) => v a -> g -> (v a, g)
shuffle :: forall a g (v :: * -> *).
(RandomGen g, Vector v a) =>
v a -> g -> (v a, g)
shuffle v a
v g
g
  | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
length v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = (v a
v, g
g)
  | Bool
otherwise     =
      (forall s. ST s (v a, g)) -> (v a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a, g)) -> (v a, g))
-> (forall s. ST s (v a, g)) -> (v a, g)
forall a b. (a -> b) -> a -> b
$
        do
          Mutable v s a
mutV   <- v a -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v
          g
newGen <- Mutable v (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, MVector v a) =>
v (PrimState m) a -> g -> m g
MS.shuffle Mutable v s a
Mutable v (PrimState (ST s)) a
mutV g
g
          v a
immutV <- Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze Mutable v s a
Mutable v (PrimState (ST s)) a
mutV
          (v a, g) -> ST s (v a, g)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (v a
immutV, g
newGen)


-- |
-- Perform a shuffle on an input immutable vector in a monad which has a source of randomness.
--
-- This uses the Fisher--Yates--Knuth algorithm.
shuffleM :: forall m a v . (MonadRandom m, PrimMonad m, Vector v a) => v a -> m (v a)
shuffleM :: forall (m :: * -> *) a (v :: * -> *).
(MonadRandom m, PrimMonad m, Vector v a) =>
v a -> m (v a)
shuffleM v a
v
  | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
length v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = v a -> m (v a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v a
v
  | Bool
otherwise =
      do
        Mutable v (PrimState m) a
mutV   <- v a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v
        Mutable v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, MVector v a) =>
v (PrimState m) a -> m ()
MS.shuffleM Mutable v (PrimState m) a
mutV
        Mutable v (PrimState m) a -> m (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze Mutable v (PrimState m) a
mutV


-- |
-- Perform a shuffle on the first k elements of a vector in a monad which has a
-- source of randomness.
--
shuffleK :: forall m a v. (MonadRandom m, PrimMonad m, Vector v a) => Int -> v a -> m (v a)
shuffleK :: forall (m :: * -> *) a (v :: * -> *).
(MonadRandom m, PrimMonad m, Vector v a) =>
Int -> v a -> m (v a)
shuffleK Int
k v a
v
  | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
length v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = v a -> m (v a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v a
v
  | Bool
otherwise =
      do
        Mutable v (PrimState m) a
mutV   <- v a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v
        Int -> Mutable v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, MVector v a) =>
Int -> v (PrimState m) a -> m ()
MS.shuffleK Int
k Mutable v (PrimState m) a
mutV
        Mutable v (PrimState m) a -> m (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze Mutable v (PrimState m) a
mutV


-- |
-- Get a random sample of k elements without replacement from a vector.
sampleWithoutReplacement :: forall m a v . (MonadRandom m, PrimMonad m, Vector v a) => Int -> v a -> m (v a)
{-# INLINEABLE sampleWithoutReplacement #-}
sampleWithoutReplacement :: forall (m :: * -> *) a (v :: * -> *).
(MonadRandom m, PrimMonad m, Vector v a) =>
Int -> v a -> m (v a)
sampleWithoutReplacement Int
k v a
v = Int -> v a -> v a
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
take Int
k (v a -> v a) -> m (v a) -> m (v a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> v a -> m (v a)
forall (m :: * -> *) a (v :: * -> *).
(MonadRandom m, PrimMonad m, Vector v a) =>
Int -> v a -> m (v a)
shuffleK Int
k v a
v


-- |
-- Perform an in-place shuffle on an immutable vector wherein the shuffled
-- indices form a maximal cycle.
--
-- This uses the Sattolo algorithm.
maximalCycle :: forall a g v. (RandomGen g, Vector v a) => v a -> g -> (v a, g)
maximalCycle :: forall a g (v :: * -> *).
(RandomGen g, Vector v a) =>
v a -> g -> (v a, g)
maximalCycle v a
v g
g
  | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
length v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = (v a
v, g
g)
  | Bool
otherwise     =
      (forall s. ST s (v a, g)) -> (v a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a, g)) -> (v a, g))
-> (forall s. ST s (v a, g)) -> (v a, g)
forall a b. (a -> b) -> a -> b
$
        do
          Mutable v s a
mutV   <- v a -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v
          g
newGen <- Mutable v (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, MVector v a) =>
v (PrimState m) a -> g -> m g
MS.maximalCycle Mutable v s a
Mutable v (PrimState (ST s)) a
mutV g
g
          v a
immutV <- Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze Mutable v s a
Mutable v (PrimState (ST s)) a
mutV
          (v a, g) -> ST s (v a, g)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (v a
immutV, g
newGen)

-- |
-- Perform an in-place shuffle on an immutable vector wherein the shuffled
-- indices form a maximal cycle in a monad with a source of randomness.
--
-- This uses the Sattolo algorithm.
maximalCycleM :: forall m a v . (MonadRandom m, PrimMonad m, Vector v a) => v a -> m (v a)
maximalCycleM :: forall (m :: * -> *) a (v :: * -> *).
(MonadRandom m, PrimMonad m, Vector v a) =>
v a -> m (v a)
maximalCycleM v a
v
  | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
length v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = v a -> m (v a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v a
v
  | Bool
otherwise =
      do
        Mutable v (PrimState m) a
mutV   <- v a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v
        Mutable v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, MVector v a) =>
v (PrimState m) a -> m ()
MS.maximalCycleM Mutable v (PrimState m) a
mutV
        Mutable v (PrimState m) a -> m (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze Mutable v (PrimState m) a
mutV


-- |
-- Perform an in-place  [derangement](https://en.wikipedia.org/wiki/Derangement)
-- on an immutable vector with a given random generator, returning a new random
-- generator.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm.
derangement :: forall a g v . (Eq a, RandomGen g, Vector v a) => v a -> g -> (v a, g)
derangement :: forall a g (v :: * -> *).
(Eq a, RandomGen g, Vector v a) =>
v a -> g -> (v a, g)
derangement v a
v g
g
  | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
length v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = (v a
v, g
g)
  | Bool
otherwise     =
      (forall s. ST s (v a, g)) -> (v a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a, g)) -> (v a, g))
-> (forall s. ST s (v a, g)) -> (v a, g)
forall a b. (a -> b) -> a -> b
$
        do
          Mutable v s a
mutV   <- v a -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v
          g
newGen <- Mutable v (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, Eq a, MVector v a) =>
v (PrimState m) a -> g -> m g
MS.derangement Mutable v s a
Mutable v (PrimState (ST s)) a
mutV g
g
          v a
immutV <- Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze Mutable v s a
Mutable v (PrimState (ST s)) a
mutV
          (v a, g) -> ST s (v a, g)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (v a
immutV, g
newGen)


-- |
-- Perform an in-place [derangement](https://en.wikipedia.org/wiki/Derangement) on
-- an immutable vector in a monad which has a source of randomness.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm.
derangementM :: forall m a v . (Eq a, MonadRandom m, PrimMonad m, Vector v a) => v a -> m (v a)
derangementM :: forall (m :: * -> *) a (v :: * -> *).
(Eq a, MonadRandom m, PrimMonad m, Vector v a) =>
v a -> m (v a)
derangementM v a
v
  | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
length v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = v a -> m (v a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v a
v
  | Bool
otherwise =
      do
        Mutable v (PrimState m) a
mutV   <- v a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v
        Mutable v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, Eq a, MVector v a) =>
v (PrimState m) a -> m ()
MS.derangementM Mutable v (PrimState m) a
mutV
        Mutable v (PrimState m) a -> m (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze Mutable v (PrimState m) a
mutV