{-# LANGUAGE FlexibleContexts #-}

-- | Internal module for sampling of cycle graph partitions.
-- Import 'RandomCycle.Vector' instead.
module RandomCycle.Vector.Cycle where

import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState, liftPrim)
import Data.STRef
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import System.Random.MWC.Distributions (uniformPermutation, uniformShuffleM)
import System.Random.Stateful

{- INTERNAL -}

-- | Internal. Helper for uniformCyclePartitionThin so as to avoid
-- re-allocating the input vector for each rejected sample.
-- IMPORTANT: Caller's responsibility to ensure proper
-- management of the 'chk' for match found.
uniformCyclePartitionThinM ::
  (StatefulGen g m, PrimMonad m) =>
  STRef (PrimState m) Bool ->
  STRef (PrimState m) Int ->
  ((Int, Int) -> Bool) ->
  V.MVector (PrimState m) Int ->
  g ->
  m ()
uniformCyclePartitionThinM :: STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
uniformCyclePartitionThinM STRef (PrimState m) Bool
chk STRef (PrimState m) Int
maxit (Int, Int) -> Bool
r MVector (PrimState m) Int
v g
gen = do
  Int
maxitVal <- ST (PrimState m) Int -> m Int
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim (ST (PrimState m) Int -> m Int) -> ST (PrimState m) Int -> m Int
forall a b. (a -> b) -> a -> b
$ STRef (PrimState m) Int -> ST (PrimState m) Int
forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Int
maxit

  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
