{-# LANGUAGE FlexibleContexts #-}

-- | Internal module for sampling of cycle graph partitions.
-- Import 'RandomCycle.Vector' instead.
module RandomCycle.Vector.Cycle where

import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState, liftPrim)
import Data.STRef
import qualified Data.Vector as V
import System.Random.MWC.Distributions (uniformPermutation, uniformShuffleM)
import System.Random.Stateful

{- INTERNAL -}

-- | Internal. Helper for uniformCyclePartitionThin so as to avoid
-- re-allocating the input vector for each rejected sample.
-- IMPORTANT: Caller's responsibility to ensure proper
-- management of the 'chk' for match found.
uniformCyclePartitionThinM ::
  (StatefulGen g m, PrimMonad m) =>
  STRef (PrimState m) Bool ->
  STRef (PrimState m) Int ->
  ((Int, Int) -> Bool) ->
  V.MVector (PrimState m) Int ->
  g ->
  m ()
uniformCyclePartitionThinM :: forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
uniformCyclePartitionThinM STRef (PrimState m) Bool
chk STRef (PrimState m) Int
maxit (Int, Int) -> Bool
r MVector (PrimState m) Int
v g
gen = do
  Int
maxitVal <- forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Int
maxit

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
maxitVal forall a. Ord a => a -> a -> Bool
<= Int
0) (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

  forall g (m :: * -> *) (v :: * -> * -> *) a.
(StatefulGen g m, PrimMonad m, MVector v a) =>
v (PrimState m) a -> g -> m ()
uniformShuffleM MVector (PrimState m) Int
v g
gen
  -- TODO: Repeated calls to freeze, indexed
  -- a possible opportunity for optimization,
  -- e.g. with imap or a check that takes 'chk'
  -- reference and shortcircuits.
  Vector Int
vVal <- forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector (PrimState m) Int
v
  if forall a. (a -> Bool) -> Vector a -> Bool
V.all (Int, Int) -> Bool
r (forall a. Vector a -> Vector (Int, a)
V.indexed Vector Int
vVal)
    then do
      forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef (PrimState m) Bool
chk (forall a b. a -> b -> a
const Bool
True)
    else do
      forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef (PrimState m) Int
maxit (\Int
x -> Int
x forall a. Num a => a -> a -> a
- Int
1)
      forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
uniformCyclePartitionThinM STRef (PrimState m) Bool
chk STRef (PrimState m) Int
maxit (Int, Int) -> Bool
r MVector (PrimState m) Int
v g
gen
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

{- RANDOM -}

-- TODO: uniform (full) cycle with [Sattolo's algorithm](https://algo.inria.fr/seminars/summary/Wilson2004b.pdf)
-- uniformCycle

-- | Select a partition of '[0..n-1]' into disjoint
--  [cycle graphs](https://en.wikipedia.org/wiki/Cycle_graph),
--  uniformly over the \(n!\) possibilities. The sampler relies on the fact that such
--  partitions are isomorphic with the permutations of '[0..n-1]' via the map sending
--  a permutation \(\sigma\) to the edge set \(\{(i, \sigma(i))\}_0^{n-1}\). In other words,
--  the cycle partition graphs are isomorphic with the rotation matrices.
--
--  Therefore, this function simply calls 'uniformPermutation' and tuples the result with its
--  indices. The returned value is a vector of edges. \(O(n)\), since 'uniformPermutation'
--  implements the [Fisher-Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle).
--
--  'uniformPermutation' uses in-place mutation, so this function must be run in a 'PrimMonad'
--  context.
--
-- ==== __Examples__
--
-- >>> import System.Random.Stateful
-- >>> import RandomCycle.Vector
-- >>> import Data.Vector (Vector)
-- >>> runSTGen_ (mkStdGen 1305) $ RV.uniformCyclePartition 4 :: Vector (Int, Int)
-- [(0,1),(1,3),(2,2),(3,0)]
uniformCyclePartition ::
  (StatefulGen g m, PrimMonad m) =>
  Int ->
  g ->
  m (V.Vector (Int, Int))
uniformCyclePartition :: forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
Int -> g -> m (Vector (Int, Int))
uniformCyclePartition Int
n g
gen = forall a. Vector a -> Vector (Int, a)
V.indexed forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, PrimMonad m, Vector v Int) =>
Int -> g -> m (v Int)
uniformPermutation Int
n g
gen

-- TODO: apply short-circuiting behavior by creating  modification
-- of 'uniformSuffleM' that carries a validity state and short-circuits as
-- soon as some edge does not satisfy the predicate.
-- current implementation is the lazy one (as in human-lazy). note that would require
-- posting a notice in this module in accordance with the BSD2 license of mwc-random.

-- | Uniform selection of a cycle partition graph of '[0..n-1]' elements,
-- conditional on an edge-wise predicate. See 'uniformCyclePartition' for
-- details on the sampler.
--
-- /O(n\/p)/, where /p/ is the probability that a uniformly sampled
-- cycle partition graph (over all /n!/ possible) satisfies the conditions.
-- This can be highly non-linear since /p/ in general is a function of /n/.
--
-- Since this is a rejection sampling method, the user is asked to provide
-- a counter for the maximum number of sampling attempts in order to guarantee
-- termination in cases where the edge predicate has probability of success close
-- to zero.
--
-- Note this will return 'pure Nothing' if given a number of vertices that is
-- non-positive, in the third argument, unlike 'uniformCyclePartition' which
-- will throw an error.
uniformCyclePartitionThin ::
  (StatefulGen g m, PrimMonad m) =>
  Int ->
  ((Int, Int) -> Bool) ->
  Int ->
  g ->
  m (Maybe (V.Vector (Int, Int)))
uniformCyclePartitionThin :: forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
Int
-> ((Int, Int) -> Bool)
-> Int
-> g
-> m (Maybe (Vector (Int, Int)))
uniformCyclePartitionThin Int
maxit (Int, Int) -> Bool
_ Int
n g
_en | Int
maxit forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
|| Int
n forall a. Ord a => a -> a -> Bool
<= Int
0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
uniformCyclePartitionThin Int
maxit (Int, Int) -> Bool
r Int
n g
gen = do
  let v :: Vector Int
v = forall a. Int -> (Int -> a) -> Vector a
V.generate Int
n forall a. a -> a
id
  MVector (PrimState m) Int
mv <- forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector Int
v
  STRef (PrimState m) Bool
chk' <- forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim forall a b. (a -> b) -> a -> b
$ forall a s. a -> ST s (STRef s a)
newSTRef Bool
False
  STRef (PrimState m) Int
maxit' <- forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim forall a b. (a -> b) -> a -> b
$ forall a s. a -> ST s (STRef s a)
newSTRef Int
maxit

  forall g (m :: * -> *).
(StatefulGen g m, PrimMonad m) =>
STRef (PrimState m) Bool
-> STRef (PrimState m) Int
-> ((Int, Int) -> Bool)
-> MVector (PrimState m) Int
-> g
-> m ()
uniformCyclePartitionThinM STRef (PrimState m) Bool
chk' STRef (PrimState m) Int
maxit' (Int, Int) -> Bool
r MVector (PrimState m) Int
mv g
gen

  Bool
chk <- forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Bool
chk'
  if Bool
chk
    then do
      forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Vector a -> Vector (Int, a)
V.indexed forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector (PrimState m) Int
mv
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing