module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import qualified Data.Vector.Primitive.Mutable as PV
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common
import Data.Bits
import Data.Int
import Data.Word
import Foreign.Storable
class Radix e where
passes :: e -> Int
size :: e -> Int
radix :: Int -> e -> Int
instance Radix Int where
passes _ = sizeOf (undefined :: Int)
size _ = 256
radix 0 e = e .&. 255
radix i e
| i == passes e 1 = radix' (e `xor` minBound)
| otherwise = radix' e
where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255
instance Radix Int8 where
passes _ = 1
size _ = 256
radix _ e = 255 .&. fromIntegral e `xor` 128
instance Radix Int16 where
passes _ = 2
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255)
instance Radix Int32 where
passes _ = 4
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255)
instance Radix Int64 where
passes _ = 8
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255)
instance Radix Word where
passes _ = sizeOf (undefined :: Word)
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255)
instance Radix Word8 where
passes _ = 1
size _ = 256
radix _ = fromIntegral
instance Radix Word16 where
passes _ = 2
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
instance Radix Word32 where
passes _ = 4
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
instance Radix Word64 where
passes _ = 8
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255)
instance (Radix i, Radix j) => Radix (i, j) where
passes ~(i, j) = passes i + passes j
size ~(i, j) = size i `max` size j
radix k ~(i, j) | k < passes j = radix k j
| otherwise = radix (k passes j) i
sort :: forall e m v. (PrimMonad m, MVector v e, Radix e)
=> v (PrimState m) e -> m ()
sort arr = sortBy (passes e) (size e) radix arr
where
e :: e
e = undefined
sortBy :: (PrimMonad m, MVector v e)
=> Int
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy passes size rdx arr = do
tmp <- new (length arr)
count <- new size
radixLoop passes rdx arr tmp count
radixLoop :: (PrimMonad m, MVector v e)
=> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> m ()
radixLoop passes rdx src dst count = go False 0
where
len = length src
go swap k
| k < passes = if swap
then body rdx dst src count k >> go (not swap) (k+1)
else body rdx src dst count k >> go (not swap) (k+1)
| otherwise = when swap (unsafeCopy src dst)
body :: (PrimMonad m, MVector v e)
=> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> Int
-> m ()
body rdx src dst count k = do
countLoop (rdx k) src count
accumulate count
moveLoop k rdx src dst count
accumulate :: (PrimMonad m)
=> PV.MVector (PrimState m) Int -> m ()
accumulate count = go 0 0
where
len = length count
go i acc
| i < len = do ci <- unsafeRead count i
unsafeWrite count i acc
go (i+1) (acc + ci)
| otherwise = return ()
moveLoop :: (PrimMonad m, MVector v e)
=> Int -> (Int -> e -> Int) -> v (PrimState m) e
-> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
moveLoop k rdx src dst prefix = go 0
where
len = length src
go i
| i < len = do srci <- unsafeRead src i
pf <- inc prefix (rdx k srci)
unsafeWrite dst pf srci
go (i+1)
| otherwise = return ()