{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} -- | -- Module : Data.Select.Mutable.Intro -- Description : Introselect internals on mutable, generic vectors. -- Copyright : (c) Donnacha Oisín Kidney, 2018 -- License : MIT -- Maintainer : mail@doisinkidney.com -- Stability : experimental -- Portability : portable -- -- Introselect internals on mutable, generic vectors. module Data.Select.Mutable.Intro (select) where import Data.Vector.Generic.Mutable (MVector) import qualified Data.Vector.Generic.Mutable as Vector import qualified Data.Select.Mutable.Median as WorstCase import Data.Vector.Mutable.Partition import Data.Median.Optimal import Data.Select.Optimal import Control.Applicative.LiftMany import Control.Monad.ST import Data.Bits #if !MIN_VERSION_base(4,8,0) import Data.Functor ((<$>)) import Control.Applicative (pure) #endif -- | @'select' ('<=') xs lb ub n@ returns the 'n'th item in the -- indices in the inclusive range ['lb','ub']. select :: MVector v a => (a -> a -> Bool) -> v s a -> Int -> Int -> Int -> ST s Int select lte !xs !l' !r' !n = go (ilg (r' - l')) l' r' where #if MIN_VERSION_base(4,8,0) ilg !x = 2 * finiteBitSize x - 1 - countLeadingZeros x #else ilg !m = 2 * loop m (0 :: Int) where loop 0 !k = k - 1 loop n' !k = loop (n' `shiftR` 1) (k+1) #endif {-# INLINE ilg #-} go 0 !l !r = WorstCase.select lte xs l r n go !d !l !r = case r - l of 0 -> pure l 1 -> (l +) <$> liftA2 (select2 lte (n - l)) (Vector.unsafeRead xs l) (Vector.unsafeRead xs (l + 1)) 2 -> (l +) <$> liftA3 (select3 lte (n - l)) (Vector.unsafeRead xs l) (Vector.unsafeRead xs (l + 1)) (Vector.unsafeRead xs (l + 2)) 3 -> (l +) <$> liftA4 (select4 lte (n - l)) (Vector.unsafeRead xs l) (Vector.unsafeRead xs (l + 1)) (Vector.unsafeRead xs (l + 2)) (Vector.unsafeRead xs (l + 3)) 4 -> (l +) <$> liftA5 (select5 lte (n - l)) (Vector.unsafeRead xs l) (Vector.unsafeRead xs (l + 1)) (Vector.unsafeRead xs (l + 2)) (Vector.unsafeRead xs (l + 3)) (Vector.unsafeRead xs (l + 4)) s -> do i <- partition lte xs l r =<< ((l +) <$> liftA3 (median3 lte) (Vector.unsafeRead xs l) (Vector.unsafeRead xs (l + (s `div` 2))) (Vector.unsafeRead xs r)) case compare n i of EQ -> pure n LT -> go (d - 1) l (i - 1) GT -> go (d - 1) (i + 1) r {-# INLINE select #-}