{-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ScopedTypeVariables #-} -- --------------------------------------------------------------------------- -- | -- Module : Data.Vector.Algorithms.Intro -- Copyright : (c) 2008-2011 Dan Doel -- Maintainer : Dan Doel -- Stability : Experimental -- Portability : Non-portable (type operators, bang patterns) -- -- This module implements various algorithms based on the introsort algorithm, -- originally described by David R. Musser in the paper /Introspective Sorting -- and Selection Algorithms/. It is also in widespread practical use, as the -- standard unstable sort used in the C++ Standard Template Library. -- -- Introsort is at its core a quicksort. The version implemented here has the -- following optimizations that make it perform better in practice: -- -- * Small segments of the array are left unsorted until a final insertion -- sort pass. This is faster than recursing all the way down to -- one-element arrays. -- -- * The pivot for segment [l,u) is chosen as the median of the elements at -- l, u-1 and (u+l)/2. This yields good behavior on mostly sorted (or -- reverse-sorted) arrays. -- -- * The algorithm tracks its recursion depth, and if it decides it is -- taking too long (depth greater than 2 * lg n), it switches to a heap -- sort to maintain O(n lg n) worst case behavior. (This is what makes the -- algorithm introsort). module Data.Vector.Algorithms.Intro ( -- * Sorting sort , sortBy , sortByBounds -- * Selecting , select , selectBy , selectByBounds -- * Partial sorting , partialSort , partialSortBy , partialSortByBounds , Comparison ) where import Prelude hiding (read, length) import Control.Monad import Control.Monad.Primitive import Data.Bits import Data.Vector.Generic.Mutable import Data.Vector.Algorithms.Common (Comparison) import qualified Data.Vector.Algorithms.Insertion as I import qualified Data.Vector.Algorithms.Optimal as O import qualified Data.Vector.Algorithms.Heap as H -- | Sorts an entire array using the default ordering. sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m () sort = sortBy compare {-# INLINE sort #-} -- | Sorts an entire array using a custom ordering. sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m () sortBy cmp a = sortByBounds cmp a 0 (length a) {-# INLINE sortBy #-} -- | Sorts a portion of an array [l,u) using a custom ordering sortByBounds :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> Int -> Int -> m () sortByBounds cmp a l u | len < 2 = return () | len == 2 = O.sort2ByOffset cmp a l | len == 3 = O.sort3ByOffset cmp a l | len == 4 = O.sort4ByOffset cmp a l | otherwise = introsort cmp a (ilg len) l u where len = u - l {-# INLINE sortByBounds #-} -- Internal version of the introsort loop which allows partial -- sort functions to call with a specified bound on iterations. introsort :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m () introsort cmp a i l u = sort i l u >> I.sortByBounds cmp a l u where sort 0 l u = H.sortByBounds cmp a l u sort d l u | len < threshold = return () | otherwise = do O.sort3ByIndex cmp a c l (u-1) -- sort the median into the lowest position p <- unsafeRead a l mid <- partitionBy cmp a p (l+1) u unsafeSwap a l (mid - 1) sort (d-1) mid u sort (d-1) l (mid - 1) where len = u - l c = (u + l) `div` 2 {-# INLINE introsort #-} -- | Moves the least k elements to the front of the array in -- no particular order. select :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m () select = selectBy compare {-# INLINE select #-} -- | Moves the least k elements (as defined by the comparison) to -- the front of the array in no particular order. selectBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> Int -> m () selectBy cmp a k = selectByBounds cmp a k 0 (length a) {-# INLINE selectBy #-} -- | Moves the least k elements in the interval [l,u) to the positions -- [l,k+l) in no particular order. selectByBounds :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m () selectByBounds cmp a k l u | l >= u = return () | otherwise = go (ilg len) l (l + k) u where len = u - l go 0 l m u = H.selectByBounds cmp a (m - l) l u go n l m u = do O.sort3ByIndex cmp a c l (u-1) p <- unsafeRead a l mid <- partitionBy cmp a p (l+1) u unsafeSwap a l (mid - 1) if m > mid then go (n-1) mid m u else if m < mid - 1 then go (n-1) l m (mid - 1) else return () where c = (u + l) `div` 2 {-# INLINE selectByBounds #-} -- | Moves the least k elements to the front of the array, sorted. partialSort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m () partialSort = partialSortBy compare {-# INLINE partialSort #-} -- | Moves the least k elements (as defined by the comparison) to -- the front of the array, sorted. partialSortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> Int -> m () partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a) {-# INLINE partialSortBy #-} -- | Moves the least k elements in the interval [l,u) to the positions -- [l,k+l), sorted. partialSortByBounds :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m () partialSortByBounds cmp a k l u | l >= u = return () | otherwise = go (ilg len) l (l + k) u where isort = introsort cmp a {-# INLINE [1] isort #-} len = u - l go 0 l m n = H.partialSortByBounds cmp a (m - l) l u go n l m u | l == m = return () | otherwise = do O.sort3ByIndex cmp a c l (u-1) p <- unsafeRead a l mid <- partitionBy cmp a p (l+1) u unsafeSwap a l (mid - 1) case compare m mid of GT -> do isort (n-1) l (mid - 1) go (n-1) mid m u EQ -> isort (n-1) l m LT -> go n l m (mid - 1) where c = (u + l) `div` 2 {-# INLINE partialSortByBounds #-} partitionBy :: forall m v e. (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> e -> Int -> Int -> m Int partitionBy cmp a = partUp where -- 6.10 panics without the signatures for partUp and partDown, 6.12 and later -- versions don't need them partUp :: e -> Int -> Int -> m Int partUp p l u | l < u = do e <- unsafeRead a l case cmp e p of LT -> partUp p (l+1) u _ -> partDown p l (u-1) | otherwise = return l partDown :: e -> Int -> Int -> m Int partDown p l u | l < u = do e <- unsafeRead a u case cmp p e of LT -> partDown p l (u-1) _ -> unsafeSwap a l u >> partUp p (l+1) u | otherwise = return l {-# INLINE partitionBy #-} -- computes the number of recursive calls after which heapsort should -- be invoked given the lower and upper indices of the array to be sorted ilg :: Int -> Int ilg m = 2 * loop m 0 where loop 0 !k = k - 1 loop n !k = loop (n `shiftR` 1) (k+1) -- the size of array at which the introsort algorithm switches to insertion sort threshold :: Int threshold = 18 {-# INLINE threshold #-}