{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Vector.Algorithms.Quicksort.Median
( Median(..)
, Median3(..)
, Median3or5(..)
, MedianResult(..)
) where
import Prelude hiding (last)
import Control.Monad.Primitive
import Data.Bits
import Data.Function
import Data.Kind (Type)
import Data.Vector.Generic.Mutable qualified as GM
data MedianResult a
= ExistingValue !a {-# UNPACK #-} !Int
| Guess !a
existingValue :: CmpFst a Int -> MedianResult a
existingValue :: forall a. CmpFst a Int -> MedianResult a
existingValue (CmpFst (a
x, Int
n)) = forall a. a -> Int -> MedianResult a
ExistingValue a
x Int
n
class Median (a :: Type) (b :: Type) (m :: Type -> Type) (s :: Type) | a -> b, m -> s where
selectMedian
:: (GM.MVector v b, Ord b)
=> a
-> v s b
-> m (MedianResult b)
data Median3 a = Median3
data Median3or5 a = Median3or5
{-# INLINE pick3 #-}
pick3 :: Ord a => a -> a -> a -> a
pick3 :: forall a. Ord a => a -> a -> a -> a
pick3 a
a a
b a
c =
if a
b forall a. Ord a => a -> a -> Bool
< a
a
then
if a
c forall a. Ord a => a -> a -> Bool
< a
a
then
if a
c forall a. Ord a => a -> a -> Bool
< a
b
then
a
b
else
a
c
else
a
a
else
if a
c forall a. Ord a => a -> a -> Bool
< a
b
then
if a
c forall a. Ord a => a -> a -> Bool
< a
a
then
a
a
else
a
c
else
a
b
{-# INLINE sort3 #-}
sort3 :: Ord a => a -> a -> a -> (a, a, a)
sort3 :: forall a. Ord a => a -> a -> a -> (a, a, a)
sort3 a
a a
b a
c =
if a
b forall a. Ord a => a -> a -> Bool
< a
a
then
if a
c forall a. Ord a => a -> a -> Bool
< a
a
then
if a
c forall a. Ord a => a -> a -> Bool
< a
b
then
(a
c, a
b, a
a)
else
(a
b, a
c, a
a)
else
(a
b, a
a, a
c)
else
if a
c forall a. Ord a => a -> a -> Bool
< a
b
then
if a
c forall a. Ord a => a -> a -> Bool
< a
a
then
(a
c, a
a, a
b)
else
(a
a, a
c, a
b)
else
(a
a, a
b, a
c)
newtype CmpFst a b = CmpFst { forall a b. CmpFst a b -> (a, b)
unCmpFst :: (a, b) }
instance Eq a => Eq (CmpFst a b) where
== :: CmpFst a b -> CmpFst a b -> Bool
(==) = forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. CmpFst a b -> (a, b)
unCmpFst
instance Ord a => Ord (CmpFst a b) where
compare :: CmpFst a b -> CmpFst a b -> Ordering
compare = forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. CmpFst a b -> (a, b)
unCmpFst
{-# INLINE readAt #-}
readAt :: (PrimMonad m, GM.MVector v a) => v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v (PrimState m) a
xs Int
n = (\a
x -> forall a b. (a, b) -> CmpFst a b
CmpFst (a
x, Int
n)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
xs Int
n
instance (PrimMonad m, s ~ PrimState m) => Median (Median3 a) a m s where
{-# INLINE selectMedian #-}
selectMedian
:: forall (v :: Type -> Type -> Type).
(GM.MVector v a, Ord a)
=> Median3 a
-> v s a
-> m (MedianResult a)
selectMedian :: forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
Median3 a -> v s a -> m (MedianResult a)
selectMedian Median3 a
_ !v s a
v = do
let len :: Int
!len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
v
pi0, pi1, pi2 :: Int
!pi0 :: Int
pi0 = Int
0
!pi1 :: Int
pi1 = Int -> Int
halve Int
len
!pi2 :: Int
pi2 = Int
len forall a. Num a => a -> a -> a
- Int
1
!CmpFst a Int
pv0 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi0
!CmpFst a Int
pv1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi1
!CmpFst a Int
pv2 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi2
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. CmpFst a Int -> MedianResult a
existingValue forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a -> a
pick3 CmpFst a Int
pv0 CmpFst a Int
pv1 CmpFst a Int
pv2
instance (PrimMonad m, s ~ PrimState m) => Median (Median3or5 a) a m s where
{-# INLINE selectMedian #-}
selectMedian
:: forall (v :: Type -> Type -> Type).
(GM.MVector v a, Ord a)
=> Median3or5 a
-> v s a
-> m (MedianResult a)
selectMedian :: forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
Median3or5 a -> v s a -> m (MedianResult a)
selectMedian Median3or5 a
_ !v s a
v = do
let len :: Int
!len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
v
pi0, pi1, pi2 :: Int
!pi0 :: Int
pi0 = Int
0
!pi1 :: Int
pi1 = Int -> Int
halve Int
len
!pi2 :: Int
pi2 = Int
len forall a. Num a => a -> a -> a
- Int
1
!CmpFst a Int
pv0 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi0
!CmpFst a Int
pv1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi1
!CmpFst a Int
pv2 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi2
if CmpFst a Int
pv0 forall a. Eq a => a -> a -> Bool
/= CmpFst a Int
pv1 Bool -> Bool -> Bool
&& CmpFst a Int
pv1 forall a. Eq a => a -> a -> Bool
/= CmpFst a Int
pv2 Bool -> Bool -> Bool
&& CmpFst a Int
pv2 forall a. Eq a => a -> a -> Bool
/= CmpFst a Int
pv0
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. CmpFst a Int -> MedianResult a
existingValue forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a -> a
pick3 CmpFst a Int
pv0 CmpFst a Int
pv1 CmpFst a Int
pv2
else do
let pi01, pi12 :: Int
!pi01 :: Int
pi01 = Int -> Int
halve Int
pi1
!pi12 :: Int
pi12 = Int
pi1 forall a. Num a => a -> a -> a
+ Int
pi01
!CmpFst a Int
pv01 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi01
!CmpFst a Int
pv12 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m (CmpFst a Int)
readAt v s a
v Int
pi12
let (!CmpFst a Int
mn, !CmpFst a Int
med, !CmpFst a Int
mx) = forall a. Ord a => a -> a -> a -> (a, a, a)
sort3 CmpFst a Int
pv0 CmpFst a Int
pv1 CmpFst a Int
pv2
(!CmpFst a Int
mn', !CmpFst a Int
mx')
| CmpFst a Int
pv01 forall a. Ord a => a -> a -> Bool
< CmpFst a Int
pv12 = (CmpFst a Int
pv01, CmpFst a Int
pv12)
| Bool
otherwise = (CmpFst a Int
pv12, CmpFst a Int
pv01)
!med' :: CmpFst a Int
med'
| CmpFst a Int
mn' forall a. Ord a => a -> a -> Bool
> CmpFst a Int
mx = CmpFst a Int
mx
| CmpFst a Int
mx' forall a. Ord a => a -> a -> Bool
< CmpFst a Int
mn = CmpFst a Int
mn
| Bool
otherwise = forall a. Ord a => a -> a -> a -> a
pick3 CmpFst a Int
mn' CmpFst a Int
med CmpFst a Int
mx'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! forall a. CmpFst a Int -> MedianResult a
existingValue CmpFst a Int
med'
{-# INLINE halve #-}
halve :: Int -> Int
halve :: Int -> Int
halve Int
x = Int
x forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1