{-|
  Functions that sorts mutable vectors using sorting network.
 -}

module Data.SortingNetwork.MutableVector (
  unsafeSortBy,
  maySortBy,
) where

import Control.Monad
import Control.Monad.Primitive
import Data.SortingNetwork.Types
import qualified Data.Vector.Generic.Mutable as VM

{- TODO: test coverage -}
{-|
  Sorts a mutable vector by applying compare-and-swap operations generated through 'MkPairs'.

  Raises error if vector size cannot be handled by the sorting network.
 -}
unsafeSortBy :: (PrimMonad m, VM.MVector v e) => MkPairs -> (e -> e -> Ordering) -> v (PrimState m) e -> m ()
unsafeSortBy :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
MkPairs -> (e -> e -> Ordering) -> v (PrimState m) e -> m ()
unsafeSortBy MkPairs
mkPairs e -> e -> Ordering
cmp v (PrimState m) e
v = case MkPairs
mkPairs Int
n of
  Just [(Int, Int)]
pairs ->
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Int, Int)]
pairs \(Int
i, Int
j) -> do
      e
vi <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VM.unsafeRead v (PrimState m) e
v Int
i
      e
vj <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VM.unsafeRead v (PrimState m) e
v Int
j
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (e -> e -> Ordering
cmp e
vi e
vj forall a. Eq a => a -> a -> Bool
== Ordering
GT) do
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VM.unsafeSwap v (PrimState m) e
v Int
i Int
j
  Maybe [(Int, Int)]
Nothing -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"MkPairs returned Nothing on length " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
n
  where
    n :: Int
n = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
VM.length v (PrimState m) e
v

{-|
  Safe version of 'unsafeSortBy'.

  This function either returns input vector reference upon successful sorting
  or 'Nothing' if the vector size cannot be handled.
 -}
maySortBy :: (PrimMonad m, VM.MVector v e) => MkPairs -> (e -> e -> Ordering) -> v (PrimState m) e -> m (Maybe (v (PrimState m) e))
maySortBy :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
MkPairs
-> (e -> e -> Ordering)
-> v (PrimState m) e
-> m (Maybe (v (PrimState m) e))
maySortBy MkPairs
mkPairs e -> e -> Ordering
cmp v (PrimState m) e
v = case MkPairs
mkPairs (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
VM.length v (PrimState m) e
v) of
  Just [(Int, Int)]
_ -> forall a. a -> Maybe a
Just v (PrimState m) e
v forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
MkPairs -> (e -> e -> Ordering) -> v (PrimState m) e -> m ()
unsafeSortBy MkPairs
mkPairs e -> e -> Ordering
cmp v (PrimState m) e
v
  Maybe [(Int, Int)]
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing