{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level2
where
import Data.Complex
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.LLVM.Native.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Base
import Foreign.Ptr
import qualified Blas.Primitive.Types as C
import qualified Blas.Primitive.Unsafe as C
gemv :: NumericR s e
-> Transpose
-> ForeignAcc ((((((), Scalar e), Matrix e), Vector e)) -> Vector e)
gemv :: NumericR s e
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Vector e) -> Vector e)
gemv NumericR s e
nR Transpose
opA = String
-> (((((), Scalar e), Matrix e), Vector e)
-> Par Native (Future (Vector e)))
-> ForeignAcc (((((), Scalar e), Matrix e), Vector e) -> Vector e)
forall a b.
String -> (a -> Par Native (Future b)) -> ForeignAcc (a -> b)
ForeignAcc String
"native.gemv" (NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par Native (Future (Vector e))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par Native (Future (Vector e))
gemv' NumericR s e
nR Transpose
opA)
gemv' :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par Native (Future (Vector e))
gemv' :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par Native (Future (Vector e))
gemv' NumericR s e
nR Transpose
opA ((((), Scalar e
alpha), Matrix e
matA), Vector e
vecx) = do
let
(((), Int
rowsA), Int
colsA) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matA
sizeY :: Int
sizeY = case Transpose
opA of
Transpose
N -> Int
rowsA
Transpose
_ -> Int
colsA
opA' :: Transpose
opA' = Transpose -> Transpose
encodeTranspose Transpose
opA
alpha' :: e
alpha' = ArrayR (Scalar e) -> Scalar e -> () -> e
forall sh e. ArrayR (Array sh e) -> Array sh e -> sh -> e
indexArray (ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
eR) Scalar e
alpha ()
aR :: ArrayR (Vector e)
aR = ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR
eR :: TypeR e
eR = case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> Elt Float => TypeR (EltR Float)
forall a. Elt a => TypeR (EltR a)
eltR @Float
NumericR s e
NumericRfloat64 -> Elt Double => TypeR (EltR Double)
forall a. Elt a => TypeR (EltR a)
eltR @Double
NumericR s e
NumericRcomplex32 -> Elt (Complex Float) => TypeR (EltR (Complex Float))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Float)
NumericR s e
NumericRcomplex64 -> Elt (Complex Double) => TypeR (EltR (Complex Double))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Double)
Future (Vector e)
future <- Par Native (Future (Vector e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
Vector e
vecy <- ArrayR (Vector e) -> ((), Int) -> Par Native (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Vector e)
aR ((), Int
sizeY)
() <- IO () -> Par Native ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Par Native ()) -> IO () -> Par Native ()
forall a b. (a -> b) -> a -> b
$ do
NumericR s e -> Matrix e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Matrix e
matA ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_A -> do
NumericR s e -> Vector e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Vector e
vecx ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_x -> do
NumericR s e -> Vector e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Vector e
vecy ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_y -> do
case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> Order
-> Transpose
-> Int
-> Int
-> Float
-> Ptr Float
-> Int
-> Ptr Float
-> Int
-> Float
-> Ptr Float
-> Int
-> IO ()
C.sgemv Order
C.RowMajor Transpose
opA' Int
rowsA Int
colsA e
Float
alpha' Ptr Float
ArrayPtrs e
ptr_A Int
colsA Ptr Float
ArrayPtrs e
ptr_x Int
1 Float
0 Ptr Float
ArrayPtrs e
ptr_y Int
1
NumericR s e
NumericRfloat64 -> Order
-> Transpose
-> Int
-> Int
-> Double
-> Ptr Double
-> Int
-> Ptr Double
-> Int
-> Double
-> Ptr Double
-> Int
-> IO ()
C.dgemv Order
C.RowMajor Transpose
opA' Int
rowsA Int
colsA e
Double
alpha' Ptr Double
ArrayPtrs e
ptr_A Int
colsA Ptr Double
ArrayPtrs e
ptr_x Int
1 Double
0 Ptr Double
ArrayPtrs e
ptr_y Int
1
NumericR s e
NumericRcomplex32 -> Order
-> Transpose
-> Int
-> Int
-> Complex Float
-> Ptr (Complex Float)
-> Int
-> Ptr (Complex Float)
-> Int
-> Complex Float
-> Ptr (Complex Float)
-> Int
-> IO ()
C.cgemv Order
C.RowMajor Transpose
opA' Int
rowsA Int
colsA (EltR (Complex Float) -> Complex Float
forall a. Elt a => EltR a -> a
toElt e
EltR (Complex Float)
alpha') (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_A) Int
colsA (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_x) Int
1 Complex Float
0 (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_y) Int
1
NumericR s e
NumericRcomplex64 -> Order
-> Transpose
-> Int
-> Int
-> Complex Double
-> Ptr (Complex Double)
-> Int
-> Ptr (Complex Double)
-> Int
-> Complex Double
-> Ptr (Complex Double)
-> Int
-> IO ()
C.zgemv Order
C.RowMajor Transpose
opA' Int
rowsA Int
colsA (EltR (Complex Double) -> Complex Double
forall a. Elt a => EltR a -> a
toElt e
EltR (Complex Double)
alpha') (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_A) Int
colsA (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_x) Int
1 Complex Double
0 (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_y) Int
1
FutureR Native (Vector e) -> Vector e -> Par Native ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR Native (Vector e)
Future (Vector e)
future Vector e
vecy
Future (Vector e) -> Par Native (Future (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Vector e)
future