{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Fold
where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar ( Array, Scalar, Vector, Shape, Z, (:.), Elt(..) )
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop as Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target
import LLVM.AST.Type.Representation
import qualified Foreign.CUDA.Analysis as CUDA
import Control.Applicative ( (<$>), (<*>) )
import Control.Monad ( (>=>), (<=<) )
import Data.String ( fromString )
import Data.Bits as P
import Prelude as P
mkFold
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array sh e))
mkFold ptx@(deviceProperties . ptxContext -> dev) aenv f z acc
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= (+++) <$> mkFoldAll dev aenv f (Just z) acc
<*> mkFoldFill ptx aenv z
| otherwise
= (+++) <$> mkFoldDim dev aenv f (Just z) acc
<*> mkFoldFill ptx aenv z
mkFold1
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array sh e))
mkFold1 (deviceProperties . ptxContext -> dev) aenv f acc
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= mkFoldAll dev aenv f Nothing acc
| otherwise
= mkFoldDim dev aenv f Nothing acc
mkFoldAll
:: forall aenv e. Elt e
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Vector e)
-> CodeGen (IROpenAcc PTX aenv (Scalar e))
mkFoldAll dev aenv combine mseed acc =
foldr1 (+++) <$> sequence [ mkFoldAllS dev aenv combine mseed acc
, mkFoldAllM1 dev aenv combine acc
, mkFoldAllM2 dev aenv combine mseed
]
mkFoldAllS
:: forall aenv e. Elt e
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Vector e)
-> CodeGen (IROpenAcc PTX aenv (Scalar e))
mkFoldAllS dev aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Scalar e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem multipleOf multipleOfQ
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "foldAllS" (paramGang ++ paramOut ++ paramEnv) $ do
tid <- threadIdx
bd <- blockDim
i0 <- A.add numType start tid
sz <- A.sub numType end start
when (A.lt scalarType i0 sz) $ do
x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0
r0 <- if A.eq scalarType sz bd
then reduceBlockSMem dev combine Nothing x0
else reduceBlockSMem dev combine (Just sz) x0
when (A.eq scalarType tid (lift 0)) $
writeArray arrOut tid =<<
case mseed of
Nothing -> return r0
Just z -> flip (app2 combine) r0 =<< z
return_
mkFoldAllM1
:: forall aenv e. Elt e
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Vector e)
-> CodeGen (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM1 dev aenv combine IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "foldAllM1" (paramGang ++ paramTmp ++ paramEnv) $ do
tid <- threadIdx
bd <- blockDim
sz <- i32 . indexHead =<< delayedExtent
imapFromTo start end $ \seg -> do
__syncthreads
from <- A.mul numType seg bd
step <- A.add numType from bd
to <- A.min scalarType sz step
reduceFromTo dev from to combine
(app1 delayedLinearIndex <=< A.fromIntegral integralType numType)
(when (A.eq scalarType tid (lift 0)) . writeArray arrTmp seg)
return_
mkFoldAllM2
:: forall aenv e. Elt e
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> CodeGen (IROpenAcc PTX aenv (Scalar e))
mkFoldAllM2 dev aenv combine mseed =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "foldAllM2" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do
tid <- threadIdx
bd <- blockDim
gd <- gridDim
sz <- i32 . indexHead $ irArrayShape arrTmp
imapFromTo start end $ \seg -> do
__syncthreads
from <- A.mul numType seg bd
step <- A.add numType from bd
to <- A.min scalarType sz step
reduceFromTo dev from to combine (readArray arrTmp) $ \r ->
when (A.eq scalarType tid (lift 0)) $
writeArray arrOut seg =<<
case mseed of
Nothing -> return r
Just z -> if A.eq scalarType gd (lift 1)
then flip (app2 combine) r =<< z
else return r
return_
mkFoldDim
:: forall aenv sh e. (Shape sh, Elt e)
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array sh e))
mkFoldDim dev aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array sh e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "fold" (paramGang ++ paramOut ++ paramEnv) $ do
tid <- threadIdx
sz <- i32 . indexHead =<< delayedExtent
when (A.lt scalarType tid sz) $ do
imapFromTo start end $ \seg -> do
__syncthreads
from <- A.mul numType seg sz
to <- A.add numType from sz
i0 <- A.add numType from tid
x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0
bd <- blockDim
r0 <- if A.gte scalarType sz bd
then reduceBlockSMem dev combine Nothing x0
else reduceBlockSMem dev combine (Just sz) x0
next <- A.add numType from bd
r <- iterFromStepTo next bd to r0 $ \offset r -> do
__syncthreads
i <- A.add numType offset tid
v' <- A.sub numType to offset
r' <- if A.gte scalarType v' bd
then do
x <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
y <- reduceBlockSMem dev combine Nothing x
return y
else do
x <- if A.lt scalarType i to
then app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
else return r
y <- reduceBlockSMem dev combine (Just v') x
return y
if A.eq scalarType tid (lift 0)
then app2 combine r r'
else return r'
when (A.eq scalarType tid (lift 0)) $
writeArray arrOut seg =<<
case mseed of
Nothing -> return r
Just z -> flip (app2 combine) r =<< z
return_
mkFoldFill
:: (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRExp PTX aenv e
-> CodeGen (IROpenAcc PTX aenv (Array sh e))
mkFoldFill ptx aenv seed =
mkGenerate ptx aenv (IRFun1 (const seed))
reduceBlockSMem
:: forall aenv e. Elt e
=> DeviceProperties
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IR Int32)
-> IR e
-> CodeGen (IR e)
reduceBlockSMem dev combine size = warpReduce >=> warpAggregate
where
int32 :: Integral a => a -> IR Int32
int32 = lift . P.fromIntegral
bytes = sizeOf (eltType (undefined::e))
warp_smem_elems = CUDA.warpSize dev + (CUDA.warpSize dev `P.quot` 2)
warpReduce :: IR e -> CodeGen (IR e)
warpReduce input = do
wid <- warpId
skip <- A.mul numType wid (int32 (warp_smem_elems * bytes))
smem <- dynamicSharedMem (int32 warp_smem_elems) skip
case size of
Nothing ->
reduceWarpSMem dev combine smem Nothing input
Just n -> do
offset <- A.mul numType wid (int32 (CUDA.warpSize dev))
valid <- A.sub numType n offset
if A.gte scalarType valid (int32 (CUDA.warpSize dev))
then reduceWarpSMem dev combine smem Nothing input
else reduceWarpSMem dev combine smem (Just valid) input
warpAggregate :: IR e -> CodeGen (IR e)
warpAggregate input = do
bd <- blockDim
warps <- A.quot integralType bd (int32 (CUDA.warpSize dev))
skip <- A.mul numType warps (int32 (warp_smem_elems * bytes))
smem <- dynamicSharedMem warps skip
wid <- warpId
lane <- laneId
when (A.eq scalarType lane (lift 0)) $ do
writeArray smem wid input
__syncthreads
tid <- threadIdx
if A.eq scalarType tid (lift 0)
then do
steps <- case size of
Nothing -> return warps
Just n -> do
a <- A.add numType n (int32 (CUDA.warpSize dev - 1))
b <- A.quot integralType a (int32 (CUDA.warpSize dev))
return b
iterFromStepTo (lift 1) (lift 1) steps input $ \step x ->
app2 combine x =<< readArray smem step
else
return input
reduceWarpSMem
:: forall aenv e. Elt e
=> DeviceProperties
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> Maybe (IR Int32)
-> IR e
-> CodeGen (IR e)
reduceWarpSMem dev combine smem size = reduce 0
where
log2 :: Double -> Double
log2 = P.logBase 2
steps = P.floor . log2 . P.fromIntegral . CUDA.warpSize $ dev
valid i =
case size of
Nothing -> return (lift True)
Just n -> A.lt scalarType i n
reduce :: Int -> IR e -> CodeGen (IR e)
reduce step x
| step >= steps = return x
| offset <- 1 `P.shiftL` step = do
lane <- laneId
writeArray smem lane x
i <- A.add numType lane (lift offset)
x' <- if valid i
then app2 combine x =<< readArray smem i
else return x
reduce (step+1) x'
reduceFromTo
:: Elt a
=> DeviceProperties
-> IR Int32
-> IR Int32
-> (IRFun2 PTX aenv (a -> a -> a))
-> (IR Int32 -> CodeGen (IR a))
-> (IR a -> CodeGen ())
-> CodeGen ()
reduceFromTo dev from to combine get set = do
tid <- threadIdx
bd <- blockDim
valid <- A.sub numType to from
i <- A.add numType from tid
_ <- if A.gte scalarType valid bd
then do
x <- get i
r <- reduceBlockSMem dev combine Nothing x
set r
return (IR OP_Unit :: IR ())
else do
when (A.lt scalarType i to) $ do
x <- get i
r <- reduceBlockSMem dev combine (Just valid) x
set r
return (IR OP_Unit :: IR ())
return ()
i32 :: IR Int -> CodeGen (IR Int32)
i32 = A.fromIntegral integralType numType
imapFromTo
:: IR Int32
-> IR Int32
-> (IR Int32 -> CodeGen ())
-> CodeGen ()
imapFromTo start end body = do
bid <- blockIdx
gd <- gridDim
i0 <- A.add numType start bid
imapFromStepTo i0 gd end body