{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level2
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level2
  where

import Data.Complex
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Sugar.Elt

import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type

import Foreign.Marshal                                              ( with )
import qualified Foreign.CUDA.Ptr                                   as CUDA
import qualified Foreign.CUDA.BLAS                                  as BLAS

import Control.Monad.Reader


-- NOTE: cuBLAS requires matrices to be stored in column-major order
-- (Fortran-style), but Accelerate uses C-style arrays in row-major order.
--
-- If the operation is N or T, we can just swap the operation. For
-- conjugate-transpose (H) operations (on complex valued arguments), since there
-- is no conjugate-no-transpose operation, we implement that via 'gemm', which
-- I assume is more efficient than ?geam followed by ?gemv.
--
gemv :: NumericR s e
     -> Transpose
     -> ForeignAcc ((((((), Scalar e), Matrix e), Vector e)) -> Vector e)
gemv :: NumericR s e
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Vector e) -> Vector e)
gemv NumericR s e
eR Transpose
opA = String
-> (((((), Scalar e), Matrix e), Vector e)
    -> Par PTX (Future (Vector e)))
-> ForeignAcc (((((), Scalar e), Matrix e), Vector e) -> Vector e)
forall a b.
String -> (a -> Par PTX (Future b)) -> ForeignAcc (a -> b)
ForeignAcc String
"ptx.gemv" (NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
gemv' NumericR s e
eR Transpose
opA)

gemv' :: NumericR s e
      -> Transpose
      -> ((((), Scalar e), Matrix e), Vector e)
      -> Par PTX (Future (Vector e))
gemv' :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
gemv' NumericR s e
NumericRcomplex32 Transpose
H = NumericR (Complex Float) (Vec2 Float)
-> Transpose
-> ((((), Scalar (Vec2 Float)), Matrix (Vec2 Float)),
    Vector (Vec2 Float))
-> Par PTX (Future (Vector (Vec2 Float)))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemm NumericR (Complex Float) (Vec2 Float)
NumericRcomplex32 Transpose
H
gemv' NumericR s e
NumericRcomplex64 Transpose
H = NumericR (Complex Double) (Vec2 Double)
-> Transpose
-> ((((), Scalar (Vec2 Double)), Matrix (Vec2 Double)),
    Vector (Vec2 Double))
-> Par PTX (Future (Vector (Vec2 Double)))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemm NumericR (Complex Double) (Vec2 Double)
NumericRcomplex64 Transpose
H
gemv' NumericR s e
nR                Transpose
t = NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
forall s e.
NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemv NumericR s e
nR Transpose
t


as_gemm
    :: NumericR s e
    -> Transpose
    -> ((((), Scalar e), Matrix e), Vector e)
    -> Par PTX (Future (Vector e))
as_gemm :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemm NumericR s e
nR Transpose
opA ((((), Scalar e
alpha), Matrix e
matA), Array DIM1
sh ArrayData e
adata) = do
  let matB :: Matrix e
matB = (DIM1, Int) -> ArrayData e -> Matrix e
forall sh e. sh -> ArrayData e -> Array sh e
Array (DIM1
sh,Int
1) ArrayData e
adata
  --
  Future (Vector e)
future <- Par PTX (Future (Vector e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
  Future (Matrix e)
result <- NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
N ((((), Scalar e
alpha), Matrix e
matA), Matrix e
matB)
  Par PTX () -> Par PTX ()
forall arch.
(Async arch, HasCallStack) =>
Par arch () -> Par arch ()
fork (Par PTX () -> Par PTX ()) -> Par PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do Array (DIM1
sh',Int
_) ArrayData e
vecy <- FutureR PTX (Matrix e) -> Par PTX (Matrix e)
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> Par arch a
get FutureR PTX (Matrix e)
Future (Matrix e)
result
            FutureR PTX (Vector e) -> Vector e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Vector e)
Future (Vector e)
future (DIM1 -> ArrayData e -> Vector e
forall sh e. sh -> ArrayData e -> Array sh e
Array DIM1
sh' ArrayData e
vecy)
  Future (Vector e) -> Par PTX (Future (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Vector e)
future

as_gemv
    :: NumericR s e
    -> Transpose
    -> ((((), Scalar e), Matrix e), Vector e)
    -> Par PTX (Future (Vector e))
as_gemv :: NumericR s e
-> Transpose
-> ((((), Scalar e), Matrix e), Vector e)
-> Par PTX (Future (Vector e))
as_gemv NumericR s e
nR Transpose
opA ((((), Scalar e
alpha), Matrix e
matA), Vector e
vecx) = do
  let
      (((), Int
rowsA), Int
colsA) = Matrix e -> (DIM1, Int)
forall sh e. Array sh e -> sh
shape Matrix e
matA

      sizeY :: Int
sizeY   = case Transpose
opA of
                  Transpose
N -> Int
rowsA
                  Transpose
_ -> Int
colsA

      opA' :: Operation
opA'    = Transpose -> Operation
encodeTranspose
              (Transpose -> Operation) -> Transpose -> Operation
forall a b. (a -> b) -> a -> b
$ case Transpose
opA of
                  Transpose
N -> Transpose
T
                  Transpose
_ -> Transpose
N

      aR :: ArrayR (Vector e)
aR      = ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
eR
      eR :: TypeR e
eR      = case NumericR s e
nR of
                  NumericR s e
NumericRfloat32   -> Elt Float => TypeR (EltR Float)
forall a. Elt a => TypeR (EltR a)
eltR @Float
                  NumericR s e
NumericRfloat64   -> Elt Double => TypeR (EltR Double)
forall a. Elt a => TypeR (EltR a)
eltR @Double
                  NumericR s e
NumericRcomplex32 -> Elt (Complex Float) => TypeR (EltR (Complex Float))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Float)
                  NumericR s e
NumericRcomplex64 -> Elt (Complex Double) => TypeR (EltR (Complex Double))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Double)
  --
  Future (Vector e)
