-- | Carefully optimised implementations of GPU transpositions.
-- Written in ImpCode so we can compile it to both CUDA and OpenCL.
module Futhark.CodeGen.ImpGen.Kernels.Transpose
  ( TransposeType (..),
    TransposeArgs,
    mapTransposeKernel,
  )
where

import Futhark.CodeGen.ImpCode.Kernels
import Futhark.IR.Prop.Types
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- | Which form of transposition to generate code for.
data TransposeType
  = TransposeNormal
  | TransposeLowWidth
  | TransposeLowHeight
  | -- | For small arrays that do not
    -- benefit from coalescing.
    TransposeSmall
  deriving (TransposeType -> TransposeType -> Bool
(TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool) -> Eq TransposeType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransposeType -> TransposeType -> Bool
$c/= :: TransposeType -> TransposeType -> Bool
== :: TransposeType -> TransposeType -> Bool
$c== :: TransposeType -> TransposeType -> Bool
Eq, Eq TransposeType
Eq TransposeType
-> (TransposeType -> TransposeType -> Ordering)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> TransposeType)
-> (TransposeType -> TransposeType -> TransposeType)
-> Ord TransposeType
TransposeType -> TransposeType -> Bool
TransposeType -> TransposeType -> Ordering
TransposeType -> TransposeType -> TransposeType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TransposeType -> TransposeType -> TransposeType
$cmin :: TransposeType -> TransposeType -> TransposeType
max :: TransposeType -> TransposeType -> TransposeType
$cmax :: TransposeType -> TransposeType -> TransposeType
>= :: TransposeType -> TransposeType -> Bool
$c>= :: TransposeType -> TransposeType -> Bool
> :: TransposeType -> TransposeType -> Bool
$c> :: TransposeType -> TransposeType -> Bool
<= :: TransposeType -> TransposeType -> Bool
$c<= :: TransposeType -> TransposeType -> Bool
< :: TransposeType -> TransposeType -> Bool
$c< :: TransposeType -> TransposeType -> Bool
compare :: TransposeType -> TransposeType -> Ordering
$ccompare :: TransposeType -> TransposeType -> Ordering
$cp1Ord :: Eq TransposeType
Ord, Int -> TransposeType -> ShowS
[TransposeType] -> ShowS
TransposeType -> String
(Int -> TransposeType -> ShowS)
-> (TransposeType -> String)
-> ([TransposeType] -> ShowS)
-> Show TransposeType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransposeType] -> ShowS
$cshowList :: [TransposeType] -> ShowS
show :: TransposeType -> String
$cshow :: TransposeType -> String
showsPrec :: Int -> TransposeType -> ShowS
$cshowsPrec :: Int -> TransposeType -> ShowS
Show)

-- | The types of the arguments accepted by a transposition function.
type TransposeArgs =
  ( VName,
    TExp Int32,
    VName,
    TExp Int32,
    TExp Int32,
    TExp Int32,
    TExp Int32,
    TExp Int32,
    TExp Int32,
    VName
  )

elemsPerThread :: TExp Int32
elemsPerThread :: TExp Int32
elemsPerThread = TExp Int32
4

mapTranspose :: TExp Int32 -> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose :: TExp Int32
-> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose TExp Int32
block_dim TransposeArgs
args PrimType
t TransposeType
kind =
  case TransposeType
kind of
    TransposeType
