{-# LANGUAGE QuasiQuotes #-}
module Futhark.CodeGen.OpenCL.Kernels
       ( SizeHeuristic (..)
       , DeviceType (..)
       , WhichSize (..)
       , HeuristicValue (..)
       , sizeHeuristicsTable

       , mapTranspose
       , TransposeType(..)
       )
       where

import qualified Language.C.Syntax as C
import qualified Language.C.Quote.OpenCL as C

-- Some OpenCL platforms have a SIMD/warp/wavefront-based execution
-- model that execute groups of threads in lockstep, permitting us to
-- perform cross-thread synchronisation within each such group without
-- the use of barriers.  Unfortunately, there seems to be no reliable
-- way to query these sizes at runtime.  Instead, we use this table to
-- figure out which size we should use for a specific platform and
-- device.  If nothing matches here, the wave size should be set to
-- one.
--
-- We also use this to select reasonable default group sizes and group
-- counts.

-- | The type of OpenCL device that this heuristic applies to.
data DeviceType = DeviceCPU | DeviceGPU

-- | The value supplies by a heuristic can be a constant, or inferred
-- from some device information.
data HeuristicValue = HeuristicConst Int
                    | HeuristicDeviceInfo String

-- | A size that can be assigned a default.
data WhichSize = LockstepWidth | NumGroups | GroupSize | TileSize

-- | A heuristic for setting the default value for something.
data SizeHeuristic =
    SizeHeuristic { platformName :: String
                  , deviceType :: DeviceType
                  , heuristicSize :: WhichSize
                  , heuristicValue :: HeuristicValue
                  }

-- | All of our heuristics.
sizeHeuristicsTable :: [SizeHeuristic]
sizeHeuristicsTable =
  [ SizeHeuristic "NVIDIA CUDA" DeviceGPU LockstepWidth $ HeuristicConst 32
  , SizeHeuristic "AMD Accelerated Parallel Processing" DeviceGPU LockstepWidth $ HeuristicConst 64
  , SizeHeuristic "" DeviceGPU LockstepWidth $ HeuristicConst 1
  , SizeHeuristic "" DeviceGPU NumGroups $ HeuristicConst 128
  , SizeHeuristic "" DeviceGPU GroupSize $ HeuristicConst 256
  , SizeHeuristic "" DeviceGPU TileSize $ HeuristicConst 32

  , SizeHeuristic "" DeviceCPU LockstepWidth $ HeuristicConst 1
  , SizeHeuristic "" DeviceCPU NumGroups $ HeuristicDeviceInfo "MAX_COMPUTE_UNITS"
  , SizeHeuristic "" DeviceCPU GroupSize $ HeuristicConst 32
  , SizeHeuristic "" DeviceCPU TileSize $ HeuristicConst 4
  ]

-- | Which form of transposition to generate code for.
data TransposeType = TransposeNormal
                   | TransposeLowWidth
                   | TransposeLowHeight
                   | TransposeSmall -- ^ For small arrays that do not
                                    -- benefit from coalescing.
                   deriving (Eq, Ord, Show)

-- | @mapTranspose name elem_type transpose_type@ Generate a transpose kernel
-- with requested @name@ for elements of type @elem_type@. There are special
-- support to handle input arrays with low width or low height, which can be
-- indicated by @transpose_type@.
--
-- Normally when transposing a @[2][n]@ array we would use a @FUT_BLOCK_DIM x
-- FUT_BLOCK_DIM@ group to process a @[2][FUT_BLOCK_DIM]@ slice of the input
-- array. This would mean that many of the threads in a group would be inactive.
-- We try to remedy this by using a special kernel that will process a larger
-- part of the input, by using more complex indexing. In our example, we could
-- use all threads in a group if we are processing @(2/FUT_BLOCK_DIM)@ as large
-- a slice of each rows per group. The variable 'mulx' contains this factor for
-- the kernel to handle input arrays with low height.
--
-- See issue #308 on GitHub for more details.
mapTranspose :: C.ToIdent a => a -> C.Type -> TransposeType -> C.Func
mapTranspose kernel_name elem_type transpose_type =
  case transpose_type of
    TransposeNormal ->
      bigKernel []
      [C.cexp|global_id_x|]
      [C.cexp|global_id_y|]
      [C.cexp|group_id_y * FUT_BLOCK_DIM + local_id_x|]
      [C.cexp|group_id_x * FUT_BLOCK_DIM + local_id_y|]
      (toNumGroups [C.cexp|width|])
      (toNumGroups [C.cexp|height|])
    TransposeLowWidth ->
      bigKernel [C.cparams|uint muly|]
      [C.cexp|group_id_x * FUT_BLOCK_DIM + (local_id_x / muly)|]
      [C.cexp|group_id_y * FUT_BLOCK_DIM * muly
           + local_id_y
           + (local_id_x % muly) * FUT_BLOCK_DIM
          |]
      [C.cexp|group_id_y * FUT_BLOCK_DIM * muly
           + local_id_x
           + (local_id_y % muly) * FUT_BLOCK_DIM|]
      [C.cexp|group_id_x * FUT_BLOCK_DIM + (local_id_y / muly)|]
      (toNumGroups [C.cexp|width|])
      (toNumGroups [C.cexp|(height + muly - 1) / muly|])
    TransposeLowHeight ->
      bigKernel [C.cparams|uint mulx|]
      [C.cexp|group_id_x * FUT_BLOCK_DIM * mulx
           + local_id_x
           + (local_id_y % mulx) * FUT_BLOCK_DIM
          |]
      [C.cexp|group_id_y * FUT_BLOCK_DIM + (local_id_y / mulx)|]
      [C.cexp|group_id_y * FUT_BLOCK_DIM + (local_id_x / mulx)|]
      [C.cexp|group_id_x * FUT_BLOCK_DIM * mulx
           + local_id_y
           + (local_id_x % mulx) * FUT_BLOCK_DIM
           |]
      (toNumGroups [C.cexp|(width + mulx - 1) / mulx|])
      (toNumGroups [C.cexp|height|])
    TransposeSmall ->
      smallKernel
  where
    toNumGroups e = [C.cexp|($exp:e + FUT_BLOCK_DIM - 1) / FUT_BLOCK_DIM|]
    bigKernel extraparams x_in_index y_in_index x_out_index y_out_index ngrpx ngrpy =
      [C.cfun|
       // This kernel is optimized to ensure all global reads and writes are coalesced,
       // and to avoid bank conflicts in shared memory.  The shared memory array is sized
       // to (BLOCK_DIM+1)*BLOCK_DIM.  This pads each row of the 2D block in shared memory
       // so that bank conflicts do not occur when threads address the array column-wise.
       //
       // Note that input_size/output_size may not equal width*height if we are dealing with
       // a truncated array - this happens sometimes for coalescing optimisations.
       __kernel void $id:kernel_name($params:params) {
         uint num_groups_y = $exp:ngrpy;
         uint num_groups_x = $exp:ngrpx;
         uint num_groups_z = ((uint)get_num_groups(0)) /  (num_groups_y * num_groups_x);

         uint local_id_y = ((uint)get_local_id(0)) / FUT_BLOCK_DIM;
         uint local_id_x = ((uint)get_local_id(0)) % FUT_BLOCK_DIM;
         uint group_id_z = ((uint)get_group_id(0)) / (num_groups_y * num_groups_x);
         uint group_id_yx = ((uint)get_group_id(0)) % (num_groups_y * num_groups_x);
         uint group_id_y = group_id_yx / num_groups_x;
         uint group_id_x = group_id_yx % num_groups_x;

         uint global_id_z = group_id_z;
         uint global_id_y = group_id_y * FUT_BLOCK_DIM + local_id_y;
         uint global_id_x = group_id_x * FUT_BLOCK_DIM + local_id_x;

         uint x_index;
         uint y_index;
         uint our_array_offset;

         // Adjust the input and output arrays with the basic offset.
         odata += odata_offset/sizeof($ty:elem_type);
         idata += idata_offset/sizeof($ty:elem_type);

         // Adjust the input and output arrays for the third dimension.
         our_array_offset = global_id_z * width * height;
         odata += our_array_offset;
         idata += our_array_offset;

         // read the matrix tile into shared memory
         x_index = $exp:x_in_index;
         y_index = $exp:y_in_index;

         uint index_in = y_index * width + x_index;

         if(x_index < width && y_index < height && index_in < input_size)
         {
             block[local_id_y*(FUT_BLOCK_DIM+1)+local_id_x] = idata[index_in];
         }

         barrier(CLK_LOCAL_MEM_FENCE);

         // Scatter the transposed matrix tile to global memory.
         x_index = $exp:x_out_index;
         y_index = $exp:y_out_index;

         uint index_out = y_index * height + x_index;

         if(x_index < height && y_index < width && index_out < output_size)
         {
             odata[index_out] = block[local_id_x*(FUT_BLOCK_DIM+1)+local_id_y];
         }
       }|]
           where params = [C.cparams|__global $ty:elem_type *odata,
                                uint odata_offset,
                                __global $ty:elem_type *idata,
                                uint idata_offset,
                                uint width,
                                uint height,
                                uint input_size,
                                uint output_size|] ++ extraparams ++
                          [C.cparams|__local $ty:elem_type* block|]

    smallKernel =
      [C.cfun|
         __kernel void $id:kernel_name(__global $ty:elem_type *odata,
                                      uint odata_offset,
                                      __global $ty:elem_type *idata,
                                      uint idata_offset,
                                      uint num_arrays,
                                      uint width,
                                      uint height,
                                      uint input_size,
                                      uint output_size) {
           uint our_array_offset = get_global_id(0) / (height*width) * (height*width);
           uint x_index = get_global_id(0) % (height*width) / height;
           uint y_index = get_global_id(0) % height;

           // Adjust the input and output arrays with the basic offset.
           odata += odata_offset/sizeof($ty:elem_type);
           idata += idata_offset/sizeof($ty:elem_type);

           // Adjust the input and output arrays.
           odata += our_array_offset;
           idata += our_array_offset;

           // Read and write the element.
           uint index_in = y_index * width + x_index;
           uint index_out = x_index * height + y_index;
           if (get_global_id(0) < input_size) {
               odata[index_out] = idata[index_in];
           }
}|]