{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Twine (
interleave,
deinterleave,
) where
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Array.Sugar ( EltRepr, Vector, Z(..), (:.)(..) )
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Control.Concurrent.MVar
import Control.Monad.State
import Data.ByteString ( ByteString )
import Data.FileEmbed
import Data.IntMap.Strict ( IntMap )
import Foreign.Storable.Complex ( )
import System.IO.Unsafe
import qualified Data.IntMap.Strict as IM
import Foreign.CUDA.Ptr ( DevicePtr )
import Foreign.CUDA.Analysis
import qualified Foreign.CUDA.Driver as CUDA
import qualified Foreign.CUDA.Driver.Stream as CUDA
import GHC.Ptr
import GHC.Base
import Prelude hiding ( lookup )
interleave
:: forall e b. Numeric (Complex e)
=> DevicePtrs (EltRepr (Complex e))
-> Stream
-> Int
-> (DevicePtr (Complex e) -> LLVM PTX b)
-> LLVM PTX b
interleave (((), d_re), d_im) s n k = do
case numericR :: NumericR (Complex e) of
nR@NumericRcomplex32 -> do
cplx <- allocateRemote (Z :. n * 2) :: LLVM PTX (Vector Float)
withTwine nR $ \(_,pack,_) -> do
withArray cplx s $ \d_cplx -> do
withLifetime' s $ \s' -> do
liftIO $ launch pack s' n d_cplx d_re d_im
k (CUDA.castDevPtr d_cplx :: DevicePtr (Complex Float))
nR@NumericRcomplex64 -> do
cplx <- allocateRemote (Z :. n * 2) :: LLVM PTX (Vector Double)
withTwine nR $ \(_,pack,_) -> do
withArray cplx s $ \d_cplx -> do
withLifetime' s $ \s' -> do
liftIO $ launch pack s' n d_cplx d_re d_im
k (CUDA.castDevPtr d_cplx :: DevicePtr (Complex Double))
deinterleave
:: forall e. Numeric (Complex e)
=> DevicePtrs (EltRepr (Complex e))
-> DevicePtr (Complex e)
-> Stream
-> Int
-> LLVM PTX ()
deinterleave (((), d_re), d_im) d_cplx s n = do
case numericR :: NumericR (Complex e) of
nR@NumericRcomplex32 -> do
withTwine nR $ \(_,_,unpack) -> do
withLifetime' s $ \s' -> do
liftIO $ launch unpack s' n d_re d_im (CUDA.castDevPtr d_cplx :: DevicePtr Float)
nR@NumericRcomplex64 -> do
withTwine nR $ \(_,_,unpack) -> do
withLifetime' s $ \s' -> do
liftIO $ launch unpack s' n d_re d_im (CUDA.castDevPtr d_cplx :: DevicePtr Double)
withTwine :: NumericR (Complex e) -> ((CUDA.Module, Kernel, Kernel) -> LLVM PTX b) -> LLVM PTX b
withTwine nR k = do
ptx <- gets ptxContext
let lc = deviceContext ptx
prp = deviceProperties ptx
mds = modules nR
mdl <- liftIO $ do
withLifetime lc $ \ctx -> do
modifyMVar mds $ \im -> do
let key = toKey ctx
case IM.lookup key im of
Nothing -> do
mdl <- CUDA.loadData $ case nR of
NumericRcomplex32 -> ptx_twine_f32
NumericRcomplex64 -> ptx_twine_f64
pack <- mkKernel "interleave" mdl prp
unpack <- mkKernel "deinterleave" mdl prp
let mkk = (mdl, pack, unpack)
lm <- newLifetime mkk
addFinalizer lc $ modifyMVar mds (\im' -> return (IM.delete key im', ()))
addFinalizer lm $ CUDA.unload mdl
return ( IM.insert key lm im, lm )
Just lm -> return (im, lm)
withLifetime' mdl k
toKey :: CUDA.Context -> IM.Key
toKey (CUDA.Context (Ptr addr#)) = I# (addr2Int# addr#)
launch :: Kernel -> CUDA.Stream -> Int -> DevicePtr e -> DevicePtr e -> DevicePtr e -> IO ()
launch Kernel{..} s n dx dy dz =
CUDA.launchKernel kernelFun (kernelThreadBlocks n,1,1) (kernelThreadBlockSize,1,1) kernelSharedMemBytes (Just s)
[ CUDA.VArg dx, CUDA.VArg dy, CUDA.VArg dz, CUDA.IArg (fromIntegral n) ]
mkKernel :: String -> CUDA.Module -> CUDA.DeviceProperties -> IO Kernel
mkKernel name mdl prp = do
fun <- CUDA.getFun mdl name
reg <- CUDA.requires fun CUDA.NumRegs
let
blockSize = 256
sharedMem = 0
maxBlocks = maxResidentBlocks prp blockSize reg sharedMem
numBlocks n = maxBlocks `min` ((n + blockSize - 1) `quot` blockSize)
return $ Kernel fun sharedMem blockSize numBlocks name
data Kernel = Kernel {
kernelFun :: {-# UNPACK #-} !CUDA.Fun
, kernelSharedMemBytes :: {-# UNPACK #-} !Int
, kernelThreadBlockSize :: {-# UNPACK #-} !Int
, kernelThreadBlocks :: (Int -> Int)
, kernelName :: String
}
modules :: NumericR (Complex e) -> MVar (IntMap (Lifetime (CUDA.Module, Kernel, Kernel)))
modules NumericRcomplex32 = modules_f32
modules NumericRcomplex64 = modules_f64
{-# NOINLINE modules_f32 #-}
modules_f32 :: MVar (IntMap (Lifetime (CUDA.Module, Kernel, Kernel)))
modules_f32 = unsafePerformIO $ newMVar IM.empty
{-# NOINLINE modules_f64 #-}
modules_f64 :: MVar (IntMap (Lifetime (CUDA.Module, Kernel, Kernel)))
modules_f64 = unsafePerformIO $ newMVar IM.empty
ptx_twine_f32 :: ByteString
ptx_twine_f32 = $(makeRelativeToProject "cubits/twine_f32.ptx" >>= embedFile)
ptx_twine_f64 :: ByteString
ptx_twine_f64 = $(makeRelativeToProject "cubits/twine_f64.ptx" >>= embedFile)