TransposeSmall ->
      [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
        [ KernelCode
get_ids,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
our_array_offset (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_global_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width),
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
x_index (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> TExp Int32
vi32 VName
get_global_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` (TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width)) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
height,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
y_index (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_global_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
height,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
odata_offset (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$
            (TExp Int32
basic_odata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
our_array_offset,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
idata_offset (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$
            (TExp Int32
basic_idata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
our_array_offset,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
index_in (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
x_index,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
index_out (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
x_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
y_index,
          TExp Bool -> KernelCode -> KernelCode
forall a. TExp Bool -> Code a -> Code a
when
            (VName -> TExp Int32
vi32 VName
get_global_id_0 TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
num_arrays)
            ( VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
odata (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
odata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
index_out) PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
                VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
index VName
idata (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
idata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
index_in) PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile
            )
        ]
    TransposeType
TransposeLowWidth ->
      KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
        TExp Int32 -> TExp Int32 -> TExp Int32 -> TExp Int32 -> KernelCode
forall t t t t.
TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> KernelCode
lowDimBody
          (VName -> TExp Int32
vi32 VName
get_group_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
muly))
          ( VName -> TExp Int32
vi32 VName
get_group_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
muly TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_1
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
muly) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim
          )
          ( VName -> TExp Int32
vi32 VName
get_group_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
muly TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_0
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
muly) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim
          )
          (VName -> TExp Int32
vi32 VName
get_group_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
muly))
    TransposeType
TransposeLowHeight ->
      KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
        TExp Int32 -> TExp Int32 -> TExp Int32 -> TExp Int32 -> KernelCode
forall t t t t.
TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> KernelCode
lowDimBody
          ( VName -> TExp Int32
vi32 VName
get_group_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
mulx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_0
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
mulx) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim
          )
          (VName -> TExp Int32
vi32 VName
get_group_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
mulx))
          (VName -> TExp Int32
vi32 VName
get_group_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
mulx))
          ( VName -> TExp Int32
vi32 VName
get_group_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
mulx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_1
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ (VName -> TExp Int32
vi32 VName
get_local_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
mulx) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim
          )
    TransposeType
TransposeNormal ->
      KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
        [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
          [ VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
x_index (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_global_id_0,
            VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
y_index (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_group_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
tile_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_1,
            TExp Bool -> KernelCode -> KernelCode
forall a. TExp Bool -> Code a -> Code a
when (VName -> TExp Int32
vi32 VName
x_index TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
width) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
              VName -> Exp -> KernelCode -> KernelCode
forall a. VName -> Exp -> Code a -> Code a
For VName
j (TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
elemsPerThread) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
                let i :: TExp Int32
i = VName -> TExp Int32
vi32 VName
j TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (TExp Int32
tile_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
elemsPerThread)
                 in [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
                      [ VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
index_in (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
x_index,
                        TExp Bool -> KernelCode -> KernelCode
forall a. TExp Bool -> Code a -> Code a
when (VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
height) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
                          VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
                            VName
block
                            ( TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                                TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
                                  (VName -> TExp Int32
vi32 VName
get_local_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (TExp Int32
tile_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1)
                                    TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_0
                            )
                            PrimType
t
                            (String -> Space
Space String
"local")
                            Volatility
Nonvolatile
                            (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
index
                              VName
idata
                              (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
idata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
index_in)
                              PrimType
t
                              (String -> Space
Space String
"global")
                              Volatility
Nonvolatile
                      ],
            KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Barrier Fence
FenceLocal,
            VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
x_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_group_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
tile_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_0,
            VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
y_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_group_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
tile_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_1,
            TExp Bool -> KernelCode -> KernelCode
forall a. TExp Bool -> Code a -> Code a
when (VName -> TExp Int32
vi32 VName
x_index TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
height) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
              VName -> Exp -> KernelCode -> KernelCode
forall a. VName -> Exp -> Code a -> Code a
For VName
j (TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
elemsPerThread) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
                let i :: TExp Int32
i = VName -> TExp Int32
vi32 VName
j TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (TExp Int32
tile_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
elemsPerThread)
                 in [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
                      [ VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
index_out (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
x_index,
                        TExp Bool -> KernelCode -> KernelCode
forall a. TExp Bool -> Code a -> Code a
when (VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
width) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
                          VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
                            VName
odata
                            (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
odata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
index_out)
                            PrimType
t
                            (String -> Space
Space String
"global")
                            Volatility
Nonvolatile
                            (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
index
                              VName
block
                              ( TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                                  TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
                                    VName -> TExp Int32
vi32 VName
get_local_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (TExp Int32
tile_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i
                              )
                              PrimType
t
                              (String -> Space
Space String
"local")
                              Volatility
Nonvolatile
                      ]
          ]
  where
    dec :: VName -> TPrimExp t ExpLeaf -> Code a
dec VName
v (TPrimExp Exp
e) =
      VName -> Volatility -> PrimType -> Code a
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
v Volatility
Nonvolatile (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e) Code a -> Code a -> Code a
forall a. Semigroup a => a -> a -> a
<> VName -> Exp -> Code a
forall a. VName -> Exp -> Code a
SetScalar VName
v Exp
e
    tile_dim :: TExp Int32
tile_dim = TExp Int32
2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim

    when :: TExp Bool -> Code a -> Code a
when TExp Bool
a Code a
b = TExp Bool -> Code a -> Code a -> Code a
forall a. TExp Bool -> Code a -> Code a -> Code a
If TExp Bool
a Code a
b Code a
forall a. Monoid a => a
mempty

    ( VName
odata,
      TExp Int32
basic_odata_offset,
      VName
idata,
      TExp Int32
basic_idata_offset,
      TExp Int32
width,
      TExp Int32
height,
      TExp Int32
mulx,
      TExp Int32
muly,
      TExp Int32
num_arrays,
      VName
block
      ) = TransposeArgs
args

    -- Be extremely careful when editing this list to ensure that
    -- the names match up.  Also, be careful that the tags on
    -- these names do not conflict with the tags of the
    -- surrounding code.  We accomplish the latter by using very
    -- low tags (normal variables start at least in the low
    -- hundreds).
    [ VName
our_array_offset,
      VName
x_index,
      VName
y_index,
      VName
odata_offset,
      VName
idata_offset,
      VName
index_in,
      VName
index_out,
      VName
get_global_id_0,
      VName
get_local_id_0,
      VName
get_local_id_1,
      VName
get_group_id_0,
      VName
get_group_id_1,
      VName
get_group_id_2,
      VName
j
      ] =
        (Int -> Name -> VName) -> [Int] -> [Name] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((Name -> Int -> VName) -> Int -> Name -> VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip Name -> Int -> VName
VName) [Int
30 ..] ([Name] -> [VName]) -> [Name] -> [VName]
forall a b. (a -> b) -> a -> b
$
          (String -> Name) -> [String] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map
            String -> Name
nameFromString
            [ String
"our_array_offset",
              String
"x_index",
              String
"y_index",
              String
"odata_offset",
              String
"idata_offset",
              String
"index_in",
              String
"index_out",
              String
"get_global_id_0",
              String
"get_local_id_0",
              String
"get_local_id_1",
              String
"get_group_id_0",
              String
"get_group_id_1",
              String
"get_group_id_2",
              String
"j"
            ]

    get_ids :: KernelCode
get_ids =
      [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
        [ VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_global_id_0 Volatility
Nonvolatile PrimType
int32,
          KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGlobalId VName
get_global_id_0 Int
0,
          VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_local_id_0 Volatility
Nonvolatile PrimType
int32,
          KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetLocalId VName
get_local_id_0 Int
0,
          VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_local_id_1 Volatility
Nonvolatile PrimType
int32,
          KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetLocalId VName
get_local_id_1 Int
1,
          VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_0 Volatility
Nonvolatile PrimType
int32,
          KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_0 Int
0,
          VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_1 Volatility
Nonvolatile PrimType
int32,
          KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_1 Int
1,
          VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_2 Volatility
Nonvolatile PrimType
int32,
          KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_2 Int
2
        ]

    mkTranspose :: KernelCode -> KernelCode
mkTranspose KernelCode
body =
      [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
        [ KernelCode
get_ids,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
our_array_offset (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_group_id_2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
height,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
odata_offset (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$
            (TExp Int32
basic_odata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
our_array_offset,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
idata_offset (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$
            (TExp Int32
basic_idata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
our_array_offset,
          KernelCode
body
        ]

    lowDimBody :: TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
-> KernelCode
lowDimBody TPrimExp t ExpLeaf
x_in_index TPrimExp t ExpLeaf
y_in_index TPrimExp t ExpLeaf
x_out_index TPrimExp t ExpLeaf
y_out_index =
      [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
        [ VName -> TPrimExp t ExpLeaf -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
x_index TPrimExp t ExpLeaf
x_in_index,
          VName -> TPrimExp t ExpLeaf -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
y_index TPrimExp t ExpLeaf
y_in_index,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
index_in (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
x_index,
          TExp Bool -> KernelCode -> KernelCode
forall a. TExp Bool -> Code a -> Code a
when (VName -> TExp Int32
vi32 VName
x_index TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
width TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
height) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
            VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
              VName
block
              (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_local_id_1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_0)
              PrimType
t
              (String -> Space
Space String
"local")
              Volatility
Nonvolatile
              (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
index
                VName
idata
                (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
idata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
index_in)
                PrimType
t
                (String -> Space
Space String
"global")
                Volatility
Nonvolatile,
          KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Barrier Fence
FenceLocal,
          VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
x_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TPrimExp t ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp t ExpLeaf
x_out_index,
          VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
y_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TPrimExp t ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp t ExpLeaf
y_out_index,
          VName -> TExp Int32 -> KernelCode
forall t a. VName -> TPrimExp t ExpLeaf -> Code a
dec VName
index_out (TExp Int32 -> KernelCode) -> TExp Int32 -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
x_index,
          TExp Bool -> KernelCode -> KernelCode
forall a. TExp Bool -> Code a -> Code a
when (VName -> TExp Int32
vi32 VName
x_index TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
height TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TExp Int32
vi32 VName
y_index TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
width) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
            VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
              VName
odata
              (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (VName -> TExp Int32
vi32 VName
odata_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
index_out))
              PrimType
t
              (String -> Space
Space String
"global")
              Volatility
Nonvolatile
              (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
index
                VName
block
                (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
vi32 VName
get_local_id_0 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
vi32 VName
get_local_id_1)
                PrimType
t
                (String -> Space
Space String
"local")
                Volatility
Nonvolatile
        ]

-- | Generate a transpose kernel.  There is special support to handle
-- input arrays with low width, low height, or both.
--
-- Normally when transposing a @[2][n]@ array we would use a @FUT_BLOCK_DIM x
-- FUT_BLOCK_DIM@ group to process a @[2][FUT_BLOCK_DIM]@ slice of the input
-- array. This would mean that many of the threads in a group would be inactive.
-- We try to remedy this by using a special kernel that will process a larger
-- part of the input, by using more complex indexing. In our example, we could
-- use all threads in a group if we are processing @(2/FUT_BLOCK_DIM)@ as large
-- a slice of each rows per group. The variable @mulx@ contains this factor for
-- the kernel to handle input arrays with low height.
--
-- See issue #308 on GitHub for more details.
--
-- These kernels are optimized to ensure all global reads and writes
-- are coalesced, and to avoid bank conflicts in shared memory.  Each
-- thread group transposes a 2D tile of block_dim*2 by block_dim*2
-- elements. The size of a thread group is block_dim/2 by
-- block_dim*2, meaning that each thread will process 4 elements in a
-- 2D tile.  The shared memory array containing the 2D tile consists
-- of block_dim*2 by block_dim*2+1 elements. Padding each row with
-- an additional element prevents bank conflicts from occuring when
-- the tile is accessed column-wise.
mapTransposeKernel ::
  String ->
  Integer ->
  TransposeArgs ->
  PrimType ->
  TransposeType ->
  Kernel
mapTransposeKernel :: String
-> Integer -> TransposeArgs -> PrimType -> TransposeType -> Kernel
mapTransposeKernel String
desc Integer
block_dim_int TransposeArgs
args PrimType
t TransposeType
kind =
  Kernel :: KernelCode
-> [KernelUse] -> [Exp] -> [Exp] -> Name -> Bool -> Kernel
Kernel
    { kernelBody :: KernelCode
kernelBody =
        VName -> Space -> KernelCode
forall a. VName -> Space -> Code a
DeclareMem VName
block (String -> Space
Space String
"local")
          KernelCode -> KernelCode -> KernelCode
forall a. Semigroup a => a -> a -> a
<> KernelOp -> KernelCode
forall a. a -> Code a
Op (VName -> Count Bytes (TExp Int64) -> KernelOp
LocalAlloc VName
block Count Bytes (TExp Int64)
block_size)
          KernelCode -> KernelCode -> KernelCode
forall a. Semigroup a => a -> a -> a
<> TExp Int32
-> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose TExp Int32
block_dim TransposeArgs
args PrimType
t TransposeType
kind,
      kernelUses :: [KernelUse]
kernelUses = [KernelUse]
uses,
      kernelNumGroups :: [Exp]
kernelNumGroups = (TExp Int32 -> Exp) -> [TExp Int32] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped [TExp Int32]
num_groups,
      kernelGroupSize :: [Exp]
kernelGroupSize = (TExp Int32 -> Exp) -> [TExp Int32] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped [TExp Int32]
group_size,
      kernelName :: Name
kernelName = String -> Name
nameFromString String
name,
      kernelFailureTolerant :: Bool
kernelFailureTolerant = Bool
True
    }
  where
    pad2DBytes :: a -> a
pad2DBytes a
k = a
k a -> a -> a
forall a. Num a => a -> a -> a
* (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) a -> a -> a
forall a. Num a => a -> a -> a
* PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
t
    block_size :: Count Bytes (TExp Int64)
block_size =
      TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
        case TransposeType
kind of
          TransposeType
TransposeSmall -> TExp Int64
1 :: TExp Int64
          -- Not used, but AMD's OpenCL
          -- does not like zero-size
          -- local memory.
          TransposeType
TransposeNormal -> Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger (Integer -> TExp Int64) -> Integer -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
pad2DBytes (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
block_dim_int
          TransposeType
TransposeLowWidth -> Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger (Integer -> TExp Int64) -> Integer -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
pad2DBytes Integer
block_dim_int
          TransposeType
TransposeLowHeight -> Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger (Integer -> TExp Int64) -> Integer -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
pad2DBytes Integer
block_dim_int
    block_dim :: TExp Int32
block_dim = Integer -> TExp Int32
forall a. Num a => Integer -> a
fromInteger Integer
block_dim_int :: TExp Int32

    ( VName
odata,
      TExp Int32
basic_odata_offset,
      VName
idata,
      TExp Int32
basic_idata_offset,
      TExp Int32
width,
      TExp Int32
height,
      TExp Int32
mulx,
      TExp Int32
muly,
      TExp Int32
num_arrays,
      VName
block
      ) = TransposeArgs
args

    ([TExp Int32]
num_groups, [TExp Int32]
group_size) =
      case TransposeType
kind of
        TransposeType
TransposeSmall ->
          ( [(TExp Int32
num_arrays TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
height) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` (TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim)],
            [TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_dim]
          )
        TransposeType
TransposeLowWidth ->
          TExp Int32
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int32], [TExp Int32])
lowDimKernelAndGroupSize TExp Int32
block_dim TExp Int32
num_arrays TExp Int32
width (TExp Int32 -> ([TExp Int32], [TExp Int32]))
-> TExp Int32 -> ([TExp Int32], [TExp Int32])
forall a b. (a -> b) -> a -> b
$ TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
muly
        TransposeType
