-- | Internal module whose primary export is 'uniformPartition'.  Use
-- 'RandomCycle.Vector' instead.
module RandomCycle.Vector.Partition where

import Data.Bits
import qualified Data.Vector.Generic as GV
import GHC.Natural (Natural)
import System.Random.Stateful

{- UTILITIES -}

---- | Internal. Find the first index where the bit is flipped, shifting the
-- bits as you go and returning the final shifted bit vector. The degenerate
-- case @bs == 0@ returns the otherwise unreachable point @(0, 0)@ to guarantee
-- termination, but note that case is nonsensical in 'partitionFromBits' and
-- handled explicitly there.
commonSubseqBits :: Natural -> (Natural, Int)
commonSubseqBits :: Natural -> (Natural, Int)
commonSubseqBits Natural
0 = (Natural
0, Int
0)
commonSubseqBits Natural
bs = forall a. (a -> Bool) -> (a -> a) -> a -> a
until forall {b}. (Natural, b) -> Bool
done (\(Natural
bs', Int
i) -> (Natural
bs' forall a. Bits a => a -> Int -> a
`shiftR` Int
1, Int
i forall a. Num a => a -> a -> a
+ Int
1)) (Natural
bs, Int
0)
  where
    done :: (Natural, b) -> Bool
done = if Natural
bs forall a. Bits a => a -> Int -> Bool
`testBit` Int
0 then Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Bits a => a -> Int -> Bool
`testBit` Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst else (forall a. Bits a => a -> Int -> Bool
`testBit` Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst

-- | Partition a vector 'v' according to groupings provided by the bits 'bs'.
-- If the first set bit in 'bs' is at a position larger than the last index of
-- 'v', this returns @[v]@. More generally, bits set at positions after the
-- last index of 'v' do not contribute to the grouping. @bs == 0@ always
-- results in @[v]@.
--
-- See 'RandomCycle.List.partitionFromBits' for other examples.
--
-- >>> import qualified Data.Vector as V
-- >>> partitionFromBits 5 (V.fromList [0..2::Int])
-- [[0],[1],[2]]
-- >>> partitionFromBits 13 (V.fromList [0..2::Int])
-- [[0],[1],[2]]
-- >>> partitionFromBits 4 (V.fromList [0..2::Int])
-- [[0,1],[2]]
-- >>> partitionFromBits 8 (V.fromList [0..2::Int])
-- [[0,1,2]]
partitionFromBits :: (GV.Vector v a) => Natural -> v a -> [v a]
partitionFromBits :: forall (v :: * -> *) a. Vector v a => Natural -> v a -> [v a]
partitionFromBits Natural
_ v a
v | forall (v :: * -> *) a. Vector v a => v a -> Bool
GV.null v a
v = []
partitionFromBits Natural
0 v a
v = [v a
v]
partitionFromBits Natural
bs v a
v =
  let (Natural
bs', Int
idx) = Natural -> (Natural, Int)
commonSubseqBits Natural
bs
      (v a
v1, v a
v2) = forall (v :: * -> *) a. Vector v a => Int -> v a -> (v a, v a)
GV.splitAt Int
idx v a
v
   in v a
v1 forall a. a -> [a] -> [a]
: forall (v :: * -> *) a. Vector v a => Natural -> v a -> [v a]
partitionFromBits Natural
bs' v a
v2

{- RANDOM -}

-- | Draw a random partition of the input vector 'xs' from the uniform
-- distribution on partitions. This proceeds by randomizing the placement of
-- each breakpoint, in other words by walking a random path in a perfect binary
-- tree. /O(n)/ for a vector length /n/.
--
-- This function preserves the order of the input list.
uniformPartition :: (GV.Vector v a, StatefulGen g m) => v a -> g -> m [v a]
uniformPartition :: forall (v :: * -> *) a g (m :: * -> *).
(Vector v a, StatefulGen g m) =>
v a -> g -> m [v a]
uniformPartition v a
xs g
g = do
  let d :: Int
d = forall (v :: * -> *) a. Vector v a => v a -> Int
GV.length v a
xs
  Natural
bs <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Natural
0, Natural
2 forall a b. (Num a, Integral b) => a -> b -> a
^ Int
d forall a. Num a => a -> a -> a
- Natural
1) g
g
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => Natural -> v a -> [v a]
partitionFromBits Natural
bs v a
xs