maxitVal Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) (() -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

  MVector (PrimState m) Int -> g -> m ()
forall g (m :: * -> *) (v :: * -> * -> *) a.
(StatefulGen g m, PrimMonad m, MVector v a) =>
v (PrimState m) a -> g -> m ()
uniformShuffleM MVector (PrimState m) Int
v g
gen
  -- TODO: Repeated calls to freeze, indexed
  -- a possible opportunity for optimization,
  -- e.g. with imap or a check that takes 'chk'
  -- reference and shortcircuits.
  Vector Int
vVal <- MVector (PrimState m) Int -> m (Vector Int)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector (PrimState m) Int
v
  if ((Int, Int) -> Bool) -> Vector (Int, Int) -> Bool
forall a. (a -> Bool) -> Vector a -> Bool
V.all (Int, Int) -> Bool
r (Vector Int -> Vector (Int, Int)
forall a. Vector a -> Vector (Int, a)
V.indexed Vector Int
vVal)
    then do
      ST (PrimState m) () -> m ()
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ STRef (PrimState m) Bool -> (Bool -> Bool) -> ST (PrimState m) ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef (PrimState m) Bool
chk (Bool -> Bool -> Bool
forall a b. a -> b -> a
const Bool
True)
    else do
      ST (PrimState m) () -> m ()
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ STRef (PrimState m) Int -> (Int -> Int) -> ST (PrimState m) ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef (PrimState m) Int
maxit (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
      STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
uniformCyclePartitionThinM STRef (PrimState m) Bool
chk STRef (PrimState m) Int
maxit (Int, Int) -> Bool
r MVector (PrimState m) Int
v g
gen
  () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Internal. Helper for'uniformCycle'. Caller's responsibility to ensure
-- first input is the length of the vector, and that i ranges over @[0..n-2]@.
-- At the moment, it favors 'MV.swap' over 'MV.unsafeSwap' to maintain bounds
-- checks.
swapIt ::
  (StatefulGen g m, PrimMonad m) =>
  -- | Vector length.
  Int ->
  -- | Vector to modify.
  MV.MVector (PrimState m) Int ->
  -- | Generator.
  g ->
  -- | Index.
  Int ->
  m ()
swapIt :: Int -> MVector (PrimState m) Int -> g -> Int -> m ()
swapIt Int
n MVector (PrimState m) Int
mv g
g Int
i = do
  let m :: Int
m = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i
  Int
j <- (Int, Int) -> g -> m Int
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
0, Int
m) g
g
  MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector (PrimState m) Int
mv Int
j (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

{- RANDOM -}

-- | Implements [Sattolo's algorithm](https://algo.inria.fr/seminars/summary/Wilson2004b.pdf)
-- to sample a full cycle permutation uniformly over /(n-1)!/ possibilities in /O(n)/ time.
-- The algorithm is nearly identical to the Fisher-Yates shuffle on @[0..n-1]@, and therefore
-- this implementation is very similar to that of 'uniformPermutation'.
--
-- This will throw an exception with syntax analogous to 'uniformPermutation'
-- if the provided size is negative.
--
-- ==== __Examples__
--
-- >>> import System.Random.Stateful
-- >>> import RandomCycle.Vector
-- >>> runSTGen_ (mkStdGen 1901) $ uniformCycle 4
-- [(0,3),(1,0),(2,1),(3,2)]
uniformCycle ::
  (StatefulGen g m, PrimMonad m) =>
  -- | Size /n/ of cycle.
  Int ->
  g ->
  m (V.Vector (Int, Int))
uniformCycle :: Int -> g -> m (Vector (Int, Int))
uniformCycle Int
n g
_ | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> m (Vector (Int, Int))
forall a. HasCallStack => [Char] -> a
error [Char]
"RandomCycle.Vector.Cycle: size must be >= 0"
uniformCycle Int
n g
gen = do
  MVector (PrimState m) Int
mv <- Int -> (Int -> m Int) -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> (Int -> m a) -> m (MVector (PrimState m) a)
MV.generateM Int
n Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  (Int -> m ()) -> [Int] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> MVector (PrimState m) Int -> g -> Int -> m ()
forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
Int -> MVector (PrimState m) Int -> g -> Int -> m ()
swapIt Int
n MVector (PrimState m) Int
mv g
gen) [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2]
  Vector Int -> Vector (Int, Int)
forall a. Vector a -> Vector (Int, a)
V.indexed (Vector Int -> Vector (Int, Int))
-> m (Vector Int) -> m (Vector (Int, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> m (Vector Int)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector (PrimState m) Int
mv

-- | Select a partition of @[0..n-1]@ into disjoint
--  [cycle graphs](https://en.wikipedia.org/wiki/Cycle_graph),
--  uniformly over the /n!/ possibilities. The sampler relies on the fact that such
--  partitions are isomorphic with the permutations of @[0..n-1]@ via the map sending
--  a permutation /\sigma/ to the edge set /\{(i, \sigma(i))\}_0^{n-1}/.
--
--  Therefore, this function simply calls 'uniformPermutation' and tuples the result with its
--  indices. The returned value is a vector of edges. /O(n)/, since 'uniformPermutation'
--  implements the [Fisher-Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle).
--
--  'uniformPermutation' uses in-place mutation, so this function must be run in a 'PrimMonad'
--  context.
--
-- ==== __Examples__
--
-- >>> import System.Random.Stateful
-- >>> import qualified RandomCycle.Vector as RV
-- >>> import Data.Vector (Vector)
-- >>> runSTGen_ (mkStdGen 1305) $ RV.uniformCyclePartition 4 :: Vector (Int, Int)
-- [(0,1),(1,3),(2,2),(3,0)]
uniformCyclePartition ::
  (StatefulGen g m, PrimMonad m) =>
  Int ->
  g ->
  m (V.Vector (Int, Int))
uniformCyclePartition :: Int -> g -> m (Vector (Int, Int))
uniformCyclePartition Int
n g
gen = Vector Int -> Vector (Int, Int)
forall a. Vector a -> Vector (Int, a)
V.indexed (Vector Int -> Vector (Int, Int))
-> m (Vector Int) -> m (Vector (Int, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> g -> m (Vector Int)
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, PrimMonad m, Vector v Int) =>
Int -> g -> m (v Int)
uniformPermutation Int
n g
gen

-- TODO: apply short-circuiting behavior by creating  modification
-- of 'uniformSuffleM' that carries a validity state and short-circuits as
-- soon as some edge does not satisfy the predicate.
-- current implementation is the lazy one (as in human-lazy). note that would require
-- posting a notice in this module in accordance with the BSD2 license of mwc-random.

-- | Uniform selection of a cycle partition graph of @[0..n-1]@ elements,
-- conditional on an edge-wise predicate. See 'uniformCyclePartition' for
-- details on the sampler.
--
-- /O(n\/p)/, where /p/ is the probability that a uniformly sampled
-- cycle partition graph (over all /n!/ possible) satisfies the conditions.
-- This can be highly non-linear since /p/ in general is a function of /n/.
--
-- Since this is a rejection sampling method, the user is asked to provide
-- a counter for the maximum number of sampling attempts in order to guarantee
-- termination in cases where the edge predicate has probability of success close
-- to zero.
--
-- Note this will return @pure Nothing@ if given a number of vertices that is
-- non-positive, in the third argument, unlike 'uniformCyclePartition' which
-- will throw an error.
--
-- ==== __Examples__
--
-- >>> import System.Random.Stateful
-- >>> import qualified RandomCycle.Vector as RV
-- >>> import Data.Vector (Vector)
-- >>> -- No self-loops
-- >>> rule = uncurry (/=)
-- >>> n = 5
-- >>> maxit = n * 1000
-- >>> runSTGen_ (mkStdGen 3) $ RV.uniformCyclePartitionThin maxit rule n
-- Just [(0,2),(1,3),(2,0),(3,4),(4,1)]
uniformCyclePartitionThin ::
  (StatefulGen g m, PrimMonad m) =>
  -- | maximum number of draws to attempt
  Int ->
  -- | edge-wise predicate, which all edges in the result must satisfy
  ((Int, Int) -> Bool) ->
  -- | number of vertices, which will be labeled @[0..n-1]@
  Int ->
  g ->
  m (Maybe (V.Vector (Int, Int)))
uniformCyclePartitionThin :: Int
-> ((Int, Int) -> Bool)
-> Int
-> g
-> m (Maybe (Vector (Int, Int)))
uniformCyclePartitionThin Int
maxit (Int, Int) -> Bool
_ Int
n g
_en | Int
maxit Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
|| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Maybe (Vector (Int, Int)) -> m (Maybe (Vector (Int, Int)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Vector (Int, Int))
forall a. Maybe a
Nothing
uniformCyclePartitionThin Int
maxit (Int, Int) -> Bool
r Int
n g
gen = do
  let v :: Vector Int
v = Int -> (Int -> Int) -> Vector Int
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
n Int -> Int
forall a. a -> a
id
  MVector (PrimState m) Int
mv <- Vector Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector Int
v
  STRef (PrimState m) Bool
chk' <- ST (PrimState m) (STRef (PrimState m) Bool)
-> m (STRef (PrimState m) Bool)
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim (ST (PrimState m) (STRef (PrimState m) Bool)
 -> m (STRef (PrimState m) Bool))
-> ST (PrimState m) (STRef (PrimState m) Bool)
-> m (STRef (PrimState m) Bool)
forall a b. (a -> b) -> a -> b
$ Bool -> ST (PrimState m) (STRef (PrimState m) Bool)
forall a s. a -> ST s (STRef s a)
newSTRef Bool
False
  STRef (PrimState m) Int
maxit' <- ST (PrimState m) (STRef (PrimState m) Int)
-> m (STRef (PrimState m) Int)
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim (ST (PrimState m) (STRef (PrimState m) Int)
 -> m (STRef (PrimState m) Int))
-> ST (PrimState m) (STRef (PrimState m) Int)
-> m (STRef (PrimState m) Int)
forall a b. (a -> b) -> a -> b
$ Int -> ST (PrimState m) (STRef (PrimState m) Int)
forall a s. a -> ST s (STRef s a)
newSTRef Int
maxit

  STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
uniformCyclePartitionThinM STRef (PrimState m) Bool
chk' STRef (PrimState m) Int
maxit' (Int, Int) -> Bool
r MVector (PrimState m) Int
mv g
gen

  Bool
chk <- ST (PrimState m) Bool -> m Bool
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim (ST (PrimState m) Bool -> m Bool)
-> ST (PrimState m) Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ STRef (PrimState m) Bool -> ST (PrimState m) Bool
forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Bool
chk'
  if Bool
chk
    then do
      Vector (Int, Int) -> Maybe (Vector (Int, Int))
forall a. a -> Maybe a
Just (Vector (Int, Int) -> Maybe (Vector (Int, Int)))
-> (Vector Int -> Vector (Int, Int))
-> Vector Int
-> Maybe (Vector (Int, Int))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int -> Vector (Int, Int)
forall a. Vector a -> Vector (Int, a)
V.indexed (Vector Int -> Maybe (Vector (Int, Int)))
-> m (Vector Int) -> m (Maybe (Vector (Int, Int)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> m (Vector Int)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector (PrimState m) Int
mv
    else Maybe (Vector (Int, Int)) -> m (Maybe (Vector (Int, Int)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Vector (Int, Int))
forall a. Maybe a
Nothing