{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import qualified Data.Vector.Primitive.Mutable as PV
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common
import Data.Bits
import Data.Int
import Data.Word
import Foreign.Storable
class Radix e where
passes :: e -> Int
size :: e -> Int
radix :: Int -> e -> Int
instance Radix Int where
passes _ = sizeOf (undefined :: Int)
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = e .&. 255
radix i e
| i == passes e - 1 = radix' (e `xor` minBound)
| otherwise = radix' e
where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255
{-# INLINE radix #-}
instance Radix Int8 where
passes _ = 1
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix _ e = 255 .&. fromIntegral e `xor` 128
{-# INLINE radix #-}
instance Radix Int16 where
passes _ = 2
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255)
{-# INLINE radix #-}
instance Radix Int32 where
passes _ = 4
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255)
{-# INLINE radix #-}
instance Radix Int64 where
passes _ = 8
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255)
{-# INLINE radix #-}
instance Radix Word where
passes _ = sizeOf (undefined :: Word)
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255)
{-# INLINE radix #-}
instance Radix Word8 where
passes _ = 1
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix _ = fromIntegral
{-# INLINE radix #-}
instance Radix Word16 where
passes _ = 2
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
{-# INLINE radix #-}
instance Radix Word32 where
passes _ = 4
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
{-# INLINE radix #-}
instance Radix Word64 where
passes _ = 8
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255)
{-# INLINE radix #-}
instance (Radix i, Radix j) => Radix (i, j) where
passes ~(i, j) = passes i + passes j
{-# INLINE passes #-}
size ~(i, j) = size i `max` size j
{-# INLINE size #-}
radix k ~(i, j) | k < passes j = radix k j
| otherwise = radix (k - passes j) i
{-# INLINE radix #-}
sort :: forall e m v. (PrimMonad m, MVector v e, Radix e)
=> v (PrimState m) e -> m ()
sort arr = sortBy (passes e) (size e) radix arr
where
e :: e
e = undefined
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e)
=> Int
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy passes size rdx arr = do
tmp <- new (length arr)
count <- new size
radixLoop passes rdx arr tmp count
{-# INLINE sortBy #-}
radixLoop :: (PrimMonad m, MVector v e)
=> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> m ()
radixLoop passes rdx src dst count = go False 0
where
len = length src
go swap k
| k < passes = if swap
then body rdx dst src count k >> go (not swap) (k+1)
else body rdx src dst count k >> go (not swap) (k+1)
| otherwise = when swap (unsafeCopy src dst)
{-# INLINE radixLoop #-}
body :: (PrimMonad m, MVector v e)
=> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> Int
-> m ()
body rdx src dst count k = do
countLoop (rdx k) src count
accumulate count
moveLoop k rdx src dst count
{-# INLINE body #-}
accumulate :: (PrimMonad m)
=> PV.MVector (PrimState m) Int -> m ()
accumulate count = go 0 0
where
len = length count
go i acc
| i < len = do ci <- unsafeRead count i
unsafeWrite count i acc
go (i+1) (acc + ci)
| otherwise = return ()
{-# INLINE accumulate #-}
moveLoop :: (PrimMonad m, MVector v e)
=> Int -> (Int -> e -> Int) -> v (PrimState m) e
-> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
moveLoop k rdx src dst prefix = go 0
where
len = length src
go i
| i < len = do srci <- unsafeRead src i
pf <- inc prefix (rdx k srci)
unsafeWrite dst pf srci
go (i+1)
| otherwise = return ()
{-# INLINE moveLoop #-}