{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE UnboxedTuples       #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
  where

import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Data.Array.Accelerate.LLVM.PTX.Foreign

import Foreign.CUDA.Ptr                                             ( DevicePtr )
import qualified Foreign.CUDA.BLAS                                  as C

import GHC.Base
import GHC.Ptr


type family DevicePtrs e :: *
type instance DevicePtrs Float         = DevicePtr Float
type instance DevicePtrs Double        = DevicePtr Double
type instance DevicePtrs (Vec2 Float)  = DevicePtr Float
type instance DevicePtrs (Vec2 Double) = DevicePtr Double


encodeTranspose :: Transpose -> C.Operation
encodeTranspose :: Transpose -> Operation
encodeTranspose Transpose
N = Operation
C.N
encodeTranspose Transpose
T = Operation
C.T
encodeTranspose Transpose
H = Operation
C.C


{-# INLINE withV2 #-}
withV2
    :: NumericR s (Vec2 a)
    -> Vec2 a
    -> (Ptr (Complex a) -> IO b)
    -> IO b
withV2 :: NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s (Vec2 a)
nR (Vec ByteArray#
ba#) Ptr (Complex a) -> IO b
k =
  let !(I# Int#
bytes#) = case NumericR s (Vec2 a)
nR of
                       NumericR s (Vec2 a)
NumericRcomplex32 -> Int
8
                       NumericR s (Vec2 a)
NumericRcomplex64 -> Int
16
  in
  case ByteArray# -> Int#
isByteArrayPinned# ByteArray#
ba# of
   Int#
1# -> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, b #)) -> IO b)
-> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
     case Ptr (Complex a) -> IO b
k (Addr# -> Ptr (Complex a)
forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba#)) of { IO State# RealWorld -> (# State# RealWorld, b #)
k#       ->
     case State# RealWorld -> (# State# RealWorld, b #)
k# State# RealWorld
s0                            of { (# State# RealWorld
s1, b
r #) ->
     case ByteArray# -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# ByteArray#
ba# State# RealWorld
s1                    of { State# RealWorld
s2          ->
       (# State# RealWorld
s2, b
r #) }}}
  --
   Int#
_  -> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, b #)) -> IO b)
-> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
     case Int#
-> Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
bytes# Int#
16# State# RealWorld
s0                        of { (# State# RealWorld
s1, MutableByteArray# RealWorld
mba# #) ->
     case Addr#
-> MutableByteArray# RealWorld
-> Int#
-> Int#
-> State# RealWorld
-> State# RealWorld
forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba#) MutableByteArray# RealWorld
mba# Int#
0# Int#
bytes# State# RealWorld
s1 of { State# RealWorld
s2             ->
     case MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
mba# State# RealWorld
s2                                  of { (# State# RealWorld
s3, ByteArray#
ba'# #) ->
     case Ptr (Complex a) -> IO b
k (Addr# -> Ptr (Complex a)
forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba'#))                               of { IO State# RealWorld -> (# State# RealWorld, b #)
k#          ->
     case State# RealWorld -> (# State# RealWorld, b #)
k# State# RealWorld
s3                                                           of { (# State# RealWorld
s4, b
r #)    ->
     case ByteArray# -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# ByteArray#
ba'# State# RealWorld
s4                                                  of { State# RealWorld
s5             ->
       (# State# RealWorld
s5, b
r #) }}}}}}

{-# INLINE withArray #-}
withArray
    :: NumericR s e
    -> Array sh e
    -> Stream
    -> (DevicePtrs e -> LLVM PTX b)
    -> LLVM PTX b
withArray :: NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
eR (Array sh
_ ArrayData e
adata) Stream
s DevicePtrs e -> LLVM PTX b
k =
  NumericR s e
-> ArrayData e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
forall s e b.
NumericR s e
-> ArrayData e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArrayData NumericR s e
eR ArrayData e
adata Stream
s DevicePtrs e -> LLVM PTX b
k

{-# INLINE withArrayData #-}
withArrayData
    :: NumericR s e
    -> ArrayData e
    -> Stream
    -> (DevicePtrs e -> LLVM PTX b)
    -> LLVM PTX b
withArrayData :: NumericR s e
-> ArrayData e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArrayData NumericR s e
NumericRfloat32 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
  SingleType Float
-> ArrayData Float
-> (DevicePtr (ScalarArrayDataR Float)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Float => SingleType Float
forall a. IsSingle a => SingleType a
singleType @Float) ArrayData e
ArrayData Float
ad ((DevicePtr (ScalarArrayDataR Float) -> LLVM PTX (Maybe Event, b))
 -> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Float)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Float)
p -> do
    b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Float)
DevicePtrs e
p
    Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
    (Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e,b
r)
withArrayData NumericR s e
NumericRfloat64 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
  SingleType Double
-> ArrayData Double
-> (DevicePtr (ScalarArrayDataR Double)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Double => SingleType Double
forall a. IsSingle a => SingleType a
singleType @Double) ArrayData e
ArrayData Double
ad ((DevicePtr (ScalarArrayDataR Double) -> LLVM PTX (Maybe Event, b))
 -> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Double)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Double)
p -> do
    b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Double)
DevicePtrs e
p
    Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
    (Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e, b
r)
withArrayData NumericR s e
NumericRcomplex32 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
  SingleType Float
-> ArrayData Float
-> (DevicePtr (ScalarArrayDataR Float)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Float => SingleType Float
forall a. IsSingle a => SingleType a
singleType @Float) ArrayData e
ArrayData Float
ad ((DevicePtr (ScalarArrayDataR Float) -> LLVM PTX (Maybe Event, b))
 -> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Float)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Float)
p -> do
    b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Float)
DevicePtrs e
p
    Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
    (Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e,b
r)
withArrayData NumericR s e
NumericRcomplex64 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
  SingleType Double
-> ArrayData Double
-> (DevicePtr (ScalarArrayDataR Double)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Double => SingleType Double
forall a. IsSingle a => SingleType a
singleType @Double) ArrayData e
ArrayData Double
ad ((DevicePtr (ScalarArrayDataR Double) -> LLVM PTX (Maybe Event, b))
 -> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Double)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Double)
p -> do
    b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Double)
DevicePtrs e
p
    Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
    (Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e, b
r)

{-# INLINE withLifetime' #-}
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' Lifetime a
l a -> LLVM PTX b
k = do
  b
r <- a -> LLVM PTX b
k (Lifetime a -> a
forall a. Lifetime a -> a
unsafeGetValue Lifetime a
l)
  IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ Lifetime a -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime a
l
  b -> LLVM PTX b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r