module Futhark.CodeGen.ImpGen.Kernels.Transpose
  ( TransposeType(..)
  , TransposeArgs
  , mapTranspose
  , mapTransposeKernel
  )
  where

import qualified Data.Set as S

import Prelude hiding (quot, rem)

import Futhark.CodeGen.ImpCode.Kernels
import Futhark.Representation.AST.Attributes.Types
import Futhark.Representation.AST.Attributes.Names (freeIn)
import Futhark.Util.IntegralExp (IntegralExp, quot, rem, quotRoundingUp)

-- | Which form of transposition to generate code for.
data TransposeType = TransposeNormal
                   | TransposeLowWidth
                   | TransposeLowHeight
                   | TransposeSmall -- ^ For small arrays that do not
                                    -- benefit from coalescing.
                   deriving (Eq, Ord, Show)

type TransposeArgs = (VName, Exp,
                      VName, Exp,
                      Exp, Exp, Exp, Exp,
                      Exp, Exp, Exp,
                      VName)

elemsPerThread :: IntegralExp a => a
elemsPerThread = 4

-- | Generate a transpose kernel.  There is special support to handle
-- input arrays with low width, low height, or both.
--
-- Normally when transposing a @[2][n]@ array we would use a @FUT_BLOCK_DIM x
-- FUT_BLOCK_DIM@ group to process a @[2][FUT_BLOCK_DIM]@ slice of the input
-- array. This would mean that many of the threads in a group would be inactive.
-- We try to remedy this by using a special kernel that will process a larger
-- part of the input, by using more complex indexing. In our example, we could
-- use all threads in a group if we are processing @(2/FUT_BLOCK_DIM)@ as large
-- a slice of each rows per group. The variable 'mulx' contains this factor for
-- the kernel to handle input arrays with low height.
--
-- See issue #308 on GitHub for more details.
--
-- These kernels are optimized to ensure all global reads and writes
-- are coalesced, and to avoid bank conflicts in shared memory.  Each
-- thread group transposes a 2D tile of block_dim*2 by block_dim*2
-- elements. The size of a thread group is block_dim/2 by
-- block_dim*2, meaning that each thread will process 4 elements in a
-- 2D tile.  The shared memory array containing the 2D tile consists
-- of block_dim*2 by block_dim*2+1 elements. Padding each row with
-- an additional element prevents bank conflicts from occuring when
-- the tile is accessed column-wise.
--
-- Note that input_size and output_size may not equal width*height if
-- we are dealing with a truncated array - this happens sometimes for
-- coalescing optimisations.
mapTranspose :: Exp -> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose block_dim args t kind =
  case kind of
    TransposeSmall ->
      mconcat
      [ get_ids

      , dec our_array_offset $ v32 get_global_id_0 `quot` (height*width) * (height*width)

      , dec x_index $ (v32 get_global_id_0 `rem` (height*width)) `quot` height
      , dec y_index $ v32 get_global_id_0 `rem` height

      , dec odata_offset $
        (basic_odata_offset `quot` primByteSize t) + v32 our_array_offset
      , dec idata_offset $
        (basic_idata_offset `quot` primByteSize t) + v32 our_array_offset

      , dec index_in $ v32 y_index * width + v32 x_index
      , dec index_out $ v32 x_index * height + v32 y_index

      , If (v32 get_global_id_0 .<. input_size)
        (Write odata (bytes $ (v32 odata_offset + v32 index_out) * tsize) t (Space "global") Nonvolatile $
         index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize) t (Space "global") Nonvolatile)
        mempty
      ]

    TransposeLowWidth ->
      mkTranspose $ lowDimBody
      (v32 get_group_id_0 * block_dim + (v32 get_local_id_0 `quot` muly))
      (v32 get_group_id_1 * block_dim * muly + v32 get_local_id_1 +
       (v32 get_local_id_0 `rem` muly) * block_dim)
      (v32 get_group_id_1* block_dim * muly + v32 get_local_id_0 +
       (v32 get_local_id_1 `rem` muly) * block_dim)
      (v32 get_group_id_0 * block_dim + (v32 get_local_id_1 `quot` muly))

    TransposeLowHeight ->
      mkTranspose $ lowDimBody
      (v32 get_group_id_0 * block_dim * mulx + v32 get_local_id_0 +
       (v32 get_local_id_1 `rem` mulx) * block_dim)
      (v32 get_group_id_1 * block_dim + (v32 get_local_id_1 `quot` mulx))
      (v32 get_group_id_1 * block_dim + (v32 get_local_id_0 `quot` mulx))
      (v32 get_group_id_0 * block_dim * mulx + v32 get_local_id_1 +
       (v32 get_local_id_0 `rem` mulx) * block_dim)

    TransposeNormal ->
      mkTranspose $ mconcat
      [ dec x_index $ v32 get_global_id_0
      , dec y_index $ v32 get_group_id_1 * tile_dim + v32 get_local_id_1
      , when (v32 x_index .<. width) $
        For j Int32 elemsPerThread $
        let i = v32 j * (tile_dim `quot` elemsPerThread)
        in mconcat [ dec index_in $ (v32 y_index + i) * width + v32 x_index
                   , when (v32 y_index + i .<. height .&&.
                           v32 index_in .<. input_size) $
                     Write block (bytes $ ((v32 get_local_id_1 + i) * (tile_dim+1)
                                           + v32 get_local_id_0) * tsize)
                     t (Space "local") Nonvolatile $
                     index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize)
                     t (Space "global") Nonvolatile]
      , Op LocalBarrier
      , SetScalar x_index $ v32 get_group_id_1 * tile_dim + v32 get_local_id_0
      , SetScalar y_index $ v32 get_group_id_0 * tile_dim + v32 get_local_id_1
      , when (v32 x_index .<. height) $
        For j Int32 elemsPerThread $
        let i = v32 j * (tile_dim `quot` elemsPerThread)
        in mconcat [ dec index_out $ (v32 y_index + i) * height + v32 x_index
                   , when (v32 y_index + i .<. width .&&.
                           v32 index_out .<. output_size) $
                     Write odata (bytes $ (v32 odata_offset + v32 index_out) * tsize)
                     t (Space "global") Nonvolatile $
                     index block (bytes $ (v32 get_local_id_0 * (tile_dim+1)
                                           +v32 get_local_id_1+i)*tsize)
                     t (Space "local") Nonvolatile
                   ]
      ]

  where dec v e = DeclareScalar v int32 <> SetScalar v e
        v32 = flip var int32
        tsize = LeafExp (SizeOf t) int32
        tile_dim = 2 * block_dim

        when a b = If a b mempty

        (odata, basic_odata_offset, idata, basic_idata_offset,
         width, height, input_size, output_size,
         mulx, muly, _num_arrays, block) = args

        -- Be extremely careful when editing this list to ensure that
        -- the names match up.  Also, be careful that the tags on
        -- these names do not conflicts with the tags of the
        -- surrounding code.  We accomplish the latter by using very
        -- low tags (normal variables start at least in the low
        -- hundreds).
        [   our_array_offset , x_index , y_index
          , odata_offset, idata_offset, index_in, index_out
          , get_global_id_0
          , get_local_id_0, get_local_id_1
          , get_group_id_0, get_group_id_1, get_group_id_2
          , j] =
          zipWith (flip VName) [30..] $ map nameFromString
          [ "our_array_offset" , "x_index" , "y_index"
          , "odata_offset", "idata_offset", "index_in", "index_out"
          , "get_global_id_0"
          , "get_local_id_0", "get_local_id_1"
          , "get_group_id_0", "get_group_id_1", "get_group_id_2"
          , "j"]

        get_ids =
          mconcat [ DeclareScalar get_global_id_0 int32
                  , Op $ GetGlobalId get_global_id_0 0
                  , DeclareScalar get_local_id_0 int32
                  , Op $ GetLocalId get_local_id_0 0
                  , DeclareScalar get_local_id_1 int32
                  , Op $ GetLocalId get_local_id_1 1
                  , DeclareScalar get_group_id_0 int32
                  , Op $ GetGroupId get_group_id_0 0
                  , DeclareScalar get_group_id_1 int32
                  , Op $ GetGroupId get_group_id_1 1
                  , DeclareScalar get_group_id_2 int32
                  , Op $ GetGroupId get_group_id_2 2
                  ]

        mkTranspose body =
          mconcat
          [ get_ids
          , dec our_array_offset $ v32 get_group_id_2 * width * height
          , dec odata_offset $
            (basic_odata_offset `quot` primByteSize t) + v32 our_array_offset
          , dec idata_offset $
            (basic_idata_offset `quot` primByteSize t) + v32 our_array_offset
          , body
          ]

        lowDimBody x_in_index y_in_index x_out_index y_out_index =
          mconcat
          [ dec x_index x_in_index
          , dec y_index y_in_index
          , dec index_in $ v32 y_index * width + v32 x_index
          , when (v32 x_index .<. width .&&. v32 y_index .<. height .&&. v32 index_in .<. input_size) $
            Write block (bytes $ (v32 get_local_id_1 * (block_dim+1) + v32 get_local_id_0) * tsize)
            t (Space "local") Nonvolatile $
            index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize)
            t (Space "global") Nonvolatile
          , Op LocalBarrier
          , SetScalar x_index x_out_index
          , SetScalar y_index y_out_index
          , dec index_out $ v32 y_index * height + v32 x_index
          , when (v32 x_index .<. height .&&. v32 y_index .<. width .&&. v32 index_out .<. output_size) $
            Write odata (bytes $ (v32 odata_offset + v32 index_out) * tsize)
            t (Space "global") Nonvolatile $
            index block (bytes $ (v32 get_local_id_0 * (block_dim+1)
                                   +v32 get_local_id_1)*tsize)
            t (Space "local") Nonvolatile
          ]

