{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level2
where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Sugar ( Array(..) )
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Foreign.Marshal ( with )
import qualified Foreign.CUDA.Ptr as CUDA
import qualified Foreign.CUDA.BLAS as BLAS
gemv :: Numeric e
=> Transpose
-> ForeignAcc ((Scalar e, Matrix e, Vector e) -> Vector e)
gemv opA = ForeignAcc "ptx.gemv" (gemv' numericR opA)
gemv' :: Numeric e
=> NumericR e
-> Transpose
-> Stream
-> (Scalar e, Matrix e, Vector e)
-> LLVM PTX (Vector e)
gemv' NumericRcomplex32 H = as_gemm H
gemv' NumericRcomplex64 H = as_gemm H
gemv' _ t = as_gemv t
as_gemm
:: Numeric e
=> Transpose
-> Stream
-> (Scalar e, Matrix e, Vector e)
-> LLVM PTX (Vector e)
as_gemm opA stream (alpha, matA, Array sh adata) = do
let matB = Array (sh,1) adata
Array (sh',_) vecy <- gemm' opA N stream (alpha, matA, matB)
return (Array sh' vecy)
as_gemv
:: forall e. Numeric e
=> Transpose
-> Stream
-> (Scalar e, Matrix e, Vector e)
-> LLVM PTX (Vector e)
as_gemv opA stream (alpha, matA, vecx) = do
let
Z :. rowsA :. colsA = arrayShape matA
sizeY = case opA of
N -> rowsA
_ -> colsA
opA' = encodeTranspose
$ case opA of
N -> T
_ -> N
vecy <- allocateRemote (Z :. sizeY) :: LLVM PTX (Vector e)
alpha' <- indexRemote alpha 0
() <- do
withArray matA stream $ \ptr_A -> do
withArray vecx stream $ \ptr_x -> do
withArray vecy stream $ \ptr_y -> do
withBLAS $ \hdl -> do
case numericR :: NumericR e of
NumericRfloat32 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.sgemv hdl opA' colsA rowsA ptr_alpha ptr_A colsA ptr_x 1 ptr_beta ptr_y 1
NumericRfloat64 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.dgemv hdl opA' colsA rowsA ptr_alpha ptr_A colsA ptr_x 1 ptr_beta ptr_y 1
NumericRcomplex32 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.cgemv hdl opA' colsA rowsA ptr_alpha (CUDA.castDevPtr ptr_A) colsA (CUDA.castDevPtr ptr_x) 1 ptr_beta (CUDA.castDevPtr ptr_y) 1
NumericRcomplex64 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.zgemv hdl opA' colsA rowsA ptr_alpha (CUDA.castDevPtr ptr_A) colsA (CUDA.castDevPtr ptr_x) 1 ptr_beta (CUDA.castDevPtr ptr_y) 1
return vecy