TransposeLowHeight ->
          TExp Int32
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int32], [TExp Int32])
lowDimKernelAndGroupSize TExp Int32
block_dim TExp Int32
num_arrays (TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
mulx) TExp Int32
height
        TransposeType
TransposeNormal ->
          let actual_dim :: TExp Int32
actual_dim = TExp Int32
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2
           in ( [ TExp Int32
width TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
actual_dim,
                  TExp Int32
height TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
actual_dim,
                  TExp Int32
num_arrays
                ],
                [TExp Int32
actual_dim, TExp Int32
actual_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
elemsPerThread, TExp Int32
1]
              )

    uses :: [KernelUse]
uses =
      (VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map
        (VName -> PrimType -> KernelUse
`ScalarUse` PrimType
int32)
        ( Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
            [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
              (TExp Int32 -> Names) -> [TExp Int32] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map
                TExp Int32 -> Names
forall a. FreeIn a => a -> Names
freeIn
                [ TExp Int32
basic_odata_offset,
                  TExp Int32
basic_idata_offset,
                  TExp Int32
num_arrays,
                  TExp Int32
width,
                  TExp Int32
height,
                  TExp Int32
mulx,
                  TExp Int32
muly
                ]
        )
        [KernelUse] -> [KernelUse] -> [KernelUse]
forall a. [a] -> [a] -> [a]
++ (VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map VName -> KernelUse
MemoryUse [VName
odata, VName
idata]

    name :: String
name =
      case TransposeType
kind of
        TransposeType
TransposeSmall -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_small"
        TransposeType
TransposeLowHeight -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_low_height"
        TransposeType
TransposeLowWidth -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_low_width"
        TransposeType
TransposeNormal -> String
desc

lowDimKernelAndGroupSize ::
  TExp Int32 ->
  TExp Int32 ->
  TExp Int32 ->
  TExp Int32 ->
  ([TExp Int32], [TExp Int32])
lowDimKernelAndGroupSize :: TExp Int32
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int32], [TExp Int32])
lowDimKernelAndGroupSize TExp Int32
block_dim TExp Int32
num_arrays TExp Int32
x_elems TExp Int32
y_elems =
  ( [ TExp Int32
x_elems TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
block_dim,
      TExp Int32
y_elems TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
block_dim,
      TExp Int32
num_arrays
    ],
    [TExp Int32
block_dim, TExp Int32
block_dim, TExp Int32
1]
  )