{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level1 (
Numeric, Vector,
sdot,
dotu,
dotc,
asum,
amax,
amin,
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Complex as A
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
sdot :: forall e. Numeric e => Exp e -> Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
sdot z xs ys =
case numericR :: NumericR e of
NumericRfloat32 -> map toFloating $ dsdot (toFloating z) (map toFloating xs) (map toFloating ys)
NumericRfloat64 -> dsdot z xs ys
NumericRcomplex32 -> map d2f $ zsdot (f2d z) (map f2d xs) (map f2d ys)
NumericRcomplex64 -> zsdot z xs ys
where
dsdot :: Exp Double -> Acc (Vector Double) -> Acc (Vector Double) -> Acc (Scalar Double)
dsdot z' xs' ys' = fold (+) z' (zipWith (*) xs' ys')
zsdot :: Exp (Complex Double) -> Acc (Vector (Complex Double)) -> Acc (Vector (Complex Double)) -> Acc (Scalar (Complex Double))
zsdot z' xs' ys' = fold (+) z' (zipWith (*) xs' ys')
f2d :: Exp (Complex Float) -> Exp (Complex Double)
f2d c = lift (toFloating (real c) :+ toFloating (imag c))
d2f :: Exp (Complex Double) -> Exp (Complex Float)
d2f c = lift (toFloating (real c) :+ toFloating (imag c))
dotu :: Numeric e => Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
dotu xs ys = fold (+) 0 (zipWith (*) xs ys)
dotc :: forall e. Numeric (Complex e)
=> Acc (Vector (Complex e))
-> Acc (Vector (Complex e))
-> Acc (Scalar (Complex e))
dotc xs ys =
case numericR :: NumericR (Complex e) of
NumericRcomplex32 -> dotu (map conjugate xs) ys
NumericRcomplex64 -> dotu (map conjugate xs) ys
asum :: forall e. Numeric e => Acc (Vector e) -> Acc (Scalar (NumericBaseT e))
asum =
case numericR :: NumericR e of
NumericRfloat32 -> sum . map abs
NumericRfloat64 -> sum . map abs
NumericRcomplex32 -> sum . map mag
NumericRcomplex64 -> sum . map mag
where
mag c = abs (real c) + abs (imag c)
amax :: forall e. Numeric e => Acc (Vector e) -> Acc (Scalar Int)
amax =
case numericR :: NumericR e of
NumericRfloat32 -> map (indexHead . fst) . fold1 cmp . indexed . map abs
NumericRfloat64 -> map (indexHead . fst) . fold1 cmp . indexed . map abs
NumericRcomplex32 -> map (indexHead . fst) . fold1 cmp . indexed . map mag
NumericRcomplex64 -> map (indexHead . fst) . fold1 cmp . indexed . map mag
where
cmp ix iy = snd ix > snd iy ? ( ix, iy )
mag c = abs (real c) + abs (imag c)
amin :: forall e. Numeric e => Acc (Vector e) -> Acc (Scalar Int)
amin =
case numericR :: NumericR e of
NumericRfloat32 -> map (indexHead . fst) . fold1 cmp . indexed . map abs
NumericRfloat64 -> map (indexHead . fst) . fold1 cmp . indexed . map abs
NumericRcomplex32 -> map (indexHead . fst) . fold1 cmp . indexed . map mag
NumericRcomplex64 -> map (indexHead . fst) . fold1 cmp . indexed . map mag
where
cmp ix iy = snd ix < snd iy ? ( ix, iy )
mag c = abs (real c) + abs (imag c)