-- | Internal module whose primary exports are 'uniformPartition'
-- and 'uniformPartitionThin'. Import @RandomCycle.List@ instead.
module RandomCycle.List.Partition where

import Control.Monad (guard)
import Data.Bits
import GHC.Natural (Natural)
import System.Random.Stateful

{- UTILITIES -}

-- | Internal. Version of @Data.List.'span'@ that uses the supplied bits @bs@
-- as a grouping variable. 'switch` flips the booleans, so that the input bit's
-- least significant digit determines the grouping. Note the case @bs == 0@ is
-- not handled specially here, since termination is guaranteed whenever 'xs' is
-- finite. Compare to @RandomCycle.Vector.Partitions.'commonSubseqBits'@.
--
-- Note this would be simpler to implement with `countTrailingZeros`, but that
-- would limit the input list to some length, e.g. 64 if using @Word@, which is
-- too restrictive.
spanBits :: (Bool -> Bool) -> Natural -> [a] -> ([a], (Natural, [a]))
spanBits :: (Bool -> Bool) -> Natural -> [a] -> ([a], (Natural, [a]))
spanBits Bool -> Bool
_ Natural
bs xs :: [a]
xs@[] = ([a]
xs, (Natural
bs, [a]
xs))
spanBits Bool -> Bool
switch Natural
bs (a
x : [a]
xs)
  | Bool -> Bool
switch (Natural
bs Natural -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
0) = let ([a]
zs, (Natural
bs', [a]
zzs)) = (Bool -> Bool) -> Natural -> [a] -> ([a], (Natural, [a]))
forall a. (Bool -> Bool) -> Natural -> [a] -> ([a], (Natural, [a]))
spanBits Bool -> Bool
switch (Natural
bs Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) [a]
xs in (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
zs, (Natural
bs', [a]
zzs))
  | Bool
otherwise = ([], (Natural
bs, a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs))

-- | Utility to generate a list partition using the provided 'Natural'
-- as grouping variable, viewed as 'Bits'. The choice of grouping variable is to
-- improve performance since the number of partitions grows exponentially in the
-- input list length.
--
-- This can be used to generate a list of all possible partitions of the input list
-- as shown in the example. See 'RandomCycle.Vector.partitionFromBits' for other examples.
--
-- >>> import GHC.Natural
-- >>> :{
-- >>> allPartitions n | n < 0 = []
-- >>> allPartitions n = map (`partitionFromBits` [0..n]) [0::Natural .. 2^(n-1) - 1]
-- >>> }
-- >>> allPartitions 4
-- [[[0,1,2,3,4]],[[0],[1,2,3,4]],[[0],[1],[2,3,4]],
-- [[0,1],[2,3,4]],[[0,1],[2],[3,4]],[[0],[1],[2],[3,4]],
-- [[0],[1,2],[3,4]],[[0,1,2],[3,4]]]
partitionFromBits :: Natural -> [a] -> [[a]]
partitionFromBits :: Natural -> [a] -> [[a]]
partitionFromBits Natural
_ [] = []
partitionFromBits Natural
bs [a]
xs =
  -- NOTE: Grouping is determined by the first bit. This is important for
  -- correctness of grouping based on spanBits implementation, but also to
  -- ensure uniformPartition is uniform over 2^(n-1) partitions.
  let switch :: Bool -> Bool
switch = if Natural
bs Natural -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
0 then Bool -> Bool
forall a. a -> a
id else Bool -> Bool
not
      ([a]
ys, (Natural
bs', [a]
yss)) = (Bool -> Bool) -> Natural -> [a] -> ([a], (Natural, [a]))
forall a. (Bool -> Bool) -> Natural -> [a] -> ([a], (Natural, [a]))
spanBits Bool -> Bool
switch Natural
bs [a]
xs
   in [a]
