{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Mutable.Shuffle where
import Control.Monad.Primitive
import Control.Monad.Random (MonadRandom (..))
import Data.Vector.Generic.Mutable
import Prelude hiding (length, read, tail)
import System.Random (RandomGen)
import qualified System.Random as SR
shuffle
:: forall m a g v
. ( PrimMonad m
, RandomGen g
, MVector v a
)
=> v (PrimState m) a -> g -> m g
{-# INLINABLE shuffle #-}
shuffle :: forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, MVector v a) =>
v (PrimState m) a -> g -> m g
shuffle v (PrimState m) a
mutV g
gen = v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
mutV g
gen (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> g -> Int -> m g
{-# INLINE go #-}
go :: v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
_ g
g (- 1) = g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
_ g
g Int
0 = g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
v g
g Int
maxInd =
do
let (Int
ind, g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (Int
0, Int
maxInd) g
g
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> g -> Int -> m g
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) g
newGen (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
shuffleM
:: forall m a v
. ( PrimMonad m
, MonadRandom m
, MVector v a
)
=> v (PrimState m) a -> m ()
{-# INLINABLE shuffleM #-}
shuffleM :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, MVector v a) =>
v (PrimState m) a -> m ()
shuffleM v (PrimState m) a
mutV = v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
mutV (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
_ (- 1) = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
_ Int
0 = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
v Int
maxInd =
do
Int
ind <- (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0, Int
maxInd)
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> Int -> m ()
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
shuffleK
:: forall m a v
. ( PrimMonad m
, MonadRandom m
, MVector v a
)
=> Int -> v (PrimState m) a -> m ()
{-# INLINABLE shuffleK #-}
shuffleK :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, MVector v a) =>
Int -> v (PrimState m) a -> m ()
shuffleK Int
numberOfShuffles v (PrimState m) a
mutV = v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
mutV (Int
numberOfShuffles Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
_ Int
k | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
= [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot pass negative value to ShuffleK"
go v (PrimState m) a
_ Int
k | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV
= [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot pass value greater than the length of the vector to ShuffleK"
go v (PrimState m) a
_ Int
0 = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
v Int
maxInd =
do
Int
ind <- (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0, Int
maxInd)
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> Int -> m ()
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
maximalCycle
:: forall m a g v
. ( PrimMonad m
, RandomGen g
, MVector v a
)
=> v (PrimState m) a -> g -> m g
{-# INLINABLE maximalCycle #-}
maximalCycle :: forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, MVector v a) =>
v (PrimState m) a -> g -> m g
maximalCycle v (PrimState m) a
mutV g
gen = v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
mutV g
gen (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> g -> Int -> m g
{-# INLINE go #-}
go :: v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
_ g
g (- 1) = g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
_ g
g Int
0 = g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
v g
g Int
maxInd =
do
let (Int
ind, g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (Int
1, Int
maxInd) g
g
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> g -> Int -> m g
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) g
newGen (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
maximalCycleM
:: forall m a v
. ( PrimMonad m
, MonadRandom m
, MVector v a
)
=> v (PrimState m) a -> m ()
{-# INLINABLE maximalCycleM #-}
maximalCycleM :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, MVector v a) =>
v (PrimState m) a -> m ()
maximalCycleM v (PrimState m) a
mutV = v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
mutV (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
_ (- 1) = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
_ Int
0 = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
v Int
maxInd =
do
Int
ind <- (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
1, Int
maxInd)
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> Int -> m ()
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
derangement
:: forall m a g v
. ( PrimMonad m
, RandomGen g
, Eq a
, MVector v a
)
=> v (PrimState m) a -> g -> m g
{-# INLINABLE derangement #-}
derangement :: forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, Eq a, MVector v a) =>
v (PrimState m) a -> g -> m g
derangement v (PrimState m) a
mutV g
gen = do
v (PrimState m) a
mutV_copy <- v (PrimState m) a -> m (v (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone v (PrimState m) a
mutV
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
mutV_copy v (PrimState m) a
mutV g
gen Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
{-# INLINE go #-}
go :: v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
_ v (PrimState m) a
_ g
g Int
_ (- 1) = g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
c v (PrimState m) a
v g
g Int
lastInd Int
0 =
do
a
v_last_old <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
c Int
lastInd
a
v_last_new <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
v Int
0
if a
v_last_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_last_new then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
c v (PrimState m) a
mutV g
g Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
c v (PrimState m) a
v g
oldGen Int
currInd Int
maxInd =
do
let (Int
swapInd, g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (Int
0, Int
maxInd) g
oldGen
a
v_old <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
c Int
currInd
a
v_ind <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
v Int
swapInd
if a
v_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_ind then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
c v (PrimState m) a
mutV g
newGen Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
do
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
swapInd
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
c (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) g
newGen (Int
currInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
derangementM
:: forall m a v
. ( PrimMonad m
, MonadRandom m
, Eq a
, MVector v a
)
=> v (PrimState m) a -> m ()
{-# INLINABLE derangementM #-}
derangementM :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, MonadRandom m, Eq a, MVector v a) =>
v (PrimState m) a -> m ()
derangementM v (PrimState m) a
mutV = do
v (PrimState m) a
mutV_copy <- v (PrimState m) a -> m (v (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone v (PrimState m) a
mutV
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
mutV_copy v (PrimState m) a
mutV Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
_ v (PrimState m) a
_ Int
_ (- 1) = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
c v (PrimState m) a
v Int
lastInd Int
0 =
do
a
v_last_old <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
c Int
lastInd
a
v_last_new <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
v Int
0
if a
v_last_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_last_new then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
c v (PrimState m) a
mutV Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
c v (PrimState m) a
v Int
currInd Int
maxInd =
do
Int
swapInd :: Int <- (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0, Int
maxInd)
a
v_old <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
c Int
currInd
a
v_ind <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
v Int
swapInd
if a
v_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_ind then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
c v (PrimState m) a
mutV Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
do
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
swapInd
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
c (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
currInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)