{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.ImpGen.Kernels.SegScan
  ( compileSegScan )
  where

import Control.Monad.Except
import Data.Maybe
import Data.List

import Prelude hiding (quot, rem)

import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)

makeLocalArrays :: SubExp -> SubExp -> [SubExp] -> Lambda InKernel
                -> InKernelGen [VName]
makeLocalArrays group_size num_threads nes scan_op = do
  let (scan_x_params, _scan_y_params) =
        splitAt (length nes) $ lambdaParams scan_op
  forM scan_x_params $ \p ->
    case paramAttr p of
      MemArray pt shape _ (ArrayIn mem _) -> do
        let shape' = Shape [num_threads] <> shape
        sArray "scan_arr" pt shape' $
          ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape'
      _ -> do
        let pt = elemType $ paramType p
            shape = Shape [group_size]
        sAllocArray "scan_arr" pt shape $ Space "local"

type CrossesSegment = Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)

-- | Produce partially scanned intervals; one per workgroup.
scanStage1 :: Pattern ExplicitMemory
           -> KernelSpace
           -> Lambda InKernel -> [SubExp]
           -> KernelBody InKernel
           -> CallKernelGen (Imp.Exp, CrossesSegment)
scanStage1 (Pattern _ pes) space scan_op nes kbody = do
  (base_constants, init_constants) <- kernelInitialisationSetSpace space $ return ()
  let (gtids, dims) = unzip $ spaceDimensions space
  dims' <- mapM toExp dims
  let constants = base_constants { kernelThreadActive = true }
      num_elements = product dims'
      elems_per_thread = num_elements `quotRoundingUp` kernelNumThreads constants
      elems_per_group = kernelGroupSize constants * elems_per_thread

  -- Squirrel away a copy of the operator with unique names that we
  -- can pass to groupScan.
  scan_op_renamed <- renameLambda scan_op

  let crossesSegment =
        case reverse dims' of
          segment_size : _ : _ -> Just $ \from to ->
            (to-from) .>. (to `rem` segment_size)
          _ -> Nothing

  sKernel constants "scan_stage1" $ allThreads constants $ do
    init_constants

    local_arrs <-
      makeLocalArrays (spaceGroupSize space) (spaceNumThreads space)
      nes scan_op

    -- The variables from scan_op will be used for the carry and such
    -- in the big chunking loop.
    dScope Nothing $ scopeOfLParams $ lambdaParams scan_op
    let (scan_x_params, scan_y_params) =
          splitAt (length nes) $ lambdaParams scan_op

    forM_ (zip scan_x_params nes) $ \(p, ne) ->
      copyDWIM (paramName p) [] ne []

    j <- newVName "j"
    sFor j Int32 elems_per_thread $ do
      chunk_offset <- dPrimV "chunk_offset" $
                      kernelGroupSize constants * Imp.var j int32 +
                      kernelGroupId constants * elems_per_group
      flat_idx <- dPrimV "flat_idx" $
                  Imp.var chunk_offset int32 + kernelLocalThreadId constants
      -- Construct segment indices.
      zipWithM_ (<--) gtids $ unflattenIndex dims' $ Imp.var flat_idx int32

      let in_bounds =
            foldl1 (.&&.) $ zipWith (.<.) (map (`Imp.var` int32) gtids) dims'
          when_in_bounds = compileStms mempty (kernelBodyStms kbody) $ do
            let (scan_res, map_res) = splitAt (length nes) $ kernelBodyResult kbody
            sComment "write to-scan values to parameters" $
              forM_ (zip scan_y_params scan_res) $ \(p, se) ->
              copyDWIM (paramName p) [] (kernelResultSubExp se) []
            sComment "write mapped values results to global memory" $
              forM_ (zip (drop (length nes) pes) map_res) $ \(pe, se) ->
              copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids)
              (kernelResultSubExp se) []
          when_out_of_bounds = forM_ (zip scan_y_params nes) $ \(p, ne) ->
            copyDWIM (paramName p) [] ne []

      sComment "threads in bounds read input; others get neutral element" $
        sIf in_bounds when_in_bounds when_out_of_bounds

      sComment "combine with carry and write to local memory" $
        compileStms mempty (bodyStms $ lambdaBody scan_op) $
        forM_ (zip local_arrs $ bodyResult $ lambdaBody scan_op) $ \(arr, se) ->
          copyDWIM arr [kernelLocalThreadId constants] se []

      let crossesSegment' = do
            f <- crossesSegment
            Just $ \from to ->
              let from' = from + Imp.var chunk_offset int32
                  to' = to + Imp.var chunk_offset int32
              in f from' to'

      groupScan constants crossesSegment'
        (kernelGroupSize constants) scan_op_renamed local_arrs

      sComment "threads in bounds write partial scan result" $
        sWhen in_bounds $ forM_ (zip pes local_arrs) $ \(pe, arr) ->
        copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids)
        (Var arr) [kernelLocalThreadId constants]

      sOp Imp.LocalBarrier

      let load_carry =
            forM_ (zip local_arrs scan_x_params) $ \(arr, p) ->
            copyDWIM (paramName p) [] (Var arr) [kernelGroupSize constants - 1]
          load_neutral =
            forM_ (zip nes scan_x_params) $ \(ne, p) ->
            copyDWIM (paramName p) [] ne []

      sComment "first thread reads last element as carry-in for next iteration" $
        sWhen (kernelLocalThreadId constants .==. 0) $
        case crossesSegment of Nothing -> load_carry
                               Just f -> sIf (f (Imp.var chunk_offset int32 +
                                                 kernelGroupSize constants-1)
                                                (Imp.var chunk_offset int32 +
                                                 kernelGroupSize constants))
                                         load_neutral load_carry

      sOp Imp.LocalBarrier

  return (elems_per_group, crossesSegment)

