{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} -- | -- Module : Data.Array.Accelerate.Data.Sort.Quick -- Copyright : [2020] Ivo Gabe de Wolff, Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Data.Sort.Quick ( sort, sortBy, ) where import Data.Array.Accelerate import Data.Array.Accelerate.Unsafe import Data.Array.Accelerate.Data.Bits import Data.Array.Accelerate.Data.Maybe -- | A quick-ish stable sort. This is a special case of 'sortBy' which -- allows the user to supply their own comparison function. -- sort :: Ord a => Acc (Vector a) -> Acc (Vector a) sort = sortBy compare -- | A non-overloaded version of 'sort'. -- -- It is often convenient to use this together with 'Data.Function.on', for -- instance: 'sortBy' ('compare' `on` 'fst') -- sortBy :: Elt a => (Exp a -> Exp a -> Exp Ordering) -> Acc (Vector a) -> Acc (Vector a) sortBy cmp input = result where -- Initially, we have one segment, namely the whole array initialFlags = scatter (fill (I1 1) 0 ++ fill (I1 1) (length input)) emptyFlags fullFlags emptyFlags = fill (I1 (1 + length input)) False_ fullFlags = fill (I1 2) True_ -- We stop when each segment contains just one element, as segments of -- one element are sorted. T2 result _ = awhile condition (step cmp) $ T2 input initialFlags type State a = ( Vector a -- Values , Vector Bool -- Head flags, denoting the starting points of the unsorted segments ) step :: Elt a => (Exp a -> Exp a -> Exp Ordering) -> Acc (State a) -> Acc (State a) step cmp (T2 values headFlags) = (T2 values' headFlags') where -- Per element, the pivot of the segment of that element -- For each segment, we just take the first element as pivot pivots = propagateSegmentHead headFlags values -- Find which elements are larger than the pivot isLarger = zipWith (\v p -> cmp v p /= LT_) values pivots -- Propagate the start index of a segment to all elements startIndex = propagateSegmentHead headFlags (generate (shape values) unindex1) -- Compute the offsets to which the elements must be moved using a scan indicesLarger, indicesSmaller :: Acc (Vector Int) indicesLarger = map (\x -> x - 1) $ postscanSegHead (+) headFlags $ map (? (1, 0)) isLarger indicesSmaller = map (\x -> x - 1) $ postscanSegHead (+) headFlags $ map (? (0, 1)) isLarger -- Propagate the number of smaller elements to each segment -- This is needed as an offset for the larger elements countSmaller :: Acc (Vector Int) countSmaller = map (+1) $ propagateSegmentLast headFlags indicesSmaller -- Compute the new indices of the elements permutation = zipWith5 partitionPermuteIndex isLarger startIndex indicesSmaller indicesLarger countSmaller -- Perform the permutation values' = scatter permutation (fill (shape values) undef) values -- Update the head flags for the next iteration (the 'recursive call' -- in a traditional implementation) -- -- Mark new section starts at: -- * the position of the pivot -- * the position of the pivot + 1 headFlags' = let f :: Int -> Exp Bool -> Exp Int -> Exp Int -> Exp (Maybe DIM1) f inc headF start countSmall = if headF then Just_ (I1 (start + countSmall + constant inc)) else Nothing_ writes :: Int -> Acc (Vector (Maybe DIM1)) writes inc = zipWith3 (f inc) headFlags startIndex countSmaller in -- Note that (writes 1) may go out of bounds of the values array. -- We made the headFlags array one larger to avoid this problem. writeFlags (writes 0) $ writeFlags (writes 1) $ headFlags -- Checks whether all segments have length 1. If that is the case, then the -- loop may terminate. -- condition :: Elt a => Acc (State a) -> Acc (Scalar Bool) condition (T2 _ headFlags) = map not $ fold (&&) True_ headFlags -- Finds the new index of an element of the list, as the result of the -- partition -- partitionPermuteIndex :: Exp Bool -> Exp Int -> Exp Int -> Exp Int -> Exp Int -> Exp Int partitionPermuteIndex isLarger start indexIfSmaller indexIfLarger countSmaller = start + (isLarger ? (countSmaller + indexIfLarger, indexIfSmaller)) -- Given head flags, propagates the value of the head to all elements in -- the segment -- propagateSegmentHead :: Elt a => Acc (Vector Bool) -> Acc (Vector a) -> Acc (Vector a) propagateSegmentHead headFlags values = map fst $ postscanl f (T2 undef True_) $ zip values headFlags where f left (T2 rightValue rightFlag) = if rightFlag then T2 rightValue True_ else left -- Given head flags, propagates the value of the head to all elements in -- the segment -- propagateSegmentLast :: Elt a => Acc (Vector Bool) -> Acc (Vector a) -> Acc (Vector a) propagateSegmentLast headFlags values = map fst $ postscanr f (T2 undef True_) $ zip values $ tail headFlags where f (T2 leftValue leftFlag) right = if leftFlag then T2 leftValue True_ else right -- Segmented postscan, where the segments are defined with head flags -- postscanSegHead :: Elt a => (Exp a -> Exp a -> Exp a) -> Acc (Vector Bool) -> Acc (Vector a) -> Acc (Vector a) postscanSegHead f headFlags values = map fst $ postscanl g (T2 undef True_) $ zip values headFlags where g (T2 leftValue leftFlag) (T2 rightValue rightFlag) = T2 (rightFlag ? (rightValue, f leftValue rightValue)) (leftFlag .|. rightFlag) -- Writes True to the specified indices in a flags arrays -- writeFlags :: Acc (Vector (Maybe DIM1)) -> Acc (Vector Bool) -> Acc (Vector Bool) writeFlags writes flags = permute const flags (writes !) (fill (shape writes) True_)