ys [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: Natural -> [a] -> [[a]]
forall a. Natural -> [a] -> [[a]]
partitionFromBits Natural
bs' [a]
yss

-- | Primarily a testing utility, to compute directly the lengths of each
-- partition element for a list of size 'n', using 'countTrailingZeros'. Note
-- this uses 'Word'.
partitionLengths :: Word -> Int -> [Int]
partitionLengths :: Word -> Int -> [Int]
partitionLengths Word
bs = Word -> Int -> Int -> [Int]
forall b. FiniteBits b => b -> Int -> Int -> [Int]
op Word
bs (Word -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros Word
bs)
  where
    op :: b -> Int -> Int -> [Int]
op b
b Int
0 Int
m = let b' :: b
b' = b -> b
forall a. Bits a => a -> a
complement b
b in b -> Int -> Int -> [Int]
op b
b' (b -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros b
b') Int
m
    op b
b Int
z Int
m =
      if Int
z Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m
        then [Int
m | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0]
        else
          let b' :: b
b' = b
b b -> Int -> b
forall a. Bits a => a -> Int -> a
`shiftR` Int
z
           in Int
z Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: b -> Int -> Int -> [Int]
op b
b' (b -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros b
b') (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
z)

{- PARTITIONING WITH THINNING -}

-- | Internal. Partition a list as determined by bits @bs@, but shortcircuit if
-- the local condition 'r' is false for some partition element.
partitionFromBitsThin :: ([a] -> Bool) -> Natural -> [a] -> Maybe [[a]]
partitionFromBitsThin :: ([a] -> Bool) -> Natural -> [a] -> Maybe [[a]]
partitionFromBitsThin [a] -> Bool
_ Natural
_ [] = [[a]] -> Maybe [[a]]
forall a. a -> Maybe a
Just []
partitionFromBitsThin [a] -> Bool
r Natural
bs [a]
xs =
  let ps :: [[a]]
ps = Natural -> [a] -> [[a]]
forall a. Natural -> [a] -> [[a]]
partitionFromBits Natural
bs [a]
xs
   in Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (([a] -> Bool) -> [[a]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all [a] -> Bool
r [[a]]
ps) Maybe () -> Maybe [[a]] -> Maybe [[a]]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [[a]] -> Maybe [[a]]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [[a]]
ps

-- | Internal. Inner logic of 'uniformPartitionThin' that carries around
-- the input list length to avoid recomputation. It is the callers job to
-- ensure @n == length xs@.
uniformPartitionThinN ::
  (StatefulGen g m) =>
  Int ->
  Int ->
  ([a] -> Bool) ->
  [a] ->
  g ->
  m (Maybe [[a]])
uniformPartitionThinN :: Int -> Int -> ([a] -> Bool) -> [a] -> g -> m (Maybe [[a]])
uniformPartitionThinN Int
maxit Int
_ [a] -> Bool
_ [a]
_ g
_ | Int
maxit Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Maybe [[a]] -> m (Maybe [[a]])
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe [[a]]
forall a. Maybe a
Nothing
uniformPartitionThinN Int
maxit Int
n [a] -> Bool
r [a]
xs g
g = do
  Natural
bs <- (Natural, Natural) -> g -> m Natural
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Natural
0, Natural
2 Natural -> Int -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1) g
g
  case ([a] -> Bool) -> Natural -> [a] -> Maybe [[a]]
forall a. ([a] -> Bool) -> Natural -> [a] -> Maybe [[a]]
partitionFromBitsThin [a] -> Bool
r Natural
bs [a]
xs of
    Maybe [[a]]
Nothing -> Int -> Int -> ([a] -> Bool) -> [a] -> g -> m (Maybe [[a]])
forall g (m :: * -> *) a.
StatefulGen g m =>
Int -> Int -> ([a] -> Bool) -> [a] -> g -> m (Maybe [[a]])
uniformPartitionThinN (Int
maxit Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
n [a] -> Bool
r [a]
xs g
g
    Just [[a]]
ys -> Maybe [[a]] -> m (Maybe [[a]])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [[a]] -> m (Maybe [[a]])) -> Maybe [[a]] -> m (Maybe [[a]])
forall a b. (a -> b) -> a -> b
$ [[a]] -> Maybe [[a]]
forall a. a -> Maybe a
Just [[a]]
ys

{- RANDOM -}

-- | Draw a random partition of the input list 'xs' from the uniform
-- distribution on partitions. This proceeds by randomizing the placement of
-- each breakpoint, in other words by walking a random path in a perfect binary
-- tree. /O(n)/ for a vector length /n/.
--
-- This function preserves the order of the input list.
--
-- ==== __Examples__
--
-- >>> import System.Random.Stateful
-- >>> pureGen = mkStdGen 0
-- >>> runStateGen_ pureGen $ uniformPartition [1..5::Int]
-- [[1,2,3],[4],[5]]
-- >>> runStateGen_ pureGen $ uniformPartition ([] :: [Int])
-- []
uniformPartition :: (StatefulGen g m) => [a] -> g -> m [[a]]
uniformPartition :: [a] -> g -> m [[a]]
uniformPartition [a]
xs g
g = do
  let d :: Int
d = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs
  -- Drawing w.p. 1/2^d, but first bit determines grouping.
  Natural
bs <- (Natural, Natural) -> g -> m Natural
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Natural
0, Natural
2 Natural -> Int -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
d Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1) g
g
  [[a]] -> m [[a]]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([[a]] -> m [[a]]) -> [[a]] -> m [[a]]
forall a b. (a -> b) -> a -> b
$ Natural -> [a] -> [[a]]
forall a. Natural -> [a] -> [[a]]
partitionFromBits Natural
bs [a]
xs

-- TODO: be more precise in the statement below about exponential growth in the
-- length condition case.

-- | Generate a partition with a local condition @r@ on each partition element.
-- Construction of a partition shortcircuits to failure as soon as the local
-- condition is false.
--
-- Since this is a rejection sampling method, the user is asked to provide
-- a counter for the maximum number of sampling attempts in order to guarantee
-- termination in cases where the edge predicate has probability of success close
-- to zero.
--
-- Run time on average is /O(n\/p)/ where /p/ is the probability @all r yss
-- == True@ for a uniformly generated partition @yss@, assuming @r@ has run
-- time linear in the length of its argument. This can be highly non-linear
-- because /p/ in general is a function of /n/.
--
-- Some cases can perhaps be deceptively expensive: For example, the condition ((>=
-- 2) .  length) leads to huge runtimes, since the number of partitions with at
-- least one element of length 1 is exponential in /n/.
--
-- ==== __Examples__
--
-- >>> import System.Random.Stateful
-- >>> maxit = 1000
-- >>> pureGen = mkStdGen 0
-- >>> r = (>= 2) . length
-- >>> runStateGen_ pureGen $ uniformPartitionThin maxit r [1..5::Int]
-- Just [[1,2],[3, 4, 5]]
-- >>> runStateGen_ pureGen $ uniformPartitionThin maxit (const False) ([] :: [Int])
-- Just []
-- >>> runStateGen_ pureGen $ uniformPartitionThin maxit r [1::Int]
-- Nothing
uniformPartitionThin ::
  (StatefulGen g m) =>
  Int ->
  ([a] -> Bool) ->
  [a] ->
  g ->
  m (Maybe [[a]])
uniformPartitionThin :: Int -> ([a] -> Bool) -> [a] -> g -> m (Maybe [[a]])
uniformPartitionThin Int
maxit [a] -> Bool
r [a]
xs = Int -> Int -> ([a] -> Bool) -> [a] -> g -> m (Maybe [[a]])
forall g (m :: * -> *) a.
StatefulGen g m =>
Int -> Int -> ([a] -> Bool) -> [a] -> g -> m (Maybe [[a]])
uniformPartitionThinN Int
maxit ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs) [a] -> Bool
r [a]
xs