{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
where
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Type
import Data.Primitive.Vec
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Foreign.CUDA.Ptr ( DevicePtr )
import qualified Foreign.CUDA.BLAS as C
import GHC.Base
import GHC.Ptr
type family DevicePtrs e :: *
type instance DevicePtrs Float = DevicePtr Float
type instance DevicePtrs Double = DevicePtr Double
type instance DevicePtrs (Vec2 Float) = DevicePtr Float
type instance DevicePtrs (Vec2 Double) = DevicePtr Double
encodeTranspose :: Transpose -> C.Operation
encodeTranspose :: Transpose -> Operation
encodeTranspose Transpose
N = Operation
C.N
encodeTranspose Transpose
T = Operation
C.T
encodeTranspose Transpose
H = Operation
C.C
{-# INLINE withV2 #-}
withV2
:: NumericR s (Vec2 a)
-> Vec2 a
-> (Ptr (Complex a) -> IO b)
-> IO b
withV2 :: NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s (Vec2 a)
nR (Vec ByteArray#
ba#) Ptr (Complex a) -> IO b
k =
let !(I# Int#
bytes#) = case NumericR s (Vec2 a)
nR of
NumericR s (Vec2 a)
NumericRcomplex32 -> Int
8
NumericR s (Vec2 a)
NumericRcomplex64 -> Int
16
in
case ByteArray# -> Int#
isByteArrayPinned# ByteArray#
ba# of
Int#
1# -> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, b #)) -> IO b)
-> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
case Ptr (Complex a) -> IO b
k (Addr# -> Ptr (Complex a)
forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba#)) of { IO State# RealWorld -> (# State# RealWorld, b #)
k# ->
case State# RealWorld -> (# State# RealWorld, b #)
k# State# RealWorld
s0 of { (# State# RealWorld
s1, b
r #) ->
case ByteArray# -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# ByteArray#
ba# State# RealWorld
s1 of { State# RealWorld
s2 ->
(# State# RealWorld
s2, b
r #) }}}
Int#
_ -> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, b #)) -> IO b)
-> (State# RealWorld -> (# State# RealWorld, b #)) -> IO b
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
case Int#
-> Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
bytes# Int#
16# State# RealWorld
s0 of { (# State# RealWorld
s1, MutableByteArray# RealWorld
mba# #) ->
case Addr#
-> MutableByteArray# RealWorld
-> Int#
-> Int#
-> State# RealWorld
-> State# RealWorld
forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba#) MutableByteArray# RealWorld
mba# Int#
0# Int#
bytes# State# RealWorld
s1 of { State# RealWorld
s2 ->
case MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
mba# State# RealWorld
s2 of { (# State# RealWorld
s3, ByteArray#
ba'# #) ->
case Ptr (Complex a) -> IO b
k (Addr# -> Ptr (Complex a)
forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba'#)) of { IO State# RealWorld -> (# State# RealWorld, b #)
k# ->
case State# RealWorld -> (# State# RealWorld, b #)
k# State# RealWorld
s3 of { (# State# RealWorld
s4, b
r #) ->
case ByteArray# -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# ByteArray#
ba'# State# RealWorld
s4 of { State# RealWorld
s5 ->
(# State# RealWorld
s5, b
r #) }}}}}}
{-# INLINE withArray #-}
withArray
:: NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray :: NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
eR (Array sh
_ ArrayData e
adata) Stream
s DevicePtrs e -> LLVM PTX b
k =
NumericR s e
-> ArrayData e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
forall s e b.
NumericR s e
-> ArrayData e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArrayData NumericR s e
eR ArrayData e
adata Stream
s DevicePtrs e -> LLVM PTX b
k
{-# INLINE withArrayData #-}
withArrayData
:: NumericR s e
-> ArrayData e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArrayData :: NumericR s e
-> ArrayData e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArrayData NumericR s e
NumericRfloat32 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
SingleType Float
-> ArrayData Float
-> (DevicePtr (ScalarArrayDataR Float)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Float => SingleType Float
forall a. IsSingle a => SingleType a
singleType @Float) ArrayData e
ArrayData Float
ad ((DevicePtr (ScalarArrayDataR Float) -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Float)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Float)
p -> do
b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Float)
DevicePtrs e
p
Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
(Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e,b
r)
withArrayData NumericR s e
NumericRfloat64 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
SingleType Double
-> ArrayData Double
-> (DevicePtr (ScalarArrayDataR Double)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Double => SingleType Double
forall a. IsSingle a => SingleType a
singleType @Double) ArrayData e
ArrayData Double
ad ((DevicePtr (ScalarArrayDataR Double) -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Double)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Double)
p -> do
b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Double)
DevicePtrs e
p
Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
(Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e, b
r)
withArrayData NumericR s e
NumericRcomplex32 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
SingleType Float
-> ArrayData Float
-> (DevicePtr (ScalarArrayDataR Float)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Float => SingleType Float
forall a. IsSingle a => SingleType a
singleType @Float) ArrayData e
ArrayData Float
ad ((DevicePtr (ScalarArrayDataR Float) -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Float)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Float)
p -> do
b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Float)
DevicePtrs e
p
Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
(Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e,b
r)
withArrayData NumericR s e
NumericRcomplex64 ArrayData e
ad Stream
s DevicePtrs e -> LLVM PTX b
k =
SingleType Double
-> ArrayData Double
-> (DevicePtr (ScalarArrayDataR Double)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Double => SingleType Double
forall a. IsSingle a => SingleType a
singleType @Double) ArrayData e
ArrayData Double
ad ((DevicePtr (ScalarArrayDataR Double) -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Double)
-> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Double)
p -> do
b
r <- DevicePtrs e -> LLVM PTX b
k DevicePtr (ScalarArrayDataR Double)
DevicePtrs e
p
Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
(Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e, b
r)
{-# INLINE withLifetime' #-}
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' Lifetime a
l a -> LLVM PTX b
k = do
b
r <- a -> LLVM PTX b
k (Lifetime a -> a
forall a. Lifetime a -> a
unsafeGetValue Lifetime a
l)
IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ Lifetime a -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime a
l
b -> LLVM PTX b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r