-- |
-- Module:     Data.Vector.Algorithms.Quicksort.Median
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com

{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MagicHash              #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE UndecidableInstances   #-}
{-# LANGUAGE UnboxedTuples          #-}

module Data.Vector.Algorithms.Quicksort.Median
  ( Median(..)
  , Median3(..)
  , Median3or5(..)
  , MedianResult(..)
  ) where

import Prelude hiding (last)

import Control.Monad.Primitive
import Data.Bits
import Data.Kind (Type)
import Data.Vector.Generic.Mutable qualified as GM
import GHC.Exts (Int(..), Int#)

-- | Median selection result.
data MedianResult a
  -- | Value that was located at specific index in the original array.
  = ExistingValue !a {-# UNPACK #-} !Int

existingValue :: (# a, Int# #) -> MedianResult a
existingValue :: forall a. (# a, Int# #) -> MedianResult a
existingValue (# !a
x, Int#
n #) = a -> Int -> MedianResult a
forall a. a -> Int -> MedianResult a
ExistingValue a
x (Int# -> Int
I# Int#
n)

-- | Median selection algorithm that, given a vector, should come up
-- with an elements that has good chances to be median (i.e to be
-- greater that half the elements and lower than the other remaining
-- half). The closer to the real median the selected element is, the
-- faster quicksort will run and the better parallelisation will be
-- achieved.
--
-- Instance can be declared for specific monad. This is useful if we want
-- to select median at random and need to thread random gen.
--
-- Parameter meaning;
-- - @a@ - the median parameter we're defining instance for
-- - @b@ - type of ellements this median selection method is applicable to
-- - @m@ - monad the median selection operates in
-- - @s@ - the same ‘index’ as in ‘ST s’ because vector to be sorted is parameterised and @m@ may need to mention it
class Median (a :: Type) (b :: Type) (m :: Type -> Type) (s :: Type) | a -> b, m -> s where
  -- | Come up with a median value of a given array
  selectMedian
    :: (GM.MVector v b, Ord b)
    => a     -- ^ Median algorithm than can carry extra info to be
             -- used during median selection (e.g. random generator)
    -> v s b -- ^ Array
    -> m (MedianResult b)

-- | Pick first, last, and the middle elements and find the one that's between the other two, e.g.
-- given elements @a@, @b@, and @c@ find @y@ among them that satisfies @x <= y <= z@.
data Median3 a = Median3

-- | Pick first, last, and the middle elements, if all of them are
-- distinct then return median of 3 like 'Median3' does, otherwise
-- take median of 5 from the already taken 3 and extra 2 elements at
-- @1/4@th and @3/4@th of array length.
data Median3or5 a = Median3or5

{-# INLINE pick3 #-}
-- Pick median among 3 values.
pick3 :: Ord a => a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
pick3 :: forall a.
Ord a =>
a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
pick3 !a
a Int#
ai !a
b Int#
bi !a
c Int#
ci =
  if a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a
  then
    -- ... b < a ...
    if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a
    then
      if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
b
      then
        -- c < b < a
        (# a
b, Int#
bi #)
      else
        -- b <= c < a
        (# a
c, Int#
ci #)
    else
      --  b < a <= c
      (# a
a, Int#
ai #)
  else
    -- ... a <= b ...
    if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
b
    then
      if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a
      then
        -- c < a <= b
        (# a
a, Int#
ai #)
      else
        -- a <= c <= b
        (# a
c, Int#
ci #)
    else
      -- a <= b <= c
      (# a
b, Int#
bi #)

{-# INLINE sort3 #-}
-- Establish sortered order among 3 values.
sort3 :: Ord a => a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int#, a, Int#, a, Int# #)
sort3 :: forall a.
Ord a =>
a
-> Int#
-> a
-> Int#
-> a
-> Int#
-> (# a, Int#, a, Int#, a, Int# #)
sort3 !a
a Int#
ai !a
b Int#
bi !a
c Int#
ci =
  if a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a
  then
    -- ... b < a ...
    if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a
    then
      if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
b
      then
        -- c < b < a
        (# a
c, Int#
ci, a
b, Int#
bi, a
a, Int#
ai #)
      else
        -- b <= c < a
        (# a
b, Int#
bi, a
c, Int#
ci, a
a, Int#
ai #)
    else
      --  b < a <= c
      (# a
b, Int#
bi, a
a, Int#
ai, a
c, Int#
ci #)
  else
    -- ... a <= b ...
    if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
b
    then
      if a
c a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a
      then
        -- c < a <= b
        (# a
c, Int#
ci, a
a, Int#
ai, a
b, Int#
bi #)
      else
        -- a <= c <= b
        (# a
a, Int#
ai, a
c, Int#
ci, a
b, Int#
bi #)
    else
      -- a <= b <= c
      (# a
a, Int#
ai, a
b, Int#
bi, a
c, Int#
ci #)

instance (PrimMonad m, s ~ PrimState m) => Median (Median3 a) a m s where
  {-# INLINE selectMedian #-}
  selectMedian
    :: forall (v :: Type -> Type -> Type).
       (GM.MVector v a, Ord a)
    => Median3 a
    -> v s a
    -> m (MedianResult a)
  selectMedian :: forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
Median3 a -> v s a -> m (MedianResult a)
selectMedian Median3 a
_ !v s a
v = do
    let len :: Int
        !len :: Int
len = v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
v
        pi0, pi1, pi2 :: Int#
        !(I# Int#
pi0)  = Int
0
        !(I# Int#
pi1)  = Int -> Int
halve Int
len
        !(I# Int#
pi2)  = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    !a
pv0 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi0)
    !a
pv1 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi1)
    !a
pv2 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi2)
    MedianResult a -> m (MedianResult a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MedianResult a -> m (MedianResult a))
-> MedianResult a -> m (MedianResult a)
forall a b. (a -> b) -> a -> b
$! (# a, Int# #) -> MedianResult a
forall a. (# a, Int# #) -> MedianResult a
existingValue (a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
forall a.
Ord a =>
a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
pick3 a
pv0 Int#
pi0 a
pv1 Int#
pi1 a
pv2 Int#
pi2)

instance (PrimMonad m, s ~ PrimState m) => Median (Median3or5 a) a m s where
  {-# INLINE selectMedian #-}
  selectMedian
    :: forall (v :: Type -> Type -> Type).
       (GM.MVector v a, Ord a)
    => Median3or5 a
    -> v s a
    -> m (MedianResult a)
  selectMedian :: forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
Median3or5 a -> v s a -> m (MedianResult a)
selectMedian Median3or5 a
_ !v s a
v = do
    let len :: Int
        !len :: Int
len = v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
v
        pi0, pi1, pi2 :: Int#
        !(I# Int#
pi0)  = Int
0
        !(I# Int#
pi1)  = Int -> Int
halve Int
len
        !(I# Int#
pi2)  = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    !a
pv0 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi0)
    !a
pv1 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi1)
    !a
pv2 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi2)

    -- If median of 3 has chances to be good enough
    if a
pv0 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
pv1 Bool -> Bool -> Bool
&& a
pv1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
pv2 Bool -> Bool -> Bool
&& a
pv2 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
pv0
    then MedianResult a -> m (MedianResult a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MedianResult a -> m (MedianResult a))
-> MedianResult a -> m (MedianResult a)
forall a b. (a -> b) -> a -> b
$! (# a, Int# #) -> MedianResult a
forall a. (# a, Int# #) -> MedianResult a
existingValue (a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
forall a.
Ord a =>
a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
pick3 a
pv0 Int#
pi0 a
pv1 Int#
pi1 a
pv2 Int#
pi2)
    else do
      let pi01, pi12 :: Int#
          !(I# Int#
pi01) = Int -> Int
halve (Int# -> Int
I# Int#
pi1)
          !(I# Int#
pi12) = Int# -> Int
I# Int#
pi1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int# -> Int
I# Int#
pi01

      !a
pv01 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi01)
      !a
pv12 <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState m) a
v (Int# -> Int
I# Int#
pi12)

      let !(# !a
mn, !Int#
mni, !a
med, !Int#
medi, !a
mx, !Int#
mxi #) = a
-> Int#
-> a
-> Int#
-> a
-> Int#
-> (# a, Int#, a, Int#, a, Int# #)
forall a.
Ord a =>
a
-> Int#
-> a
-> Int#
-> a
-> Int#
-> (# a, Int#, a, Int#, a, Int# #)
sort3 a
pv0 Int#
pi0 a
pv1 Int#
pi1 a
pv2 Int#
pi2
          !(# !a
mn', !Int#
mni', !a
mx', !Int#
mxi' #)
            | a
pv01 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
pv12 = (# a
pv01, Int#
pi01, a
pv12, Int#
pi12 #)
            | Bool
otherwise   = (# a
pv12, Int#
pi12, a
pv01, Int#
pi01 #)

          !med' :: (# a, Int# #)
med'@(# !a
_, !Int#
_ #)
            | a
mn' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
mx  = (# a
mx, Int#
mxi #)
            | a
mx' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
mn  = (# a
mn, Int#
mni #)
            | Bool
otherwise = a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
forall a.
Ord a =>
a -> Int# -> a -> Int# -> a -> Int# -> (# a, Int# #)
pick3 a
mn' Int#
mni' a
med Int#
medi a
mx' Int#
mxi'

      MedianResult a -> m (MedianResult a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MedianResult a -> m (MedianResult a))
-> MedianResult a -> m (MedianResult a)
forall a b. (a -> b) -> a -> b
$! (# a, Int# #) -> MedianResult a
forall a. (# a, Int# #) -> MedianResult a
existingValue (# a, Int# #)
med'

{-# INLINE halve #-}
halve :: Int -> Int
halve :: Int -> Int
halve Int
x = Int
x Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1