module Data.Repa.Array.Internals.Operator.Partition
( partition
, partitionBy
, partitionByIx)
where
import Data.Repa.Array.Meta.Delayed as A
import Data.Repa.Array.Meta.Linear as A
import Data.Repa.Array.Meta.Tuple as A
import Data.Repa.Array.Internals.Bulk as A
import Data.Repa.Array.Internals.Target as A
import Data.Repa.Array.Internals.Layout as A
import Data.Repa.Array.Material.Nested as A
import Data.Repa.Eval.Elt as A
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import System.IO.Unsafe
#include "repa-array.h"
partition
:: (BulkI lSrc (Int, a), Target lDst a, Index lDst ~ Int, Elt a)
=> Name lDst
-> Int
-> Array lSrc (Int, a)
-> Array N (Array lDst a)
partition nDst iSegs aSrc
| iSegs <= 0
= A.fromLists nDst []
| otherwise
= unsafePerformIO
$ do
let !len = A.length aSrc
let !vStarts = U.prescanl (+) 0 $ U.replicate iSegs len
!mLens <- UM.replicate iSegs 0
let !lenDst = iSegs * len
!buf <- unsafeNewBuffer (A.create nDst lenDst)
let loop_partition_init !iDst
| iDst >= lenDst = return ()
| otherwise
= do unsafeWriteBuffer buf iDst zero
loop_partition_init (iDst + 1)
loop_partition_init 0
let loop_partition !iSrc
| iSrc >= len = return ()
| otherwise
= do let !(k, v) = aSrc `A.index` iSrc
if k >= iSegs
then loop_partition (iSrc + 1)
else do
let !s = U.unsafeIndex vStarts k
!o <- UM.unsafeRead mLens k
unsafeWriteBuffer buf (s + o) v
UM.unsafeWrite mLens k (o + 1)
loop_partition (iSrc + 1)
loop_partition 0
vLens <- U.unsafeFreeze mLens
aElems <- unsafeFreezeBuffer buf
return $ NArray vStarts vLens aElems
partitionBy
:: (BulkI lSrc a, Target lDst a, Index lDst ~ Int, Elt a)
=> Name lDst
-> Int
-> (a -> Int)
-> Array lSrc a
-> Array N (Array lDst a)
partitionBy nDst iSeg fSeg aSrc
= partition nDst iSeg
$ tup2 (A.map fSeg aSrc) aSrc
partitionByIx
:: (BulkI lSrc a, Target lDst a, Index lDst ~ Int, Elt a)
=> Name lDst
-> Int
-> (Int -> a -> Int)
-> Array lSrc a
-> Array N (Array lDst a)
partitionByIx nDst iSeg fSeg aSrc
= partition nDst iSeg
$ tup2 aSegVals aSrc
where
fSeg' (ix, x) = fSeg ix x
aIxSrc = tup2 (linear $ A.length aSrc) aSrc
aSegVals = A.map fSeg' aIxSrc