{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Massiv.Array.Ops.Sort
( tally
, quicksort
, quicksortM_
, unsafeUnstablePartitionRegionM
) where
import Control.Monad (when)
import Control.Scheduler
import Data.Massiv.Array.Delayed.Stream
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Ops.Transform
import Data.Massiv.Core.Common
import Data.Massiv.Vector (scatMaybes, sunfoldrN)
import System.IO.Unsafe
tally :: (Mutable r Ix1 e, Resize r ix, Load r ix e, Ord e) => Array r ix e -> Array DS Ix1 (e, Int)
tally arr
| isEmpty arr = setComp (getComp arr) empty
| otherwise = scatMaybes $ sunfoldrN (sz + 1) count (0, 0, sorted ! 0)
where
sz@(Sz k) = size sorted
count (!i, !n, !prev)
| i < k =
let !e' = unsafeLinearIndex sorted i
in if prev == e'
then Just (Nothing, (i + 1, n + 1, prev))
else Just (Just (prev, n), (i + 1, 1, e'))
| otherwise = Just (Just (prev, n), (i + 1, n, prev))
{-# INLINE count #-}
sorted = quicksort $ flatten arr
{-# INLINE tally #-}
unsafeUnstablePartitionRegionM ::
forall r e m. (Mutable r Ix1 e, PrimMonad m)
=> MArray (PrimState m) r Ix1 e
-> (e -> Bool)
-> Ix1
-> Ix1
-> m Ix1
unsafeUnstablePartitionRegionM marr f start end = fromLeft start (end + 1)
where
fromLeft i j
| i == j = pure i
| otherwise = do
x <- unsafeRead marr i
if f x
then fromLeft (i + 1) j
else fromRight i (j - 1)
fromRight i j
| i == j = pure i
| otherwise = do
x <- unsafeRead marr j
if f x
then do
unsafeWrite marr j =<< unsafeRead marr i
unsafeWrite marr i x
fromLeft (i + 1) j
else fromRight i (j - 1)
{-# INLINE unsafeUnstablePartitionRegionM #-}
quicksort ::
(Mutable r Ix1 e, Ord e) => Array r Ix1 e -> Array r Ix1 e
quicksort arr = unsafePerformIO $ withMArray_ arr quicksortM_
{-# INLINE quicksort #-}
quicksortM_ ::
(Ord e, Mutable r Ix1 e, PrimMonad m)
=> Scheduler m ()
-> MArray (PrimState m) r Ix1 e
-> m ()
quicksortM_ scheduler marr =
scheduleWork scheduler $ qsort (numWorkers scheduler) 0 (unSz (msize marr) - 1)
where
leSwap i j = do
ei <- unsafeRead marr i
ej <- unsafeRead marr j
if ei < ej
then do
unsafeWrite marr i ej
unsafeWrite marr j ei
pure ei
else pure ej
{-# INLINE leSwap #-}
getPivot lo hi = do
let !mid = (hi + lo) `div` 2
_ <- leSwap mid lo
_ <- leSwap hi lo
leSwap mid hi
{-# INLINE getPivot #-}
qsort !n !lo !hi =
when (lo < hi) $ do
p <- getPivot lo hi
l <- unsafeUnstablePartitionRegionM marr (< p) lo (hi - 1)
h <- unsafeUnstablePartitionRegionM marr (== p) l hi
if n > 0
then do
let !n' = n - 1
scheduleWork scheduler $ qsort n' lo (l - 1)
scheduleWork scheduler $ qsort n' h hi
else do
qsort n lo (l - 1)
qsort n h hi
{-# INLINE quicksortM_ #-}