{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
module Bench.Vector.Algo.NextPermutation (generatePermTests) where

import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as M
import qualified Data.Vector.Generic.Mutable as G
import System.Random.Stateful
    ( StatefulGen, UniformRange(uniformRM) )

-- | Generate a list of benchmarks for permutation algorithms.
-- The list contains pairs of benchmark names and corresponding actions.
-- The actions are to be executed by the benchmarking framework.
-- 
-- The list contains the following benchmarks:
-- - @(next|prev)Permutation@ on a small vector repeated until the end of the permutation cycle
-- - Bijective versions of @(next|prev)Permutation@ on a vector of size @n@, repeated @n@ times
--  - ascending permutation
--  - descending permutation
--  - random permutation
-- - Baseline for bijective versions: just copying a vector of size @n@. Note that the tests for
--   bijective versions begins with copying a vector.
generatePermTests :: StatefulGen g IO => g -> Int -> IO [(String, IO ())]
generatePermTests :: forall g. StatefulGen g IO => g -> Int -> IO [(String, IO ())]
generatePermTests g
gen Int
useSize = do
  let !k :: Int
k = Int -> Int
useSizeToPermLen Int
useSize
  let !vasc :: Vector Int
vasc = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
V.generate Int
useSize Int -> Int
forall a. a -> a
id
      !vdesc :: Vector Int
vdesc = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
V.generate Int
useSize (Int
useSizeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
-)
  !Vector Int
vrnd <- g -> Int -> IO (Vector Int)
forall g. StatefulGen g IO => g -> Int -> IO (Vector Int)
randomPermutationWith g
gen Int
useSize
  [(String, IO ())] -> IO [(String, IO ())]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ (String
"nextPermutation (small vector, until end)", Int -> IO ()
loopPermutations Int
k)
    , (String
"nextPermutationBijective (ascending perm of size n, n times)", Vector Int -> Int -> IO ()
repeatNextPermutation Vector Int
vasc Int
useSize)
    , (String
"nextPermutationBijective (descending perm of size n, n times)", Vector Int -> Int -> IO ()
repeatNextPermutation Vector Int
vdesc Int
useSize)
    , (String
"nextPermutationBijective (random perm of size n, n times)", Vector Int -> Int -> IO ()
repeatNextPermutation Vector Int
vrnd Int
useSize)
    , (String
"prevPermutation (small vector, until end)", Int -> IO ()
loopRevPermutations Int
k)
    , (String
"prevPermutationBijective (ascending perm of size n, n times)", Vector Int -> Int -> IO ()
repeatPrevPermutation Vector Int
vasc Int
useSize)
    , (String
"prevPermutationBijective (descending perm of size n, n times)", Vector Int -> Int -> IO ()
repeatPrevPermutation Vector Int
vdesc Int
useSize)
    , (String
"prevPermutationBijective (random perm of size n, n times)", Vector Int -> Int -> IO ()
repeatPrevPermutation Vector Int
vrnd Int
useSize)
    , (String
"baseline for *Bijective (just copying the vector of size n)", Vector Int -> IO (MVector (PrimState IO) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector Int
vrnd IO (MVector RealWorld Int) -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    ]

-- | Given a PRNG and a length @n@, generate a random permutation of @[0..n-1]@.
randomPermutationWith :: (StatefulGen g IO) => g -> Int -> IO (V.Vector Int)
randomPermutationWith :: forall g. StatefulGen g IO => g -> Int -> IO (Vector Int)
randomPermutationWith g
gen Int
n = do
  MVector RealWorld Int
v <- Int -> (Int -> Int) -> IO (MVector (PrimState IO) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> (Int -> a) -> m (MVector (PrimState m) a)
M.generate Int
n Int -> Int
forall a. a -> a
id
  Vector Int -> (Int -> IO ()) -> IO ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
V.forM_ (Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
V.generate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int
forall a. a -> a
id) ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ !Int
i -> do
    Int
j <- (Int, Int) -> g -> IO Int
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (Int, Int) -> g -> m Int
uniformRM (Int
i,Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) g
gen
    MVector (PrimState IO) Int -> Int -> Int -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> Int -> m ()
M.swap MVector RealWorld Int
MVector (PrimState IO) Int
v Int
i Int
j
  MVector (PrimState IO) Int -> IO (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector RealWorld Int
MVector (PrimState IO) Int
v

-- | Given @useSize@ benchmark option, compute the largest @n <= 12@ such that @n! <= useSize@.
-- Repeat-nextPermutation-until-end benchmark will use @n@ as the length of the vector.
-- Note that 12 is the largest @n@ such that @n!@ can be represented as an 'Int32'.
useSizeToPermLen :: Int -> Int
useSizeToPermLen :: Int -> Int
useSizeToPermLen Int
us = case (Int -> Bool) -> Vector Int -> Maybe Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Maybe Int
V.findIndex (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 Int
us) (Vector Int -> Maybe Int) -> Vector Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
V.scanl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
V.generate Int
12 (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) of
    Just Int
i -> Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1
    Maybe Int
Nothing -> Int
12

-- | A bijective version of @G.nextPermutation@ that reverses the vector
-- if it is already in descending order.
-- "Bijective" here means that the function forms a cycle over all permutations
-- of the vector's elements.
--
-- This has a nice property that should be benchmarked: 
-- this function takes amortized constant time each call,
-- if successively called either Omega(n) times on a single vector having distinct elements,
-- or arbitrary times on a single vector initially in strictly ascending order.
nextPermutationBijective :: (G.MVector v a, Ord a) => v G.RealWorld a -> IO Bool
nextPermutationBijective :: forall (v :: * -> * -> *) a.
(MVector v a, Ord a) =>
v RealWorld a -> IO Bool
nextPermutationBijective v RealWorld a
v = do
  Bool
res <- v (PrimState IO) a -> IO Bool
forall (m :: * -> *) e (v :: * -> * -> *).
(PrimMonad m, Ord e, MVector v e) =>
v (PrimState m) e -> m Bool
G.nextPermutation v RealWorld a
v (PrimState IO) a
v
  if Bool
res then Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True else v (PrimState IO) a -> IO ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m ()
G.reverse v RealWorld a
v (PrimState IO) a
v IO () -> IO Bool -> IO Bool
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | A bijective version of @G.prevPermutation@ that reverses the vector
-- if it is already in ascending order.
-- "Bijective" here means that the function forms a cycle over all permutations
-- of the vector's elements.
--
-- This has a nice property that should be benchmarked:
-- this function takes amortized constant time each call,
-- if successively called either Omega(n) times on a single vector having distinct elements,
-- or arbitrary times on a single vector initially in strictly descending order.
prevPermutationBijective :: (G.MVector v a, Ord a) => v G.RealWorld a -> IO Bool
prevPermutationBijective :: forall (v :: * -> * -> *) a.
(MVector v a, Ord a) =>
v RealWorld a -> IO Bool
prevPermutationBijective v RealWorld a
v = do
  Bool
res <- v (PrimState IO) a -> IO Bool
forall (m :: * -> *) e (v :: * -> * -> *).
(PrimMonad m, Ord e, MVector v e) =>
v (PrimState m) e -> m Bool
G.prevPermutation v RealWorld a
v (PrimState IO) a
v
  if Bool
res then Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True else v (PrimState IO) a -> IO ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m ()
G.reverse v RealWorld a
v (PrimState IO) a
v IO () -> IO Bool -> IO Bool
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Repeat @nextPermutation@ on @[0..n-1]@ until the end.
loopPermutations :: Int -> IO ()
loopPermutations :: Int -> IO ()
loopPermutations Int
n = do
  MVector (PrimState IO) Int
v <- Int -> (Int -> Int) -> IO (MVector (PrimState IO) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> (Int -> a) -> m (MVector (PrimState m) a)
M.generate Int
n Int -> Int
forall a. a -> a
id
  let loop :: IO ()
loop = do
        Bool
res <- MVector (PrimState IO) Int -> IO Bool
forall (m :: * -> *) e.
(PrimMonad m, Ord e, Unbox e) =>
MVector (PrimState m) e -> m Bool
M.nextPermutation MVector (PrimState IO) Int
v
        if Bool
res then IO ()
loop else () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  IO ()
loop

-- | Repeat @prevPermutation@ on @[n-1,n-2..0]@ until the end.
loopRevPermutations :: Int -> IO ()
loopRevPermutations :: Int -> IO ()
loopRevPermutations Int
n = do
  MVector (PrimState IO) Int
v <- Int -> (Int -> Int) -> IO (MVector (PrimState IO) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> (Int -> a) -> m (MVector (PrimState m) a)
M.generate Int
n (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
-)
  let loop :: IO ()
loop = do
        Bool
res <- MVector (PrimState IO) Int -> IO Bool
forall (m :: * -> *) e.
(PrimMonad m, Ord e, Unbox e) =>
MVector (PrimState m) e -> m Bool
M.prevPermutation MVector (PrimState IO) Int
v
        if Bool
res then IO ()
loop else () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  IO ()
loop

-- | Repeat @nextPermutationBijective@ on a given vector given times.
repeatNextPermutation :: V.Vector Int -> Int -> IO ()
repeatNextPermutation :: Vector Int -> Int -> IO ()
repeatNextPermutation !Vector Int
v !Int
n = do
  !MVector RealWorld Int
mv <- Vector Int -> IO (MVector (PrimState IO) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector Int
v
  let loop :: t -> IO ()
loop !t
i | t
i t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
0 = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      loop !t
i = do
        Bool
_ <- MVector RealWorld Int -> IO Bool
forall (v :: * -> * -> *) a.
(MVector v a, Ord a) =>
v RealWorld a -> IO Bool
nextPermutationBijective MVector RealWorld Int
mv
        t -> IO ()
loop (t
it -> t -> t
forall a. Num a => a -> a -> a
-t
1)
  Int -> IO ()
forall {t}. (Num t, Ord t) => t -> IO ()
loop Int
n

-- | Repeat @prevPermutationBijective@ on a given vector given times.
repeatPrevPermutation :: V.Vector Int -> Int -> IO ()
repeatPrevPermutation :: Vector Int -> Int -> IO ()
repeatPrevPermutation !Vector Int
v !Int
n = do
  !MVector RealWorld Int
mv <- Vector Int -> IO (MVector (PrimState IO) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector Int
v
  let loop :: t -> IO ()
loop !t
i | t
i t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
0 = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      loop !t
i = do
        Bool
_ <- MVector RealWorld Int -> IO Bool
forall (v :: * -> * -> *) a.
(MVector v a, Ord a) =>
v RealWorld a -> IO Bool
prevPermutationBijective MVector RealWorld Int
mv
        t -> IO ()
loop (t
it -> t -> t
forall a. Num a => a -> a -> a
-t
1)
  Int -> IO ()
forall {t}. (Num t, Ord t) => t -> IO ()
loop Int
n