{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.LLVM.PTX.Execute (
executeAcc, executeAfun,
executeOpenAcc,
) where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.Execute
import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch ( multipleOf )
import Data.Array.Accelerate.LLVM.PTX.Array.Data
import Data.Array.Accelerate.LLVM.PTX.Array.Prim ( memsetArrayAsync )
import Data.Array.Accelerate.LLVM.PTX.Execute.Async
import Data.Array.Accelerate.LLVM.PTX.Execute.Environment
import Data.Array.Accelerate.LLVM.PTX.Execute.Marshal
import Data.Array.Accelerate.LLVM.PTX.Link
import Data.Array.Accelerate.LLVM.PTX.Target
import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug
import Data.Range.Range ( Range(..) )
import Control.Parallel.Meta ( runExecutable )
import qualified Foreign.CUDA.Driver as CUDA
import Control.Monad ( when )
import Control.Monad.State ( gets, liftIO )
import Data.ByteString.Short.Char8 ( ShortByteString, unpack )
import Data.Int ( Int32 )
import Data.List ( find )
import Data.Maybe ( fromMaybe )
import Data.Word ( Word32 )
import Text.Printf ( printf )
import Prelude hiding ( exp, map, sum, scanl, scanr )
import qualified Prelude as P
instance Execute PTX where
map = simpleOp
generate = simpleOp
transform = simpleOp
backpermute = simpleOp
fold = foldOp
fold1 = fold1Op
foldSeg = foldSegOp
fold1Seg = foldSegOp
scanl = scanOp
scanl1 = scan1Op
scanl' = scan'Op
scanr = scanOp
scanr1 = scan1Op
scanr' = scan'Op
permute = permuteOp
stencil1 = stencil1Op
stencil2 = stencil2Op
simpleOp
:: (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh
-> LLVM PTX (Array sh e)
simpleOp exe gamma aenv stream sh = withExecutable exe $ \ptxExecutable -> do
let kernel = case functionTable ptxExecutable of
k:_ -> k
_ -> $internalError "simpleOp" "no kernels found"
out <- allocateRemote sh
ptx <- gets llvmTarget
liftIO $ executeOp ptx kernel gamma aenv stream (IE 0 (size sh)) out
return out
simpleNamed
:: (Shape sh, Elt e)
=> ShortByteString
-> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh
-> LLVM PTX (Array sh e)
simpleNamed fun exe gamma aenv stream sh = withExecutable exe $ \ptxExecutable -> do
out <- allocateRemote sh
ptx <- gets llvmTarget
liftIO $ executeOp ptx (ptxExecutable !# fun) gamma aenv stream (IE 0 (size sh)) out
return out
fold1Op
:: (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> (sh :. Int)
-> LLVM PTX (Array sh e)
fold1Op exe gamma aenv stream sh@(sx :. sz)
= $boundsCheck "fold1" "empty array" (sz > 0)
$ case size sh of
0 -> allocateRemote sx
_ -> foldCore exe gamma aenv stream sh
foldOp
:: (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> (sh :. Int)
-> LLVM PTX (Array sh e)
foldOp exe gamma aenv stream sh@(sx :. _)
= case size sh of
0 -> simpleNamed "generate" exe gamma aenv stream (listToShape (P.map (max 1) (shapeToList sx)))
_ -> foldCore exe gamma aenv stream sh
foldCore
:: forall aenv sh e. (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> (sh :. Int)
-> LLVM PTX (Array sh e)
foldCore exe gamma aenv stream sh
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldAllOp exe gamma aenv stream sh
| otherwise
= foldDimOp exe gamma aenv stream sh
foldAllOp
:: forall aenv e. Elt e
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> DIM1
-> LLVM PTX (Scalar e)
foldAllOp exe gamma aenv stream (Z :. n) = withExecutable exe $ \ptxExecutable -> do
ptx <- gets llvmTarget
let
ks = ptxExecutable !# "foldAllS"
km1 = ptxExecutable !# "foldAllM1"
km2 = ptxExecutable !# "foldAllM2"
if kernelThreadBlocks ks n == 1
then do
out <- allocateRemote Z
liftIO $ executeOp ptx ks gamma aenv stream (IE 0 n) out
return out
else do
let
rec :: Vector e -> LLVM PTX (Scalar e)
rec tmp@(Array ((),m) adata)
| m <= 1 = return $ Array () adata
| otherwise = do
let s = m `multipleOf` kernelThreadBlockSize km2
out <- allocateRemote (Z :. s)
liftIO $ executeOp ptx km2 gamma aenv stream (IE 0 s) (tmp, out)
rec out
let s = n `multipleOf` kernelThreadBlockSize km1
tmp <- allocateRemote (Z :. s)
liftIO $ executeOp ptx km1 gamma aenv stream (IE 0 s) tmp
rec tmp
foldDimOp
:: forall aenv sh e. (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> (sh :. Int)
-> LLVM PTX (Array sh e)
foldDimOp exe gamma aenv stream (sh :. sz) = withExecutable exe $ \ptxExecutable -> do
let
kernel = if sz > 0
then ptxExecutable !# "fold"
else ptxExecutable !# "generate"
out <- allocateRemote sh
ptx <- gets llvmTarget
liftIO $ executeOp ptx kernel gamma aenv stream (IE 0 (size sh)) out
return out
foldSegOp
:: (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> (sh :. Int)
-> (Z :. Int)
-> LLVM PTX (Array (sh :. Int) e)
foldSegOp exe gamma aenv stream (sh :. sz) (Z :. ss) = withExecutable exe $ \ptxExecutable -> do
let
n = ss - 1
m = size sh * n
foldseg = if (sz`quot`ss) < (2 * kernelThreadBlockSize foldseg_cta)
then foldseg_warp
else foldseg_cta
foldseg_cta = ptxExecutable !# "foldSeg_block"
foldseg_warp = ptxExecutable !# "foldSeg_warp"
out <- allocateRemote (sh :. n)
ptx <- gets llvmTarget
liftIO $ executeOp ptx foldseg gamma aenv stream (IE 0 m) out
return out
scanOp
:: (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh :. Int
-> LLVM PTX (Array (sh:.Int) e)
scanOp exe gamma aenv stream (sz :. n) =
case n of
0 -> simpleNamed "generate" exe gamma aenv stream (sz :. 1)
_ -> scanCore exe gamma aenv stream sz n (n+1)
scan1Op
:: (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh :. Int
-> LLVM PTX (Array (sh:.Int) e)
scan1Op exe gamma aenv stream (sz :. n)
= $boundsCheck "scan1" "empty array" (n > 0)
$ scanCore exe gamma aenv stream sz n n
scanCore
:: forall aenv sh e. (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh
-> Int
-> Int
-> LLVM PTX (Array (sh:.Int) e)
scanCore exe gamma aenv stream sz n m
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= scanAllOp exe gamma aenv stream n m
| otherwise
= scanDimOp exe gamma aenv stream sz m
scanAllOp
:: forall aenv e. Elt e
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> Int
-> Int
-> LLVM PTX (Vector e)
scanAllOp exe gamma aenv stream n m = withExecutable exe $ \ptxExecutable -> do
let
k1 = ptxExecutable !# "scanP1"
k2 = ptxExecutable !# "scanP2"
k3 = ptxExecutable !# "scanP3"
c = kernelThreadBlockSize k1
s = n `multipleOf` c
ptx <- gets llvmTarget
out <- allocateRemote (Z :. m)
tmp <- allocateRemote (Z :. s) :: LLVM PTX (Vector e)
liftIO $ executeOp ptx k1 gamma aenv stream (IE 0 s) (tmp, out)
when (s > 1) $ do
liftIO $ executeOp ptx k2 gamma aenv stream (IE 0 s) tmp
liftIO $ executeOp ptx k3 gamma aenv stream (IE 0 (s-1)) (tmp, out, i32 c)
return out
scanDimOp
:: forall aenv sh e. (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh
-> Int
-> LLVM PTX (Array (sh:.Int) e)
scanDimOp exe gamma aenv stream sz m = withExecutable exe $ \ptxExecutable -> do
ptx <- gets llvmTarget
out <- allocateRemote (sz :. m)
liftIO $ executeOp ptx (ptxExecutable !# "scan") gamma aenv stream (IE 0 (size sz)) out
return out
scan'Op
:: forall aenv sh e. (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh :. Int
-> LLVM PTX (Array (sh:.Int) e, Array sh e)
scan'Op exe gamma aenv stream sh@(sz :. n) =
case n of
0 -> do out <- allocateRemote (sz :. 0)
sum <- simpleNamed "generate" exe gamma aenv stream sz
return (out, sum)
_ -> scan'Core exe gamma aenv stream sh
scan'Core
:: forall aenv sh e. (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh :. Int
-> LLVM PTX (Array (sh:.Int) e, Array sh e)
scan'Core exe gamma aenv stream sh
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= scan'AllOp exe gamma aenv stream sh
| otherwise
= scan'DimOp exe gamma aenv stream sh
scan'AllOp
:: forall aenv e. Elt e
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> DIM1
-> LLVM PTX (Vector e, Scalar e)
scan'AllOp exe gamma aenv stream (Z :. n) = withExecutable exe $ \ptxExecutable -> do
let
k1 = ptxExecutable !# "scanP1"
k2 = ptxExecutable !# "scanP2"
k3 = ptxExecutable !# "scanP3"
c = kernelThreadBlockSize k1
s = n `multipleOf` c
ptx <- gets llvmTarget
out <- allocateRemote (Z :. n)
tmp <- allocateRemote (Z :. s) :: LLVM PTX (Vector e)
liftIO $ executeOp ptx k1 gamma aenv stream (IE 0 s) (tmp, out)
if s == 1
then case tmp of
Array _ ad -> return (out, Array () ad)
else do
sum <- allocateRemote Z
liftIO $ executeOp ptx k2 gamma aenv stream (IE 0 s) (tmp, sum)
liftIO $ executeOp ptx k3 gamma aenv stream (IE 0 (s-1)) (tmp, out, i32 c)
return (out, sum)
scan'DimOp
:: forall aenv sh e. (Shape sh, Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> sh :. Int
-> LLVM PTX (Array (sh:.Int) e, Array sh e)
scan'DimOp exe gamma aenv stream sh@(sz :. _) = withExecutable exe $ \ptxExecutable -> do
ptx <- gets llvmTarget
out <- allocateRemote sh
sum <- allocateRemote sz
liftIO $ executeOp ptx (ptxExecutable !# "scan") gamma aenv stream (IE 0 (size sz)) (out,sum)
return (out,sum)
permuteOp
:: (Shape sh, Shape sh', Elt e)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> Bool
-> sh
-> Array sh' e
-> LLVM PTX (Array sh' e)
permuteOp exe gamma aenv stream inplace shIn dfs = withExecutable exe $ \ptxExecutable -> do
let
n = size shIn
m = size (shape dfs)
kernel = case functionTable ptxExecutable of
k:_ -> k
_ -> $internalError "permute" "no kernels found"
ptx <- gets llvmTarget
out <- if inplace
then return dfs
else cloneArrayAsync stream dfs
case kernelName kernel of
"permute_rmw" -> liftIO $ executeOp ptx kernel gamma aenv stream (IE 0 n) out
"permute_mutex" -> do
barrier@(Array _ ad) <- allocateRemote (Z :. m) :: LLVM PTX (Vector Word32)
memsetArrayAsync stream m 0 ad
liftIO $ executeOp ptx kernel gamma aenv stream (IE 0 n) (out, barrier)
_ -> $internalError "permute" "unexpected kernel image"
return out
stencil1Op
:: (Shape sh, Elt b)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> Array sh a
-> LLVM PTX (Array sh b)
stencil1Op exe gamma aenv stream arr =
simpleOp exe gamma aenv stream (shape arr)
stencil2Op
:: (Shape sh, Elt c)
=> ExecutableR PTX
-> Gamma aenv
-> Aval aenv
-> Stream
-> Array sh a
-> Array sh b
-> LLVM PTX (Array sh c)
stencil2Op exe gamma aenv stream arr brr =
simpleOp exe gamma aenv stream (shape arr `intersect` shape brr)
defaultPPT :: Int
defaultPPT = 32768
{-# INLINE i32 #-}
i32 :: Int -> Int32
i32 = fromIntegral
(!#) :: FunctionTable -> ShortByteString -> Kernel
(!#) exe name
= fromMaybe ($internalError "lookupFunction" ("function not found: " ++ unpack name))
$ lookupKernel name exe
lookupKernel :: ShortByteString -> FunctionTable -> Maybe Kernel
lookupKernel name ptxExecutable =
find (\k -> kernelName k == name) (functionTable ptxExecutable)
executeOp
:: Marshalable args
=> PTX
-> Kernel
-> Gamma aenv
-> Aval aenv
-> Stream
-> Range
-> args
-> IO ()
executeOp ptx@PTX{..} kernel@Kernel{..} gamma aenv stream r args =
runExecutable fillP kernelName defaultPPT r $ \start end _ -> do
argv <- marshal ptx stream (i32 start, i32 end, args, (gamma,aenv))
launch kernel stream (end-start) argv
launch :: Kernel -> Stream -> Int -> [CUDA.FunParam] -> IO ()
launch Kernel{..} stream n args =
when (n > 0) $
withLifetime stream $ \st ->
Debug.monitorProcTime query msg (Just st) $
CUDA.launchKernel kernelFun grid cta smem (Just st) args
where
cta = (kernelThreadBlockSize, 1, 1)
grid = (kernelThreadBlocks n, 1, 1)
smem = kernelSharedMemBytes
query = if Debug.monitoringIsEnabled
then return True
else Debug.queryFlag Debug.dump_exec
fst3 (x,_,_) = x
msg wall cpu gpu = do
Debug.addProcessorTime Debug.PTX gpu
Debug.traceIO Debug.dump_exec $
printf "exec: %s <<< %d, %d, %d >>> %s"
(unpack kernelName) (fst3 grid) (fst3 cta) smem (Debug.elapsed wall cpu gpu)