{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-}
module Data.Array.Repa.Eval.Selection
(selectChunkedS, selectChunkedP)
where
import Data.Array.Repa.Eval.Gang
import Data.Array.Repa.Shape
import Data.Vector.Unboxed as V
import Data.Vector.Unboxed.Mutable as VM
import GHC.Base (remInt, quotInt)
import Prelude as P
import Control.Monad as P
import Data.IORef
selectChunkedS
:: Shape sh
=> (sh -> a -> IO ())
-> (sh -> Bool)
-> (sh -> a)
-> sh
-> IO Int
{-# INLINE selectChunkedS #-}
selectChunkedS fnWrite fnMatch fnProduce !shSize
= fill 0 0
where lenSrc = size shSize
fill !nSrc !nDst
| nSrc >= lenSrc = return nDst
| ixSrc <- fromIndex shSize nSrc
, fnMatch ixSrc
= do fnWrite ixSrc (fnProduce ixSrc)
fill (nSrc + 1) (nDst + 1)
| otherwise
= fill (nSrc + 1) nDst
selectChunkedP
:: forall a
. Unbox a
=> (Int -> Bool)
-> (Int -> a)
-> Int
-> IO [IOVector a]
{-# INLINE selectChunkedP #-}
selectChunkedP fnMatch fnProduce !len
= do
refs <- P.replicateM threads
$ do vec <- VM.new $ len `div` threads
newIORef vec
gangIO theGang
$ \thread -> makeChunk (refs !! thread)
(splitIx thread)
(splitIx (thread + 1) - 1)
P.mapM readIORef refs
where
!threads = gangSize theGang
!chunkLen = len `quotInt` threads
!chunkLeftover = len `remInt` threads
{-# INLINE splitIx #-}
splitIx thread
| thread < chunkLeftover = thread * (chunkLen + 1)
| otherwise = thread * chunkLen + chunkLeftover
makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
makeChunk !ref !ixSrc !ixSrcEnd
| ixSrc > ixSrcEnd
= do vecDst <- VM.new 0
writeIORef ref vecDst
| otherwise
= do vecDst <- VM.new (len `div` threads)
vecDst' <- fillChunk ixSrc ixSrcEnd vecDst 0 (VM.length vecDst)
writeIORef ref vecDst'
fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk !ixSrc !ixSrcEnd !vecDst !ixDst !ixDstLen
| ixSrc > ixSrcEnd
= return $ VM.slice 0 ixDst vecDst
| ixDst >= ixDstLen
= do let ixDstLen' = (VM.length vecDst + 1) * 2
vecDst' <- VM.grow vecDst ixDstLen'
fillChunk ixSrc ixSrcEnd vecDst' ixDst ixDstLen'
| fnMatch ixSrc
= do VM.unsafeWrite vecDst ixDst (fnProduce ixSrc)
fillChunk (ixSrc + 1) ixSrcEnd vecDst (ixDst + 1) ixDstLen
| otherwise
= fillChunk (ixSrc + 1) ixSrcEnd vecDst ixDst ixDstLen