{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-}
module Data.Array.Repa.Eval.Selection
        (selectChunkedS, selectChunkedP)
where
import Data.Array.Repa.Eval.Gang
import Data.Array.Repa.Shape
import Data.Vector.Unboxed                      as V
import Data.Vector.Unboxed.Mutable              as VM
import GHC.Base                                 (remInt, quotInt)
import Prelude                                  as P
import Control.Monad                            as P
import Data.IORef


-- | Select indices matching a predicate.
--  
--   * This primitive can be useful for writing filtering functions.
--
selectChunkedS
        :: Shape sh
        => (sh -> a -> IO ())   -- ^ Update function to write into result.
        -> (sh -> Bool)         -- ^ See if this predicate matches.
        -> (sh -> a)            -- ^  .. and apply fn to the matching index
        -> sh                   -- ^ Extent of indices to apply to predicate.
        -> IO Int               -- ^ Number of elements written to destination array.

{-# INLINE selectChunkedS #-}
selectChunkedS :: (sh -> a -> IO ()) -> (sh -> Bool) -> (sh -> a) -> sh -> IO Int
selectChunkedS sh -> a -> IO ()
fnWrite sh -> Bool
fnMatch sh -> a
fnProduce !sh
shSize
 = Int -> Int -> IO Int
fill Int
0 Int
0
 where  lenSrc :: Int
lenSrc  = sh -> Int
forall sh. Shape sh => sh -> Int
size sh
shSize

        fill :: Int -> Int -> IO Int
fill !Int
nSrc !Int
nDst
         | Int
nSrc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenSrc       = Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
nDst

         | sh
ixSrc        <- sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
shSize Int
nSrc
         , sh -> Bool
fnMatch sh
ixSrc
         = do   sh -> a -> IO ()
fnWrite sh
ixSrc (sh -> a
fnProduce sh
ixSrc)
                Int -> Int -> IO Int
fill (Int
nSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
nDst Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

         | Bool
otherwise
         =      Int -> Int -> IO Int
fill (Int
nSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
nDst


-- | Select indices matching a predicate, in parallel.
--  
--   * This primitive can be useful for writing filtering functions.
--
--   * The array is split into linear chunks, with one chunk being given to
--     each thread.
--
--   * The number of elements in the result array depends on how many threads
--     you're running the program with.
--
selectChunkedP
        :: forall a
        .  Unbox a
        => (Int -> Bool)        -- ^ See if this predicate matches.
        -> (Int -> a)           --   .. and apply fn to the matching index
        -> Int                  -- Extent of indices to apply to predicate.
        -> IO [IOVector a]      -- Chunks containing array elements.

{-# INLINE selectChunkedP #-}
selectChunkedP :: (Int -> Bool) -> (Int -> a) -> Int -> IO [IOVector a]
selectChunkedP Int -> Bool
fnMatch Int -> a
fnProduce !Int
len
 = do
        -- Make IORefs that the threads will write their result chunks to.
        -- We start with a chunk size proportial to the number of threads we have,
        -- but the threads themselves can grow the chunks if they run out of space.
        [IORef (IOVector a)]
refs    <- Int -> IO (IORef (IOVector a)) -> IO [IORef (IOVector a)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
P.replicateM Int
threads
                (IO (IORef (IOVector a)) -> IO [IORef (IOVector a)])
-> IO (IORef (IOVector a)) -> IO [IORef (IOVector a)]
forall a b. (a -> b) -> a -> b
$ do    IOVector a
vec     <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VM.new (Int -> IO (MVector (PrimState IO) a))
-> Int -> IO (MVector (PrimState IO) a)
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
threads
                        IOVector a -> IO (IORef (IOVector a))
forall a. a -> IO (IORef a)
newIORef IOVector a
vec

        -- Fire off a thread to fill each chunk.
        Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang
         ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
thread -> IORef (IOVector a) -> Int -> Int -> IO ()
makeChunk ([IORef (IOVector a)]
refs [IORef (IOVector a)] -> Int -> IORef (IOVector a)
forall a. [a] -> Int -> a
!! Int
thread)
                        (Int -> Int
splitIx Int
thread)
                        (Int -> Int
splitIx (Int
thread Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

        -- Read the result chunks back from the IORefs.
        -- If a thread had to grow a chunk, then these might not be the same ones
        -- we created back in the first step.
        (IORef (IOVector a) -> IO (IOVector a))
-> [IORef (IOVector a)] -> IO [IOVector a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
P.mapM IORef (IOVector a) -> IO (IOVector a)
forall a. IORef a -> IO a
readIORef [IORef (IOVector a)]
refs

 where  -- See how many threads we have available.
        !threads :: Int
threads        = Gang -> Int
gangSize Gang
theGang
        !chunkLen :: Int
chunkLen       = Int
len Int -> Int -> Int
`quotInt` Int
threads
        !chunkLeftover :: Int
chunkLeftover  = Int
len Int -> Int -> Int
`remInt`  Int
threads


        -- Decide where to split the source array.
        {-# INLINE splitIx #-}
        splitIx :: Int -> Int
splitIx Int
thread
         | Int
thread Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
chunkLeftover = Int
thread Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
chunkLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
         | Bool
otherwise              = Int
thread Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
chunkLen  Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
chunkLeftover


        -- Fill the given chunk with elements selected from this range of indices.
        makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
        makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
makeChunk !IORef (IOVector a)
ref !Int
ixSrc !Int
ixSrcEnd
         | Int
ixSrc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
ixSrcEnd
         = do  IOVector a
vecDst   <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VM.new Int
0
               IORef (IOVector a) -> IOVector a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IOVector a)
ref IOVector a
vecDst

         | Bool
otherwise
         = do  IOVector a
vecDst   <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VM.new (Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
threads)
               IOVector a
vecDst'  <- Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk Int
ixSrc Int
ixSrcEnd IOVector a
vecDst Int
0 (IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
VM.length IOVector a
vecDst)
               IORef (IOVector a) -> IOVector a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IOVector a)
ref IOVector a
vecDst'


        -- The main filling loop.
        fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
        fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk !Int
ixSrc !Int
ixSrcEnd !IOVector a
vecDst !Int
ixDst !Int
ixDstLen
         -- If we've finished selecting elements, then slice the vector down
         -- so it doesn't have any empty space at the end.
         | Int
ixSrc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
ixSrcEnd
         =      IOVector a -> IO (IOVector a)
forall (m :: * -> *) a. Monad m => a -> m a
return  (IOVector a -> IO (IOVector a)) -> IOVector a -> IO (IOVector a)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> IOVector a -> IOVector a
forall a s. Unbox a => Int -> Int -> MVector s a -> MVector s a
VM.slice Int
0 Int
ixDst IOVector a
vecDst

         -- If we've run out of space in the chunk then grow it some more.
         | Int
ixDst Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
ixDstLen
         = do   let ixDstLen' :: Int
ixDstLen'   = (IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
VM.length IOVector a
vecDst Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2
                IOVector a
vecDst'         <- MVector (PrimState IO) a -> Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
VM.grow IOVector a
MVector (PrimState IO) a
vecDst Int
ixDstLen'
                Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk Int
ixSrc Int
ixSrcEnd IOVector a
vecDst' Int
ixDst Int
ixDstLen'

         -- We've got a maching element, so add it to the chunk.
         | Int -> Bool
fnMatch Int
ixSrc
         = do   MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite IOVector a
MVector (PrimState IO) a
vecDst Int
ixDst (Int -> a
fnProduce Int
ixSrc)
                Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk (Int
ixSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
ixSrcEnd IOVector a
vecDst (Int
ixDst Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
ixDstLen

         -- The element doesnt match, so keep going.
         | Bool
otherwise
         =      Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk (Int
ixSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
ixSrcEnd IOVector a
vecDst Int
ixDst Int
ixDstLen