{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# lANGUAGE ScopedTypeVariables #-}
module Data.Vector.Algorithms.AmericanFlag ( sort
, sortBy
, Lexicographic(..)
) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Proxy
import Data.Word
import Data.Int
import Data.Bits
import qualified Data.ByteString as B
import Data.Vector.Generic.Mutable
import qualified Data.Vector.Primitive.Mutable as PV
import qualified Data.Vector.Unboxed.Mutable as U
import Data.Vector.Algorithms.Common
import qualified Data.Vector.Algorithms.Insertion as I
import Foreign.Storable
class Lexicographic e where
extent :: e -> Int
size :: Proxy e -> Int
index :: Int -> e -> Int
instance Lexicographic Word8 where
extent _ = 1
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index _ n = fromIntegral n
{-# INLINE index #-}
instance Lexicographic Word16 where
extent _ = 2
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 1 n = fromIntegral $ n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic Word32 where
extent _ = 4
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 3 n = fromIntegral $ n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic Word64 where
extent _ = 8
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 7 n = fromIntegral $ n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic Word where
extent _ = sizeOf (0 :: Word)
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 7 n = fromIntegral $ n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic Int8 where
extent _ = 1
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index _ n = 255 .&. fromIntegral n `xor` 128
{-# INLINE index #-}
instance Lexicographic Int16 where
extent _ = 2
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 8) .&. 255
index 1 n = fromIntegral $ n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic Int32 where
extent _ = 4
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 24) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 3 n = fromIntegral $ n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic Int64 where
extent _ = 8
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 56) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 7 n = fromIntegral $ n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic Int where
extent _ = sizeOf (0 :: Int)
{-# INLINE extent #-}
size _ = 256
{-# INLINE size #-}
index 0 n = ((n `xor` minBound) `shiftR` 56) .&. 255
index 1 n = (n `shiftR` 48) .&. 255
index 2 n = (n `shiftR` 40) .&. 255
index 3 n = (n `shiftR` 32) .&. 255
index 4 n = (n `shiftR` 24) .&. 255
index 5 n = (n `shiftR` 16) .&. 255
index 6 n = (n `shiftR` 8) .&. 255
index 7 n = n .&. 255
index _ _ = 0
{-# INLINE index #-}
instance Lexicographic B.ByteString where
extent = B.length
{-# INLINE extent #-}
size _ = 257
{-# INLINE size #-}
index i b
| i >= B.length b = 0
| otherwise = fromIntegral (B.index b i) + 1
{-# INLINE index #-}
instance (Lexicographic a, Lexicographic b) => Lexicographic (a, b) where
extent (a,b) = extent a + extent b
{-# INLINE extent #-}
size _ = size (Proxy :: Proxy a) `max` size (Proxy :: Proxy b)
{-# INLINE size #-}
index i (a,b)
| i >= extent a = index i b
| otherwise = index i a
{-# INLINE index #-}
instance (Lexicographic a, Lexicographic b) => Lexicographic (Either a b) where
extent (Left a) = 1 + extent a
extent (Right b) = 1 + extent b
{-# INLINE extent #-}
size _ = size (Proxy :: Proxy a) `max` size (Proxy :: Proxy b)
{-# INLINE size #-}
index 0 (Left _) = 0
index 0 (Right _) = 1
index n (Left a) = index (n-1) a
index n (Right b) = index (n-1) b
{-# INLINE index #-}
terminate :: Lexicographic e => e -> Int -> Bool
terminate e i = i >= extent e
{-# INLINE terminate #-}
sort :: forall e m v. (PrimMonad m, MVector v e, Lexicographic e, Ord e)
=> v (PrimState m) e -> m ()
sort v = sortBy compare terminate (size p) index v
where p :: Proxy e
p = Proxy
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e)
=> Comparison e
-> (e -> Int -> Bool)
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy cmp stop buckets radix v
| length v == 0 = return ()
| otherwise = do count <- new buckets
pile <- new buckets
countLoop (radix 0) v count
flagLoop cmp stop radix count pile v
{-# INLINE sortBy #-}
flagLoop :: (PrimMonad m, MVector v e)
=> Comparison e
-> (e -> Int -> Bool)
-> (Int -> e -> Int)
-> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
flagLoop cmp stop radix count pile v = go 0 v
where
go pass v = do e <- unsafeRead v 0
unless (stop e $ pass - 1) $ go' pass v
go' pass v
| len < threshold = I.sortByBounds cmp v 0 len
| otherwise = do accumulate count pile
permute (radix pass) count pile v
recurse 0
where
len = length v
ppass = pass + 1
recurse i
| i < len = do j <- countStripe (radix ppass) (radix pass) count v i
go ppass (unsafeSlice i (j - i) v)
recurse j
| otherwise = return ()
{-# INLINE flagLoop #-}
accumulate :: (PrimMonad m)
=> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> m ()
accumulate count pile = loop 0 0
where
len = length count
loop i acc
| i < len = do ci <- unsafeRead count i
let acc' = acc + ci
unsafeWrite pile i acc
unsafeWrite count i acc'
loop (i+1) acc'
| otherwise = return ()
{-# INLINE accumulate #-}
permute :: (PrimMonad m, MVector v e)
=> (e -> Int)
-> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
permute rdx count pile v = go 0
where
len = length v
go i
| i < len = do e <- unsafeRead v i
let r = rdx e
p <- unsafeRead pile r
m <- if r > 0
then unsafeRead count (r-1)
else return 0
case () of
_ | m <= i && i < p -> go p
| i == p -> unsafeWrite pile r (p+1) >> go (i+1)
| otherwise -> follow i e p >> go (i+1)
| otherwise = return ()
follow i e j = do en <- unsafeRead v j
let r = rdx en
p <- inc pile r
if p == j
then follow i e (j+1)
else unsafeWrite v j e >> if i == p
then unsafeWrite v i en
else follow i en p
{-# INLINE permute #-}
countStripe :: (PrimMonad m, MVector v e)
=> (e -> Int)
-> (e -> Int)
-> PV.MVector (PrimState m) Int
-> v (PrimState m) e
-> Int
-> m Int
countStripe rdx str count v lo = do set count 0
e <- unsafeRead v lo
go (str e) e (lo+1)
where
len = length v
go !s e i = inc count (rdx e) >>
if i < len
then do en <- unsafeRead v i
if str en == s
then go s en (i+1)
else return i
else return len
{-# INLINE countStripe #-}
threshold :: Int
threshold = 25