{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.FoldSeg
where
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar ( Array, Segments, Shape(rank), (:.), Elt(..) )
import LLVM.AST.Type.Representation
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop as Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Fold ( reduceBlockSMem, reduceWarpSMem, imapFromTo )
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target
import qualified Foreign.CUDA.Analysis as CUDA
import Control.Applicative ( (<$>), (<*>) )
import Control.Monad ( void )
import Data.String ( fromString )
import Prelude as P
mkFoldSeg
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFoldSeg (deviceProperties . ptxContext -> dev) aenv combine seed arr seg =
(+++) <$> mkFoldSegP_block dev aenv combine (Just seed) arr seg
<*> mkFoldSegP_warp dev aenv combine (Just seed) arr seg
mkFold1Seg
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFold1Seg (deviceProperties . ptxContext -> dev) aenv combine arr seg =
(+++) <$> mkFoldSegP_block dev aenv combine Nothing arr seg
<*> mkFoldSegP_warp dev aenv combine Nothing arr seg
mkFoldSegP_block
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFoldSegP_block dev aenv combine mseed arr seg =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh :. Int) e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.decWarp dev) dsmem const [|| const ||]
dsmem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "foldSeg_block" (paramGang ++ paramOut ++ paramEnv) $ do
smem <- staticSharedMem 2
sz <- i32 . indexHead =<< delayedExtent arr
ss <- do n <- i32 . indexHead =<< delayedExtent seg
A.sub numType n (lift 1)
imapFromTo start end $ \s -> do
tid <- threadIdx
when (A.lt scalarType tid (lift 2)) $ do
i <- case rank (undefined::sh) of
0 -> return s
_ -> A.rem integralType s ss
j <- A.add numType i tid
v <- app1 (delayedLinearIndex seg) =<< A.fromIntegral integralType numType j
writeArray smem tid =<< i32 v
__syncthreads
u <- readArray smem (lift 0 :: IR Int32)
v <- readArray smem (lift 1 :: IR Int32)
(inf,sup) <- A.unpair <$> case rank (undefined::sh) of
0 -> return (A.pair u v)
_ -> do q <- A.quot integralType s ss
a <- A.mul numType q sz
A.pair <$> A.add numType u a
<*> A.add numType v a
void $
if A.eq scalarType inf sup
then do
case mseed of
Nothing -> return (IR OP_Unit :: IR ())
Just z -> do
when (A.eq scalarType tid (lift 0)) $ writeArray arrOut s =<< z
return (IR OP_Unit)
else do
bd <- blockDim
i0 <- A.add numType inf tid
x0 <- if A.lt scalarType i0 sup
then app1 (delayedLinearIndex arr) =<< A.fromIntegral integralType numType i0
else let
go :: TupleType a -> Operands a
go UnitTuple = OP_Unit
go (PairTuple a b) = OP_Pair (go a) (go b)
go (SingleTuple t) = ir' t (undef t)
in
return . IR $ go (eltType (undefined::e))
v0 <- A.sub numType sup inf
r0 <- if A.gte scalarType v0 bd
then reduceBlockSMem dev combine Nothing x0
else reduceBlockSMem dev combine (Just v0) x0
nxt <- A.add numType inf bd
r <- iterFromStepTo nxt bd sup r0 $ \offset r -> do
__syncthreads
i' <- A.add numType offset tid
v' <- A.sub numType sup offset
r' <- if A.gte scalarType v' bd
then do
x <- app1 (delayedLinearIndex arr) =<< A.fromIntegral integralType numType i'
y <- reduceBlockSMem dev combine Nothing x
return y
else do
x <- if A.lt scalarType i' sup
then app1 (delayedLinearIndex arr) =<< A.fromIntegral integralType numType i'
else return r
y <- reduceBlockSMem dev combine (Just v') x
return y
if A.eq scalarType tid (lift 0)
then app2 combine r r'
else return r'
when (A.eq scalarType tid (lift 0)) $
writeArray arrOut s =<<
case mseed of
Nothing -> return r
Just z -> flip (app2 combine) r =<< z
return (IR OP_Unit)
return_
mkFoldSegP_warp
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Array (sh :. Int) e)
-> IRDelayed PTX aenv (Segments i)
-> CodeGen (IROpenAcc PTX aenv (Array (sh :. Int) e))
mkFoldSegP_warp dev aenv combine mseed arr seg =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh :. Int) e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.decWarp dev) dsmem grid gridQ
dsmem n = warps * (2 + per_warp_elems) * bytes
where
warps = n `P.quot` ws
grid n m = multipleOf n (m `P.quot` ws)
gridQ = [|| \n m -> $$multipleOfQ n (m `P.quot` ws) ||]
per_warp_bytes = per_warp_elems * bytes
per_warp_elems = ws + (ws `P.quot` 2)
ws = CUDA.warpSize dev
bytes = sizeOf (eltType (undefined :: e))
int32 :: Integral a => a -> IR Int32
int32 = lift . P.fromIntegral
in
makeOpenAccWith config "foldSeg_warp" (paramGang ++ paramOut ++ paramEnv) $ do
tid <- threadIdx
wid <- A.quot integralType tid (int32 ws)
bd <- blockDim
wpb <- A.quot integralType bd (int32 ws)
bid <- blockIdx
gwid <- do a <- A.mul numType bid wpb
b <- A.add numType wid a
return b
lim <- do
a <- A.mul numType wid (int32 (2 * bytes))
b <- dynamicSharedMem (lift 2) a
return b
smem <- do
a <- A.mul numType wpb (int32 (2 * bytes))
b <- A.mul numType wid (int32 per_warp_bytes)
c <- A.add numType a b
d <- dynamicSharedMem (int32 per_warp_elems) c
return d
sz <- i32 . indexHead =<< delayedExtent arr
ss <- do a <- i32 . indexHead =<< delayedExtent seg
b <- A.sub numType a (lift 1)
return b
s0 <- A.add numType start gwid
gd <- gridDim
step <- A.mul numType wpb gd
imapFromStepTo s0 step end $ \s -> do
lane <- laneId
when (A.lt scalarType lane (lift 2)) $ do
a <- case rank (undefined::sh) of
0 -> return s
_ -> A.rem integralType s ss
b <- A.add numType a lane
c <- app1 (delayedLinearIndex seg) =<< A.fromIntegral integralType numType b
writeArray lim lane =<< i32 c
(inf,sup) <- do
u <- readArray lim (lift 0 :: IR Int32)
v <- readArray lim (lift 1 :: IR Int32)
A.unpair <$> case rank (undefined::sh) of
0 -> return (A.pair u v)
_ -> do q <- A.quot integralType s ss
a <- A.mul numType q sz
A.pair <$> A.add numType u a <*> A.add numType v a
__syncthreads
void $
if A.eq scalarType inf sup
then do
case mseed of
Nothing -> return (IR OP_Unit :: IR ())
Just z -> do
when (A.eq scalarType lane (lift 0)) $ writeArray arrOut s =<< z
return (IR OP_Unit)
else do
i0 <- A.add numType inf lane
x0 <- if A.lt scalarType i0 sup
then app1 (delayedLinearIndex arr) =<< A.fromIntegral integralType numType i0
else let
go :: TupleType a -> Operands a
go UnitTuple = OP_Unit
go (PairTuple a b) = OP_Pair (go a) (go b)
go (SingleTuple t) = ir' t (undef t)
in
return . IR $ go (eltType (undefined::e))
v0 <- A.sub numType sup inf
r0 <- if A.gte scalarType v0 (int32 ws)
then reduceWarpSMem dev combine smem Nothing x0
else reduceWarpSMem dev combine smem (Just v0) x0
nx <- A.add numType inf (int32 ws)
r <- iterFromStepTo nx (int32 ws) sup r0 $ \offset r -> do
__syncthreads
i' <- A.add numType offset lane
v' <- A.sub numType sup offset
r' <- if A.gte scalarType v' (int32 ws)
then do
x <- app1 (delayedLinearIndex arr) =<< A.fromIntegral integralType numType i'
y <- reduceWarpSMem dev combine smem Nothing x
return y
else do
x <- if A.lt scalarType i' sup
then app1 (delayedLinearIndex arr) =<< A.fromIntegral integralType numType i'
else return r
y <- reduceWarpSMem dev combine smem (Just v') x
return y
if A.eq scalarType lane (lift 0)
then app2 combine r r'
else return r'
when (A.eq scalarType lane (lift 0)) $
writeArray arrOut s =<<
case mseed of
Nothing -> return r
Just z -> flip (app2 combine) r =<< z
return (IR OP_Unit)
return_
i32 :: IsIntegral a => IR a -> CodeGen (IR Int32)
i32 = A.fromIntegral integralType numType