future  <- Par PTX (Future (Vector e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
  Stream
stream  <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
  Vector e
vecy    <- ArrayR (Vector e) -> DIM1 -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Vector e)
aR ((), Int
sizeY)
  e
alpha'  <- TypeR e -> Scalar e -> Int -> Par PTX e
forall arch e sh.
Remote arch =>
TypeR e -> Array sh e -> Int -> Par arch e
indexRemote TypeR e
eR Scalar e
alpha Int
0
  ()      <- LLVM PTX () -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (LLVM PTX () -> Par PTX ()) -> LLVM PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do
    NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matA Stream
stream   ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_A -> do
     NumericR s e
-> Vector e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Vector e
vecx Stream
stream  ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_x -> do
      NumericR s e
-> Vector e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Vector e
vecy Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_y -> do
       (Handle -> LLVM PTX ()) -> LLVM PTX ()
forall b. (Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS                ((Handle -> LLVM PTX ()) -> LLVM PTX ())
-> (Handle -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Handle
hdl   -> do
         case NumericR s e
nR of
           NumericR s e
NumericRfloat32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
            e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
             Float -> (Ptr Float -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Float
0     ((Ptr Float -> IO ()) -> IO ()) -> (Ptr Float -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Float
ptr_beta  ->
               Handle
-> Operation
-> Int
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> DevicePtr Float
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> IO ()
BLAS.sgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr e
Ptr Float
ptr_alpha DevicePtr Float
DevicePtrs e
ptr_A Int
colsA DevicePtr Float
DevicePtrs e
ptr_x Int
1 Ptr Float
ptr_beta DevicePtr Float
DevicePtrs e
ptr_y Int
1

           NumericR s e
NumericRfloat64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
            e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
             Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Double
0     ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Double
ptr_beta  ->
               Handle
-> Operation
-> Int
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> DevicePtr Double
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> IO ()
BLAS.dgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr e
Ptr Double
ptr_alpha DevicePtr Double
DevicePtrs e
ptr_A Int
colsA DevicePtr Double
DevicePtrs e
ptr_x Int
1 Ptr Double
ptr_beta DevicePtr Double
DevicePtrs e
ptr_y Int
1

           NumericR s e
NumericRcomplex32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
            NumericR s (Vec2 Float)
-> Vec2 Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Float)
nR e
Vec2 Float
alpha' ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_alpha ->
             Complex Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Float
0          ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_beta  ->
               Handle
-> Operation
-> Int
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> DevicePtr (Complex Float)
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> IO ()
BLAS.cgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr (Complex Float)
ptr_alpha (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_A) Int
colsA (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_x) Int
1 Ptr (Complex Float)
ptr_beta (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_y)  Int
1

           NumericR s e
NumericRcomplex64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
            NumericR s (Vec2 Double)
-> Vec2 Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Double)
nR e
Vec2 Double
alpha' ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_alpha ->
             Complex Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Double
0          ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_beta  ->
               Handle
-> Operation
-> Int
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> DevicePtr (Complex Double)
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> IO ()
BLAS.zgemv Handle
hdl Operation
opA' Int
colsA Int
rowsA Ptr (Complex Double)
ptr_alpha (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_A) Int
colsA (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_x) Int
1 Ptr (Complex Double)
ptr_beta (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_y)  Int
1
  --
  FutureR PTX (Vector e) -> Vector e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Vector e)
Future (Vector e)
future Vector e
vecy
  Future (Vector e) -> Par PTX (Future (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Vector e)
future