module Data.Vector.Algorithms.AmericanFlag ( sort
                                           , sortBy
                                           , Lexicographic(..)
                                           ) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Word
import Data.Int
import Data.Bits
import qualified Data.ByteString as B
import Data.Vector.Generic.Mutable
import qualified Data.Vector.Primitive.Mutable as PV
import qualified Data.Vector.Unboxed.Mutable as U
import Data.Vector.Algorithms.Common
import qualified Data.Vector.Algorithms.Insertion as I
class Lexicographic e where
  
  
  terminate :: e -> Int -> Bool
  
  size      :: e -> Int
  
  
  index     :: Int -> e -> Int
instance Lexicographic Word8 where
  terminate _ n = n > 0
  
  size _ = 256
  
  index _ n = fromIntegral n
  
instance Lexicographic Word16 where
  terminate _ n = n > 1
  
  size _ = 256
  
  index 0 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 1 n = fromIntegral $ n .&. 255
  index _ _ = 0
  
instance Lexicographic Word32 where
  terminate _ n = n > 3
  
  size _ = 256
  
  index 0 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 2 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 3 n = fromIntegral $ n .&. 255
  index _ _ = 0
  
instance Lexicographic Word64 where
  terminate _ n = n > 7
  
  size _ = 256
  
  index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 7 n = fromIntegral $ n .&. 255
  index _ _ = 0
  
instance Lexicographic Word where
  terminate _ n = n > 7
  
  size _ = 256
  
  index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 7 n = fromIntegral $ n .&. 255
  index _ _ = 0
  
instance Lexicographic Int8 where
  terminate _ n = n > 0
  
  size _ = 256
  
  index _ n = 255 .&. fromIntegral n `xor` 128
  
instance Lexicographic Int16 where
  terminate _ n = n > 1
  
  size _ = 256
  
  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 8) .&. 255
  index 1 n = fromIntegral $ n .&. 255
  index _ _ = 0
  
instance Lexicographic Int32 where
  terminate _ n = n > 3
  
  size _ = 256
  
  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 24) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 2 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 3 n = fromIntegral $ n .&. 255
  index _ _ = 0
  
instance Lexicographic Int64 where
  terminate _ n = n > 7
  
  size _ = 256
  
  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 56) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 7 n = fromIntegral $ n .&. 255
  index _ _ = 0
  
instance Lexicographic Int where
  terminate _ n = n > 7
  
  size _ = 256
  
  index 0 n = ((n `xor` minBound) `shiftR` 56) .&. 255
  index 1 n = (n `shiftR` 48) .&. 255
  index 2 n = (n `shiftR` 40) .&. 255
  index 3 n = (n `shiftR` 32) .&. 255
  index 4 n = (n `shiftR` 24) .&. 255
  index 5 n = (n `shiftR` 16) .&. 255
  index 6 n = (n `shiftR`  8) .&. 255
  index 7 n = n .&. 255
  index _ _ = 0
  
instance Lexicographic B.ByteString where
  terminate b i = i >= B.length b
  
  size _ = 257
  
  index i b
    | i >= B.length b = 0
    | otherwise       = fromIntegral (B.index b i) + 1
  
sort :: forall e m v. (PrimMonad m, MVector v e, Lexicographic e, Ord e)
     => v (PrimState m) e -> m ()
sort v = sortBy compare terminate (size e) index v
 where e :: e
       e = undefined
sortBy :: (PrimMonad m, MVector v e)
       => Comparison e       
       -> (e -> Int -> Bool) 
       -> Int                
       -> (Int -> e -> Int)  
       -> v (PrimState m) e  
       -> m ()
sortBy cmp stop buckets radix v
  | length v == 0 = return ()
  | otherwise     = do count <- new buckets
                       pile <- new buckets
                       countLoop (radix 0) v count
                       flagLoop cmp stop radix count pile v
flagLoop :: (PrimMonad m, MVector v e)
         => Comparison e
         -> (e -> Int -> Bool)           
         -> (Int -> e -> Int)            
         -> PV.MVector (PrimState m) Int 
         -> PV.MVector (PrimState m) Int 
         -> v (PrimState m) e            
         -> m ()
flagLoop cmp stop radix count pile v = go 0 v
 where
 go pass v = do e <- unsafeRead v 0
                unless (stop e $ pass  1) $ go' pass v
 go' pass v
   | len < threshold = I.sortByBounds cmp v 0 len
   | otherwise       = do accumulate count pile
                          permute (radix pass) count pile v
                          recurse 0
  where
  len = length v
  ppass = pass + 1
  recurse i
    | i < len   = do j <- countStripe (radix ppass) (radix pass) count v i
                     go ppass (unsafeSlice i (j  i) v)
                     recurse j
    | otherwise = return ()
accumulate :: (PrimMonad m)
           => PV.MVector (PrimState m) Int
           -> PV.MVector (PrimState m) Int
           -> m ()
accumulate count pile = loop 0 0
 where
 len = length count
 loop i acc
   | i < len = do ci <- unsafeRead count i
                  let acc' = acc + ci
                  unsafeWrite pile i acc
                  unsafeWrite count i acc'
                  loop (i+1) acc'
   | otherwise    = return ()
permute :: (PrimMonad m, MVector v e)
        => (e -> Int)                       
        -> PV.MVector (PrimState m) Int     
        -> PV.MVector (PrimState m) Int     
        -> v (PrimState m) e                
        -> m ()
permute rdx count pile v = go 0
 where
 len = length v
 go i
   | i < len   = do e <- unsafeRead v i
                    let r = rdx e
                    p <- unsafeRead pile r
                    m <- if r > 0
                            then unsafeRead count (r1)
                            else return 0
                    case () of
                      
                      
                      _ | m <= i && i < p  -> go p
                      
                      
                        | i == p           -> unsafeWrite pile r (p+1) >> go (i+1)
                      
                        | otherwise        -> follow i e p >> go (i+1)
   | otherwise = return ()
 
 follow i e j = do en <- unsafeRead v j
                   let r = rdx en
                   p <- inc pile r
                   if p == j
                      
                      then follow i e (j+1)
                      else unsafeWrite v j e >> if i == p
                                             then unsafeWrite v i en
                                             else follow i en p
countStripe :: (PrimMonad m, MVector v e)
            => (e -> Int)                   
            -> (e -> Int)                   
            -> PV.MVector (PrimState m) Int 
            -> v (PrimState m) e            
            -> Int                          
            -> m Int                        
countStripe rdx str count v lo = do set count 0
                                    e <- unsafeRead v lo
                                    go (str e) e (lo+1)
 where
 len = length v
 go !s e i = inc count (rdx e) >>
            if i < len
               then do en <- unsafeRead v i
                       if str en == s
                          then go s en (i+1)
                          else return i
                else return len
threshold :: Int
threshold = 25