{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE CPP #-}
module Std.Data.Vector.Sort (
mergeSort
, mergeSortBy
, mergeTileSize
, insertSort
, insertSortBy
, Down(..)
, radixSort
, Radix(..)
, RadixDown(..)
) where
import Control.Monad.ST
import Data.Bits
import Data.Int
import Data.Ord (Down (..))
import Data.Primitive (sizeOf)
import Data.Primitive.Types (Prim (..))
import Data.Word
import Prelude hiding (splitAt)
import Std.Data.Array
import Std.Data.Vector.Base
import Std.Data.Vector.Extra
import Std.Data.PrimArray.Cast
mergeSort :: forall v a. (Vec v a, Ord a) => v a -> v a
{-# INLINABLE mergeSort #-}
mergeSort = mergeSortBy compare
mergeSortBy :: forall v a. Vec v a => (a -> a -> Ordering) -> v a -> v a
{-# INLINE mergeSortBy #-}
mergeSortBy cmp v@(Vec _ _ l)
| l <= mergeTileSize = insertSortBy cmp v
| otherwise = runST (do
w1 <- newArr l
w2 <- newArr l
firstPass v 0 w1
w <- mergePass w1 w2 mergeTileSize
return $! fromArr w 0 l)
where
firstPass !v !i !marr
| i >= l = return ()
| otherwise = do
let (v',rest) = splitAt mergeTileSize v
insertSortToMArr cmp v' i marr
firstPass rest (i+mergeTileSize) marr
mergePass !w1 !w2 !blockSiz
| blockSiz >= l = unsafeFreezeArr w1
| otherwise = do
mergeLoop w1 w2 blockSiz 0
mergePass w2 w1 (blockSiz*2)
mergeLoop !src !target !blockSiz !i
| i >= l-blockSiz =
if i >= l
then return ()
else copyMutableArr target i src i (l-i)
| otherwise = do
let !mergeEnd = min (i+blockSiz+blockSiz) l
mergeBlock src target (i+blockSiz) mergeEnd i (i+blockSiz) i
mergeLoop src target blockSiz mergeEnd
mergeBlock !src !target !leftEnd !rightEnd !i !j !k = do
l <- readArr src i
r <- readArr src j
case r `cmp` l of
LT -> do
writeArr target k r
let !j' = j + 1
!k' = k + 1
if j' >= rightEnd
then copyMutableArr target k' src i (leftEnd - i)
else mergeBlock src target leftEnd rightEnd i j' k'
_ -> do
writeArr target k l
let !i' = i + 1
!k' = k + 1
if i' >= leftEnd
then copyMutableArr target k' src j (rightEnd - j)
else mergeBlock src target leftEnd rightEnd i' j k'
mergeTileSize :: Int
{-# INLINE mergeTileSize #-}
mergeTileSize = 16
insertSort :: (Vec v a, Ord a) => v a -> v a
{-# INLINE insertSort #-}
insertSort = insertSortBy compare
insertSortBy :: Vec v a => (a -> a -> Ordering) -> v a -> v a
{-# INLINE insertSortBy #-}
insertSortBy _ v@(Vec _ _ 0) = empty
insertSortBy _ v@(Vec arr s 1) = case indexArr' arr s of (# x #) -> singleton x
insertSortBy cmp v@(Vec arr s l) = create l (insertSortToMArr cmp v 0)
insertSortToMArr :: Vec v a
=> (a -> a -> Ordering)
-> v a
-> Int
-> MArray v s a
-> ST s ()
{-# INLINE insertSortToMArr #-}
insertSortToMArr cmp (Vec arr s l) moff marr = go s
where
!end = s + l
!doff = moff-s
go !i | i >= end = return ()
| otherwise = case indexArr' arr i of
(# x #) -> do insert x (i+doff)
go (i+1)
insert !temp !i
| i <= moff = do
writeArr marr moff temp
| otherwise = do
x <- readArr marr (i-1)
case temp `cmp` x of
LT -> do
writeArr marr i x
insert temp (i-1)
_ -> writeArr marr i temp
class Radix a where
bucketSize :: a -> Int
passes :: a -> Int
radixLSB :: a -> Int
radix :: Int -> a -> Int
radixMSB :: a -> Int
instance Radix Int8 where
{-# INLINE bucketSize #-};
bucketSize _ = 256
{-# INLINE passes #-}
passes _ = 1
{-# INLINE radixLSB #-}
radixLSB a = 255 .&. fromIntegral a `xor` 128
{-# INLINE radix #-}
radix _ a = 255 .&. fromIntegral a `xor` 128
{-# INLINE radixMSB #-}
radixMSB a = 255 .&. fromIntegral a `xor` 128
#define MULTI_BYTES_INT_RADIX(T) \
{-# INLINE bucketSize #-}; \
bucketSize _ = 256; \
{-# INLINE passes #-}; \
passes _ = sizeOf (undefined :: T); \
{-# INLINE radixLSB #-}; \
radixLSB a = fromIntegral (255 .&. a); \
{-# INLINE radix #-}; \
radix i a = fromIntegral (a `unsafeShiftR` (i `unsafeShiftL` 3)) .&. 255; \
{-# INLINE radixMSB #-}; \
radixMSB a = fromIntegral ((a `xor` minBound) `unsafeShiftR` ((passes a-1) `unsafeShiftL` 3)) .&. 255
instance Radix Int where MULTI_BYTES_INT_RADIX(Int)
instance Radix Int16 where MULTI_BYTES_INT_RADIX(Int16)
instance Radix Int32 where MULTI_BYTES_INT_RADIX(Int32)
instance Radix Int64 where MULTI_BYTES_INT_RADIX(Int64)
instance Radix Word8 where
{-# INLINE bucketSize #-};
bucketSize _ = 256
{-# INLINE passes #-}
passes _ = 1
{-# INLINE radixLSB #-}
radixLSB = fromIntegral
{-# INLINE radix #-}
radix _ = fromIntegral
{-# INLINE radixMSB #-}
radixMSB = fromIntegral
#define MULTI_BYTES_WORD_RADIX(T) \
{-# INLINE bucketSize #-}; \
bucketSize _ = 256; \
{-# INLINE passes #-}; \
passes _ = sizeOf (undefined :: T); \
{-# INLINE radixLSB #-}; \
radixLSB a = fromIntegral (255 .&. a); \
{-# INLINE radix #-}; \
radix i a = fromIntegral (a `unsafeShiftR` (i `unsafeShiftL` 3)) .&. 255; \
{-# INLINE radixMSB #-}; \
radixMSB a = fromIntegral (a `unsafeShiftR` ((passes a-1) `unsafeShiftL` 3)) .&. 255
instance Radix Word where MULTI_BYTES_INT_RADIX(Word)
instance Radix Word16 where MULTI_BYTES_INT_RADIX(Word16)
instance Radix Word32 where MULTI_BYTES_INT_RADIX(Word32)
instance Radix Word64 where MULTI_BYTES_INT_RADIX(Word64)
newtype RadixDown a = RadixDown a deriving (Show, Eq, Prim)
instance Radix a => Radix (RadixDown a) where
{-# INLINE bucketSize #-}
bucketSize (RadixDown a) = bucketSize a
{-# INLINE passes #-}
passes (RadixDown a) = passes a
{-# INLINE radixLSB #-}
radixLSB (RadixDown a) = bucketSize a - radixLSB a -1
{-# INLINE radix #-}
radix i (RadixDown a) = bucketSize a - radix i a -1
{-# INLINE radixMSB #-}
radixMSB (RadixDown a) = bucketSize a - radixMSB a -1
radixSort :: forall v a. (Vec v a, Radix a) => v a -> v a
{-# INLINABLE radixSort #-}
radixSort v@(Vec _ _ 0) = empty
radixSort v@(Vec arr s 1) = case indexArr' arr s of (# x #) -> singleton x
radixSort (Vec arr s l) = runST (do
bucket <- newArrWith buktSiz 0 :: ST s (MutablePrimArray s Int)
w1 <- newArr l
firstCountPass arr bucket s
accumBucket bucket buktSiz 0 0
firstMovePass arr s bucket w1
w <- if passSiz == 1
then unsafeFreezeArr w1
else do
w2 <- newArr l
radixLoop w1 w2 bucket buktSiz 1
return $! fromArr w 0 l)
where
passSiz = passes (undefined :: a)
buktSiz = bucketSize (undefined :: a)
!end = s + l
{-# INLINABLE firstCountPass #-}
firstCountPass !arr !bucket !i
| i >= end = return ()
| otherwise = case indexArr' arr i of
(# x #) -> do
let !r = radixLSB x
c <- readArr bucket r
writeArr bucket r (c+1)
firstCountPass arr bucket (i+1)
{-# INLINABLE accumBucket #-}
accumBucket !bucket !buktSiz !i !acc
| i >= buktSiz = return ()
| otherwise = do
c <- readArr bucket i
writeArr bucket i acc
accumBucket bucket buktSiz (i+1) (acc+c)
{-# INLINABLE firstMovePass #-}
firstMovePass !arr !i !bucket !w
| i >= end = return ()
| otherwise = case indexArr' arr i of
(# x #) -> do
let !r = radixLSB x
c <- readArr bucket r
writeArr bucket r (c+1)
writeArr w c x
firstMovePass arr (i+1) bucket w
{-# INLINABLE radixLoop #-}
radixLoop !w1 !w2 !bucket !buktSiz !pass
| pass >= passSiz-1 = do
setArr bucket 0 buktSiz 0
lastCountPass w1 bucket 0
accumBucket bucket buktSiz 0 0
lastMovePass w1 bucket w2 0
unsafeFreezeArr w2
| otherwise = do
setArr bucket 0 buktSiz 0
countPass w1 bucket pass 0
accumBucket bucket buktSiz 0 0
movePass w1 bucket pass w2 0
radixLoop w2 w1 bucket buktSiz (pass+1)
{-# INLINABLE countPass #-}
countPass !marr !bucket !pass !i
| i >= l = return ()
| otherwise = do
x <- readArr marr i
let !r = radix pass x
c <- readArr bucket r
writeArr bucket r (c+1)
countPass marr bucket pass (i+1)
{-# INLINABLE movePass #-}
movePass !src !bucket !pass !target !i
| i >= l = return ()
| otherwise = do
x <- readArr src i
let !r = radix pass x
c <- readArr bucket r
writeArr bucket r (c+1)
writeArr target c x
movePass src bucket pass target (i+1)
{-# INLINABLE lastCountPass #-}
lastCountPass !marr !bucket !i
| i >= l = return ()
| otherwise = do
x <- readArr marr i
let !r = radixMSB x
c <- readArr bucket r
writeArr bucket r (c+1)
lastCountPass marr bucket (i+1)
{-# INLINABLE lastMovePass #-}
lastMovePass !src !bucket !target !i
| i >= l = return ()
| otherwise = do
x <- readArr src i
let !r = radixMSB x
c <- readArr bucket r
writeArr bucket r (c+1)
writeArr target c x
lastMovePass src bucket target (i+1)