{-# LANGUAGE CPP #-}
module Test.Hspec.Core.Shuffle (
  shuffleForest
#ifdef TEST
, shuffle
, mkArray
#endif
) where

import           Prelude ()
import           Test.Hspec.Core.Compat
import           Test.Hspec.Core.Tree

import           System.Random
import           Control.Monad.ST
import           Data.STRef
import           Data.Array.ST

shuffleForest :: STRef st StdGen -> [Tree c a] -> ST st [Tree c a]
shuffleForest :: forall st c a. STRef st StdGen -> [Tree c a] -> ST st [Tree c a]
shuffleForest STRef st StdGen
ref [Tree c a]
xs = (STRef st StdGen -> [Tree c a] -> ST st [Tree c a]
forall st a. STRef st StdGen -> [a] -> ST st [a]
shuffle STRef st StdGen
ref [Tree c a]
xs ST st [Tree c a]
-> ([Tree c a] -> ST st [Tree c a]) -> ST st [Tree c a]
forall a b. ST st a -> (a -> ST st b) -> ST st b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Tree c a -> ST st (Tree c a)) -> [Tree c a] -> ST st [Tree c a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (STRef st StdGen -> Tree c a -> ST st (Tree c a)
forall st c a. STRef st StdGen -> Tree c a -> ST st (Tree c a)
shuffleTree STRef st StdGen
ref))

shuffleTree :: STRef st StdGen -> Tree c a -> ST st (Tree c a)
shuffleTree :: forall st c a. STRef st StdGen -> Tree c a -> ST st (Tree c a)
shuffleTree STRef st StdGen
ref Tree c a
t = case Tree c a
t of
  Node String
d [Tree c a]
xs -> String -> [Tree c a] -> Tree c a
forall c a. String -> [Tree c a] -> Tree c a
Node String
d ([Tree c a] -> Tree c a) -> ST st [Tree c a] -> ST st (Tree c a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STRef st StdGen -> [Tree c a] -> ST st [Tree c a]
forall st c a. STRef st StdGen -> [Tree c a] -> ST st [Tree c a]
shuffleForest STRef st StdGen
ref [Tree c a]
xs
  NodeWithCleanup Maybe (String, Location)
loc c
c [Tree c a]
xs -> Maybe (String, Location) -> c -> [Tree c a] -> Tree c a
forall c a. Maybe (String, Location) -> c -> [Tree c a] -> Tree c a
NodeWithCleanup Maybe (String, Location)
loc c
c ([Tree c a] -> Tree c a) -> ST st [Tree c a] -> ST st (Tree c a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STRef st StdGen -> [Tree c a] -> ST st [Tree c a]
forall st c a. STRef st StdGen -> [Tree c a] -> ST st [Tree c a]
shuffleForest STRef st StdGen
ref [Tree c a]
xs
  Leaf {} -> Tree c a -> ST st (Tree c a)
forall a. a -> ST st a
forall (m :: * -> *) a. Monad m => a -> m a
return Tree c a
t

shuffle :: STRef st StdGen -> [a] -> ST st [a]
shuffle :: forall st a. STRef st StdGen -> [a] -> ST st [a]
shuffle STRef st StdGen
ref [a]
xs = do
  STArray st Int a
arr <- [a] -> ST st (STArray st Int a)
forall a st. [a] -> ST st (STArray st Int a)
mkArray [a]
xs
  bounds :: (Int, Int)
bounds@(Int
_, Int
n) <- STArray st Int a -> ST st (Int, Int)
forall i. Ix i => STArray st i a -> ST st (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds STArray st Int a
arr
  [Int] -> (Int -> ST st a) -> ST st [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((Int, Int) -> [Int]
forall a. Ix a => (a, a) -> [a]
range (Int, Int)
bounds) ((Int -> ST st a) -> ST st [a]) -> (Int -> ST st a) -> ST st [a]
forall a b. (a -> b) -> a -> b
$ \ Int
i -> do
    Int
j <- (Int, Int) -> ST st Int
forall {b}. Random b => (b, b) -> ST st b
randomIndex (Int
i, Int
n)
    a
vi <- STArray st Int a -> Int -> ST st a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray st Int a
arr Int
i
    a
vj <- STArray st Int a -> Int -> ST st a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray st Int a
arr Int
j
    STArray st Int a -> Int -> a -> ST st ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray st Int a
arr Int
j a
vi
    a -> ST st a
forall a. a -> ST st a
forall (m :: * -> *) a. Monad m => a -> m a
return a
vj
  where
    randomIndex :: (b, b) -> ST st b
randomIndex (b, b)
bounds = do
      (b
a, StdGen
gen) <- (b, b) -> StdGen -> (b, StdGen)
forall g. RandomGen g => (b, b) -> g -> (b, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (b, b)
bounds (StdGen -> (b, StdGen)) -> ST st StdGen -> ST st (b, StdGen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STRef st StdGen -> ST st StdGen
forall s a. STRef s a -> ST s a
readSTRef STRef st StdGen
ref
      STRef st StdGen -> StdGen -> ST st ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef st StdGen
ref StdGen
gen
      b -> ST st b
forall a. a -> ST st a
forall (m :: * -> *) a. Monad m => a -> m a
return b
a

mkArray :: [a] -> ST st (STArray st Int a)
mkArray :: forall a st. [a] -> ST st (STArray st Int a)
mkArray [a]
xs = (Int, Int) -> [a] -> ST st (STArray st Int a)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray (Int
1, [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs) [a]
xs