{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
-- |
-- Module      : Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base
-- Copyright   : [2017..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base
  where

import Data.Array.Accelerate.Math.FFT.Type

import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Data.Array.Accelerate.LLVM.PTX.Foreign

import Foreign.CUDA.Ptr                                             ( DevicePtr )


{-# INLINE withArray #-}
withArray
    :: NumericR e
    -> Array sh (Vec2 e)
    -> Stream
    -> (DevicePtr e -> LLVM PTX b)
    -> LLVM PTX b
withArray :: NumericR e
-> Array sh (Vec2 e)
-> Stream
-> (DevicePtr e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR e
eR (Array sh
_ ArrayData (Vec2 e)
adata) = NumericR e
-> ArrayData (Vec2 e)
-> Stream
-> (DevicePtr e -> LLVM PTX b)
-> LLVM PTX b
forall e b.
NumericR e
-> ArrayData (Vec2 e)
-> Stream
-> (DevicePtr e -> LLVM PTX b)
-> LLVM PTX b
withArrayData NumericR e
eR ArrayData (Vec2 e)
adata

{-# INLINE withArrayData #-}
withArrayData
    :: NumericR e
    -> ArrayData (Vec2 e)
    -> Stream
    -> (DevicePtr e -> LLVM PTX b)
    -> LLVM PTX b
withArrayData :: NumericR e
-> ArrayData (Vec2 e)
-> Stream
-> (DevicePtr e -> LLVM PTX b)
-> LLVM PTX b
withArrayData NumericR e
NumericRfloat32 ArrayData (Vec2 e)
ad Stream
s DevicePtr e -> LLVM PTX b
k =
  SingleType Float
-> ArrayData Float
-> (DevicePtr (ScalarArrayDataR Float)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Float => SingleType Float
forall a. IsSingle a => SingleType a
singleType @Float) ArrayData Float
ArrayData (Vec2 e)
ad ((DevicePtr (ScalarArrayDataR Float) -> LLVM PTX (Maybe Event, b))
 -> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Float)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Float)
p -> do
    b
r <- DevicePtr e -> LLVM PTX b
k DevicePtr e
DevicePtr (ScalarArrayDataR Float)
p
    Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
    (Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e,b
r)
withArrayData NumericR e
NumericRfloat64 ArrayData (Vec2 e)
ad Stream
s DevicePtr e -> LLVM PTX b
k =
  SingleType Double
-> ArrayData Double
-> (DevicePtr (ScalarArrayDataR Double)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
withDevicePtr (IsSingle Double => SingleType Double
forall a. IsSingle a => SingleType a
singleType @Double) ArrayData Double
ArrayData (Vec2 e)
ad ((DevicePtr (ScalarArrayDataR Double) -> LLVM PTX (Maybe Event, b))
 -> LLVM PTX b)
-> (DevicePtr (ScalarArrayDataR Double)
    -> LLVM PTX (Maybe Event, b))
-> LLVM PTX b
forall a b. (a -> b) -> a -> b
$ \DevicePtr (ScalarArrayDataR Double)
p -> do
    b
r <- DevicePtr e -> LLVM PTX b
k DevicePtr e
DevicePtr (ScalarArrayDataR Double)
p
    Event
e <- Stream -> LLVM PTX Event
waypoint Stream
s
    (Maybe Event, b) -> LLVM PTX (Maybe Event, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Event -> Maybe Event
forall a. a -> Maybe a
Just Event
e, b
r)

{-# INLINE withLifetime' #-}
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' Lifetime a
l a -> LLVM PTX b
k = do
  b
r <- a -> LLVM PTX b
k (Lifetime a -> a
forall a. Lifetime a -> a
unsafeGetValue Lifetime a
l)
  IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ Lifetime a -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime a
l
  b -> LLVM PTX b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r