{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# lANGUAGE ScopedTypeVariables #-}

-- ---------------------------------------------------------------------------
-- |
-- Module      : Data.Vector.Algorithms.AmericanFlag
-- Copyright   : (c) 2011 Dan Doel
-- Maintainer  : Dan Doel <dan.doel@gmail.com>
-- Stability   : Experimental
-- Portability : Non-portable (FlexibleContexts, ScopedTypeVariables)
--
-- This module implements American flag sort: an in-place, unstable, bucket
-- sort. Also in contrast to radix sort, the values are inspected in a big
-- endian order, and buckets are sorted via recursive splitting. This,
-- however, makes it sensible for sorting strings in lexicographic order
-- (provided indexing is fast).
--
-- The algorithm works as follows: at each stage, the array is looped over,
-- counting the number of elements for each bucket. Then, starting at the
-- beginning of the array, elements are permuted in place to reside in the
-- proper bucket, following chains until they reach back to the current
-- base index. Finally, each bucket is sorted recursively. This lends itself
-- well to the aforementioned variable-length strings, and so the algorithm
-- takes a stopping predicate, which is given a representative of the stripe,
-- rather than running for a set number of iterations.

module Data.Vector.Algorithms.AmericanFlag ( sort
                                           , sortBy
                                           , Lexicographic(..)
                                           ) where

import Prelude hiding (read, length)

import Control.Monad
import Control.Monad.Primitive

import Data.Word
import Data.Int
import Data.Bits

import qualified Data.ByteString as B

import Data.Vector.Generic.Mutable
import qualified Data.Vector.Primitive.Mutable as PV

import qualified Data.Vector.Unboxed.Mutable as U

import Data.Vector.Algorithms.Common

import qualified Data.Vector.Algorithms.Insertion as I

-- | The methods of this class specify the information necessary to sort
-- arrays using the default ordering. The name 'Lexicographic' is meant
-- to convey that index should return results in a similar way to indexing
-- into a string.
class Lexicographic e where
  -- | Given a representative of a stripe and an index number, this
  -- function should determine whether to stop sorting.
  terminate :: e -> Int -> Bool
  -- | The size of the bucket array necessary for sorting es
  size      :: e -> Int
  -- | Determines which bucket a given element should inhabit for a
  -- particular iteration.
  index     :: Int -> e -> Int

instance Lexicographic Word8 where
  terminate _ n = n > 0
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index _ n = fromIntegral n
  {-# INLINE index #-}

instance Lexicographic Word16 where
  terminate _ n = n > 1
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 1 n = fromIntegral $ n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic Word32 where
  terminate _ n = n > 3
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 2 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 3 n = fromIntegral $ n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic Word64 where
  terminate _ n = n > 7
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 7 n = fromIntegral $ n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic Word where
  terminate _ n = n > 7
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 7 n = fromIntegral $ n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic Int8 where
  terminate _ n = n > 0
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index _ n = 255 .&. fromIntegral n `xor` 128
  {-# INLINE index #-}

instance Lexicographic Int16 where
  terminate _ n = n > 1
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 8) .&. 255
  index 1 n = fromIntegral $ n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic Int32 where
  terminate _ n = n > 3
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 24) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 2 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 3 n = fromIntegral $ n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic Int64 where
  terminate _ n = n > 7
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 56) .&. 255
  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
  index 7 n = fromIntegral $ n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic Int where
  terminate _ n = n > 7
  {-# INLINE terminate #-}
  size _ = 256
  {-# INLINE size #-}
  index 0 n = ((n `xor` minBound) `shiftR` 56) .&. 255
  index 1 n = (n `shiftR` 48) .&. 255
  index 2 n = (n `shiftR` 40) .&. 255
  index 3 n = (n `shiftR` 32) .&. 255
  index 4 n = (n `shiftR` 24) .&. 255
  index 5 n = (n `shiftR` 16) .&. 255
  index 6 n = (n `shiftR`  8) .&. 255
  index 7 n = n .&. 255
  index _ _ = 0
  {-# INLINE index #-}

instance Lexicographic B.ByteString where
  terminate b i = i >= B.length b
  {-# INLINE terminate #-}
  size _ = 257
  {-# INLINE size #-}
  index i b
    | i >= B.length b = 0
    | otherwise       = fromIntegral (B.index b i) + 1
  {-# INLINE index #-}

-- | Sorts an array using the default ordering. Both Lexicographic and
-- Ord are necessary because the algorithm falls back to insertion sort
-- for sufficiently small arrays.
sort :: forall e m v. (PrimMonad m, MVector v e, Lexicographic e, Ord e)
     => v (PrimState m) e -> m ()
sort v = sortBy compare terminate (size e) index v
 where e :: e
       e = undefined
{-# INLINE sort #-}

-- | A fully parameterized version of the sorting algorithm. Again, this
-- function takes both radix information and a comparison, because the
-- algorithms falls back to insertion sort for small arrays.
sortBy :: (PrimMonad m, MVector v e)
       => Comparison e       -- ^ a comparison for the insertion sort flalback
       -> (e -> Int -> Bool) -- ^ determines whether a stripe is complete
       -> Int                -- ^ the number of buckets necessary
       -> (Int -> e -> Int)  -- ^ the big-endian radix function
       -> v (PrimState m) e  -- ^ the array to be sorted
       -> m ()
sortBy cmp stop buckets radix v
  | length v == 0 = return ()
  | otherwise     = do count <- new buckets
                       pile <- new buckets
                       countLoop (radix 0) v count
                       flagLoop cmp stop radix count pile v
{-# INLINE sortBy #-}

flagLoop :: (PrimMonad m, MVector v e)
         => Comparison e
         -> (e -> Int -> Bool)           -- number of passes
         -> (Int -> e -> Int)            -- radix function
         -> PV.MVector (PrimState m) Int -- auxiliary count array
         -> PV.MVector (PrimState m) Int -- auxiliary pile array
         -> v (PrimState m) e            -- source array
         -> m ()
flagLoop cmp stop radix count pile v = go 0 v
 where

 go pass v = do e <- unsafeRead v 0
                unless (stop e $ pass - 1) $ go' pass v

 go' pass v
   | len < threshold = I.sortByBounds cmp v 0 len
   | otherwise       = do accumulate count pile
                          permute (radix pass) count pile v
                          recurse 0
  where
  len = length v
  ppass = pass + 1

  recurse i
    | i < len   = do j <- countStripe (radix ppass) (radix pass) count v i
                     go ppass (unsafeSlice i (j - i) v)
                     recurse j
    | otherwise = return ()
{-# INLINE flagLoop #-}

accumulate :: (PrimMonad m)
           => PV.MVector (PrimState m) Int
           -> PV.MVector (PrimState m) Int
           -> m ()
accumulate count pile = loop 0 0
 where
 len = length count

 loop i acc
   | i < len = do ci <- unsafeRead count i
                  let acc' = acc + ci
                  unsafeWrite pile i acc
                  unsafeWrite count i acc'
                  loop (i+1) acc'
   | otherwise    = return ()
{-# INLINE accumulate #-}

permute :: (PrimMonad m, MVector v e)
        => (e -> Int)                       -- radix function
        -> PV.MVector (PrimState m) Int     -- count array
        -> PV.MVector (PrimState m) Int     -- pile array
        -> v (PrimState m) e                -- source array
        -> m ()
permute rdx count pile v = go 0
 where
 len = length v

 go i
   | i < len   = do e <- unsafeRead v i
                    let r = rdx e
                    p <- unsafeRead pile r
                    m <- if r > 0
                            then unsafeRead count (r-1)
                            else return 0
                    case () of
                      -- if the current element is already in the right pile,
                      -- go to the end of the pile
                      _ | m <= i && i < p  -> go p
                      -- if the current element happens to be in the right
                      -- pile, bump the pile counter and go to the next element
                        | i == p           -> unsafeWrite pile r (p+1) >> go (i+1)
                      -- otherwise follow the chain
                        | otherwise        -> follow i e p >> go (i+1)
   | otherwise = return ()
 
 follow i e j = do en <- unsafeRead v j
                   let r = rdx en
                   p <- inc pile r
                   if p == j
                      -- if the target happens to be in the right pile, don't move it.
                      then follow i e (j+1)
                      else unsafeWrite v j e >> if i == p
                                             then unsafeWrite v i en
                                             else follow i en p
{-# INLINE permute #-}

countStripe :: (PrimMonad m, MVector v e)
            => (e -> Int)                   -- radix function
            -> (e -> Int)                   -- stripe function
            -> PV.MVector (PrimState m) Int -- count array
            -> v (PrimState m) e            -- source array
            -> Int                          -- starting position
            -> m Int                        -- end of stripe: [lo,hi)
countStripe rdx str count v lo = do set count 0
                                    e <- unsafeRead v lo
                                    go (str e) e (lo+1)
 where
 len = length v

 go !s e i = inc count (rdx e) >>
            if i < len
               then do en <- unsafeRead v i
                       if str en == s
                          then go s en (i+1)
                          else return i
                else return len
{-# INLINE countStripe #-}

threshold :: Int
threshold = 25