-- |
-- Module:     Data.Vector.Algorithms.Quicksort.Predefined.AveragingMedian
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com

{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE UndecidableInstances #-}

module Data.Vector.Algorithms.Quicksort.Predefined.AveragingMedian
  ( AveragingMedian(..)
  ) where

import Control.Monad.Primitive
import Data.Bits
import Data.Int
import Data.Kind
import Data.Vector.Generic.Mutable qualified as GM

import Data.Vector.Algorithms.Quicksort.Median
import Data.Vector.Algorithms.Quicksort.Predefined.Pair

data AveragingMedian a = AveragingMedian

instance (PrimMonad m, s ~ PrimState m) => Median (AveragingMedian Int64) Int64 m s where
  {-# INLINE selectMedian #-}
  selectMedian
    :: forall (v :: Type -> Type -> Type).
       GM.MVector v Int64
    => AveragingMedian Int64
    -> v s Int64
    -> m (MedianResult Int64)
  selectMedian :: forall (v :: * -> * -> *).
MVector v Int64 =>
AveragingMedian Int64 -> v s Int64 -> m (MedianResult Int64)
selectMedian AveragingMedian Int64
_ !v s Int64
v = do
    let len :: Int
        !len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s Int64
v
        pi0, pi1, pi2 :: Int
        !pi0 :: Int
pi0  = Int
0
        !pi1 :: Int
pi1  = Int -> Int
halve Int
len
        !pi2 :: Int
pi2  = Int
len forall a. Num a => a -> a -> a
- Int
1
    !Int64
pv0 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s Int64
v Int
pi0
    !Int64
pv1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s Int64
v Int
pi1
    !Int64
pv2 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s Int64
v Int
pi2
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. a -> MedianResult a
Guess forall a b. (a -> b) -> a -> b
$ Int64
pv0 forall a. Num a => a -> a -> a
+ Int64
pv1 forall a. Num a => a -> a -> a
+ Int64
pv2 forall a. Integral a => a -> a -> a
`quot` Int64
3

instance (PrimMonad m, s ~ PrimState m) => Median (AveragingMedian (TestPair Int32 b)) (TestPair Int32 b) m s where
  {-# INLINE selectMedian #-}
  selectMedian
    :: forall (v :: Type -> Type -> Type).
       GM.MVector v (TestPair Int32 b)
    => AveragingMedian (TestPair Int32 b)
    -> v s (TestPair Int32 b)
    -> m (MedianResult (TestPair Int32 b))
  selectMedian :: forall (v :: * -> * -> *).
MVector v (TestPair Int32 b) =>
AveragingMedian (TestPair Int32 b)
-> v s (TestPair Int32 b) -> m (MedianResult (TestPair Int32 b))
selectMedian AveragingMedian (TestPair Int32 b)
_ !v s (TestPair Int32 b)
v = do
    let len :: Int
        !len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s (TestPair Int32 b)
v
        pi0, pi1, pi2 :: Int
        !pi0 :: Int
pi0  = Int
0
        !pi1 :: Int
pi1  = Int -> Int
halve Int
len
        !pi2 :: Int
pi2  = Int
len forall a. Num a => a -> a -> a
- Int
1
    !(TestPair !Int32
pv0 b
pv0') <-                   forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s (TestPair Int32 b)
v Int
pi0
    !Int32
pv1                  <- forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. TestPair a b -> (a, b)
toTuple forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s (TestPair Int32 b)
v Int
pi1
    !Int32
pv2                  <- forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. TestPair a b -> (a, b)
toTuple forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s (TestPair Int32 b)
v Int
pi2
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. a -> MedianResult a
Guess forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> TestPair a b
TestPair (Int32
pv0 forall a. Num a => a -> a -> a
+ Int32
pv1 forall a. Num a => a -> a -> a
+ Int32
pv2 forall a. Integral a => a -> a -> a
`quot` Int32
3) b
pv0'

{-# INLINE halve #-}
halve :: Int -> Int
halve :: Int -> Int
halve Int
x = Int
x forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1