module Data.Vector.Algorithms.AmericanFlag ( sort
, sortBy
, Lexicographic(..)
) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Word
import Data.Int
import Data.Bits
import qualified Data.ByteString as B
import Data.Vector.Generic.Mutable
import qualified Data.Vector.Primitive.Mutable as PV
import qualified Data.Vector.Unboxed.Mutable as U
import Data.Vector.Algorithms.Common
import qualified Data.Vector.Algorithms.Insertion as I
class Lexicographic e where
terminate :: e -> Int -> Bool
size :: e -> Int
index :: Int -> e -> Int
instance Lexicographic Word8 where
terminate _ n = n > 0
size _ = 256
index _ n = fromIntegral n
instance Lexicographic Word16 where
terminate _ n = n > 1
size _ = 256
index 0 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 1 n = fromIntegral $ n .&. 255
index _ _ = 0
instance Lexicographic Word32 where
terminate _ n = n > 3
size _ = 256
index 0 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 3 n = fromIntegral $ n .&. 255
index _ _ = 0
instance Lexicographic Word64 where
terminate _ n = n > 7
size _ = 256
index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 7 n = fromIntegral $ n .&. 255
index _ _ = 0
instance Lexicographic Word where
terminate _ n = n > 7
size _ = 256
index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 7 n = fromIntegral $ n .&. 255
index _ _ = 0
instance Lexicographic Int8 where
terminate _ n = n > 0
size _ = 256
index _ n = 255 .&. fromIntegral n `xor` 128
instance Lexicographic Int16 where
terminate _ n = n > 1
size _ = 256
index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 8) .&. 255
index 1 n = fromIntegral $ n .&. 255
index _ _ = 0
instance Lexicographic Int32 where
terminate _ n = n > 3
size _ = 256
index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 24) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 3 n = fromIntegral $ n .&. 255
index _ _ = 0
instance Lexicographic Int64 where
terminate _ n = n > 7
size _ = 256
index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 56) .&. 255
index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255
index 7 n = fromIntegral $ n .&. 255
index _ _ = 0
instance Lexicographic Int where
terminate _ n = n > 7
size _ = 256
index 0 n = ((n `xor` minBound) `shiftR` 56) .&. 255
index 1 n = (n `shiftR` 48) .&. 255
index 2 n = (n `shiftR` 40) .&. 255
index 3 n = (n `shiftR` 32) .&. 255
index 4 n = (n `shiftR` 24) .&. 255
index 5 n = (n `shiftR` 16) .&. 255
index 6 n = (n `shiftR` 8) .&. 255
index 7 n = n .&. 255
index _ _ = 0
instance Lexicographic B.ByteString where
terminate b i = i >= B.length b
size _ = 257
index i b
| i >= B.length b = 0
| otherwise = fromIntegral (B.index b i) + 1
sort :: forall e m v. (PrimMonad m, MVector v e, Lexicographic e, Ord e)
=> v (PrimState m) e -> m ()
sort v = sortBy compare terminate (size e) index v
where e :: e
e = undefined
sortBy :: (PrimMonad m, MVector v e)
=> Comparison e
-> (e -> Int -> Bool)
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy cmp stop buckets radix v
| length v == 0 = return ()
| otherwise = do count <- new buckets
pile <- new buckets
countLoop (radix 0) v count
flagLoop cmp stop radix count pile v
flagLoop :: (PrimMonad m, MVector v e)
=> Comparison e
-> (e -> Int -> Bool)
-> (Int -> e -> Int)
-> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
flagLoop cmp stop radix count pile v = go 0 v
where
go pass v = do e <- unsafeRead v 0
unless (stop e $ pass 1) $ go' pass v
go' pass v
| len < threshold = I.sortByBounds cmp v 0 len
| otherwise = do accumulate count pile
permute (radix pass) count pile v
recurse 0
where
len = length v
ppass = pass + 1
recurse i
| i < len = do j <- countStripe (radix ppass) (radix pass) count v i
go ppass (unsafeSlice i (j i) v)
recurse j
| otherwise = return ()
accumulate :: (PrimMonad m)
=> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> m ()
accumulate count pile = loop 0 0
where
len = length count
loop i acc
| i < len = do ci <- unsafeRead count i
let acc' = acc + ci
unsafeWrite pile i acc
unsafeWrite count i acc'
loop (i+1) acc'
| otherwise = return ()
permute :: (PrimMonad m, MVector v e)
=> (e -> Int)
-> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
permute rdx count pile v = go 0
where
len = length v
go i
| i < len = do e <- unsafeRead v i
let r = rdx e
p <- unsafeRead pile r
m <- if r > 0
then unsafeRead count (r1)
else return 0
case () of
_ | m <= i && i < p -> go p
| i == p -> unsafeWrite pile r (p+1) >> go (i+1)
| otherwise -> follow i e p >> go (i+1)
| otherwise = return ()
follow i e j = do en <- unsafeRead v j
let r = rdx en
p <- inc pile r
if p == j
then follow i e (j+1)
else unsafeWrite v j e >> if i == p
then unsafeWrite v i en
else follow i en p
countStripe :: (PrimMonad m, MVector v e)
=> (e -> Int)
-> (e -> Int)
-> PV.MVector (PrimState m) Int
-> v (PrimState m) e
-> Int
-> m Int
countStripe rdx str count v lo = do set count 0
e <- unsafeRead v lo
go (str e) e (lo+1)
where
len = length v
go !s e i = inc count (rdx e) >>
if i < len
then do en <- unsafeRead v i
if str en == s
then go s en (i+1)
else return i
else return len
threshold :: Int
threshold = 25