{-# LANGUAGE TypeFamilies #-} {- | Simple manual implementation of embedding cufft functionality in the @accelerate@ framework. In this example, a plan is created for every transform and is run within 'CUDA.run1'. -} module Main where import qualified Data.Array.Accelerate.LLVM.PTX.Foreign as AF import qualified Data.Array.Accelerate.LLVM.PTX as PTX import qualified Data.Array.Accelerate.Array.Sugar as Sugar import qualified Foreign.CUDA.FFT as CUFFT import qualified Foreign.CUDA.Driver as CUDA import Foreign.Storable (Storable) import Foreign.Ptr (Ptr) import qualified Data.Array.Accelerate as A import Data.Array.Accelerate (Acc, Vector, Z(Z), (:.)((:.)), ) transformForeign :: AF.Stream -> Vector Float -> AF.LLVM AF.PTX (Vector Float) transformForeign stream input = let (Z:.inlen) = A.arrayShape input outlen = (div inlen 2 + 1) * 2 in do output <- AF.allocateRemote (Z:.outlen) withArray input stream $ \iptr -> withArray output stream $ \optr -> AF.liftIO $ do h <- CUFFT.plan1D inlen CUFFT.R2C 1 CUFFT.execR2C h iptr optr return output withArray :: (Sugar.EltRepr e ~ er, AF.ArrayPtrs er ~ Ptr er, Storable er) => A.Array sh e -> AF.Stream -> (CUDA.DevicePtr er -> AF.LLVM AF.PTX r) -> AF.LLVM AF.PTX r withArray (Sugar.Array _ adata) s k = AF.withDevicePtr adata $ \p -> do r <- k p e <- AF.checkpoint s return (Just e, r) transform :: Acc (Vector Float) -> Acc (Vector Float) transform = A.foreignAcc (AF.ForeignAcc "transformForeign" transformForeign) (error "no fft fallback implemented") main :: IO () main = print $ PTX.run1 transform $ A.fromList (Z:.5) [1,0,0,0,0]