{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-}
module Data.Array.Repa.Eval.Selection
(selectChunkedS, selectChunkedP)
where
import Data.Array.Repa.Eval.Gang
import Data.Array.Repa.Shape
import Data.Vector.Unboxed as V
import Data.Vector.Unboxed.Mutable as VM
import GHC.Base (remInt, quotInt)
import Prelude as P
import Control.Monad as P
import Data.IORef
selectChunkedS
:: Shape sh
=> (sh -> a -> IO ())
-> (sh -> Bool)
-> (sh -> a)
-> sh
-> IO Int
{-# INLINE selectChunkedS #-}
selectChunkedS :: (sh -> a -> IO ()) -> (sh -> Bool) -> (sh -> a) -> sh -> IO Int
selectChunkedS sh -> a -> IO ()
fnWrite sh -> Bool
fnMatch sh -> a
fnProduce !sh
shSize
= Int -> Int -> IO Int
fill Int
0 Int
0
where lenSrc :: Int
lenSrc = sh -> Int
forall sh. Shape sh => sh -> Int
size sh
shSize
fill :: Int -> Int -> IO Int
fill !Int
nSrc !Int
nDst
| Int
nSrc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenSrc = Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
nDst
| sh
ixSrc <- sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
shSize Int
nSrc
, sh -> Bool
fnMatch sh
ixSrc
= do sh -> a -> IO ()
fnWrite sh
ixSrc (sh -> a
fnProduce sh
ixSrc)
Int -> Int -> IO Int
fill (Int
nSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
nDst Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise
= Int -> Int -> IO Int
fill (Int
nSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
nDst
selectChunkedP
:: forall a
. Unbox a
=> (Int -> Bool)
-> (Int -> a)
-> Int
-> IO [IOVector a]
{-# INLINE selectChunkedP #-}
selectChunkedP :: (Int -> Bool) -> (Int -> a) -> Int -> IO [IOVector a]
selectChunkedP Int -> Bool
fnMatch Int -> a
fnProduce !Int
len
= do
[IORef (IOVector a)]
refs <- Int -> IO (IORef (IOVector a)) -> IO [IORef (IOVector a)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
P.replicateM Int
threads
(IO (IORef (IOVector a)) -> IO [IORef (IOVector a)])
-> IO (IORef (IOVector a)) -> IO [IORef (IOVector a)]
forall a b. (a -> b) -> a -> b
$ do IOVector a
vec <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VM.new (Int -> IO (MVector (PrimState IO) a))
-> Int -> IO (MVector (PrimState IO) a)
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
threads
IOVector a -> IO (IORef (IOVector a))
forall a. a -> IO (IORef a)
newIORef IOVector a
vec
Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang
((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
thread -> IORef (IOVector a) -> Int -> Int -> IO ()
makeChunk ([IORef (IOVector a)]
refs [IORef (IOVector a)] -> Int -> IORef (IOVector a)
forall a. [a] -> Int -> a
!! Int
thread)
(Int -> Int
splitIx Int
thread)
(Int -> Int
splitIx (Int
thread Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
(IORef (IOVector a) -> IO (IOVector a))
-> [IORef (IOVector a)] -> IO [IOVector a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
P.mapM IORef (IOVector a) -> IO (IOVector a)
forall a. IORef a -> IO a
readIORef [IORef (IOVector a)]
refs
where
!threads :: Int
threads = Gang -> Int
gangSize Gang
theGang
!chunkLen :: Int
chunkLen = Int
len Int -> Int -> Int
`quotInt` Int
threads
!chunkLeftover :: Int
chunkLeftover = Int
len Int -> Int -> Int
`remInt` Int
threads
{-# INLINE splitIx #-}
splitIx :: Int -> Int
splitIx Int
thread
| Int
thread Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
chunkLeftover = Int
thread Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
chunkLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise = Int
thread Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
chunkLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
chunkLeftover
makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
makeChunk !IORef (IOVector a)
ref !Int
ixSrc !Int
ixSrcEnd
| Int
ixSrc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
ixSrcEnd
= do IOVector a
vecDst <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VM.new Int
0
IORef (IOVector a) -> IOVector a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IOVector a)
ref IOVector a
vecDst
| Bool
otherwise
= do IOVector a
vecDst <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VM.new (Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
threads)
IOVector a
vecDst' <- Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk Int
ixSrc Int
ixSrcEnd IOVector a
vecDst Int
0 (IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
VM.length IOVector a
vecDst)
IORef (IOVector a) -> IOVector a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IOVector a)
ref IOVector a
vecDst'
fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk !Int
ixSrc !Int
ixSrcEnd !IOVector a
vecDst !Int
ixDst !Int
ixDstLen
| Int
ixSrc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
ixSrcEnd
= IOVector a -> IO (IOVector a)
forall (m :: * -> *) a. Monad m => a -> m a
return (IOVector a -> IO (IOVector a)) -> IOVector a -> IO (IOVector a)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> IOVector a -> IOVector a
forall a s. Unbox a => Int -> Int -> MVector s a -> MVector s a
VM.slice Int
0 Int
ixDst IOVector a
vecDst
| Int
ixDst Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
ixDstLen
= do let ixDstLen' :: Int
ixDstLen' = (IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
VM.length IOVector a
vecDst Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2
IOVector a
vecDst' <- MVector (PrimState IO) a -> Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
VM.grow IOVector a
MVector (PrimState IO) a
vecDst Int
ixDstLen'
Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk Int
ixSrc Int
ixSrcEnd IOVector a
vecDst' Int
ixDst Int
ixDstLen'
| Int -> Bool
fnMatch Int
ixSrc
= do MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite IOVector a
MVector (PrimState IO) a
vecDst Int
ixDst (Int -> a
fnProduce Int
ixSrc)
Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk (Int
ixSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
ixSrcEnd IOVector a
vecDst (Int
ixDst Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
ixDstLen
| Bool
otherwise
= Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk (Int
ixSrc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
ixSrcEnd IOVector a
vecDst Int
ixDst Int
ixDstLen