{-# LANGUAGE BangPatterns, ExplicitForAll, TypeOperators, MagicHash #-}
{-# OPTIONS -fno-warn-orphans #-}
module Data.Array.Repa.Operators.Reduction
( foldS, foldP
, foldAllS, foldAllP
, sumS, sumP
, sumAllS, sumAllP
, equalsS, equalsP)
where
import Data.Array.Repa.Base
import Data.Array.Repa.Index
import Data.Array.Repa.Eval
import Data.Array.Repa.Repr.Unboxed
import Data.Array.Repa.Operators.Mapping as R
import Data.Array.Repa.Shape as S
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as M
import Prelude hiding (sum)
import qualified Data.Array.Repa.Eval.Reduction as E
import System.IO.Unsafe
import GHC.Exts
foldS :: (Shape sh, Source r a, Unbox a)
=> (a -> a -> a)
-> a
-> Array r (sh :. Int) a
-> Array U sh a
foldS :: (a -> a -> a) -> a -> Array r (sh :. Int) a -> Array U sh a
foldS a -> a -> a
f a
z Array r (sh :. Int) a
arr
= Array r (sh :. Int) a
arr Array r (sh :. Int) a -> Array U sh a -> Array U sh a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
let sh :: sh :. Int
sh@(sh
sz :. Int
n') = Array r (sh :. Int) a -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (sh :. Int) a
arr
!(I# Int#
n) = Int
n'
in IO (Array U sh a) -> Array U sh a
forall a. IO a -> a
unsafePerformIO
(IO (Array U sh a) -> Array U sh a)
-> IO (Array U sh a) -> Array U sh a
forall a b. (a -> b) -> a -> b
$ do IOVector a
mvec <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew (sh -> Int
forall sh. Shape sh => sh -> Int
S.size sh
sz)
IOVector a -> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> IO ()
forall a.
Unbox a =>
IOVector a -> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> IO ()
E.foldS IOVector a
mvec (\Int#
ix -> Array r (sh :. Int) a
arr Array r (sh :. Int) a -> (sh :. Int) -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` (sh :. Int) -> Int -> sh :. Int
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh :. Int
sh (Int# -> Int
I# Int#
ix)) a -> a -> a
f a
z Int#
n
!Vector a
vec <- MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze IOVector a
MVector (PrimState IO) a
mvec
Array U sh a -> IO (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now (Array U sh a -> IO (Array U sh a))
-> Array U sh a -> IO (Array U sh a)
forall a b. (a -> b) -> a -> b
$ sh -> Vector a -> Array U sh a
forall sh e. sh -> Vector e -> Array U sh e
fromUnboxed sh
sz Vector a
vec
{-# INLINE [1] foldS #-}
foldP :: (Shape sh, Source r a, Unbox a, Monad m)
=> (a -> a -> a)
-> a
-> Array r (sh :. Int) a
-> m (Array U sh a)
foldP :: (a -> a -> a) -> a -> Array r (sh :. Int) a -> m (Array U sh a)
foldP a -> a -> a
f a
z Array r (sh :. Int) a
arr
= Array r (sh :. Int) a
arr Array r (sh :. Int) a -> m (Array U sh a) -> m (Array U sh a)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
let sh :: sh :. Int
sh@(sh
sz :. Int
n) = Array r (sh :. Int) a -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (sh :. Int) a
arr
in case (sh :. Int) -> Int
forall sh. Shape sh => sh -> Int
rank sh :. Int
sh of
Int
1 -> do
a
x <- (a -> a -> a) -> a -> Array r (sh :. Int) a -> m a
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r sh a -> m a
foldAllP a -> a -> a
f a
z Array r (sh :. Int) a
arr
Array U sh a -> m (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now (Array U sh a -> m (Array U sh a))
-> Array U sh a -> m (Array U sh a)
forall a b. (a -> b) -> a -> b
$ sh -> Vector a -> Array U sh a
forall sh e. sh -> Vector e -> Array U sh e
fromUnboxed sh
sz (Vector a -> Array U sh a) -> Vector a -> Array U sh a
forall a b. (a -> b) -> a -> b
$ a -> Vector a
forall a. Unbox a => a -> Vector a
V.singleton a
x
Int
_ -> Array U sh a -> m (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now
(Array U sh a -> m (Array U sh a))
-> Array U sh a -> m (Array U sh a)
forall a b. (a -> b) -> a -> b
$ IO (Array U sh a) -> Array U sh a
forall a. IO a -> a
unsafePerformIO
(IO (Array U sh a) -> Array U sh a)
-> IO (Array U sh a) -> Array U sh a
forall a b. (a -> b) -> a -> b
$ do IOVector a
mvec <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew (sh -> Int
forall sh. Shape sh => sh -> Int
S.size sh
sz)
IOVector a -> (Int -> a) -> (a -> a -> a) -> a -> Int -> IO ()
forall a.
Unbox a =>
IOVector a -> (Int -> a) -> (a -> a -> a) -> a -> Int -> IO ()
E.foldP IOVector a
mvec (\Int
ix -> Array r (sh :. Int) a
arr Array r (sh :. Int) a -> (sh :. Int) -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` (sh :. Int) -> Int -> sh :. Int
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh :. Int
sh Int
ix) a -> a -> a
f a
z Int
n
!Vector a
vec <- MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze IOVector a
MVector (PrimState IO) a
mvec
Array U sh a -> IO (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now (Array U sh a -> IO (Array U sh a))
-> Array U sh a -> IO (Array U sh a)
forall a b. (a -> b) -> a -> b
$ sh -> Vector a -> Array U sh a
forall sh e. sh -> Vector e -> Array U sh e
fromUnboxed sh
sz Vector a
vec
{-# INLINE [1] foldP #-}
foldAllS :: (Shape sh, Source r a)
=> (a -> a -> a)
-> a
-> Array r sh a
-> a
foldAllS :: (a -> a -> a) -> a -> Array r sh a -> a
foldAllS a -> a -> a
f a
z Array r sh a
arr
= Array r sh a
arr Array r sh a -> a -> a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
let !ex :: sh
ex = Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr
!(I# Int#
n) = sh -> Int
forall sh. Shape sh => sh -> Int
size sh
ex
in (Int# -> a) -> (a -> a -> a) -> a -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> a
E.foldAllS
(\Int#
ix -> Array r sh a
arr Array r sh a -> sh -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
ex (Int# -> Int
I# Int#
ix))
a -> a -> a
f a
z Int#
n
{-# INLINE [1] foldAllS #-}
foldAllP
:: (Shape sh, Source r a, Unbox a, Monad m)
=> (a -> a -> a)
-> a
-> Array r sh a
-> m a
foldAllP :: (a -> a -> a) -> a -> Array r sh a -> m a
foldAllP a -> a -> a
f a
z Array r sh a
arr
= Array r sh a
arr Array r sh a -> m a -> m a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
let sh :: sh
sh = Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr
n :: Int
n = sh -> Int
forall sh. Shape sh => sh -> Int
size sh
sh
in a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
(a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ IO a -> a
forall a. IO a -> a
unsafePerformIO
(IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ (Int -> a) -> (a -> a -> a) -> a -> Int -> IO a
forall a.
Unbox a =>
(Int -> a) -> (a -> a -> a) -> a -> Int -> IO a
E.foldAllP (\Int
ix -> Array r sh a
arr Array r sh a -> sh -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
sh Int
ix) a -> a -> a
f a
z Int
n
{-# INLINE [1] foldAllP #-}
sumS :: (Shape sh, Source r a, Num a, Unbox a)
=> Array r (sh :. Int) a
-> Array U sh a
sumS :: Array r (sh :. Int) a -> Array U sh a
sumS = (a -> a -> a) -> a -> Array r (sh :. Int) a -> Array U sh a
forall sh r a.
(Shape sh, Source r a, Unbox a) =>
(a -> a -> a) -> a -> Array r (sh :. Int) a -> Array U sh a
foldS a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0
{-# INLINE [3] sumS #-}
sumP :: (Shape sh, Source r a, Num a, Unbox a, Monad m)
=> Array r (sh :. Int) a
-> m (Array U sh a)
sumP :: Array r (sh :. Int) a -> m (Array U sh a)
sumP = (a -> a -> a) -> a -> Array r (sh :. Int) a -> m (Array U sh a)
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r (sh :. Int) a -> m (Array U sh a)
foldP a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0
{-# INLINE [3] sumP #-}
sumAllS :: (Shape sh, Source r a, Num a)
=> Array r sh a
-> a
sumAllS :: Array r sh a -> a
sumAllS = (a -> a -> a) -> a -> Array r sh a -> a
forall sh r a.
(Shape sh, Source r a) =>
(a -> a -> a) -> a -> Array r sh a -> a
foldAllS a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0
{-# INLINE [3] sumAllS #-}
sumAllP :: (Shape sh, Source r a, Unbox a, Num a, Monad m)
=> Array r sh a
-> m a
sumAllP :: Array r sh a -> m a
sumAllP = (a -> a -> a) -> a -> Array r sh a -> m a
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r sh a -> m a
foldAllP a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0
{-# INLINE [3] sumAllP #-}
instance (Shape sh, Eq sh, Source r a, Eq a) => Eq (Array r sh a) where
== :: Array r sh a -> Array r sh a -> Bool
(==) Array r sh a
arr1 Array r sh a
arr2
= Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr1 sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
== Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr2
Bool -> Bool -> Bool
&& ((Bool -> Bool -> Bool) -> Bool -> Array D sh Bool -> Bool
forall sh r a.
(Shape sh, Source r a) =>
(a -> a -> a) -> a -> Array r sh a -> a
foldAllS Bool -> Bool -> Bool
(&&) Bool
True ((a -> a -> Bool) -> Array r sh a -> Array r sh a -> Array D sh Bool
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) Array r sh a
arr1 Array r sh a
arr2))
equalsP :: (Shape sh, Source r1 a, Source r2 a, Eq a, Monad m)
=> Array r1 sh a
-> Array r2 sh a
-> m Bool
equalsP :: Array r1 sh a -> Array r2 sh a -> m Bool
equalsP Array r1 sh a
arr1 Array r2 sh a
arr2
= do Bool
same <- (Bool -> Bool -> Bool) -> Bool -> Array D sh Bool -> m Bool
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r sh a -> m a
foldAllP Bool -> Bool -> Bool
(&&) Bool
True ((a -> a -> Bool)
-> Array r1 sh a -> Array r2 sh a -> Array D sh Bool
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) Array r1 sh a
arr1 Array r2 sh a
arr2)
Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ (Array r1 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh a
arr1 sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
== Array r2 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh a
arr2) Bool -> Bool -> Bool
&& Bool
same
equalsS :: (Shape sh, Source r1 a, Source r2 a, Eq a)
=> Array r1 sh a
-> Array r2 sh a
-> Bool
equalsS :: Array r1 sh a -> Array r2 sh a -> Bool
equalsS Array r1 sh a
arr1 Array r2 sh a
arr2
= Array r1 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh a
arr1 sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
== Array r2 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh a
arr2
Bool -> Bool -> Bool
&& ((Bool -> Bool -> Bool) -> Bool -> Array D sh Bool -> Bool
forall sh r a.
(Shape sh, Source r a) =>
(a -> a -> a) -> a -> Array r sh a -> a
foldAllS Bool -> Bool -> Bool
(&&) Bool
True ((a -> a -> Bool)
-> Array r1 sh a -> Array r2 sh a -> Array D sh Bool
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) Array r1 sh a
arr1 Array r2 sh a
arr2))