module Futhark.CodeGen.ImpGen.Kernels.Transpose
( TransposeType(..)
, TransposeArgs
, mapTranspose
, mapTransposeKernel
)
where
import qualified Data.Set as S
import Prelude hiding (quot, rem)
import Futhark.CodeGen.ImpCode.Kernels
import Futhark.Representation.AST.Attributes.Types
import Futhark.Representation.AST.Attributes.Names (freeIn)
import Futhark.Util.IntegralExp (IntegralExp, quot, rem, quotRoundingUp)
data TransposeType = TransposeNormal
| TransposeLowWidth
| TransposeLowHeight
| TransposeSmall
deriving (Eq, Ord, Show)
type TransposeArgs = (VName, Exp,
VName, Exp,
Exp, Exp, Exp, Exp,
Exp, Exp, Exp,
VName)
elemsPerThread :: IntegralExp a => a
elemsPerThread = 4
mapTranspose :: Exp -> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose block_dim args t kind =
case kind of
TransposeSmall ->
mconcat
[ get_ids
, dec our_array_offset $ v32 get_global_id_0 `quot` (height*width) * (height*width)
, dec x_index $ (v32 get_global_id_0 `rem` (height*width)) `quot` height
, dec y_index $ v32 get_global_id_0 `rem` height
, dec odata_offset $
(basic_odata_offset `quot` primByteSize t) + v32 our_array_offset
, dec idata_offset $
(basic_idata_offset `quot` primByteSize t) + v32 our_array_offset
, dec index_in $ v32 y_index * width + v32 x_index
, dec index_out $ v32 x_index * height + v32 y_index
, If (v32 get_global_id_0 .<. input_size)
(Write odata (bytes $ (v32 odata_offset + v32 index_out) * tsize) t (Space "global") Nonvolatile $
index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize) t (Space "global") Nonvolatile)
mempty
]
TransposeLowWidth ->
mkTranspose $ lowDimBody
(v32 get_group_id_0 * block_dim + (v32 get_local_id_0 `quot` muly))
(v32 get_group_id_1 * block_dim * muly + v32 get_local_id_1 +
(v32 get_local_id_0 `rem` muly) * block_dim)
(v32 get_group_id_1* block_dim * muly + v32 get_local_id_0 +
(v32 get_local_id_1 `rem` muly) * block_dim)
(v32 get_group_id_0 * block_dim + (v32 get_local_id_1 `quot` muly))
TransposeLowHeight ->
mkTranspose $ lowDimBody
(v32 get_group_id_0 * block_dim * mulx + v32 get_local_id_0 +
(v32 get_local_id_1 `rem` mulx) * block_dim)
(v32 get_group_id_1 * block_dim + (v32 get_local_id_1 `quot` mulx))
(v32 get_group_id_1 * block_dim + (v32 get_local_id_0 `quot` mulx))
(v32 get_group_id_0 * block_dim * mulx + v32 get_local_id_1 +
(v32 get_local_id_0 `rem` mulx) * block_dim)
TransposeNormal ->
mkTranspose $ mconcat
[ dec x_index $ v32 get_global_id_0
, dec y_index $ v32 get_group_id_1 * tile_dim + v32 get_local_id_1
, when (v32 x_index .<. width) $
For j Int32 elemsPerThread $
let i = v32 j * (tile_dim `quot` elemsPerThread)
in mconcat [ dec index_in $ (v32 y_index + i) * width + v32 x_index
, when (v32 y_index + i .<. height .&&.
v32 index_in .<. input_size) $
Write block (bytes $ ((v32 get_local_id_1 + i) * (tile_dim+1)
+ v32 get_local_id_0) * tsize)
t (Space "local") Nonvolatile $
index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize)
t (Space "global") Nonvolatile]
, Op LocalBarrier
, SetScalar x_index $ v32 get_group_id_1 * tile_dim + v32 get_local_id_0
, SetScalar y_index $ v32 get_group_id_0 * tile_dim + v32 get_local_id_1
, when (v32 x_index .<. height) $
For j Int32 elemsPerThread $
let i = v32 j * (tile_dim `quot` elemsPerThread)
in mconcat [ dec index_out $ (v32 y_index + i) * height + v32 x_index
, when (v32 y_index + i .<. width .&&.
v32 index_out .<. output_size) $
Write odata (bytes $ (v32 odata_offset + v32 index_out) * tsize)
t (Space "global") Nonvolatile $
index block (bytes $ (v32 get_local_id_0 * (tile_dim+1)
+v32 get_local_id_1+i)*tsize)
t (Space "local") Nonvolatile
]
]
where dec v e = DeclareScalar v int32 <> SetScalar v e
v32 = flip var int32
tsize = LeafExp (SizeOf t) int32
tile_dim = 2 * block_dim
when a b = If a b mempty
(odata, basic_odata_offset, idata, basic_idata_offset,
width, height, input_size, output_size,
mulx, muly, _num_arrays, block) = args
[ our_array_offset , x_index , y_index
, odata_offset, idata_offset, index_in, index_out
, get_global_id_0
, get_local_id_0, get_local_id_1
, get_group_id_0, get_group_id_1, get_group_id_2
, j] =
zipWith (flip VName) [30..] $ map nameFromString
[ "our_array_offset" , "x_index" , "y_index"
, "odata_offset", "idata_offset", "index_in", "index_out"
, "get_global_id_0"
, "get_local_id_0", "get_local_id_1"
, "get_group_id_0", "get_group_id_1", "get_group_id_2"
, "j"]
get_ids =
mconcat [ DeclareScalar get_global_id_0 int32
, Op $ GetGlobalId get_global_id_0 0
, DeclareScalar get_local_id_0 int32
, Op $ GetLocalId get_local_id_0 0
, DeclareScalar get_local_id_1 int32
, Op $ GetLocalId get_local_id_1 1
, DeclareScalar get_group_id_0 int32
, Op $ GetGroupId get_group_id_0 0
, DeclareScalar get_group_id_1 int32
, Op $ GetGroupId get_group_id_1 1
, DeclareScalar get_group_id_2 int32
, Op $ GetGroupId get_group_id_2 2
]
mkTranspose body =
mconcat
[ get_ids
, dec our_array_offset $ v32 get_group_id_2 * width * height
, dec odata_offset $
(basic_odata_offset `quot` primByteSize t) + v32 our_array_offset
, dec idata_offset $
(basic_idata_offset `quot` primByteSize t) + v32 our_array_offset
, body
]
lowDimBody x_in_index y_in_index x_out_index y_out_index =
mconcat
[ dec x_index x_in_index
, dec y_index y_in_index
, dec index_in $ v32 y_index * width + v32 x_index
, when (v32 x_index .<. width .&&. v32 y_index .<. height .&&. v32 index_in .<. input_size) $
Write block (bytes $ (v32 get_local_id_1 * (block_dim+1) + v32 get_local_id_0) * tsize)
t (Space "local") Nonvolatile $
index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize)
t (Space "global") Nonvolatile
, Op LocalBarrier
, SetScalar x_index x_out_index
, SetScalar y_index y_out_index
, dec index_out $ v32 y_index * height + v32 x_index
, when (v32 x_index .<. height .&&. v32 y_index .<. width .&&. v32 index_out .<. output_size) $
Write odata (bytes $ (v32 odata_offset + v32 index_out) * tsize)
t (Space "global") Nonvolatile $
index block (bytes $ (v32 get_local_id_0 * (block_dim+1)
+v32 get_local_id_1)*tsize)
t (Space "local") Nonvolatile
]
mapTransposeKernel :: String -> Integer -> TransposeArgs -> PrimType -> TransposeType
-> Kernel
mapTransposeKernel desc block_dim_int args t kind =
Kernel
{ kernelBody = mapTranspose block_dim args t kind
, kernelLocalMemory = [(block, Right block_size)]
, kernelUses = uses
, kernelNumGroups = num_groups
, kernelGroupSize = group_size
, kernelName = nameFromString name
}
where pad2DBytes k = k * (k + 1) * primByteSize t
block_size =
case kind of TransposeSmall -> 1
TransposeNormal -> fromInteger $ pad2DBytes $ 2*block_dim_int
TransposeLowWidth -> fromInteger $ pad2DBytes block_dim_int
TransposeLowHeight -> fromInteger $ pad2DBytes block_dim_int
block_dim = fromInteger block_dim_int
(odata, basic_odata_offset, idata, basic_idata_offset,
width, height, input_size, output_size,
mulx, muly, num_arrays,
block) = args
(num_groups, group_size) =
case kind of
TransposeSmall ->
([(num_arrays * width * height) `quotRoundingUp` (block_dim * block_dim)],
[block_dim * block_dim])
TransposeLowWidth ->
lowDimKernelAndGroupSize block_dim num_arrays width $ height `quotRoundingUp` muly
TransposeLowHeight ->
lowDimKernelAndGroupSize block_dim num_arrays (width `quotRoundingUp` mulx) height
TransposeNormal ->
let actual_dim = block_dim*2
in ( [ width `quotRoundingUp` actual_dim
, height `quotRoundingUp` actual_dim
, num_arrays]
, [actual_dim, actual_dim `quot` elemsPerThread, 1])
uses = map (`ScalarUse` int32)
(S.toList $ mconcat $ map freeIn
[basic_odata_offset, basic_idata_offset, num_arrays,
width, height, input_size, output_size, mulx, muly]) ++
map MemoryUse [odata, idata]
name =
case kind of TransposeSmall -> desc ++ "_small"
TransposeLowHeight -> desc ++ "_low_height"
TransposeLowWidth -> desc ++ "_low_width"
TransposeNormal -> desc
lowDimKernelAndGroupSize :: Exp -> Exp -> Exp -> Exp -> ([Exp], [Exp])
lowDimKernelAndGroupSize block_dim num_arrays x_elems y_elems =
([x_elems `quotRoundingUp` block_dim,
y_elems `quotRoundingUp` block_dim,
num_arrays],
[block_dim, block_dim, 1])