{-# LANGUAGE BangPatterns, MagicHash #-}
module Data.Array.Repa.Eval.Reduction
( foldS, foldP
, foldAllS, foldAllP)
where
import Data.Array.Repa.Eval.Gang
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as M
import GHC.Base ( quotInt, divInt )
import GHC.Exts
foldS :: V.Unbox a
=> M.IOVector a
-> (Int# -> a)
-> (a -> a -> a)
-> a
-> Int#
-> IO ()
{-# INLINE [1] foldS #-}
foldS !vec get c !r !n
= iter 0# 0#
where
!(I# end) = M.length vec
{-# INLINE iter #-}
iter !sh !sz
| 1# <- sh >=# end
= return ()
| otherwise
= do let !next = sz +# n
M.unsafeWrite vec (I# sh) (reduceAny get c r sz next)
iter (sh +# 1#) next
foldP :: V.Unbox a
=> M.IOVector a
-> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO ()
{-# INLINE [1] foldP #-}
foldP vec f c !r (I# n)
= gangIO theGang
$ \(I# tid) -> fill (split tid) (split (tid +# 1#))
where
!(I# threads) = gangSize theGang
!(I# len) = M.length vec
!step = (len +# threads -# 1#) `quotInt#` threads
{-# INLINE split #-}
split !ix
= let !ix' = ix *# step
in case len <# ix' of
0# -> ix'
_ -> len
{-# INLINE fill #-}
fill !start !end
= iter start (start *# n)
where
{-# INLINE iter #-}
iter !sh !sz
| 1# <- sh >=# end
= return ()
| otherwise
= do let !next = sz +# n
M.unsafeWrite vec (I# sh) (reduce f c r (I# sz) (I# next))
iter (sh +# 1#) next
foldAllS :: (Int# -> a)
-> (a -> a -> a)
-> a
-> Int#
-> a
{-# INLINE [1] foldAllS #-}
foldAllS f c !r !len
= reduceAny (\i -> f i) c r 0# len
foldAllP :: V.Unbox a
=> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO a
{-# INLINE [1] foldAllP #-}
foldAllP f c !r !len
| len == 0 = return r
| otherwise = do
mvec <- M.unsafeNew chunks
gangIO theGang $ \tid -> fill mvec tid (split tid) (split (tid+1))
vec <- V.unsafeFreeze mvec
return $! V.foldl' c r vec
where
!threads = gangSize theGang
!step = (len + threads - 1) `quotInt` threads
chunks = ((len + step - 1) `divInt` step) `min` threads
{-# INLINE split #-}
split !ix = len `min` (ix * step)
{-# INLINE fill #-}
fill !mvec !tid !start !end
| start >= end = return ()
| otherwise = M.unsafeWrite mvec tid (reduce f c (f start) (start+1) end)
{-# INLINE [0] reduce #-}
reduce :: (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> Int
-> a
reduce f c !r (I# start) (I# end)
= reduceAny (\i -> f (I# i)) c r start end
{-# INLINE [0] reduceAny #-}
reduceAny :: (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny f c !r !start !end
= iter start r
where
{-# INLINE iter #-}
iter !i !z
| 1# <- i >=# end = z
| otherwise = iter (i +# 1#) (z `c` f i)
{-# INLINE [0] reduceInt #-}
reduceInt
:: (Int# -> Int#)
-> (Int# -> Int# -> Int#)
-> Int#
-> Int# -> Int#
-> Int#
reduceInt f c !r !start !end
= iter start r
where
{-# INLINE iter #-}
iter !i !z
| 1# <- i >=# end = z
| otherwise = iter (i +# 1#) (z `c` f i)
{-# INLINE [0] reduceFloat #-}
reduceFloat
:: (Int# -> Float#)
-> (Float# -> Float# -> Float#)
-> Float#
-> Int# -> Int#
-> Float#
reduceFloat f c !r !start !end
= iter start r
where
{-# INLINE iter #-}
iter !i !z
| 1# <- i >=# end = z
| otherwise = iter (i +# 1#) (z `c` f i)
{-# INLINE [0] reduceDouble #-}
reduceDouble
:: (Int# -> Double#)
-> (Double# -> Double# -> Double#)
-> Double#
-> Int# -> Int#
-> Double#
reduceDouble f c !r !start !end
= iter start r
where
{-# INLINE iter #-}
iter !i !z
| 1# <- i >=# end = z
| otherwise = iter (i +# 1#) (z `c` f i)
{-# INLINE unboxInt #-}
unboxInt :: Int -> Int#
unboxInt (I# i) = i
{-# INLINE unboxFloat #-}
unboxFloat :: Float -> Float#
unboxFloat (F# f) = f
{-# INLINE unboxDouble #-}
unboxDouble :: Double -> Double#
unboxDouble (D# d) = d
{-# RULES "reduceInt"
forall (get :: Int# -> Int) f r start end
. reduceAny get f r start end
= I# (reduceInt
(\i -> unboxInt (get i))
(\d1 d2 -> unboxInt (f (I# d1) (I# d2)))
(unboxInt r)
start
end)
#-}
{-# RULES "reduceFloat"
forall (get :: Int# -> Float) f r start end
. reduceAny get f r start end
= F# (reduceFloat
(\i -> unboxFloat (get i))
(\d1 d2 -> unboxFloat (f (F# d1) (F# d2)))
(unboxFloat r)
start
end)
#-}
{-# RULES "reduceDouble"
forall (get :: Int# -> Double) f r start end
. reduceAny get f r start end
= D# (reduceDouble
(\i -> unboxDouble (get i))
(\d1 d2 -> unboxDouble (f (D# d1) (D# d2)))
(unboxDouble r)
start
end)
#-}