-- |
-- Module:     Data.Vector.Algorithms.Quicksort.Median
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com

{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MagicHash              #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE UndecidableInstances   #-}

module Data.Vector.Algorithms.Quicksort.Median
  ( Median(..)
  , Median3(..)
  , Median3or5(..)
  , MedianResult(..)
  ) where

import Prelude hiding (last)

import Control.Monad.Primitive
import Data.Bits
import Data.Function
import Data.Kind (Type)
import Data.Vector.Generic.Mutable qualified as GM

-- | Median selection result.
data MedianResult a
  -- | Value that was located at specific index in the original array.
  = ExistingValue !a {-# UNPACK #-} !Int
  -- | Value that is a good guess for a real median but may not be
  -- present in the array (or we don't know where it's exactly).
  --
  -- Good example is to pick first, last, and middle element and
  -- average them, which restricts us to dealing with numeric values
  -- but may yield good results depending on distribution of values in
  -- the array to be sorted.
  | Guess !a

existingValue :: CmpFst a Int -> MedianResult a
existingValue :: forall a. CmpFst a Int -> MedianResult a
existingValue (CmpFst (a
x, Int
n)) = forall a. a -> Int -> MedianResult a
ExistingValue a
x Int
n

-- | Median selection algorithm that, given a vector, should come up
-- with an elements that has good chances to be median (i.e to be
-- greater that half the elements and lower than the other remaining
-- half). The closer to the real median the selected element is, the
-- faster quicksort will run and the better parallelisation will be
-- achieved.
--
-- Instance can be declared for specific monad. This is useful if we want
-- to select median at random and need to thread random gen.
--
-- Parameter meaning;
-- - @a@ - the median parameter we're defining instance for
-- - @b@ - type of ellements this median selection method is applicable to
-- - @m@ - monad the median selection operates in
-- - @s@ - the same ‘index’ as in ‘ST s’ because vector to be sorted is parameterised and @m@ may need to mention it
class Median (a :: Type) (b :: Type) (m :: Type -> Type) (s :: Type) | a -> b, m -> s where
  -- | Come up with a median value of a given array
  selectMedian
    :: (GM.MVector v b, Ord b)
    => a     -- ^ Median algorithm than can carry extra info to be
             -- used during median selection (e.g. random generator)
    -> v s b -- ^ Array
    -> m (MedianResult b)

-- | Pick first, last, and the middle elements and find the one that's between the other two, e.g.
-- given elements @a@, @b@, and @c@ find @y@ among them that satisfies @x <= y <= z@.
data Median3 a = Median3

-- | Pick first, last, and the middle elements, if all of them are
-- distinct then return median of 3 like 'Median3' does, otherwise
-- take median of 5 from the already taken 3 and extra 2 elements at
-- @1/4@th and @3/4@th of array length.
data Median3or5 a = Median3or5

{-# INLINE pick3 #-}
-- Pick median among 3 values.
pick3 :: Ord a => a -> a -> a -> a
pick3 :: forall a. Ord a => a -> a -> a -> a
pick3 a
a a
b a
c =
  if a
b forall a. Ord a => a -> a -> Bool
< a
a
  then
    -- ... b < a ...
    if a
c forall a. Ord a => a -> a -> Bool
< a
a
    then
      if a
c forall a. Ord a => a -> a -> Bool
< a
b
      then
        -- c < b < a
        a
b
      else
        -- b <= c < a
        a
c
    else
      --  b < a <= c
      a
a
  else
    -- ... a <= b ...
    if a
c forall a. Ord a => a -> a -> Bool
< a
b
    then
      if a
c forall a. Ord a => a -> a -> Bool
< a
a
      then
        -- c < a <= b
        a
a
      else
        -- a <= c <= b
        a
c
    else
      -- a <= b <= c
      a
b

{-# INLINE sort3 #-}
-- Establish sortered order among 3 values.
sort3 :: Ord a => a -> a -> a -> (a, a, a)
sort3 :: forall a. Ord a => a -> a -> a -> (a, a, a)
sort3 a
a a
b a
c =
  if a
b forall a. Ord a => a -> a -> Bool
< a
a
  then
    -- ... b < a ...
    if a
c forall a. Ord a => a -> a -> Bool
< a
a
    then
      if a
c forall a. Ord a => a -> a -> Bool
< a
b
      then
        -- c < b < a
        (a
c, a
b, a
a)
      else
        -- b <= c < a
        (a
b, a
c, a
a)
    else
      --  b < a <= c
      (a
b, a
a, a
c)
  else
    -- ... a <= b ...
    if a
c forall a. Ord a => a -> a -> Bool
< a
b
    then
      if a
c forall a. Ord a => a -> a -> Bool
< a
a
      then
        -- c < a <= b
        (a
c, a
a, a
b)
      else
        -- a <= c <= b
        (a
a, a
c, a
b)
    else
      -- a <= b <= c
      (a
a, a
b, a
c)

newtype CmpFst a b = CmpFst { forall a b. CmpFst a b -> (a, b)
unCmpFst :: (a, b) }

instance Eq a => Eq (CmpFst a b) where
  == :: CmpFst a b -> CmpFst a b -> Bool
(==) = forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. CmpFst a b -> (a, b)
unCmpFst

instance Ord a => Ord (CmpFst a b) where
  compare :: CmpFst a b -> CmpFst a b -> Ordering
compare = forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. CmpFst a b -> (a, b)
unCmpFst

{-# INLINE readAt #-}
readAt :: (PrimMonad m, GM.MVector v a) => v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v (PrimState m) a
xs Int
n = (\a
x -> forall a b. (a, b) -> CmpFst a b
CmpFst (a
x, Int
n)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
xs Int
n

instance (PrimMonad m, s ~ PrimState m) => Median (Median3 a) a m s where
  {-# INLINE selectMedian #-}
  selectMedian
    :: forall (v :: Type -> Type -> Type).
       (GM.MVector v a, Ord a)
    => Median3 a
    -> v s a
    -> m (MedianResult a)
  selectMedian :: forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
Median3 a -> v s a -> m (MedianResult a)
selectMedian Median3 a
_ !v s a
v = do
    let len :: Int
        !len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
v
        pi0, pi1, pi2 :: Int
        !pi0 :: Int
pi0  = Int
0
        !pi1 :: Int
pi1  = Int -> Int
halve Int
len
        !pi2 :: Int
pi2  = Int
len forall a. Num a => a -> a -> a
- Int
1
    !CmpFst a Int
pv0 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi0
    !CmpFst a Int
pv1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi1
    !CmpFst a Int
pv2 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi2
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. CmpFst a Int -> MedianResult a
existingValue forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a -> a
pick3 CmpFst a Int
pv0 CmpFst a Int
pv1 CmpFst a Int
pv2

instance (PrimMonad m, s ~ PrimState m) => Median (Median3or5 a) a m s where
  {-# INLINE selectMedian #-}
  selectMedian
    :: forall (v :: Type -> Type -> Type).
       (GM.MVector v a, Ord a)
    => Median3or5 a
    -> v s a
    -> m (MedianResult a)
  selectMedian :: forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
Median3or5 a -> v s a -> m (MedianResult a)
selectMedian Median3or5 a
_ !v s a
v = do
    let len :: Int
        !len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
v
        pi0, pi1, pi2 :: Int
        !pi0 :: Int
pi0  = Int
0
        !pi1 :: Int
pi1  = Int -> Int
halve Int
len
        !pi2 :: Int
pi2  = Int
len forall a. Num a => a -> a -> a
- Int
1
    !CmpFst a Int
pv0 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi0
    !CmpFst a Int
pv1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi1
    !CmpFst a Int
pv2 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi2

    -- If median of 3 has chances to be good enough
    if CmpFst a Int
pv0 forall a. Eq a => a -> a -> Bool
/= CmpFst a Int
pv1 Bool -> Bool -> Bool
&& CmpFst a Int
pv1 forall a. Eq a => a -> a -> Bool
/= CmpFst a Int
pv2 Bool -> Bool -> Bool
&& CmpFst a Int
pv2 forall a. Eq a => a -> a -> Bool
/= CmpFst a Int
pv0
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. CmpFst a Int -> MedianResult a
existingValue forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a -> a
pick3 CmpFst a Int
pv0 CmpFst a Int
pv1 CmpFst a Int
pv2
    else do
      let pi01, pi12 :: Int
          !pi01 :: Int
pi01 = Int -> Int
halve Int
pi1
          !pi12 :: Int
pi12 = Int
pi1 forall a. Num a => a -> a -> a
+ Int
pi01

      !CmpFst a Int
pv01 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi01
      !CmpFst a Int
pv12 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi12

      let (!CmpFst a Int
mn, !CmpFst a Int
med, !CmpFst a Int
mx) = forall a. Ord a => a -> a -> a -> (a, a, a)
sort3 CmpFst a Int
pv0 CmpFst a Int
pv1 CmpFst a Int
pv2
          (!CmpFst a Int
mn', !CmpFst a Int
mx')
            | CmpFst a Int
pv01 forall a. Ord a => a -> a -> Bool
< CmpFst a Int
pv12 = (CmpFst a Int
pv01, CmpFst a Int
pv12)
            | Bool
otherwise   = (CmpFst a Int
pv12, CmpFst a Int
pv01)

          !med' :: CmpFst a Int
med'
            | CmpFst a Int
mn' forall a. Ord a => a -> a -> Bool
> CmpFst a Int
mx  = CmpFst a Int
mx
            | CmpFst a Int
mx' forall a. Ord a => a -> a -> Bool
< CmpFst a Int
mn  = CmpFst a Int
mn
            | Bool
otherwise = forall a. Ord a => a -> a -> a -> a
pick3 CmpFst a Int
mn' CmpFst a Int
med CmpFst a Int
mx'

      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. CmpFst a Int -> MedianResult a
existingValue CmpFst a Int
med'

{-# INLINE halve #-}
halve :: Int -> Int
halve :: Int -> Int
halve Int
x = Int
x forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1