mapTransposeKernel :: String -> Integer -> TransposeArgs -> PrimType -> TransposeType
                   -> Kernel
mapTransposeKernel desc block_dim_int args t kind =
  Kernel
  { kernelBody = mapTranspose block_dim args t kind
  , kernelLocalMemory = [(block, Right block_size)]
  , kernelUses = uses
  , kernelNumGroups = num_groups
  , kernelGroupSize = group_size
  , kernelName = nameFromString name
  }
  where pad2DBytes k = k * (k + 1) * primByteSize t
        block_size =
          case kind of TransposeSmall -> 1 -- Not used, but AMD's
                                           -- OpenCL does not like
                                           -- zero-size local memory.
                       TransposeNormal -> fromInteger $ pad2DBytes $ 2*block_dim_int
                       TransposeLowWidth -> fromInteger $ pad2DBytes block_dim_int
                       TransposeLowHeight -> fromInteger $ pad2DBytes block_dim_int
        block_dim = fromInteger block_dim_int

        (odata, basic_odata_offset, idata, basic_idata_offset,
         width, height, input_size, output_size,
         mulx, muly, num_arrays,
         block) = args

        (num_groups, group_size) =
          case kind of
            TransposeSmall ->
              ([(num_arrays * width * height) `quotRoundingUp` (block_dim * block_dim)],
               [block_dim * block_dim])
            TransposeLowWidth ->
              lowDimKernelAndGroupSize block_dim num_arrays width $ height `quotRoundingUp` muly
            TransposeLowHeight ->
              lowDimKernelAndGroupSize block_dim num_arrays (width `quotRoundingUp` mulx) height
            TransposeNormal ->
              let actual_dim = block_dim*2
              in ( [ width `quotRoundingUp` actual_dim
                   , height `quotRoundingUp` actual_dim
                   , num_arrays]
                 , [actual_dim, actual_dim `quot` elemsPerThread, 1])

        uses = map (`ScalarUse` int32)
               (S.toList $ mconcat $ map freeIn
                [basic_odata_offset, basic_idata_offset, num_arrays,
                 width, height, input_size, output_size, mulx, muly]) ++
               map MemoryUse [odata, idata]

        name =
          case kind of TransposeSmall -> desc ++ "_small"
                       TransposeLowHeight -> desc ++ "_low_height"
                       TransposeLowWidth -> desc ++ "_low_width"
                       TransposeNormal -> desc

lowDimKernelAndGroupSize :: Exp -> Exp -> Exp -> Exp -> ([Exp], [Exp])
lowDimKernelAndGroupSize block_dim num_arrays x_elems y_elems =
  ([x_elems `quotRoundingUp` block_dim,
    y_elems `quotRoundingUp` block_dim,
    num_arrays],
   [block_dim, block_dim, 1])