{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FlexibleContexts #-}
-- |
-- Module      : Data.Massiv.Array.Ops.Sort
-- Copyright   : (c) Alexey Kuleshevich 2018-2021
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Array.Ops.Sort
  ( tally
  , quicksort
  , quicksortM_
  , unsafeUnstablePartitionRegionM
  ) where

import Control.Monad (when)
import Control.Scheduler
import Data.Massiv.Array.Delayed.Stream
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Ops.Transform
import Data.Massiv.Core.Common
import Data.Massiv.Vector (scatMaybes, sunfoldrN)
import System.IO.Unsafe

-- | Count how many occurance of each element there is in the array. Results will be
-- sorted in ascending order of the element.
--
-- ==== __Example__
--
-- >>> import Data.Massiv.Array as A
-- >>> xs = fromList Seq [2, 4, 3, 2, 4, 5, 2, 1] :: Array P Ix1 Int
-- >>> xs
-- Array P Seq (Sz1 8)
--   [ 2, 4, 3, 2, 4, 5, 2, 1 ]
-- >>> tally xs
-- Array DS Seq (Sz1 5)
--   [ (1,1), (2,3), (3,1), (4,2), (5,1) ]
--
-- @since 0.4.4
tally :: (Mutable r Ix1 e, Resize r ix, Load r ix e, Ord e) => Array r ix e -> Array DS Ix1 (e, Int)
tally :: Array r ix e -> Array DS Ix1 (e, Ix1)
tally Array r ix e
arr
  | Array r ix e -> Bool
forall r ix e. Load r ix e => Array r ix e -> Bool
isEmpty Array r ix e
arr = Comp -> Array DS Ix1 (e, Ix1) -> Array DS Ix1 (e, Ix1)
forall r ix e.
Construct r ix e =>
Comp -> Array r ix e -> Array r ix e
setComp (Array r ix e -> Comp
forall r ix e. Load r ix e => Array r ix e -> Comp
getComp Array r ix e
arr) Array DS Ix1 (e, Ix1)
forall r ix e. Construct r ix e => Array r ix e
empty
  | Bool
otherwise = Array DS Ix1 (Maybe (e, Ix1)) -> Array DS Ix1 (e, Ix1)
forall r ix a.
Stream r ix (Maybe a) =>
Array r ix (Maybe a) -> Vector DS a
scatMaybes (Array DS Ix1 (Maybe (e, Ix1)) -> Array DS Ix1 (e, Ix1))
-> Array DS Ix1 (Maybe (e, Ix1)) -> Array DS Ix1 (e, Ix1)
forall a b. (a -> b) -> a -> b
$ Sz1
-> ((Ix1, Ix1, e) -> Maybe (Maybe (e, Ix1), (Ix1, Ix1, e)))
-> (Ix1, Ix1, e)
-> Array DS Ix1 (Maybe (e, Ix1))
forall s e. Sz1 -> (s -> Maybe (e, s)) -> s -> Vector DS e
sunfoldrN (Sz1
sz Sz1 -> Sz1 -> Sz1
forall a. Num a => a -> a -> a
+ Sz1
1) (Ix1, Ix1, e) -> Maybe (Maybe (e, Ix1), (Ix1, Ix1, e))
forall b. Num b => (Ix1, b, e) -> Maybe (Maybe (e, b), (Ix1, b, e))
count (Ix1
0, Ix1
0, Array r Ix1 e
sorted Array r Ix1 e -> Ix1 -> e
forall r ix e. Manifest r ix e => Array r ix e -> ix -> e
! Ix1
0)
  where
    sz :: Sz1
sz@(Sz Ix1
k) = Array r Ix1 e -> Sz1
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array r Ix1 e
sorted
    count :: (Ix1, b, e) -> Maybe (Maybe (e, b), (Ix1, b, e))
count (!Ix1
i, !b
n, !e
prev)
      | Ix1
i Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
< Ix1
k =
        let !e' :: e
e' = Array r Ix1 e -> Ix1 -> e
forall r ix e. Source r ix e => Array r ix e -> Ix1 -> e
unsafeLinearIndex Array r Ix1 e
sorted Ix1
i
         in if e
prev e -> e -> Bool
forall a. Eq a => a -> a -> Bool
== e
e'
              then (Maybe (e, b), (Ix1, b, e)) -> Maybe (Maybe (e, b), (Ix1, b, e))
forall a. a -> Maybe a
Just (Maybe (e, b)
forall a. Maybe a
Nothing, (Ix1
i Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1, b
n b -> b -> b
forall a. Num a => a -> a -> a
+ b
1, e
prev))
              else (Maybe (e, b), (Ix1, b, e)) -> Maybe (Maybe (e, b), (Ix1, b, e))
forall a. a -> Maybe a
Just ((e, b) -> Maybe (e, b)
forall a. a -> Maybe a
Just (e
prev, b
n), (Ix1
i Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1, b
1, e
e'))
      | Bool
otherwise = (Maybe (e, b), (Ix1, b, e)) -> Maybe (Maybe (e, b), (Ix1, b, e))
forall a. a -> Maybe a
Just ((e, b) -> Maybe (e, b)
forall a. a -> Maybe a
Just (e
prev, b
n), (Ix1
i Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1, b
n, e
prev))
    {-# INLINE count #-}
    sorted :: Array r Ix1 e
sorted = Array r Ix1 e -> Array r Ix1 e
forall r e.
(Mutable r Ix1 e, Ord e) =>
Array r Ix1 e -> Array r Ix1 e
quicksort (Array r Ix1 e -> Array r Ix1 e) -> Array r Ix1 e -> Array r Ix1 e
forall a b. (a -> b) -> a -> b
$ Array r ix e -> Array r Ix1 e
forall r ix e.
(Load r ix e, Resize r ix) =>
Array r ix e -> Array r Ix1 e
flatten Array r ix e
arr
{-# INLINE tally #-}



-- | Partition a segment of a vector. Starting and ending indices are unchecked.
--
-- @since 0.3.2
unsafeUnstablePartitionRegionM ::
     forall r e m. (Mutable r Ix1 e, PrimMonad m)
  => MArray (PrimState m) r Ix1 e
  -> (e -> Bool)
  -> Ix1 -- ^ Start index of the region
  -> Ix1 -- ^ End index of the region
  -> m Ix1
unsafeUnstablePartitionRegionM :: MArray (PrimState m) r Ix1 e -> (e -> Bool) -> Ix1 -> Ix1 -> m Ix1
unsafeUnstablePartitionRegionM MArray (PrimState m) r Ix1 e
marr e -> Bool
f Ix1
start Ix1
end = Ix1 -> Ix1 -> m Ix1
fromLeft Ix1
start (Ix1
end Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
  where
    fromLeft :: Ix1 -> Ix1 -> m Ix1
fromLeft Ix1
i Ix1
j
      | Ix1
i Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
j = Ix1 -> m Ix1
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ix1
i
      | Bool
otherwise = do
        e
x <- MArray (PrimState m) r Ix1 e -> Ix1 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
unsafeRead MArray (PrimState m) r Ix1 e
marr Ix1
i
        if e -> Bool
f e
x
          then Ix1 -> Ix1 -> m Ix1
fromLeft (Ix1
i Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1) Ix1
j
          else Ix1 -> Ix1 -> m Ix1
fromRight Ix1
i (Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1)
    fromRight :: Ix1 -> Ix1 -> m Ix1
fromRight Ix1
i Ix1
j
      | Ix1
i Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
j = Ix1 -> m Ix1
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ix1
i
      | Bool
otherwise = do
        e
x <- MArray (PrimState m) r Ix1 e -> Ix1 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
unsafeRead MArray (PrimState m) r Ix1 e
marr Ix1
j
        if e -> Bool
f e
x
          then do
            MArray (PrimState m) r Ix1 e -> Ix1 -> e -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
unsafeWrite MArray (PrimState m) r Ix1 e
marr Ix1
j (e -> m ()) -> m e -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MArray (PrimState m) r Ix1 e -> Ix1 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
unsafeRead MArray (PrimState m) r Ix1 e
marr Ix1
i
            MArray (PrimState m) r Ix1 e -> Ix1 -> e -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
unsafeWrite MArray (PrimState m) r Ix1 e
marr Ix1
i e
x
            Ix1 -> Ix1 -> m Ix1
fromLeft (Ix1
i Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1) Ix1
j
          else Ix1 -> Ix1 -> m Ix1
fromRight Ix1
i (Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1)
{-# INLINE unsafeUnstablePartitionRegionM #-}


-- | This is an implementation of [Quicksort](https://en.wikipedia.org/wiki/Quicksort), which is an
-- efficient, but unstable sort that uses Median-of-three for pivot choosing, as such it performs
-- very well not only for random values, but also for common edge cases like already sorted,
-- reversed sorted and arrays with many duplicate elements. It will also respect the computation
-- strategy and will result in a nice speed up for systems with multiple CPUs.
--
-- @since 0.3.2
quicksort ::
     (Mutable r Ix1 e, Ord e) => Array r Ix1 e -> Array r Ix1 e
quicksort :: Array r Ix1 e -> Array r Ix1 e
quicksort Array r Ix1 e
arr = IO (Array r Ix1 e) -> Array r Ix1 e
forall a. IO a -> a
unsafePerformIO (IO (Array r Ix1 e) -> Array r Ix1 e)
-> IO (Array r Ix1 e) -> Array r Ix1 e
forall a b. (a -> b) -> a -> b
$ Array r Ix1 e
-> (Scheduler IO () -> MArray RealWorld r Ix1 e -> IO ())
-> IO (Array r Ix1 e)
forall r ix e (m :: * -> *) a.
(Mutable r ix e, MonadUnliftIO m) =>
Array r ix e
-> (Scheduler m () -> MArray RealWorld r ix e -> m a)
-> m (Array r ix e)
withMArray_ Array r Ix1 e
arr Scheduler IO () -> MArray RealWorld r Ix1 e -> IO ()
forall e r (m :: * -> *).
(Ord e, Mutable r Ix1 e, PrimMonad m) =>
Scheduler m () -> MArray (PrimState m) r Ix1 e -> m ()
quicksortM_
{-# INLINE quicksort #-}



-- | Mutable version of `quicksort`
--
-- @since 0.3.2
quicksortM_ ::
     (Ord e, Mutable r Ix1 e, PrimMonad m)
  => Scheduler m ()
  -> MArray (PrimState m) r Ix1 e
  -> m ()
quicksortM_ :: Scheduler m () -> MArray (PrimState m) r Ix1 e -> m ()
quicksortM_ Scheduler m ()
scheduler MArray (PrimState m) r Ix1 e
marr =
  Scheduler m () -> m () -> m ()
forall (m :: * -> *) a. Scheduler m a -> m a -> m ()
scheduleWork Scheduler m ()
scheduler (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Ix1 -> Ix1 -> Ix1 -> m ()
forall t. (Ord t, Num t) => t -> Ix1 -> Ix1 -> m ()
qsort (Scheduler m () -> Ix1
forall (m :: * -> *) a. Scheduler m a -> Ix1
numWorkers Scheduler m ()
scheduler) Ix1
0 (Sz1 -> Ix1
forall ix. Sz ix -> ix
unSz (MArray (PrimState m) r Ix1 e -> Sz1
forall r ix e s. Mutable r ix e => MArray s r ix e -> Sz ix
msize MArray (PrimState m) r Ix1 e
marr) Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1)
  where
    leSwap :: Ix1 -> Ix1 -> m e
leSwap Ix1
i Ix1
j = do
      e
ei <- MArray (PrimState m) r Ix1 e -> Ix1 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
unsafeRead MArray (PrimState m) r Ix1 e
marr Ix1
i
      e
ej <- MArray (PrimState m) r Ix1 e -> Ix1 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
unsafeRead MArray (PrimState m) r Ix1 e
marr Ix1
j
      if e
ei e -> e -> Bool
forall a. Ord a => a -> a -> Bool
< e
ej
        then do
          MArray (PrimState m) r Ix1 e -> Ix1 -> e -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
unsafeWrite MArray (PrimState m) r Ix1 e
marr Ix1
i e
ej
          MArray (PrimState m) r Ix1 e -> Ix1 -> e -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
unsafeWrite MArray (PrimState m) r Ix1 e
marr Ix1
j e
ei
          e -> m e
forall (f :: * -> *) a. Applicative f => a -> f a
pure e
ei
        else e -> m e
forall (f :: * -> *) a. Applicative f => a -> f a
pure e
ej
    {-# INLINE leSwap #-}
    getPivot :: Ix1 -> Ix1 -> m e
getPivot Ix1
lo Ix1
hi = do
      let !mid :: Ix1
mid = (Ix1
hi Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
lo) Ix1 -> Ix1 -> Ix1
forall a. Integral a => a -> a -> a
`div` Ix1
2
      e
_ <- Ix1 -> Ix1 -> m e
leSwap Ix1
mid Ix1
lo
      e
_ <- Ix1 -> Ix1 -> m e
leSwap Ix1
hi Ix1
lo
      Ix1 -> Ix1 -> m e
leSwap Ix1
mid Ix1
hi
    {-# INLINE getPivot #-}
    qsort :: t -> Ix1 -> Ix1 -> m ()
qsort !t
n !Ix1
lo !Ix1
hi =
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ix1
lo Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
< Ix1
hi) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        e
p <- Ix1 -> Ix1 -> m e
getPivot Ix1
lo Ix1
hi
        Ix1
l <- MArray (PrimState m) r Ix1 e -> (e -> Bool) -> Ix1 -> Ix1 -> m Ix1
forall r e (m :: * -> *).
(Mutable r Ix1 e, PrimMonad m) =>
MArray (PrimState m) r Ix1 e -> (e -> Bool) -> Ix1 -> Ix1 -> m Ix1
unsafeUnstablePartitionRegionM MArray (PrimState m) r Ix1 e
marr (e -> e -> Bool
forall a. Ord a => a -> a -> Bool
< e
p) Ix1
lo (Ix1
hi Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1)
        Ix1
h <- MArray (PrimState m) r Ix1 e -> (e -> Bool) -> Ix1 -> Ix1 -> m Ix1
forall r e (m :: * -> *).
(Mutable r Ix1 e, PrimMonad m) =>
MArray (PrimState m) r Ix1 e -> (e -> Bool) -> Ix1 -> Ix1 -> m Ix1
unsafeUnstablePartitionRegionM MArray (PrimState m) r Ix1 e
marr (e -> e -> Bool
forall a. Eq a => a -> a -> Bool
== e
p) Ix1
l Ix1
hi
        if t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> t
0
          then do
            let !n' :: t
n' = t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1
            Scheduler m () -> m () -> m ()
forall (m :: * -> *) a. Scheduler m a -> m a -> m ()
scheduleWork Scheduler m ()
scheduler (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ t -> Ix1 -> Ix1 -> m ()
qsort t
n' Ix1
lo (Ix1
l Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1)
            Scheduler m () -> m () -> m ()
forall (m :: * -> *) a. Scheduler m a -> m a -> m ()
scheduleWork Scheduler m ()
scheduler (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ t -> Ix1 -> Ix1 -> m ()
qsort t
n' Ix1
h Ix1
hi
          else do
            t -> Ix1 -> Ix1 -> m ()
qsort t
n Ix1
lo (Ix1
l Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1)
            t -> Ix1 -> Ix1 -> m ()
qsort t
n Ix1
h Ix1
hi
{-# INLINE quicksortM_ #-}