scanStage2 :: Pattern ExplicitMemory
           -> Imp.Exp -> CrossesSegment -> KernelSpace
           -> Lambda InKernel -> [SubExp]
           -> CallKernelGen ()
scanStage2 (Pattern _ pes) elems_per_group crossesSegment space scan_op nes = do
  -- A single group, with one thread for each group in stage 1.
  group_size <- toExp $ spaceNumGroups space
  (constants, init_constants) <-
    kernelInitialisationSimple 1 group_size Nothing

  let (gtids, dims) = unzip $ spaceDimensions space
  dims' <- mapM toExp dims
  let crossesSegment' = do
        f <- crossesSegment
        Just $ \from to ->
          f ((from + 1) * elems_per_group - 1) ((to + 1) * elems_per_group - 1)

  sKernel constants "scan_stage2" $ do
    init_constants

    local_arrs <- makeLocalArrays (spaceNumGroups space) (spaceNumGroups space)
                  nes scan_op

    flat_idx <- dPrimV "flat_idx" $
      (kernelLocalThreadId constants + 1) * elems_per_group - 1
    -- Construct segment indices.
    zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ Imp.var flat_idx int32

    let in_bounds =
          foldl1 (.&&.) $ zipWith (.<.) (map (`Imp.var` int32) gtids) dims'
        when_in_bounds = forM_ (zip local_arrs pes) $ \(arr, pe) ->
          copyDWIM arr [kernelLocalThreadId constants]
          (Var $ patElemName pe) $ map (`Imp.var` int32) gtids
        when_out_of_bounds = forM_ (zip local_arrs nes) $ \(arr, ne) ->
          copyDWIM arr [kernelLocalThreadId constants] ne []

    sComment "threads in bound read carries; others get neutral element" $
      sIf in_bounds when_in_bounds when_out_of_bounds

    groupScan constants crossesSegment'
      (kernelGroupSize constants) scan_op local_arrs

    sComment "threads in bounds write scanned carries" $
      sWhen in_bounds $ forM_ (zip pes local_arrs) $ \(pe, arr) ->
      copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids)
      (Var arr) [kernelLocalThreadId constants]

scanStage3 :: Pattern ExplicitMemory
           -> Imp.Exp -> CrossesSegment -> KernelSpace
           -> Lambda InKernel -> [SubExp]
           -> CallKernelGen ()
scanStage3 (Pattern _ pes) elems_per_group crossesSegment space scan_op nes = do
  let (gtids, dims) = unzip $ spaceDimensions space
  dims' <- mapM toExp dims
  (constants, init_constants) <- simpleKernelConstants (product dims') "scan"
  sKernel constants "scan_stage3" $ do
    init_constants
    -- Compute our logical index.
    zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ kernelGlobalThreadId constants
    -- Figure out which group this element was originally in.
    orig_group <- dPrimV "orig_group" $
                  kernelGlobalThreadId constants `quot` elems_per_group
    -- Then the index of the carry-in of the preceding group.
    carry_in_flat_idx <- dPrimV "carry_in_flat_idx" $
                         Imp.var orig_group int32 * elems_per_group - 1
    -- Figure out the logical index of the carry-in.
    let carry_in_idx = unflattenIndex dims' $ Imp.var carry_in_flat_idx int32

    -- Apply the carry if we are not in the scan results for the first
    -- group, and are not the last element in such a group (because
    -- then the carry was updated in stage 2), and we are not crossing
    -- a segment boundary.
    let crosses_segment = fromMaybe false $
          crossesSegment <*>
            pure (Imp.var carry_in_flat_idx int32) <*>
            pure (kernelGlobalThreadId constants)
        is_a_carry = kernelGlobalThreadId constants .==.
                     (Imp.var orig_group int32 + 1) * elems_per_group - 1
        no_carry_in = Imp.var orig_group int32 .==. 0 .||. is_a_carry .||. crosses_segment

    sWhen (kernelThreadActive constants) $ sUnless no_carry_in $ do
      dScope Nothing $ scopeOfLParams $ lambdaParams scan_op
      let (scan_x_params, scan_y_params) =
            splitAt (length nes) $ lambdaParams scan_op
      forM_ (zip scan_x_params pes) $ \(p, pe) ->
        copyDWIM (paramName p) [] (Var $ patElemName pe) carry_in_idx
      forM_ (zip scan_y_params pes) $ \(p, pe) ->
        copyDWIM (paramName p) [] (Var $ patElemName pe) $ map (`Imp.var` int32) gtids
      compileBody' scan_x_params $ lambdaBody scan_op
      forM_ (zip scan_x_params pes) $ \(p, pe) ->
        copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids) (Var $ paramName p) []

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan :: Pattern ExplicitMemory
               -> KernelSpace
               -> Lambda InKernel -> [SubExp]
               -> KernelBody InKernel
               -> CallKernelGen ()
compileSegScan pat space scan_op nes kbody = do
  (elems_per_group, crossesSegment) <- scanStage1 pat space scan_op nes kbody

  emit $ Imp.DebugPrint "\n# SegScan" Nothing
  emit $ Imp.DebugPrint "elems_per_group" $ Just (int32, elems_per_group)

  scan_op' <- renameLambda scan_op
  scan_op'' <- renameLambda scan_op
  scanStage2 pat elems_per_group crossesSegment space scan_op' nes
  scanStage3 pat elems_per_group crossesSegment space scan_op'' nes