{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}

-- ---------------------------------------------------------------------------
-- |
-- Module      : Data.Vector.Algorithms.Merge
-- Copyright   : (c) 2008-2011 Dan Doel
-- Maintainer  : Dan Doel <dan.doel@gmail.com>
-- Stability   : Experimental
-- Portability : Portable
--
-- This module implements a simple top-down merge sort. The temporary buffer
-- is preallocated to 1/2 the size of the input array, and shared through
-- the entire sorting process to ease the amount of allocation performed in
-- total. This is a stable sort.

module Data.Vector.Algorithms.Merge
       ( sort
       , sortBy
       , Comparison
       ) where

import Prelude hiding (read, length)

import Control.Monad.Primitive

import Data.Bits
import Data.Vector.Generic.Mutable

import Data.Vector.Algorithms.Common (Comparison, copyOffset, midPoint)

import qualified Data.Vector.Algorithms.Optimal   as O
import qualified Data.Vector.Algorithms.Insertion as I

-- | Sorts an array using the default comparison.
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort :: v (PrimState m) e -> m ()
sort = Comparison e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
sortBy Comparison e
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINABLE sort #-}

-- | Sorts an array using a custom comparison.
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy :: Comparison e -> v (PrimState m) e -> m ()
sortBy Comparison e
cmp v (PrimState m) e
vec = if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
4
                    then if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
2
                            then if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
2
                                    then () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                                    else Comparison e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> m ()
O.sort2ByOffset Comparison e
cmp v (PrimState m) e
vec Int
0
                            else if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3
                                    then Comparison e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> m ()
O.sort3ByOffset Comparison e
cmp v (PrimState m) e
vec Int
0
                                    else Comparison e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> m ()
O.sort4ByOffset Comparison e
cmp v (PrimState m) e
vec Int
0
                    else if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
threshold
                            then Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
I.sortByBounds Comparison e
cmp v (PrimState m) e
vec Int
0 Int
len
                            else do v (PrimState m) e
buf <- Int -> m (v (PrimState m) e)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new Int
halfLen
                                    Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf Comparison e
cmp v (PrimState m) e
vec v (PrimState m) e
buf
 where
 len :: Int
len     = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
vec
 -- odd lengths have a larger half that needs to fit, so use ceiling, not floor
 halfLen :: Int
halfLen = (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
{-# INLINE sortBy #-}

mergeSortWithBuf :: (PrimMonad m, MVector v e)
                 => Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf :: Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf Comparison e
cmp v (PrimState m) e
src v (PrimState m) e
buf = Int -> Int -> m ()
loop Int
0 (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src)
 where
 loop :: Int -> Int -> m ()
loop Int
l Int
u
   | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
threshold = Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
I.sortByBounds Comparison e
cmp v (PrimState m) e
src Int
l Int
u
   | Bool
otherwise       = do Int -> Int -> m ()
loop Int
l Int
mid
                          Int -> Int -> m ()
loop Int
mid Int
u
                          Comparison e
-> v (PrimState m) e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e
-> v (PrimState m) e -> v (PrimState m) e -> Int -> m ()
merge Comparison e
cmp (Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
l Int
len v (PrimState m) e
src) v (PrimState m) e
buf (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l)
  where len :: Int
len = Int
u Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
        mid :: Int
mid = Int -> Int -> Int
midPoint Int
u Int
l
{-# INLINE mergeSortWithBuf #-}

merge :: (PrimMonad m, MVector v e)
      => Comparison e -> v (PrimState m) e -> v (PrimState m) e
      -> Int -> m ()
merge :: Comparison e
-> v (PrimState m) e -> v (PrimState m) e -> Int -> m ()
merge Comparison e
cmp v (PrimState m) e
src v (PrimState m) e
buf Int
mid = do v (PrimState m) e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) e
tmp v (PrimState m) e
lower
                           e
eTmp <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
tmp Int
0
                           e
eUpp <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
upper Int
0
                           v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop v (PrimState m) e
tmp Int
0 e
eTmp v (PrimState m) e
upper Int
0 e
eUpp Int
0
 where
 lower :: v (PrimState m) e
lower = Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
0   Int
mid                v (PrimState m) e
src
 upper :: v (PrimState m) e
upper = Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
mid (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
mid) v (PrimState m) e
src
 tmp :: v (PrimState m) e
tmp   = Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
0   Int
mid                v (PrimState m) e
buf

 wroteHigh :: v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> Int -> m ()
wroteHigh v (PrimState m) e
low Int
iLow e
eLow v (PrimState m) e
high Int
iHigh Int
iIns
   | Int
iHigh Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
high = v (PrimState m) e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy (Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
iIns (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
low Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iLow) v (PrimState m) e
src)
                                       (Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
iLow (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
low Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iLow) v (PrimState m) e
low)
   | Bool
otherwise            = do e
eHigh <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
high Int
iHigh
                               v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop v (PrimState m) e
low Int
iLow e
eLow v (PrimState m) e
high Int
iHigh e
eHigh Int
iIns

 wroteLow :: v (PrimState m) e
-> Int -> v (PrimState m) e -> Int -> e -> Int -> m ()
wroteLow v (PrimState m) e
low Int
iLow v (PrimState m) e
high Int
iHigh e
eHigh Int
iIns
   | Int
iLow  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
low  = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
   | Bool
otherwise            = do e
eLow <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
low Int
iLow
                               v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop v (PrimState m) e
low Int
iLow e
eLow v (PrimState m) e
high Int
iHigh e
eHigh Int
iIns

 loop :: v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop !v (PrimState m) e
low !Int
iLow !e
eLow !v (PrimState m) e
high !Int
iHigh !e
eHigh !Int
iIns = case Comparison e
cmp e
eHigh e
eLow of
     Ordering
LT -> do v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
src Int
iIns e
eHigh
              v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> Int -> m ()
wroteHigh v (PrimState m) e
low Int
iLow e
eLow v (PrimState m) e
high (Int
iHigh Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
iIns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
     Ordering
_  -> do v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
src Int
iIns e
eLow
              v (PrimState m) e
-> Int -> v (PrimState m) e -> Int -> e -> Int -> m ()
wroteLow v (PrimState m) e
low (Int
iLow Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) v (PrimState m) e
high Int
iHigh e
eHigh (Int
iIns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE merge #-}

threshold :: Int
threshold :: Int
threshold = Int
25
{-# INLINE threshold #-}