{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Vector.Algorithms.Merge
       ( sort
       , sortBy
       , Comparison
       ) where
import Prelude hiding (read, length)
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison, copyOffset, midPoint)
import qualified Data.Vector.Algorithms.Optimal   as O
import qualified Data.Vector.Algorithms.Insertion as I
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort = sortBy compare
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy cmp vec
  | len <= 1  = return ()
  | len == 2  = O.sort2ByOffset cmp vec 0
  | len == 3  = O.sort3ByOffset cmp vec 0
  | len == 4  = O.sort4ByOffset cmp vec 0
  | otherwise = do buf <- new len
                   mergeSortWithBuf cmp vec buf
 where
 len = length vec
{-# INLINE sortBy #-}
mergeSortWithBuf :: (PrimMonad m, MVector v e)
                 => Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf cmp src buf = loop 0 (length src)
 where
 loop l u
   | len < threshold = I.sortByBounds cmp src l u
   | otherwise       = do loop l mid
                          loop mid u
                          merge cmp (unsafeSlice l len src) buf (mid - l)
  where len = u - l
        mid = midPoint u l
{-# INLINE mergeSortWithBuf #-}
merge :: (PrimMonad m, MVector v e)
      => Comparison e -> v (PrimState m) e -> v (PrimState m) e
      -> Int -> m ()
merge cmp src buf mid = do unsafeCopy tmp lower
                           eTmp <- unsafeRead tmp 0
                           eUpp <- unsafeRead upper 0
                           loop tmp 0 eTmp upper 0 eUpp 0
 where
 lower = unsafeSlice 0   mid                src
 upper = unsafeSlice mid (length src - mid) src
 tmp   = unsafeSlice 0   mid                buf
 wroteHigh low iLow eLow high iHigh iIns
   | iHigh >= length high = unsafeCopy (unsafeSlice iIns (length low - iLow) src)
                                       (unsafeSlice iLow (length low - iLow) low)
   | otherwise            = do eHigh <- unsafeRead high iHigh
                               loop low iLow eLow high iHigh eHigh iIns
 wroteLow low iLow high iHigh eHigh iIns
   | iLow  >= length low  = return ()
   | otherwise            = do eLow <- unsafeRead low iLow
                               loop low iLow eLow high iHigh eHigh iIns
 loop !low !iLow !eLow !high !iHigh !eHigh !iIns = case cmp eHigh eLow of
     LT -> do unsafeWrite src iIns eHigh
              wroteHigh low iLow eLow high (iHigh + 1) (iIns + 1)
     _  -> do unsafeWrite src iIns eLow
              wroteLow low (iLow + 1) high iHigh eHigh (iIns + 1)
{-# INLINE merge #-}
threshold :: Int
threshold = 25
{-# INLINE threshold #-}