{-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ScopedTypeVariables #-} -- --------------------------------------------------------------------------- -- | -- Module : Data.Vector.Algorithms.Radix -- Copyright : (c) 2008-2011 Dan Doel -- Maintainer : Dan Doel -- Stability : Experimental -- Portability : Non-portable (scoped type variables, bang patterns) -- -- This module provides a radix sort for a subclass of unboxed arrays. The -- radix class gives information on -- * the number of passes needed for the data type -- -- * the size of the auxiliary arrays -- -- * how to compute the pass-k radix of a value -- -- Radix sort is not a comparison sort, so it is able to achieve O(n) run -- time, though it also uses O(n) auxiliary space. In addition, there is a -- constant space overhead of 2*size*sizeOf(Int) for the sort, so it is not -- advisable to use this sort for large numbers of very small arrays. -- -- A standard example (upon which one could base their own Radix instance) -- is Word32: -- -- * We choose to sort on r = 8 bits at a time -- -- * A Word32 has b = 32 bits total -- -- Thus, b/r = 4 passes are required, 2^r = 256 elements are needed in an -- auxiliary array, and the radix function is: -- -- > radix k e = (e `shiftR` (k*8)) .&. 255 module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where import Prelude hiding (read, length) import Control.Monad import Control.Monad.Primitive import qualified Data.Vector.Primitive.Mutable as PV import Data.Vector.Generic.Mutable import Data.Vector.Algorithms.Common import Data.Bits import Data.Int import Data.Word import Foreign.Storable class Radix e where -- | The number of passes necessary to sort an array of es passes :: e -> Int -- | The size of an auxiliary array size :: e -> Int -- | The radix function parameterized by the current pass radix :: Int -> e -> Int instance Radix Int where passes _ = sizeOf (undefined :: Int) {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = e .&. 255 radix i e | i == passes e - 1 = radix' (e `xor` minBound) | otherwise = radix' e where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255 {-# INLINE radix #-} instance Radix Int8 where passes _ = 1 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix _ e = 255 .&. fromIntegral e `xor` 128 {-# INLINE radix #-} instance Radix Int16 where passes _ = 2 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = fromIntegral (e .&. 255) radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255) {-# INLINE radix #-} instance Radix Int32 where passes _ = 4 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = fromIntegral (e .&. 255) radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255) {-# INLINE radix #-} instance Radix Int64 where passes _ = 8 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = fromIntegral (e .&. 255) radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255) radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255) radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255) radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255) radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255) {-# INLINE radix #-} instance Radix Word where passes _ = sizeOf (undefined :: Word) {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = fromIntegral (e .&. 255) radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255) {-# INLINE radix #-} instance Radix Word8 where passes _ = 1 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix _ = fromIntegral {-# INLINE radix #-} instance Radix Word16 where passes _ = 2 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = fromIntegral (e .&. 255) radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) {-# INLINE radix #-} instance Radix Word32 where passes _ = 4 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = fromIntegral (e .&. 255) radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255) {-# INLINE radix #-} instance Radix Word64 where passes _ = 8 {-# INLINE passes #-} size _ = 256 {-# INLINE size #-} radix 0 e = fromIntegral (e .&. 255) radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255) radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255) radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255) radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255) radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255) {-# INLINE radix #-} instance (Radix i, Radix j) => Radix (i, j) where passes ~(i, j) = passes i + passes j {-# INLINE passes #-} size ~(i, j) = size i `max` size j {-# INLINE size #-} radix k ~(i, j) | k < passes j = radix k j | otherwise = radix (k - passes j) i {-# INLINE radix #-} -- | Sorts an array based on the Radix instance. sort :: forall e m v. (PrimMonad m, MVector v e, Radix e) => v (PrimState m) e -> m () sort arr = sortBy (passes e) (size e) radix arr where e :: e e = undefined {-# INLINABLE sort #-} -- | Radix sorts an array using custom radix information -- requires the number of passes to fully sort the array, -- the size of of auxiliary arrays necessary (should be -- one greater than the maximum value returned by the radix -- function), and a radix function, which takes the pass -- and an element, and returns the relevant radix. sortBy :: (PrimMonad m, MVector v e) => Int -- ^ the number of passes -> Int -- ^ the size of auxiliary arrays -> (Int -> e -> Int) -- ^ the radix function -> v (PrimState m) e -- ^ the array to be sorted -> m () sortBy passes size rdx arr = do tmp <- new (length arr) count <- new size radixLoop passes rdx arr tmp count {-# INLINE sortBy #-} radixLoop :: (PrimMonad m, MVector v e) => Int -- passes -> (Int -> e -> Int) -- radix function -> v (PrimState m) e -- array to sort -> v (PrimState m) e -- temporary array -> PV.MVector (PrimState m) Int -- radix count array -> m () radixLoop passes rdx src dst count = go False 0 where len = length src go swap k | k < passes = if swap then body rdx dst src count k >> go (not swap) (k+1) else body rdx src dst count k >> go (not swap) (k+1) | otherwise = when swap (unsafeCopy src dst) {-# INLINE radixLoop #-} body :: (PrimMonad m, MVector v e) => (Int -> e -> Int) -- radix function -> v (PrimState m) e -- source array -> v (PrimState m) e -- destination array -> PV.MVector (PrimState m) Int -- radix count -> Int -- current pass -> m () body rdx src dst count k = do countLoop (rdx k) src count accumulate count moveLoop k rdx src dst count {-# INLINE body #-} accumulate :: (PrimMonad m) => PV.MVector (PrimState m) Int -> m () accumulate count = go 0 0 where len = length count go i acc | i < len = do ci <- unsafeRead count i unsafeWrite count i acc go (i+1) (acc + ci) | otherwise = return () {-# INLINE accumulate #-} moveLoop :: (PrimMonad m, MVector v e) => Int -> (Int -> e -> Int) -> v (PrimState m) e -> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m () moveLoop k rdx src dst prefix = go 0 where len = length src go i | i < len = do srci <- unsafeRead src i pf <- inc prefix (rdx k srci) unsafeWrite dst pf srci go (i+1) | otherwise = return () {-# INLINE moveLoop #-}