{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.LLVM.PTX.Array.Data (
module Data.Array.Accelerate.LLVM.Array.Data,
module Data.Array.Accelerate.LLVM.PTX.Array.Data,
) where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Array.Unique ( UniqueArray(..) )
import Data.Array.Accelerate.Lifetime ( Lifetime(..) )
import qualified Data.Array.Accelerate.Array.Representation as R
import Data.Array.Accelerate.LLVM.Array.Data
import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.State
import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.LLVM.PTX.Execute.Async
import qualified Data.Array.Accelerate.LLVM.PTX.Array.Prim as Prim
import Control.Applicative
import Control.Monad.State ( liftIO, gets )
import Data.Typeable
import Foreign.Ptr
import Foreign.Storable
import System.IO.Unsafe
import Prelude
instance Remote PTX where
{-# INLINEABLE allocateRemote #-}
allocateRemote !sh = do
arr <- liftIO $ allocateArray sh
runArray arr (\ad -> Prim.mallocArray (size sh) ad >> return ad)
{-# INLINEABLE useRemoteR #-}
useRemoteR !n !mst !ad = do
case mst of
Nothing -> Prim.useArray n ad
Just st -> Prim.useArrayAsync st n ad
{-# INLINEABLE copyToRemoteR #-}
copyToRemoteR !from !to !mst !ad = do
case mst of
Nothing -> Prim.pokeArrayR from to ad
Just st -> Prim.pokeArrayAsyncR st from to ad
{-# INLINEABLE copyToHostR #-}
copyToHostR !from !to !mst !ad = do
case mst of
Nothing -> Prim.peekArrayR from to ad
Just st -> Prim.peekArrayAsyncR st from to ad
{-# INLINEABLE copyToPeerR #-}
copyToPeerR !from !to !dst !mst !ad = do
case mst of
Nothing -> Prim.copyArrayPeerR (ptxContext dst) (ptxMemoryTable dst) from to ad
Just st -> Prim.copyArrayPeerAsyncR (ptxContext dst) (ptxMemoryTable dst) st from to ad
{-# INLINEABLE indexRemote #-}
indexRemote arr i =
runIndexArray Prim.indexArray arr i
copyToHostLazy
:: Arrays arrs
=> arrs
-> LLVM PTX arrs
copyToHostLazy arrs = do
ptx <- gets llvmTarget
liftIO $ runArrays arrs $ \(Array sh adata) ->
let
n :: Int
n = R.size sh
peekR :: (ArrayElt e, ArrayPtrs e ~ Ptr a, Storable a, Typeable a, Typeable e)
=> ArrayData e
-> UniqueArray a
-> IO (UniqueArray a)
peekR ad (UniqueArray uid (Lifetime ref weak fp)) = do
fp' <- unsafeInterleaveIO $
evalPTX ptx $ do
s <- fork
copyToHostR 0 n (Just s) ad
e <- checkpoint s
block e
join s
return fp
return $ UniqueArray uid (Lifetime ref weak fp')
runR :: ArrayEltR e -> ArrayData e -> IO (ArrayData e)
runR ArrayEltRunit AD_Unit = return AD_Unit
runR (ArrayEltRpair aeR2 aeR1) (AD_Pair ad2 ad1) = AD_Pair <$> runR aeR2 ad2 <*> runR aeR1 ad1
runR ArrayEltRint ad@(AD_Int ua) = AD_Int <$> peekR ad ua
runR ArrayEltRint8 ad@(AD_Int8 ua) = AD_Int8 <$> peekR ad ua
runR ArrayEltRint16 ad@(AD_Int16 ua) = AD_Int16 <$> peekR ad ua
runR ArrayEltRint32 ad@(AD_Int32 ua) = AD_Int32 <$> peekR ad ua
runR ArrayEltRint64 ad@(AD_Int64 ua) = AD_Int64 <$> peekR ad ua
runR ArrayEltRword ad@(AD_Word ua) = AD_Word <$> peekR ad ua
runR ArrayEltRword8 ad@(AD_Word8 ua) = AD_Word8 <$> peekR ad ua
runR ArrayEltRword16 ad@(AD_Word16 ua) = AD_Word16 <$> peekR ad ua
runR ArrayEltRword32 ad@(AD_Word32 ua) = AD_Word32 <$> peekR ad ua
runR ArrayEltRword64 ad@(AD_Word64 ua) = AD_Word64 <$> peekR ad ua
runR ArrayEltRcshort ad@(AD_CShort ua) = AD_CShort <$> peekR ad ua
runR ArrayEltRcushort ad@(AD_CUShort ua) = AD_CUShort <$> peekR ad ua
runR ArrayEltRcint ad@(AD_CInt ua) = AD_CInt <$> peekR ad ua
runR ArrayEltRcuint ad@(AD_CUInt ua) = AD_CUInt <$> peekR ad ua
runR ArrayEltRclong ad@(AD_CLong ua) = AD_CLong <$> peekR ad ua
runR ArrayEltRculong ad@(AD_CULong ua) = AD_CULong <$> peekR ad ua
runR ArrayEltRcllong ad@(AD_CLLong ua) = AD_CLLong <$> peekR ad ua
runR ArrayEltRcullong ad@(AD_CULLong ua) = AD_CULLong <$> peekR ad ua
runR ArrayEltRfloat ad@(AD_Float ua) = AD_Float <$> peekR ad ua
runR ArrayEltRdouble ad@(AD_Double ua) = AD_Double <$> peekR ad ua
runR ArrayEltRcfloat ad@(AD_CFloat ua) = AD_CFloat <$> peekR ad ua
runR ArrayEltRcdouble ad@(AD_CDouble ua) = AD_CDouble <$> peekR ad ua
runR ArrayEltRbool ad@(AD_Bool ua) = AD_Bool <$> peekR ad ua
runR ArrayEltRchar ad@(AD_Char ua) = AD_Char <$> peekR ad ua
runR ArrayEltRcchar ad@(AD_CChar ua) = AD_CChar <$> peekR ad ua
runR ArrayEltRcschar ad@(AD_CSChar ua) = AD_CSChar <$> peekR ad ua
runR ArrayEltRcuchar ad@(AD_CUChar ua) = AD_CUChar <$> peekR ad ua
in
Array sh <$> runR arrayElt adata
cloneArrayAsync
:: (Shape sh, Elt e)
=> Stream
-> Array sh e
-> LLVM PTX (Array sh e)
cloneArrayAsync stream arr@(Array _ src) = do
out@(Array _ dst) <- allocateRemote sh
copyR arrayElt src dst
return out
where
sh = shape arr
n = size sh
copyR :: ArrayEltR e -> ArrayData e -> ArrayData e -> LLVM PTX ()
copyR ArrayEltRunit _ _ = return ()
copyR (ArrayEltRpair aeR1 aeR2) ad1 ad2 = copyR aeR1 (fstArrayData ad1) (fstArrayData ad2) >>
copyR aeR2 (sndArrayData ad1) (sndArrayData ad2)
copyR ArrayEltRint ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRint8 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRint16 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRint32 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRint64 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRword ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRword8 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRword16 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRword32 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRword64 ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRfloat ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRdouble ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRbool ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRchar ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcshort ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcushort ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcint ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcuint ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRclong ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRculong ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcllong ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcullong ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcfloat ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcdouble ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcchar ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcschar ad1 ad2 = copyPrim ad1 ad2
copyR ArrayEltRcuchar ad1 ad2 = copyPrim ad1 ad2
copyPrim
:: (ArrayElt e, ArrayPtrs e ~ Ptr a, Typeable e, Storable a, Typeable a)
=> ArrayData e
-> ArrayData e
-> LLVM PTX ()
copyPrim a1 a2 = Prim.copyArrayAsync stream n a1 a2