{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level2
where
import Data.Complex
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Foreign.Marshal ( with )
import qualified Foreign.CUDA.Ptr as CUDA
import qualified Foreign.CUDA.BLAS as BLAS
import Control.Monad.Reader
gemv :: NumericR s e
-> Transpose
-> ForeignAcc ((((((), Scalar e), Matrix e), Vector e)) -> Vector e)
gemv :: NumericR s e
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Vector e) -> Vector e)
gemv NumericR s e
eR Transpose
opA = String
-> (((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e)))
-> ForeignAcc (((((), Scalar e), Matrix e), Vector e) -> Vector e)
forall a b.
String -> (a -> Par PTX (Future b)) -> ForeignAcc (a -> b)
ForeignAcc String
"ptx.gemv" (NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
gemv' NumericR s e
eR Transpose
opA)
gemv' :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
gemv' :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
gemv' NumericR s e
NumericRcomplex32 Transpose
H = NumericR (Complex Float) (Vec2 Float)
-> Transpose
-> ((((), Scalar (Vec2 Float)), Matrix (Vec2 Float)),
Vector (Vec2 Float))
-> Par PTX (Future (Vector (Vec2 Float)))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemm NumericR (Complex Float) (Vec2 Float)
NumericRcomplex32 Transpose
H
gemv' NumericR s e
NumericRcomplex64 Transpose
H = NumericR (Complex Double) (Vec2 Double)
-> Transpose
-> ((((), Scalar (Vec2 Double)), Matrix (Vec2 Double)),
Vector (Vec2 Double))
-> Par PTX (Future (Vector (Vec2 Double)))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemm NumericR (Complex Double) (Vec2 Double)
NumericRcomplex64 Transpose
H
gemv' NumericR s e
nR Transpose
t = NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemv NumericR s e
nR Transpose
t
as_gemm
:: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemm :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemm NumericR s e
nR Transpose
opA ((((), Scalar e
alpha), Matrix e
matA), Array DIM1
sh ArrayData e
adata) = do
let matB :: Matrix e
matB = (DIM1, Int) -> ArrayData e -> Matrix e
forall sh e. sh -> ArrayData e -> Array sh e
Array (DIM1
sh,Int
1) ArrayData e
adata
Future (Vector e)
future <- Par PTX (Future (Vector e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
Future (Matrix e)
result <- NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
N ((((), Scalar e
alpha), Matrix e
matA), Matrix e
matB)
Par PTX () -> Par PTX ()
forall arch.
(Async arch, HasCallStack) =>
Par arch () -> Par arch ()
fork (Par PTX () -> Par PTX ()) -> Par PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do Array (DIM1
sh',Int
_) ArrayData e
vecy <- FutureR PTX (Matrix e) -> Par PTX (Matrix e)
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> Par arch a
get FutureR PTX (Matrix e)
Future (Matrix e)
result
FutureR PTX (Vector e) -> Vector e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Vector e)
Future (Vector e)
future (DIM1 -> ArrayData e -> Vector e
forall sh e. sh -> ArrayData e -> Array sh e
Array DIM1
sh' ArrayData e
vecy)
Future (Vector e) -> Par PTX (Future (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Vector e)
future
as_gemv
:: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemv :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemv NumericR s e
nR Transpose
opA ((((), Scalar e
alpha), Matrix e
matA), Vector e
vecx) = do
let
(((), Int
rowsA), Int
colsA) = Matrix e -> (DIM1, Int)
forall sh e. Array sh e -> sh
shape Matrix e
matA
sizeY :: Int
sizeY = case Transpose
opA of
Transpose
N -> Int
rowsA
Transpose
_ -> Int
colsA
opA' :: Operation
opA' = Transpose -> Operation
encodeTranspose
(Transpose -> Operation) -> Transpose -> Operation
forall a b. (a -> b) -> a -> b
$ case Transpose
opA of
Transpose
N -> Transpose
T
Transpose
_ -> Transpose
N
aR :: ArrayR (Vector e)
aR = ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
eR
eR :: TypeR e
eR = case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> Elt Float => TypeR (EltR Float)
forall a. Elt a => TypeR (EltR a)
eltR @Float
NumericR s e
NumericRfloat64 -> Elt Double => TypeR (EltR Double)
forall a. Elt a => TypeR (EltR a)
eltR @Double
NumericR s e
NumericRcomplex32 -> Elt (Complex Float) => TypeR (EltR (Complex Float))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Float)
NumericR s e
NumericRcomplex64 -> Elt (Complex Double) => TypeR (EltR (Complex Double))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Double)
Future (Vector e)
future <- Par PTX (Future (Vector e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
Vector e
vecy <- ArrayR (Vector e) -> DIM1 -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Vector e)
aR ((), Int
sizeY)
e
alpha' <- TypeR e -> Scalar e -> Int -> Par PTX e
forall arch e sh.
Remote arch =>
TypeR e -> Array sh e -> Int -> Par arch e
indexRemote TypeR e
eR Scalar e
alpha Int
0
() <- LLVM PTX () -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (LLVM PTX () -> Par PTX ()) -> LLVM PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do
NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matA Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_A -> do
NumericR s e
-> Vector e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Vector e
vecx Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_x -> do
NumericR s e
-> Vector e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Vector e
vecy Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_y -> do
(Handle -> LLVM PTX ()) -> LLVM PTX ()
forall b. (Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS ((Handle -> LLVM PTX ()) -> LLVM PTX ())
-> (Handle -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Handle
hdl -> do
case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
Float -> (Ptr Float -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Float
0 ((Ptr Float -> IO ()) -> IO ()) -> (Ptr Float -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Float
ptr_beta ->
Handle
-> Operation
-> Int
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> DevicePtr Float
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> IO ()
BLAS.sgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr e
Ptr Float
ptr_alpha DevicePtr Float
DevicePtrs e
ptr_A Int
colsA DevicePtr Float
DevicePtrs e
ptr_x Int
1 Ptr Float
ptr_beta DevicePtr Float
DevicePtrs e
ptr_y Int
1
NumericR s e
NumericRfloat64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Double
0 ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Double
ptr_beta ->
Handle
-> Operation
-> Int
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> DevicePtr Double
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> IO ()
BLAS.dgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr e
Ptr Double
ptr_alpha DevicePtr Double
DevicePtrs e
ptr_A Int
colsA DevicePtr Double
DevicePtrs e
ptr_x Int
1 Ptr Double
ptr_beta DevicePtr Double
DevicePtrs e
ptr_y Int
1
NumericR s e
NumericRcomplex32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
NumericR s (Vec2 Float)
-> Vec2 Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Float)
nR e
Vec2 Float
alpha' ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_alpha ->
Complex Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Float
0 ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_beta ->
Handle
-> Operation
-> Int
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> DevicePtr (Complex Float)
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> IO ()
BLAS.cgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr (Complex Float)
ptr_alpha (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_A) Int
colsA (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_x) Int
1 Ptr (Complex Float)
ptr_beta (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_y) Int
1
NumericR s e
NumericRcomplex64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
NumericR s (Vec2 Double)
-> Vec2 Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Double)
nR e
Vec2 Double
alpha' ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_alpha ->
Complex Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Double
0 ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_beta ->
Handle
-> Operation
-> Int
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> DevicePtr (Complex Double)
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> IO ()
BLAS.zgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr (Complex Double)
ptr_alpha (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_A) Int
colsA (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_x) Int
1 Ptr (Complex Double)
ptr_beta (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_y) Int
1
FutureR PTX (Vector e) -> Vector e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Vector e)
Future (Vector e)
future Vector e
vecy
Future (Vector e) -> Par PTX (Future (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Vector e)
future