{-# LANGUAGE MagicHash #-}
-- | Evaluate an array by breaking it up into linear chunks and filling
--   each chunk in parallel.
module Data.Array.Repa.Eval.Chunked
        ( fillLinearS
        , fillBlock2S
        , fillChunkedP
        , fillChunkedIOP)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Eval.Gang

import GHC.Exts
import Prelude          as P

-------------------------------------------------------------------------------
-- | Fill something sequentially.
-- 
--   * The array is filled linearly from start to finish.  
-- 
fillLinearS
        :: Int                  -- ^ Number of elements.
        -> (Int -> a -> IO ())  -- ^ Update function to write into result buffer.
        -> (Int -> a)           -- ^ Fn to get the value at a given index.
        -> IO ()

fillLinearS :: Int -> (Int -> a -> IO ()) -> (Int -> a) -> IO ()
fillLinearS !(I# Int#
len) Int -> a -> IO ()
write Int -> a
getElem
 = Int# -> IO ()
fill Int#
0#
 where  fill :: Int# -> IO ()
fill !Int#
ix
         | Int#
1# <- Int#
ix Int# -> Int# -> Int#
>=# Int#
len
         = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

         | Bool
otherwise
         = do   Int -> a -> IO ()
write (Int# -> Int
I# Int#
ix) (Int -> a
getElem (Int# -> Int
I# Int#
ix))
                Int# -> IO ()
fill (Int#
ix Int# -> Int# -> Int#
+# Int#
1#)
{-# INLINE [0] fillLinearS #-}


-------------------------------------------------------------------------------
-- | Fill a block in a rank-2 array, sequentially.
--
--   * Blockwise filling can be more cache-efficient than linear filling for
--     rank-2 arrays.
--
--   * The block is filled in row major order from top to bottom.
--
fillBlock2S
        :: (Int  -> a -> IO ()) -- ^ Update function to write into result buffer.
        -> (DIM2 -> a)          -- ^ Fn to get the value at the given index.
        -> Int#                 -- ^ Width of the whole array.
        -> Int#                 -- ^ x0 lower left corner of block to fill.
        -> Int#                 -- ^ y0
        -> Int#                 -- ^ w0 width of block to fill
        -> Int#                 -- ^ h0 height of block to fill
        -> IO ()

fillBlock2S :: (Int -> a -> IO ())
-> (DIM2 -> a) -> Int# -> Int# -> Int# -> Int# -> Int# -> IO ()
fillBlock2S
        Int -> a -> IO ()
write DIM2 -> a
getElem
        !Int#
imageWidth !Int#
x0 !Int#
y0 !Int#
w0 Int#
h0

 = do   Int# -> Int# -> IO ()
fillBlock Int#
y0 Int#
ix0
 where  !x1 :: Int#
x1     = Int#
x0 Int# -> Int# -> Int#
+# Int#
w0
        !y1 :: Int#
y1     = Int#
y0 Int# -> Int# -> Int#
+# Int#
h0
        !ix0 :: Int#
ix0    = Int#
x0 Int# -> Int# -> Int#
+# (Int#
y0 Int# -> Int# -> Int#
*# Int#
imageWidth)

        {-# INLINE fillBlock #-}
        fillBlock :: Int# -> Int# -> IO ()
fillBlock !Int#
y !Int#
ix
         | Int#
1# <- Int#
y Int# -> Int# -> Int#
>=# Int#
y1
         = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

         | Bool
otherwise
         = do   Int# -> Int# -> IO ()
fillLine1 Int#
x0 Int#
ix
                Int# -> Int# -> IO ()
fillBlock (Int#
y Int# -> Int# -> Int#
+# Int#
1#) (Int#
ix Int# -> Int# -> Int#
+# Int#
imageWidth)

         where  {-# INLINE fillLine1 #-}
                fillLine1 :: Int# -> Int# -> IO ()
fillLine1 !Int#
x !Int#
ix'
                 | Int#
1# <- Int#
x Int# -> Int# -> Int#
>=# Int#
x1
                 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

                 | Bool
otherwise
                 = do   Int -> a -> IO ()
write (Int# -> Int
I# Int#
ix') (DIM2 -> a
getElem (Z
Z Z -> Int -> Z :. Int
forall tail head. tail -> head -> tail :. head
:. (Int# -> Int
I# Int#
y) (Z :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. (Int# -> Int
I# Int#
x)))
                        Int# -> Int# -> IO ()
fillLine1 (Int#
x Int# -> Int# -> Int#
+# Int#
1#) (Int#
ix' Int# -> Int# -> Int#
+# Int#
1#)

{-# INLINE [0] fillBlock2S #-}


-------------------------------------------------------------------------------
-- | Fill something in parallel.
-- 
--   * The array is split into linear chunks,
--     and each thread linearly fills one chunk.
-- 
fillChunkedP
        :: Int                  -- ^ Number of elements.
        -> (Int -> a -> IO ())  -- ^ Update function to write into result buffer.
        -> (Int -> a)           -- ^ Fn to get the value at a given index.
        -> IO ()

fillChunkedP :: Int -> (Int -> a -> IO ()) -> (Int -> a) -> IO ()
fillChunkedP !(I# Int#
len) Int -> a -> IO ()
write Int -> a
getElem
 =      Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang
         ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$  \(I# Int#
thread) -> 
              let !start :: Int#
start   = Int# -> Int#
splitIx Int#
thread
                  !end :: Int#
end     = Int# -> Int#
splitIx (Int#
thread Int# -> Int# -> Int#
+# Int#
1#)
              in  Int# -> Int# -> IO ()
fill Int#
start Int#
end

 where
        -- Decide now to split the work across the threads.
        -- If the length of the vector doesn't divide evenly among the threads,
        -- then the first few get an extra element.
        !(I# Int#
threads)   = Gang -> Int
gangSize Gang
theGang
        !chunkLen :: Int#
chunkLen       = Int#
len Int# -> Int# -> Int#
`quotInt#` Int#
threads
        !chunkLeftover :: Int#
chunkLeftover  = Int#
len Int# -> Int# -> Int#
`remInt#`  Int#
threads

        {-# INLINE splitIx #-}
        splitIx :: Int# -> Int#
splitIx Int#
thread
         | Int#
1# <- Int#
thread Int# -> Int# -> Int#
<# Int#
chunkLeftover 
         = Int#
thread Int# -> Int# -> Int#
*# (Int#
chunkLen Int# -> Int# -> Int#
+# Int#
1#)

         | Bool
otherwise    
         = Int#
thread Int# -> Int# -> Int#
*# Int#
chunkLen  Int# -> Int# -> Int#
+# Int#
chunkLeftover

        -- Evaluate the elements of a single chunk.
        {-# INLINE fill #-}
        fill :: Int# -> Int# -> IO ()
fill !Int#
ix !Int#
end
         | Int#
1# <- Int#
ix Int# -> Int# -> Int#
>=# Int#
end     
         = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

         | Bool
otherwise
         = do   Int -> a -> IO ()
write (Int# -> Int
I# Int#
ix) (Int -> a
getElem (Int# -> Int
I# Int#
ix))
                Int# -> Int# -> IO ()
fill (Int#
ix Int# -> Int# -> Int#
+# Int#
1#) Int#
end
{-# INLINE [0] fillChunkedP #-}


-------------------------------------------------------------------------------
-- | Fill something in parallel, using a separate IO action for each thread.
--
--   * The array is split into linear chunks,
--     and each thread linearly fills one chunk.
--
fillChunkedIOP
        :: Int  -- ^ Number of elements.
        -> (Int -> a -> IO ())          
                -- ^ Update fn to write into result buffer.
        -> (Int -> IO (Int -> IO a))    
                -- ^ Create a fn to get the value at a given index.
                --   The first `Int` is the thread number, so you can do some
                --   per-thread initialisation.
        -> IO ()

fillChunkedIOP :: Int -> (Int -> a -> IO ()) -> (Int -> IO (Int -> IO a)) -> IO ()
fillChunkedIOP !(I# Int#
len) Int -> a -> IO ()
write Int -> IO (Int -> IO a)
mkGetElem
 =      Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang
         ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$  \(I# Int#
thread) -> 
              let !start :: Int#
start = Int# -> Int#
splitIx Int#
thread
                  !end :: Int#
end   = Int# -> Int#
splitIx (Int#
thread Int# -> Int# -> Int#
+# Int#
1#)
              in Int# -> Int# -> Int# -> IO ()
fillChunk Int#
thread Int#
start Int#
end 

 where
        -- Decide now to split the work across the threads.
        -- If the length of the vector doesn't divide evenly among the threads,
        -- then the first few get an extra element.
        !(I# Int#
threads)   = Gang -> Int
gangSize Gang
theGang
        !chunkLen :: Int#
chunkLen       = Int#
len Int# -> Int# -> Int#
`quotInt#` Int#
threads
        !chunkLeftover :: Int#
chunkLeftover  = Int#
len Int# -> Int# -> Int#
`remInt#`  Int#
threads

        {-# INLINE splitIx #-}
        splitIx :: Int# -> Int#
splitIx Int#
thread
         | Int#
1# <- Int#
thread Int# -> Int# -> Int#
<# Int#
chunkLeftover = Int#
thread Int# -> Int# -> Int#
*# (Int#
chunkLen Int# -> Int# -> Int#
+# Int#
1#)
         | Bool
otherwise                     = Int#
thread Int# -> Int# -> Int#
*# Int#
chunkLen  Int# -> Int# -> Int#
+# Int#
chunkLeftover

        -- Given the threadId, starting and ending indices. 
        --      Make a function to get each element for this chunk
        --      and call it for every index.
        {-# INLINE fillChunk #-}
        fillChunk :: Int# -> Int# -> Int# -> IO ()
fillChunk !Int#
thread !Int#
ixStart !Int#
ixEnd
         = do   Int -> IO a
getElem <- Int -> IO (Int -> IO a)
mkGetElem (Int# -> Int
I# Int#
thread)
                (Int -> IO a) -> Int# -> Int# -> IO ()
fill Int -> IO a
getElem Int#
ixStart Int#
ixEnd
                
        -- Call the provided getElem function for every element
        --      in a chunk, and feed the result to the write function.
        {-# INLINE fill #-}
        fill :: (Int -> IO a) -> Int# -> Int# -> IO ()
fill !Int -> IO a
getElem !Int#
ix0 !Int#
end
         = Int# -> IO ()
go Int#
ix0 
         where  go :: Int# -> IO ()
go !Int#
ix
                 | Int#
1# <- Int#
ix Int# -> Int# -> Int#
>=# Int#
end
                 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

                 | Bool
otherwise
                 = do   a
x       <- Int -> IO a
getElem (Int# -> Int
I# Int#
ix)
                        Int -> a -> IO ()
write (Int# -> Int
I# Int#
ix) a
x
                        Int# -> IO ()
go (Int#
ix Int# -> Int# -> Int#
+# Int#
1#)
{-# INLINE [0] fillChunkedIOP #-}