{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- ---------------------------------------------------------------------------
-- |
-- Module      : Data.Vector.Algorithms.Radix
-- Copyright   : (c) 2008-2011 Dan Doel
-- Maintainer  : Dan Doel <dan.doel@gmail.com>
-- Stability   : Experimental
-- Portability : Non-portable (scoped type variables, bang patterns)
--
-- This module provides a radix sort for a subclass of unboxed arrays. The
-- radix class gives information on
--   * the number of passes needed for the data type
--
--   * the size of the auxiliary arrays
--
--   * how to compute the pass-k radix of a value
--
-- Radix sort is not a comparison sort, so it is able to achieve O(n) run
-- time, though it also uses O(n) auxiliary space. In addition, there is a
-- constant space overhead of 2*size*sizeOf(Int) for the sort, so it is not
-- advisable to use this sort for large numbers of very small arrays.
--
-- A standard example (upon which one could base their own Radix instance)
-- is Word32:
--
--   * We choose to sort on r = 8 bits at a time
--
--   * A Word32 has b = 32 bits total
--
--   Thus, b/r = 4 passes are required, 2^r = 256 elements are needed in an
--   auxiliary array, and the radix function is:
--
--    > radix k e = (e `shiftR` (k*8)) .&. 255

module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where

import Prelude hiding (read, length)

import Control.Monad
import Control.Monad.Primitive

import qualified Data.Vector.Primitive.Mutable as PV
import Data.Vector.Generic.Mutable

import Data.Vector.Algorithms.Common

import Data.Bits
import Data.Int
import Data.Word


import Foreign.Storable

class Radix e where
  -- | The number of passes necessary to sort an array of es
  passes :: e -> Int
  -- | The size of an auxiliary array
  size   :: e -> Int
  -- | The radix function parameterized by the current pass
  radix  :: Int -> e -> Int

instance Radix Int where
  passes :: Int -> Int
passes Int
_ = Int -> Int
forall a. Storable a => a -> Int
sizeOf (Int
forall a. HasCallStack => a
undefined :: Int)
  {-# INLINE passes #-}
  size :: Int -> Int
size Int
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Int -> Int
radix Int
0 Int
e = Int
e Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  radix Int
i Int
e
    | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Int
forall e. Radix e => e -> Int
passes Int
e Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 = Int -> Int
radix' (Int
e Int -> Int -> Int
forall a. Bits a => a -> a -> a
`xor` Int
forall a. Bounded a => a
minBound)
    | Bool
otherwise         = Int -> Int
radix' Int
e
   where radix' :: Int -> Int
radix' Int
e = (Int
e Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` (Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
3)) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  {-# INLINE radix #-}

instance Radix Int8 where
  passes :: Int8 -> Int
passes Int8
_ = Int
1
  {-# INLINE passes #-}
  size :: Int8 -> Int
size Int8
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Int8 -> Int
radix Int
_ Int8
e = Int
255 Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
e Int -> Int -> Int
forall a. Bits a => a -> a -> a
`xor` Int
128
  {-# INLINE radix #-}

instance Radix Int16 where
  passes :: Int16 -> Int
passes Int16
_ = Int
2
  {-# INLINE passes #-}
  size :: Int16 -> Int
size Int16
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Int16 -> Int
radix Int
0 Int16
e = Int16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int16
e Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.&. Int16
255)
  radix Int
1 Int16
e = Int16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (((Int16
e Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
`xor` Int16
forall a. Bounded a => a
minBound) Int16 -> Int -> Int16
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.&. Int16
255)
  {-# INLINE radix #-}

instance Radix Int32 where
  passes :: Int32 -> Int
passes Int32
_ = Int
4
  {-# INLINE passes #-}
  size :: Int32 -> Int
size Int32
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Int32 -> Int
radix Int
0 Int32
e = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
e Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255)
  radix Int
1 Int32
e = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int32
e Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255)
  radix Int
2 Int32
e = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int32
e Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255)
  radix Int
3 Int32
e = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (((Int32
e Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
`xor` Int32
forall a. Bounded a => a
minBound) Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255)
  {-# INLINE radix #-}

instance Radix Int64 where
  passes :: Int64 -> Int
passes Int64
_ = Int
8
  {-# INLINE passes #-}
  size :: Int64 -> Int
size Int64
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Int64 -> Int
radix Int
0 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
e Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  radix Int
1 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  radix Int
2 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  radix Int
3 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  radix Int
4 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
32) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  radix Int
5 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
40) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  radix Int
6 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
48) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  radix Int
7 Int64
e = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (((Int64
e Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
`xor` Int64
forall a. Bounded a => a
minBound) Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
56) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255)
  {-# INLINE radix #-}

instance Radix Word where
  passes :: Word -> Int
passes Word
_ = Word -> Int
forall a. Storable a => a -> Int
sizeOf (Word
forall a. HasCallStack => a
undefined :: Word)
  {-# INLINE passes #-}
  size :: Word -> Int
size Word
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Word -> Int
radix Int
0 Word
e = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word
e Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255)
  radix Int
i Word
e = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word
e Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` (Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
3)) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255)
  {-# INLINE radix #-}

instance Radix Word8 where
  passes :: Word8 -> Int
passes Word8
_ = Int
1
  {-# INLINE passes #-}
  size :: Word8 -> Int
size Word8
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Word8 -> Int
radix Int
_ = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  {-# INLINE radix #-}

instance Radix Word16 where
  passes :: Word16 -> Int
passes Word16
_ = Int
2
  {-# INLINE passes #-}
  size :: Word16 -> Int
size   Word16
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Word16 -> Int
radix Int
0 Word16
e = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
e Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
255)
  radix Int
1 Word16
e = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word16
e Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
255)
  {-# INLINE radix #-}

instance Radix Word32 where
  passes :: Word32 -> Int
passes Word32
_ = Int
4
  {-# INLINE passes #-}
  size :: Word32 -> Int
size   Word32
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Word32 -> Int
radix Int
0 Word32
e = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
e Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255)
  radix Int
1 Word32
e = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
e Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255)
  radix Int
2 Word32
e = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
e Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255)
  radix Int
3 Word32
e = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
e Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255)
  {-# INLINE radix #-}

instance Radix Word64 where
  passes :: Word64 -> Int
passes Word64
_ = Int
8
  {-# INLINE passes #-}
  size :: Word64 -> Int
size   Word64
_ = Int
256
  {-# INLINE size #-}
  radix :: Int -> Word64 -> Int
radix Int
0 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
e Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  radix Int
1 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  radix Int
2 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  radix Int
3 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  radix Int
4 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  radix Int
5 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
40) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  radix Int
6 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
48) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  radix Int
7 Word64
e = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
56) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255)
  {-# INLINE radix #-}

instance (Radix i, Radix j) => Radix (i, j) where
  passes :: (i, j) -> Int
passes ~(i
i, j
j) = i -> Int
forall e. Radix e => e -> Int
passes i
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ j -> Int
forall e. Radix e => e -> Int
passes j
j
  {-# INLINE passes #-}
  size :: (i, j) -> Int
size   ~(i
i, j
j) = i -> Int
forall e. Radix e => e -> Int
size i
i Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` j -> Int
forall e. Radix e => e -> Int
size j
j
  {-# INLINE size #-}
  radix :: Int -> (i, j) -> Int
radix Int
k ~(i
i, j
j) | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< j -> Int
forall e. Radix e => e -> Int
passes j
j = Int -> j -> Int
forall e. Radix e => Int -> e -> Int
radix Int
k j
j
                     | Bool
otherwise    = Int -> i -> Int
forall e. Radix e => Int -> e -> Int
radix (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- j -> Int
forall e. Radix e => e -> Int
passes j
j) i
i
  {-# INLINE radix #-}

-- | Sorts an array based on the Radix instance.
sort :: forall e m v. (PrimMonad m, MVector v e, Radix e)
     => v (PrimState m) e -> m ()
sort :: v (PrimState m) e -> m ()
sort v (PrimState m) e
arr = Int -> Int -> (Int -> e -> Int) -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int -> Int -> (Int -> e -> Int) -> v (PrimState m) e -> m ()
sortBy (e -> Int
forall e. Radix e => e -> Int
passes e
e) (e -> Int
forall e. Radix e => e -> Int
size e
e) Int -> e -> Int
forall e. Radix e => Int -> e -> Int
radix v (PrimState m) e
arr
 where
 e :: e
 e :: e
e = e
forall a. HasCallStack => a
undefined
{-# INLINABLE sort #-}

-- | Radix sorts an array using custom radix information
-- requires the number of passes to fully sort the array,
-- the size of of auxiliary arrays necessary (should be
-- one greater than the maximum value returned by the radix
-- function), and a radix function, which takes the pass
-- and an element, and returns the relevant radix.
sortBy :: (PrimMonad m, MVector v e)
       => Int               -- ^ the number of passes
       -> Int               -- ^ the size of auxiliary arrays
       -> (Int -> e -> Int) -- ^ the radix function
       -> v (PrimState m) e -- ^ the array to be sorted
       -> m ()
sortBy :: Int -> Int -> (Int -> e -> Int) -> v (PrimState m) e -> m ()
sortBy Int
passes Int
size Int -> e -> Int
rdx v (PrimState m) e
arr = do
  v (PrimState m) e
tmp    <- Int -> m (v (PrimState m) e)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
arr)
  MVector (PrimState m) Int
count  <- Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new Int
size
  Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
radixLoop Int
passes Int -> e -> Int
rdx v (PrimState m) e
arr v (PrimState m) e
tmp MVector (PrimState m) Int
count
{-# INLINE sortBy #-}

radixLoop :: (PrimMonad m, MVector v e)
          => Int                          -- passes
          -> (Int -> e -> Int)            -- radix function
          -> v (PrimState m) e            -- array to sort
          -> v (PrimState m) e            -- temporary array
          -> PV.MVector (PrimState m) Int -- radix count array
          -> m ()
radixLoop :: Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
radixLoop Int
passes Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count = Bool -> Int -> m ()
go Bool
False Int
0
 where
 len :: Int
len = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src
 go :: Bool -> Int -> m ()
go Bool
swap Int
k
   | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
passes = if Bool
swap
                    then (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
body Int -> e -> Int
rdx v (PrimState m) e
dst v (PrimState m) e
src MVector (PrimState m) Int
count Int
k m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> Int -> m ()
go (Bool -> Bool
not Bool
swap) (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                    else (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
body Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count Int
k m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> Int -> m ()
go (Bool -> Bool
not Bool
swap) (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
   | Bool
otherwise  = Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
swap (v (PrimState m) e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) e
src v (PrimState m) e
dst)
{-# INLINE radixLoop #-}

body :: (PrimMonad m, MVector v e)
     => (Int -> e -> Int)            -- radix function
     -> v (PrimState m) e            -- source array
     -> v (PrimState m) e            -- destination array
     -> PV.MVector (PrimState m) Int -- radix count
     -> Int                          -- current pass
     -> m ()
body :: (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
body Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count Int
k = do
  (e -> Int)
-> v (PrimState m) e -> MVector (PrimState m) Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(e -> Int)
-> v (PrimState m) e -> MVector (PrimState m) Int -> m ()
countLoop (Int -> e -> Int
rdx Int
k) v (PrimState m) e
src MVector (PrimState m) Int
count
  MVector (PrimState m) Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Int -> m ()
accumulate MVector (PrimState m) Int
count
  Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
moveLoop Int
k Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count
{-# INLINE body #-}

accumulate :: (PrimMonad m)
           => PV.MVector (PrimState m) Int -> m ()
accumulate :: MVector (PrimState m) Int -> m ()
accumulate MVector (PrimState m) Int
count = Int -> Int -> m ()
go Int
0 Int
0
 where
 len :: Int
len = MVector (PrimState m) Int -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length MVector (PrimState m) Int
count
 go :: Int -> Int -> m ()
go Int
i Int
acc
   | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len   = do Int
ci <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead MVector (PrimState m) Int
count Int
i
                    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector (PrimState m) Int
count Int
i Int
acc
                    Int -> Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ci)
   | Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE accumulate #-}

moveLoop :: (PrimMonad m, MVector v e)
         => Int -> (Int -> e -> Int) -> v (PrimState m) e
         -> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
moveLoop :: Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
moveLoop Int
k Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
prefix = Int -> m ()
go Int
0
 where
 len :: Int
len = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src
 go :: Int -> m ()
go Int
i
   | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len    = do e
srci <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
src Int
i
                     Int
pf   <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v Int) =>
v (PrimState m) Int -> Int -> m Int
inc MVector (PrimState m) Int
prefix (Int -> e -> Int
rdx Int
k e
srci)
                     v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
dst Int
pf e
srci
                     Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
   | Bool
otherwise  = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE moveLoop #-}