{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Data.Select.Mutable.Intro
(select)
where
import Data.Vector.Generic.Mutable (MVector)
import qualified Data.Vector.Generic.Mutable as Vector
import qualified Data.Select.Mutable.Median as WorstCase
import Data.Vector.Mutable.Partition
import Data.Median.Optimal
import Data.Select.Optimal
import Control.Applicative.LiftMany
import Control.Monad.ST
import Data.Bits
#if !MIN_VERSION_base(4,8,0)
import Data.Functor ((<$>))
import Control.Applicative (pure)
#endif
select
:: MVector v a
=> (a -> a -> Bool) -> v s a -> Int -> Int -> Int -> ST s Int
select lte !xs !l' !r' !n = go (ilg (r' - l')) l' r'
where
#if MIN_VERSION_base(4,8,0)
ilg !x = 2 * finiteBitSize x - 1 - countLeadingZeros x
#else
ilg !m = 2 * loop m (0 :: Int)
where
loop 0 !k = k - 1
loop n' !k = loop (n' `shiftR` 1) (k+1)
#endif
{-# INLINE ilg #-}
go 0 !l !r = WorstCase.select lte xs l r n
go !d !l !r =
case r - l of
0 -> pure l
1 ->
(l +) <$>
liftA2
(select2 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
2 ->
(l +) <$>
liftA3
(select3 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
3 ->
(l +) <$>
liftA4
(select4 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
(Vector.unsafeRead xs (l + 3))
4 ->
(l +) <$>
liftA5
(select5 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
(Vector.unsafeRead xs (l + 3))
(Vector.unsafeRead xs (l + 4))
s -> do
i <-
partition lte xs l r =<<
((l +) <$>
liftA3
(median3 lte)
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + (s `div` 2)))
(Vector.unsafeRead xs r))
case compare n i of
EQ -> pure n
LT -> go (d - 1) l (i - 1)
GT -> go (d - 1) (i + 1) r
{-# INLINE select #-}