module Futhark.CodeGen.ImpGen.GPU.Transpose
( TransposeType (..),
TransposeArgs,
mapTransposeKernel,
)
where
import Futhark.CodeGen.ImpCode.GPU
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
data TransposeType
= TransposeNormal
| TransposeLowWidth
| TransposeLowHeight
|
TransposeSmall
deriving (TransposeType -> TransposeType -> Bool
(TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool) -> Eq TransposeType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TransposeType -> TransposeType -> Bool
== :: TransposeType -> TransposeType -> Bool
$c/= :: TransposeType -> TransposeType -> Bool
/= :: TransposeType -> TransposeType -> Bool
Eq, Eq TransposeType
Eq TransposeType
-> (TransposeType -> TransposeType -> Ordering)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> TransposeType)
-> (TransposeType -> TransposeType -> TransposeType)
-> Ord TransposeType
TransposeType -> TransposeType -> Bool
TransposeType -> TransposeType -> Ordering
TransposeType -> TransposeType -> TransposeType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: TransposeType -> TransposeType -> Ordering
compare :: TransposeType -> TransposeType -> Ordering
$c< :: TransposeType -> TransposeType -> Bool
< :: TransposeType -> TransposeType -> Bool
$c<= :: TransposeType -> TransposeType -> Bool
<= :: TransposeType -> TransposeType -> Bool
$c> :: TransposeType -> TransposeType -> Bool
> :: TransposeType -> TransposeType -> Bool
$c>= :: TransposeType -> TransposeType -> Bool
>= :: TransposeType -> TransposeType -> Bool
$cmax :: TransposeType -> TransposeType -> TransposeType
max :: TransposeType -> TransposeType -> TransposeType
$cmin :: TransposeType -> TransposeType -> TransposeType
min :: TransposeType -> TransposeType -> TransposeType
Ord, Int -> TransposeType -> ShowS
[TransposeType] -> ShowS
TransposeType -> String
(Int -> TransposeType -> ShowS)
-> (TransposeType -> String)
-> ([TransposeType] -> ShowS)
-> Show TransposeType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TransposeType -> ShowS
showsPrec :: Int -> TransposeType -> ShowS
$cshow :: TransposeType -> String
show :: TransposeType -> String
$cshowList :: [TransposeType] -> ShowS
showList :: [TransposeType] -> ShowS
Show)
type TransposeArgs int =
( VName,
TExp int,
VName,
TExp int,
TExp int,
TExp int,
TExp int,
TExp int,
TExp int,
VName
)
elemsPerThread :: Num a => a
elemsPerThread :: forall a. Num a => a
elemsPerThread = a
8
mapTranspose :: forall int. IntExp int => (PrimType, VName -> TExp int) -> TExp int -> TransposeArgs int -> PrimType -> TransposeType -> KernelCode
mapTranspose :: forall {k} (int :: k).
IntExp int =>
(PrimType, VName -> TExp int)
-> TExp int
-> TransposeArgs int
-> PrimType
-> TransposeType
-> KernelCode
mapTranspose (PrimType
int, VName -> TPrimExp int VName
le) TPrimExp int VName
block_dim TransposeArgs int
args PrimType
t TransposeType
kind =
case TransposeType
kind of
TransposeType
TransposeSmall ->
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ KernelCode
get_ids,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
our_array_offset (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_global_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` (TPrimExp int VName
height TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
width) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* (TPrimExp int VName
height TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
width),
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
x_index (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> TPrimExp int VName
le VName
get_global_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`rem` (TPrimExp int VName
height TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
width)) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp int VName
height,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
y_index (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_global_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp int VName
height,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
val Volatility
Nonvolatile PrimType
t,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
odata_offset (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$
(TPrimExp int VName
basic_odata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp int VName
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
our_array_offset,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
idata_offset (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$
(TPrimExp int VName
basic_idata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp int VName
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
our_array_offset,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
index_in (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
width TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
x_index,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
index_out (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
x_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
height TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
y_index,
TExp Bool -> KernelCode -> KernelCode
forall {a}. TExp Bool -> Code a -> Code a
when
(VName -> TPrimExp int VName
le VName
get_global_id_0 TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
width TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
height TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
num_arrays)
( [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> KernelCode
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read VName
val VName
idata (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
toOffset (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
idata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
index_in) PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile,
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
odata (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
toOffset (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
odata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
index_out) PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile (VName -> PrimType -> Exp
var VName
val PrimType
t)
]
)
]
TransposeType
TransposeLowWidth ->
KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
TPrimExp int VName
-> TPrimExp int VName
-> TPrimExp int VName
-> TPrimExp int VName
-> KernelCode
forall {k} {k} {k} {k} {t :: k} {t :: k} {t :: k} {t :: k}.
TPrimExp t VName
-> TPrimExp t VName
-> TPrimExp t VName
-> TPrimExp t VName
-> KernelCode
lowDimBody
(VName -> TPrimExp int VName
le VName
get_group_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp int VName
muly))
( VName -> TPrimExp int VName
le VName
get_group_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
muly
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_1
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp int VName
muly) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim
)
( VName -> TPrimExp int VName
le VName
get_group_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
muly
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_0
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp int VName
muly) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim
)
(VName -> TPrimExp int VName
le VName
get_group_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp int VName
muly))
TransposeType
TransposeLowHeight ->
KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
TPrimExp int VName
-> TPrimExp int VName
-> TPrimExp int VName
-> TPrimExp int VName
-> KernelCode
forall {k} {k} {k} {k} {t :: k} {t :: k} {t :: k} {t :: k}.
TPrimExp t VName
-> TPrimExp t VName
-> TPrimExp t VName
-> TPrimExp t VName
-> KernelCode
lowDimBody
( VName -> TPrimExp int VName
le VName
get_group_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
mulx
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_0
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp int VName
mulx) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim
)
(VName -> TPrimExp int VName
le VName
get_group_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp int VName
mulx))
(VName -> TPrimExp int VName
le VName
get_group_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp int VName
mulx))
( VName -> TPrimExp int VName
le VName
get_group_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
mulx
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_1
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp int VName
le VName
get_local_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp int VName
mulx) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim
)
TransposeType
TransposeNormal ->
KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
x_index (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_global_id_0,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
y_index (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_group_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
tile_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_1,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
val Volatility
Nonvolatile PrimType
t,
TExp Bool -> KernelCode -> KernelCode
forall {a}. TExp Bool -> Code a -> Code a
when (VName -> TPrimExp int VName
le VName
x_index TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
width) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
VName -> Exp -> KernelCode -> KernelCode
forall a. VName -> Exp -> Code a -> Code a
For VName
j (TPrimExp int VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp int VName
forall a. Num a => a
elemsPerThread :: TExp int)) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
let i :: TPrimExp int VName
i = VName -> TPrimExp int VName
le VName
j TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* (TPrimExp int VName
tile_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp int VName
forall a. Num a => a
elemsPerThread)
in [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
index_in (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
i) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
width TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
x_index,
TExp Bool -> KernelCode -> KernelCode
forall {a}. TExp Bool -> Code a -> Code a
when (VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
i TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
height) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> KernelCode
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read
VName
val
VName
idata
(TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
toOffset (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
idata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
index_in)
PrimType
t
(String -> Space
Space String
"global")
Volatility
Nonvolatile,
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
VName
block
( TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp int VName -> TExp Int64
toOffset (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(VName -> TPrimExp int VName
le VName
get_local_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
i) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* (TPrimExp int VName
tile_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
1)
TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_0
)
PrimType
t
(String -> Space
Space String
"local")
Volatility
Nonvolatile
(VName -> PrimType -> Exp
var VName
val PrimType
t)
]
],
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Barrier Fence
FenceLocal,
VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
x_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp int VName -> Exp) -> TPrimExp int VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_group_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
tile_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_0,
VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
y_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp int VName -> Exp) -> TPrimExp int VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_group_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
tile_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_1,
TExp Bool -> KernelCode -> KernelCode
forall {a}. TExp Bool -> Code a -> Code a
when (VName -> TPrimExp int VName
le VName
x_index TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
height) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
VName -> Exp -> KernelCode -> KernelCode
forall a. VName -> Exp -> Code a -> Code a
For VName
j (TPrimExp int VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp int VName
forall a. Num a => a
elemsPerThread :: TExp int)) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
let i :: TPrimExp int VName
i = VName -> TPrimExp int VName
le VName
j TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* (TPrimExp int VName
tile_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp int VName
forall a. Num a => a
elemsPerThread)
in [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
index_out (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
i) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
height TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
x_index,
TExp Bool -> KernelCode -> KernelCode
forall {a}. TExp Bool -> Code a -> Code a
when (VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
i TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
width) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> KernelCode
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read
VName
val
VName
block
( TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName
-> Count Elements (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp int VName -> TExp Int64
toOffset (TPrimExp int VName -> Count Elements (TExp Int64))
-> TPrimExp int VName -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp int VName
le VName
get_local_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* (TPrimExp int VName
tile_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
1) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
i
)
PrimType
t
(String -> Space
Space String
"local")
Volatility
Nonvolatile,
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
VName
odata
(TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
toOffset (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
odata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
index_out)
PrimType
t
(String -> Space
Space String
"global")
Volatility
Nonvolatile
(VName -> PrimType -> Exp
var VName
val PrimType
t)
]
]
]
where
toOffset :: TExp int -> TExp Int64
toOffset :: TPrimExp int VName -> TExp Int64
toOffset = TPrimExp int VName -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64
dec :: VName -> TPrimExp t VName -> Code a
dec VName
v (TPrimExp Exp
e) =
VName -> Volatility -> PrimType -> Code a
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
v Volatility
Nonvolatile (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e) Code a -> Code a -> Code a
forall a. Semigroup a => a -> a -> a
<> VName -> Exp -> Code a
forall a. VName -> Exp -> Code a
SetScalar VName
v Exp
e
tile_dim :: TPrimExp int VName
tile_dim = TPrimExp int VName
2 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
block_dim
when :: TExp Bool -> Code a -> Code a
when TExp Bool
a Code a
b = TExp Bool -> Code a -> Code a -> Code a
forall a. TExp Bool -> Code a -> Code a -> Code a
If TExp Bool
a Code a
b Code a
forall a. Monoid a => a
mempty
( VName
odata,
TPrimExp int VName
basic_odata_offset,
VName
idata,
TPrimExp int VName
basic_idata_offset,
TPrimExp int VName
width,
TPrimExp int VName
height,
TPrimExp int VName
mulx,
TPrimExp int VName
muly,
TPrimExp int VName
num_arrays,
VName
block
) = TransposeArgs int
args
[ VName
our_array_offset,
VName
x_index,
VName
y_index,
VName
odata_offset,
VName
idata_offset,
VName
index_in,
VName
index_out,
VName
get_global_id_0,
VName
get_local_id_0,
VName
get_local_id_1,
VName
get_local_size_0,
VName
get_group_id_0,
VName
get_group_id_1,
VName
get_group_id_2,
VName
j,
VName
val
] =
(Int -> Name -> VName) -> [Int] -> [Name] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((Name -> Int -> VName) -> Int -> Name -> VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip Name -> Int -> VName
VName) [Int
30 ..] ([Name] -> [VName]) -> [Name] -> [VName]
forall a b. (a -> b) -> a -> b
$
(String -> Name) -> [String] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map
String -> Name
nameFromString
[ String
"our_array_offset",
String
"x_index",
String
"y_index",
String
"odata_offset",
String
"idata_offset",
String
"index_in",
String
"index_out",
String
"get_global_id_0",
String
"get_local_id_0",
String
"get_local_id_1",
String
"get_local_size_0",
String
"get_group_id_0",
String
"get_group_id_1",
String
"get_group_id_2",
String
"j",
String
"val"
]
get_ids :: KernelCode
get_ids =
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_local_id_0 Volatility
Nonvolatile PrimType
int,
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetLocalId VName
get_local_id_0 Int
0,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_local_id_1 Volatility
Nonvolatile PrimType
int,
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetLocalId VName
get_local_id_1 Int
1,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_0 Volatility
Nonvolatile PrimType
int,
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_0 Int
0,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_1 Volatility
Nonvolatile PrimType
int,
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_1 Int
1,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_2 Volatility
Nonvolatile PrimType
int,
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_2 Int
2,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_local_size_0 Volatility
Nonvolatile PrimType
int,
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetLocalSize VName
get_local_size_0 Int
0,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_global_id_0 Volatility
Nonvolatile PrimType
int,
VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
get_global_id_0 (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp int VName -> Exp) -> TPrimExp int VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_group_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp int VName
le VName
get_local_size_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_0
]
mkTranspose :: KernelCode -> KernelCode
mkTranspose KernelCode
body =
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ KernelCode
get_ids,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
our_array_offset (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_group_id_2 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
width TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
height,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
odata_offset (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$
(TPrimExp int VName
basic_odata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp int VName
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
our_array_offset,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
idata_offset (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$
(TPrimExp int VName
basic_idata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp int VName
forall a. Num a => PrimType -> a
primByteSize PrimType
t) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
our_array_offset,
KernelCode
body
]
lowDimBody :: TPrimExp t VName
-> TPrimExp t VName
-> TPrimExp t VName
-> TPrimExp t VName
-> KernelCode
lowDimBody TPrimExp t VName
x_in_index TPrimExp t VName
y_in_index TPrimExp t VName
x_out_index TPrimExp t VName
y_out_index =
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName -> TPrimExp t VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
x_index TPrimExp t VName
x_in_index,
VName -> TPrimExp t VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
y_index TPrimExp t VName
y_in_index,
VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
val Volatility
Nonvolatile PrimType
t,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
index_in (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
width TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
x_index,
TExp Bool -> KernelCode -> KernelCode
forall {a}. TExp Bool -> Code a -> Code a
when (VName -> TPrimExp int VName
le VName
x_index TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
width TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
height) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> KernelCode
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read
VName
val
VName
idata
(TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
idata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
index_in)
PrimType
t
(String -> Space
Space String
"global")
Volatility
Nonvolatile,
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
VName
block
(TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_local_id_1 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* (TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
1) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_0)
PrimType
t
(String -> Space
Space String
"local")
Volatility
Nonvolatile
(VName -> PrimType -> Exp
var VName
val PrimType
t)
],
KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Barrier Fence
FenceLocal,
VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
x_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
x_out_index,
VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
y_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
y_out_index,
VName -> TPrimExp int VName -> KernelCode
forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
index_out (TPrimExp int VName -> KernelCode)
-> TPrimExp int VName -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* TPrimExp int VName
height TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
x_index,
TExp Bool -> KernelCode -> KernelCode
forall {a}. TExp Bool -> Code a -> Code a
when (VName -> TPrimExp int VName
le VName
x_index TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
height TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp int VName
le VName
y_index TPrimExp int VName -> TPrimExp int VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp int VName
width) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
[KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
[ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> KernelCode
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read
VName
val
VName
block
(TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
toOffset (TPrimExp int VName -> TExp Int64)
-> TPrimExp int VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp int VName
le VName
get_local_id_0 TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
* (TPrimExp int VName
block_dim TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ TPrimExp int VName
1) TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
get_local_id_1)
PrimType
t
(String -> Space
Space String
"local")
Volatility
Nonvolatile,
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
VName
odata
(TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp int VName -> TExp Int64
toOffset (VName -> TPrimExp int VName
le VName
odata_offset TPrimExp int VName -> TPrimExp int VName -> TPrimExp int VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp int VName
le VName
index_out))
PrimType
t
(String -> Space
Space String
"global")
Volatility
Nonvolatile
(VName -> PrimType -> Exp
var VName
val PrimType
t)
]
]
lowDimKernelAndGroupSize ::
IntExp int =>
TExp int ->
TExp int ->
TExp int ->
TExp int ->
([TExp int], [TExp int])
lowDimKernelAndGroupSize :: forall {k} (int :: k).
IntExp int =>
TExp int
-> TExp int -> TExp int -> TExp int -> ([TExp int], [TExp int])
lowDimKernelAndGroupSize TExp int
block_dim TExp int
num_arrays TExp int
x_elems TExp int
y_elems =
( [ TExp int
x_elems TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`divUp` TExp int
block_dim,
TExp int
y_elems TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`divUp` TExp int
block_dim,
TExp int
num_arrays
],
[TExp int
block_dim, TExp int
block_dim, TExp int
1]
)
mapTransposeKernel ::
forall int.
IntExp int =>
(PrimType, VName -> TExp int) ->
String ->
Integer ->
TransposeArgs int ->
PrimType ->
TransposeType ->
Kernel
mapTransposeKernel :: forall {k} (int :: k).
IntExp int =>
(PrimType, VName -> TExp int)
-> String
-> Integer
-> TransposeArgs int
-> PrimType
-> TransposeType
-> Kernel
mapTransposeKernel (PrimType
int, VName -> TExp int
le) String
desc Integer
block_dim_int TransposeArgs int
args PrimType
t TransposeType
kind =
Kernel
{ kernelBody :: KernelCode
kernelBody =
VName -> Space -> KernelCode
forall a. VName -> Space -> Code a
DeclareMem VName
block (String -> Space
Space String
"local")
KernelCode -> KernelCode -> KernelCode
forall a. Semigroup a => a -> a -> a
<> KernelOp -> KernelCode
forall a. a -> Code a
Op (VName -> Count Bytes (TExp Int64) -> KernelOp
LocalAlloc VName
block Count Bytes (TExp Int64)
block_size)
KernelCode -> KernelCode -> KernelCode
forall a. Semigroup a => a -> a -> a
<> (PrimType, VName -> TExp int)
-> TExp int
-> TransposeArgs int
-> PrimType
-> TransposeType
-> KernelCode
forall {k} (int :: k).
IntExp int =>
(PrimType, VName -> TExp int)
-> TExp int
-> TransposeArgs int
-> PrimType
-> TransposeType
-> KernelCode
mapTranspose (PrimType
int, VName -> TExp int
le) TExp int
block_dim TransposeArgs int
args PrimType
t TransposeType
kind,
kernelUses :: [KernelUse]
kernelUses = [KernelUse]
uses,
kernelNumGroups :: [Exp]
kernelNumGroups = (TExp int -> Exp) -> [TExp int] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp int -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped [TExp int]
num_groups,
kernelGroupSize :: [GroupDim]
kernelGroupSize = (TExp int -> GroupDim) -> [TExp int] -> [GroupDim]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> GroupDim
forall a b. a -> Either a b
Left (Exp -> GroupDim) -> (TExp int -> Exp) -> TExp int -> GroupDim
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp int -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TExp int]
group_size,
kernelName :: Name
kernelName = String -> Name
nameFromString (String
name String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
forall a. Pretty a => a -> String
prettyString PrimType
int),
kernelFailureTolerant :: Bool
kernelFailureTolerant = Bool
True,
kernelCheckLocalMemory :: Bool
kernelCheckLocalMemory = Bool
False
}
where
pad2DBytes :: a -> a
pad2DBytes a
k = a
k a -> a -> a
forall a. Num a => a -> a -> a
* (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) a -> a -> a
forall a. Num a => a -> a -> a
* PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
t
block_size :: Count Bytes (TExp Int64)
block_size :: Count Bytes (TExp Int64)
block_size =
TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
case TransposeType
kind of
TransposeType
TransposeSmall -> TExp Int64
1
TransposeType
TransposeNormal -> Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger (Integer -> TExp Int64) -> Integer -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall {a}. Num a => a -> a
pad2DBytes (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
block_dim_int
TransposeType
TransposeLowWidth -> Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger (Integer -> TExp Int64) -> Integer -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall {a}. Num a => a -> a
pad2DBytes Integer
block_dim_int
TransposeType
TransposeLowHeight -> Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger (Integer -> TExp Int64) -> Integer -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall {a}. Num a => a -> a
pad2DBytes Integer
block_dim_int
block_dim :: TExp int
block_dim = Integer -> TExp int
forall a. Num a => Integer -> a
fromInteger Integer
block_dim_int :: TExp int
( VName
odata,
TExp int
basic_odata_offset,
VName
idata,
TExp int
basic_idata_offset,
TExp int
width,
TExp int
height,
TExp int
mulx,
TExp int
muly,
TExp int
num_arrays,
VName
block
) = TransposeArgs int
args
([TExp int]
num_groups, [TExp int]
group_size) =
case TransposeType
kind of
TransposeType
TransposeSmall ->
( [(TExp int
num_arrays TExp int -> TExp int -> TExp int
forall a. Num a => a -> a -> a
* TExp int
width TExp int -> TExp int -> TExp int
forall a. Num a => a -> a -> a
* TExp int
height) TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`divUp` (TExp int
block_dim TExp int -> TExp int -> TExp int
forall a. Num a => a -> a -> a
* TExp int
block_dim)],
[TExp int
block_dim TExp int -> TExp int -> TExp int
forall a. Num a => a -> a -> a
* TExp int
block_dim]
)
TransposeType
TransposeLowWidth ->
TExp int
-> TExp int -> TExp int -> TExp int -> ([TExp int], [TExp int])
forall {k} (int :: k).
IntExp int =>
TExp int
-> TExp int -> TExp int -> TExp int -> ([TExp int], [TExp int])
lowDimKernelAndGroupSize TExp int
block_dim TExp int
num_arrays TExp int
width (TExp int -> ([TExp int], [TExp int]))
-> TExp int -> ([TExp int], [TExp int])
forall a b. (a -> b) -> a -> b
$ TExp int
height TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`divUp` TExp int
muly
TransposeType
TransposeLowHeight ->
TExp int
-> TExp int -> TExp int -> TExp int -> ([TExp int], [TExp int])
forall {k} (int :: k).
IntExp int =>
TExp int
-> TExp int -> TExp int -> TExp int -> ([TExp int], [TExp int])
lowDimKernelAndGroupSize TExp int
block_dim TExp int
num_arrays (TExp int
width TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`divUp` TExp int
mulx) TExp int
height
TransposeType
TransposeNormal ->
let actual_dim :: TExp int
actual_dim = TExp int
block_dim TExp int -> TExp int -> TExp int
forall a. Num a => a -> a -> a
* TExp int
2
in ( [ TExp int
width TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`divUp` TExp int
actual_dim,
TExp int
height TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`divUp` TExp int
actual_dim,
TExp int
num_arrays
],
[TExp int
actual_dim, TExp int
actual_dim TExp int -> TExp int -> TExp int
forall e. IntegralExp e => e -> e -> e
`quot` TExp int
forall a. Num a => a
elemsPerThread, TExp int
1]
)
uses :: [KernelUse]
uses =
(VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map
(VName -> PrimType -> KernelUse
`ScalarUse` IntType -> PrimType
IntType IntType
Int64)
( Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
(TExp int -> Names) -> [TExp int] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map
TExp int -> Names
forall a. FreeIn a => a -> Names
freeIn
[ TExp int
basic_odata_offset,
TExp int
basic_idata_offset,
TExp int
num_arrays,
TExp int
width,
TExp int
height,
TExp int
mulx,
TExp int
muly
]
)
[KernelUse] -> [KernelUse] -> [KernelUse]
forall a. [a] -> [a] -> [a]
++ (VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map VName -> KernelUse
MemoryUse [VName
odata, VName
idata]
name :: String
name =
case TransposeType
kind of
TransposeType
TransposeSmall -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_small"
TransposeType
TransposeLowHeight -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_low_height"
TransposeType
TransposeLowWidth -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_low_width"
TransposeType
TransposeNormal -> String
desc