{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.Base
  ( KernelConstants (..)
  , inKernelOperations
  , computeKernelUses
  , keyWithEntryPoint
  , CallKernelGen
  , InKernelGen
  , computeThreadChunkSize
  , kernelInitialisation
  , kernelInitialisationSetSpace
  , setSpaceIndices
  , makeAllMemoryGlobal
  , allThreads
  , compileKernelStms
  , groupReduce
  , groupScan
  , isActive
  , sKernel
  )
  where

import Control.Arrow ((&&&))
import Control.Monad.Except
import Control.Monad.Reader
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.List

import Prelude hiding (quot)

import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpCode.Kernels (bytes)
import qualified Futhark.CodeGen.ImpGen as ImpGen
import Futhark.CodeGen.ImpGen ((<--),
                               sFor, sWhile, sComment, sIf, sWhen, sUnless,
                               sOp,
                               dPrim, dPrim_, dPrimV)
import Futhark.Tools (partitionChunkedKernelLambdaParameters)
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem, IntegralExp)
import Futhark.Util (splitAt3, maybeNth)

type CallKernelGen = ImpGen.ImpM ExplicitMemory Imp.HostOp
type InKernelGen = ImpGen.ImpM InKernel Imp.KernelOp

data KernelConstants = KernelConstants
                       { kernelGlobalThreadId :: Imp.Exp
                       , kernelLocalThreadId :: Imp.Exp
                       , kernelGroupId :: Imp.Exp
                       , kernelGlobalThreadIdVar :: VName
                       , kernelLocalThreadIdVar :: VName
                       , kernelGroupIdVar :: VName
                       , kernelGroupSize :: Imp.Exp
                       , kernelNumGroups :: Imp.Exp
                       , kernelNumThreads :: Imp.Exp
                       , kernelWaveSize :: Imp.Exp
                       , kernelDimensions :: [(VName, Imp.Exp)]
                       , kernelThreadActive :: Imp.Exp
                       , kernelStreamed :: [(VName, Imp.DimSize)]
                       -- ^ Chunk sizes and their maximum size.  Hint
                       -- for unrolling.
                       }

inKernelOperations :: KernelConstants -> ImpGen.Operations InKernel Imp.KernelOp
inKernelOperations constants = (ImpGen.defaultOperations $ compileInKernelOp constants)
                               { ImpGen.opsCopyCompiler = inKernelCopy
                               , ImpGen.opsExpCompiler = inKernelExpCompiler
                               , ImpGen.opsStmsCompiler = \_ -> compileKernelStms constants
                               }

keyWithEntryPoint :: Name -> Name -> Name
keyWithEntryPoint fname key =
  nameFromString $ nameToString fname ++ "." ++ nameToString key

-- | We have no bulk copy operation (e.g. memmove) inside kernels, so
-- turn any copy into a loop.
inKernelCopy :: ImpGen.CopyCompiler InKernel Imp.KernelOp
inKernelCopy = ImpGen.copyElementWise

compileInKernelOp :: KernelConstants -> Pattern InKernel -> Op InKernel
                  -> InKernelGen ()
compileInKernelOp _ (Pattern _ [mem]) Alloc{} =
  compilerLimitationS $ "Cannot allocate memory block " ++ pretty mem ++ " in kernel."
compileInKernelOp _ dest Alloc{} =
  compilerBugS $ "Invalid target for in-kernel allocation: " ++ show dest
compileInKernelOp constants pat (Inner op) =
  compileKernelExp constants pat op

inKernelExpCompiler :: ImpGen.ExpCompiler InKernel Imp.KernelOp
inKernelExpCompiler _ (BasicOp (Assert _ _ (loc, locs))) =
  compilerLimitationS $
  unlines [ "Cannot compile assertion at " ++
            intercalate " -> " (reverse $ map locStr $ loc:locs) ++
            " inside parallel kernel."
          , "As a workaround, surround the expression with 'unsafe'."]
-- The static arrays stuff does not work inside kernels.
inKernelExpCompiler (Pattern _ [dest]) (BasicOp (ArrayLit es _)) =
  forM_ (zip [0..] es) $ \(i,e) ->
  ImpGen.copyDWIM (patElemName dest) [fromIntegral (i::Int32)] e []
inKernelExpCompiler dest e =
  ImpGen.defCompileExp dest e

compileKernelExp :: KernelConstants -> Pattern InKernel -> KernelExp InKernel
                 -> InKernelGen ()

compileKernelExp _ pat (Barrier ses) = do
  forM_ (zip (patternNames pat) ses) $ \(d, se) ->
    ImpGen.copyDWIM d [] se []
  sOp Imp.LocalBarrier

compileKernelExp _ (Pattern [] [size]) (SplitSpace o w i elems_per_thread) = do
  num_elements <- Imp.elements <$> ImpGen.compileSubExp w
  i' <- ImpGen.compileSubExp i
  elems_per_thread' <- Imp.elements <$> ImpGen.compileSubExp elems_per_thread
  computeThreadChunkSize o i' elems_per_thread' num_elements (patElemName size)

compileKernelExp constants pat (Combine (CombineSpace scatter cspace) _ aspace body) = do
  -- First we compute how many times we have to iterate to cover
  -- cspace with our group size.  It is a fairly common case that
  -- we statically know that this requires 1 iteration, so we
  -- could detect it and not generate a loop in that case.
  -- However, it seems to have no impact on performance (an extra
  -- conditional jump), so for simplicity we just always generate
  -- the loop.
  let cspace_dims = map (streamBounded . snd) cspace
      num_iters
        | cspace_dims == [kernelGroupSize constants] = 1
        | otherwise = product cspace_dims `quotRoundingUp`
                      kernelGroupSize constants

  iter <- newVName "comb_iter"

  sFor iter Int32 num_iters $ do
    mapM_ ((`dPrim_` int32) . fst) cspace
    -- Compute the *flat* array index.
    cid <- dPrimV "flat_comb_id" $
      Imp.var iter int32 * kernelGroupSize constants +
      kernelLocalThreadId constants

    -- Turn it into a nested array index.
    zipWithM_ (<--) (map fst cspace) $ unflattenIndex cspace_dims (Imp.var cid int32)

    -- Construct the body.  This is mostly about the book-keeping
    -- for the scatter-like part.
    let (scatter_ws, scatter_ns, _scatter_vs) = unzip3 scatter
        scatter_ws_repl = concat $ zipWith replicate scatter_ns scatter_ws
        (scatter_pes, normal_pes) =
          splitAt (sum scatter_ns) $ patternElements pat
        (res_is, res_vs, res_normal) =
          splitAt3 (sum scatter_ns) (sum scatter_ns) $ bodyResult body

    -- Execute the body if we are within bounds.
    sWhen (isActive cspace .&&. isActive aspace) $ allThreads constants $
      ImpGen.compileStms (freeIn $ bodyResult body) (stmsToList $ bodyStms body) $ do

      forM_ (zip4 scatter_ws_repl res_is res_vs scatter_pes) $
        \(w, res_i, res_v, scatter_pe) -> do
          let res_i' = ImpGen.compileSubExpOfType int32 res_i
              w'     = ImpGen.compileSubExpOfType int32 w
              -- We have to check that 'res_i' is in-bounds wrt. an array of size 'w'.
              in_bounds = 0 .<=. res_i' .&&. res_i' .<. w'
          sWhen in_bounds $ ImpGen.copyDWIM (patElemName scatter_pe) [res_i'] res_v []

      forM_ (zip normal_pes res_normal) $ \(pe, res) ->
        ImpGen.copyDWIM (patElemName pe) local_index res []

  sOp Imp.LocalBarrier

  where streamBounded (Var v)
          | Just x <- lookup v $ kernelStreamed constants =
              Imp.sizeToExp x
        streamBounded se = ImpGen.compileSubExpOfType int32 se

        local_index = map (ImpGen.compileSubExpOfType int32 . Var . fst) cspace

compileKernelExp constants (Pattern _ dests) (GroupReduce w lam input) = do
  let [my_index_param, offset_param] = take 2 $ lambdaParams lam
      lam' = lam { lambdaParams = drop 2 $ lambdaParams lam }

  dPrim_ (paramName my_index_param) int32
  dPrim_ (paramName offset_param) int32
  paramName my_index_param <-- kernelGlobalThreadId constants
  w' <- ImpGen.compileSubExp w
  groupReduceWithOffset constants (paramName offset_param) w' lam' $ map snd input

  sOp Imp.LocalBarrier

  -- The final result will be stored in element 0 of the local memory array.
  forM_ (zip dests input) $ \(dest, (_, arr)) ->
    ImpGen.copyDWIM (patElemName dest) [] (Var arr) [0]

compileKernelExp constants _ (GroupScan w lam input) = do
  w' <- ImpGen.compileSubExp w
  groupScan constants Nothing w' lam $ map snd input

compileKernelExp constants (Pattern _ final) (GroupStream w maxchunk lam accs _arrs) = do
  let GroupStreamLambda block_size block_offset acc_params arr_params body = lam
      block_offset' = Imp.var block_offset int32
  w' <- ImpGen.compileSubExp w
  max_block_size <- ImpGen.compileSubExp maxchunk

  ImpGen.dLParams (acc_params++arr_params)
  zipWithM_ ImpGen.compileSubExpTo (map paramName acc_params) accs
  dPrim_ block_size int32

  -- If the GroupStream is morally just a do-loop, generate simpler code.
  case mapM isSimpleThreadInSpace $ stmsToList $ bodyStms body of
    Just stms' | ValueExp x <- max_block_size, oneIsh x -> do
      let body' = body { bodyStms = stmsFromList stms' }
          body'' = allThreads constants $
                   ImpGen.compileLoopBody (map paramName acc_params) body'
      block_size <-- 1

      -- Check if loop is candidate for unrolling.
      let loop =
            case w of
              Var w_var | Just w_bound <- lookup w_var $ kernelStreamed constants,
                          w_bound /= Imp.ConstSize 1 ->
                          -- Candidate for unrolling, so generate two loops.
                          sIf (w' .==. Imp.sizeToExp w_bound)
                          (sFor block_offset Int32 (Imp.sizeToExp w_bound) body'')
                          (sFor block_offset Int32 w' body'')
              _ -> sFor block_offset Int32 w' body''

      if kernelThreadActive constants == Imp.ValueExp (BoolValue True)
        then loop
        else sWhen (kernelThreadActive constants) loop

    _ -> do
      dPrim_ block_offset int32
      let body' = streaming constants block_size maxchunk $
                  ImpGen.compileBody' acc_params body

      block_offset <-- 0

      let not_at_end = block_offset' .<. w'
          set_block_size =
            sIf (w' - block_offset' .<. max_block_size)
            (block_size <-- (w' - block_offset'))
            (block_size <-- max_block_size)
          increase_offset =
            block_offset <-- block_offset' + max_block_size

      -- Three cases to consider for simpler generated code based
      -- on max block size: (0) if full input size, do not
      -- generate a loop; (1) if one, generate for-loop (2)
      -- otherwise, generate chunked while-loop.
      if max_block_size == w' then
        (block_size <-- w') >> body'
      else if max_block_size == Imp.ValueExp (value (1::Int32)) then do
             block_size <-- w'
             sFor block_offset Int32 w' body'
           else
             sWhile not_at_end $
             set_block_size >> body' >> increase_offset

  forM_ (zip final acc_params) $ \(pe, p) ->
    ImpGen.copyDWIM (patElemName pe) [] (Var $ paramName p) []

  where isSimpleThreadInSpace (Let _ _ Op{}) = Nothing
        isSimpleThreadInSpace bnd = Just bnd

compileKernelExp _ _ (GroupGenReduce w arrs op bucket values locks) = do
  -- Check if bucket is in-bounds
  bucket' <- mapM ImpGen.compileSubExp bucket
  w' <- mapM ImpGen.compileSubExp w
  sWhen (indexInBounds bucket' w') $
    atomicUpdate arrs bucket op values locking
  where indexInBounds inds bounds =
          foldl1 (.&&.) $ zipWith checkBound inds bounds
          where checkBound ind bound = 0 .<=. ind .&&. ind .<. bound
        locking = Locking locks 0 1 0

compileKernelExp _ dest e =
  compilerBugS $ unlines ["Invalid target", "  " ++ show dest,
                          "for kernel expression", "  " ++ pretty e]

streaming :: KernelConstants -> VName -> SubExp -> InKernelGen () -> InKernelGen ()
streaming constants chunksize bound m = do
  bound' <- ImpGen.subExpToDimSize bound
  let constants' =
        constants { kernelStreamed = (chunksize, bound') : kernelStreamed constants }
  ImpGen.emit =<< ImpGen.subImpM_ (inKernelOperations constants') m

-- | Locking strategy used for an atomic update.
data Locking = Locking { lockingArray :: VName -- ^ Array containing the lock.
                       , lockingIsUnlocked :: Imp.Exp -- ^ Value for us to consider the lock free.
                       , lockingToLock :: Imp.Exp -- ^ What to write when we lock it.
                       , lockingToUnlock :: Imp.Exp -- ^ What to write when we unlock it.
                       }

atomicUpdate :: ExplicitMemorish lore =>
                [VName] -> [SubExp] -> Lambda lore -> [SubExp] -> Locking
             -> ImpGen.ImpM lore Imp.KernelOp ()

atomicUpdate arrs bucket lam values _
  | Just ops_and_ts <- splitOp lam,
    all ((==32) . primBitSize . snd) ops_and_ts =
  -- If the operator is a vectorised binary operator on 32-bit values,
  -- we can use a particularly efficient implementation. If the
  -- operator has an atomic implementation we use that, otherwise it
  -- is still a binary operator which can be implemented by atomic
  -- compare-and-swap if 32 bits.
  forM_ (zip3 arrs ops_and_ts values) $ \(a, (op, t), val) -> do

  -- Common variables.
  old <- dPrim "old" t
  bucket' <- mapM ImpGen.compileSubExp bucket

  (arr', _a_space, bucket_offset) <- ImpGen.fullyIndexArray a bucket'

  val' <- ImpGen.compileSubExp val
  case opHasAtomicSupport old arr' bucket_offset op of
    Just f -> sOp $ f val'

    Nothing -> do
      -- Code generation target:
      --
      -- old = d_his[idx];
      -- do {
      --   assumed = old;
      --   tmp = OP::apply(val, assumed);
      --   old = atomicCAS(&d_his[idx], assumed, tmp);
      -- } while(assumed != old);
      assumed <- dPrim "assumed" t
      run_loop <- dPrimV "run_loop" 1
      ImpGen.copyDWIM old [] (Var a) bucket'

      -- Critical section
      x <- dPrim "x" t
      y <- dPrim "y" t

      -- While-loop: Try to insert your value
      let (toBits, fromBits) =
            case t of FloatType Float32 -> (\v -> Imp.FunExp "to_bits32" [v] int32,
                                            \v -> Imp.FunExp "from_bits32" [v] t)
                      _                 -> (id, id)
      sWhile (Imp.var run_loop int32) $ do
        assumed <-- Imp.var old t
        x <-- val'
        y <-- Imp.var assumed t
        x <-- Imp.BinOpExp op (Imp.var x t) (Imp.var y t)
        old_bits <- dPrim "old_bits" int32
        sOp $ Imp.Atomic $
          Imp.AtomicCmpXchg old_bits arr' bucket_offset
          (toBits (Imp.var assumed t)) (toBits (Imp.var x t))
        old <-- fromBits (Imp.var old_bits int32)
        sWhen (toBits (Imp.var assumed t) .==. Imp.var old_bits int32)
          (run_loop <-- 0)

  where opHasAtomicSupport old arr' bucket' bop = do
          let atomic f = Imp.Atomic . f old arr' bucket'
          atomic <$> Imp.atomicBinOp bop

atomicUpdate arrs bucket op values locking = do
  old <- dPrim "old" int32
  continue <- dPrimV "continue" true

  -- Check if bucket is in-bounds
  bucket' <- mapM ImpGen.compileSubExp bucket

  -- Correctly index into locks.
  (locks', _locks_space, locks_offset) <-
    ImpGen.fullyIndexArray (lockingArray locking) bucket'

  -- Preparing parameters
  let (acc_params, arr_params) =
        splitAt (length values) $ lambdaParams op

  -- Critical section
  let try_acquire_lock =
        sOp $ Imp.Atomic $
        Imp.AtomicCmpXchg old locks' locks_offset (lockingIsUnlocked locking) (lockingToLock locking)
      lock_acquired = Imp.var old int32 .==. lockingIsUnlocked locking
      release_lock = ImpGen.everythingVolatile $
                     ImpGen.sWrite (lockingArray locking) bucket' $ lockingToUnlock locking
      break_loop = continue <-- false

  -- We copy the current value and the new value to the parameters.
  -- It is important that the right-hand-side is bound first for the
  -- (rare) case when we are dealing with arrays.
  let bind_acc_params =
        ImpGen.sComment "bind lhs" $
        forM_ (zip acc_params arrs) $ \(acc_p, arr) ->
        ImpGen.copyDWIM (paramName acc_p) [] (Var arr) bucket'

  let bind_arr_params =
        ImpGen.sComment "bind rhs" $
        forM_ (zip arr_params values) $ \(arr_p, val) ->
        ImpGen.copyDWIM (paramName arr_p) [] val []

  let op_body = ImpGen.sComment "execute operation" $
                ImpGen.compileBody' acc_params $ lambdaBody op

      do_gen_reduce = ImpGen.sComment "update global result" $
                      zipWithM_ (writeArray bucket') arrs $ map (Var . paramName) acc_params

  -- While-loop: Try to insert your value
  sWhile (Imp.var continue Bool) $ do
    try_acquire_lock
    sWhen lock_acquired $ do
      ImpGen.dLParams $ lambdaParams op
      bind_arr_params
      bind_acc_params
      op_body
      do_gen_reduce
      release_lock
      break_loop
    sOp Imp.MemFence
  where writeArray bucket' arr val =
          ImpGen.copyDWIM arr bucket' val []

-- | Horizontally fission a lambda that models a binary operator.
splitOp :: Attributes lore => Lambda lore -> Maybe [(BinOp, PrimType)]
splitOp lam = mapM splitStm $ bodyResult $ lambdaBody lam
  where n = length $ lambdaReturnType lam
        splitStm :: SubExp -> Maybe (BinOp, PrimType)
        splitStm (Var res) = do
          Let (Pattern [] [pe]) _ (BasicOp (BinOp op (Var x) (Var y))) <-
            find (([res]==) . patternNames . stmPattern) $
            stmsToList $ bodyStms $ lambdaBody lam
          i <- Var res `elemIndex` bodyResult (lambdaBody lam)
          xp <- maybeNth i $ lambdaParams lam
          yp <- maybeNth (n+i) $ lambdaParams lam
          guard $ paramName xp == x
          guard $ paramName yp == y
          Prim t <- Just $ patElemType pe
          return (op, t)
        splitStm _ = Nothing

computeKernelUses :: FreeIn a =>
                     a -> [VName]
                  -> CallKernelGen ([Imp.KernelUse], [Imp.LocalMemoryUse])
computeKernelUses kernel_body bound_in_kernel = do
  let actually_free = freeIn kernel_body `S.difference` S.fromList bound_in_kernel

  -- Compute the variables that we need to pass to the kernel.
  reads_from <- readsFromSet actually_free

  -- Are we using any local memory?
  local_memory <- computeLocalMemoryUse actually_free
  return (nub reads_from, nub local_memory)

readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet free =
  fmap catMaybes $
  forM (S.toList free) $ \var -> do
    t <- lookupType var
    case t of
      Array {} -> return Nothing
      Mem _ (Space "local") -> return Nothing
      Mem {} -> return $ Just $ Imp.MemoryUse var
      Prim bt ->
        isConstExp var >>= \case
          Just ce -> return $ Just $ Imp.ConstUse var ce
          Nothing | bt == Cert -> return Nothing
                  | otherwise  -> return $ Just $ Imp.ScalarUse var bt

computeLocalMemoryUse :: Names -> CallKernelGen [Imp.LocalMemoryUse]
computeLocalMemoryUse free =
  fmap catMaybes $
  forM (S.toList free) $ \var -> do
    t <- lookupType var
    case t of
      Mem memsize (Space "local") -> do
        memsize' <- localMemSize =<< ImpGen.subExpToDimSize memsize
        return $ Just (var, memsize')
      _ -> return Nothing

localMemSize :: Imp.MemSize -> CallKernelGen (Either Imp.MemSize Imp.KernelConstExp)
localMemSize (Imp.ConstSize x) =
  return $ Right $ ValueExp $ IntValue $ Int64Value x
localMemSize (Imp.VarSize v) = isConstExp v >>= \case
  Just e | isStaticExp e -> return $ Right e
  _ -> return $ Left $ Imp.VarSize v

isConstExp :: VName -> CallKernelGen (Maybe Imp.KernelConstExp)
isConstExp v = do
  vtable <- ImpGen.getVTable
  fname <- asks ImpGen.envFunction
  let lookupConstExp name = constExp =<< hasExp =<< M.lookup name vtable
      constExp (Op (Inner (GetSize key _))) =
        Just $ LeafExp (Imp.SizeConst $ keyWithEntryPoint fname key) int32
      constExp e = primExpFromExp lookupConstExp e
  return $ lookupConstExp v
  where hasExp (ImpGen.ArrayVar e _) = e
        hasExp (ImpGen.ScalarVar e _) = e
        hasExp (ImpGen.MemVar e _) = e

-- | Only some constant expressions quality as *static* expressions,
-- which we can use for static memory allocation.  This is a bit of a
-- hack, as it is primarly motivated by what you can put as the size
-- when daring an array in C.
isStaticExp :: Imp.KernelConstExp -> Bool
isStaticExp LeafExp{} = True
isStaticExp ValueExp{} = True
isStaticExp (BinOpExp Add{} x y) = isStaticExp x && isStaticExp y
isStaticExp (BinOpExp Sub{} x y) = isStaticExp x && isStaticExp y
isStaticExp (BinOpExp Mul{} x y) = isStaticExp x && isStaticExp y
isStaticExp _ = False

computeThreadChunkSize :: SplitOrdering
                       -> Imp.Exp
                       -> Imp.Count Imp.Elements
                       -> Imp.Count Imp.Elements
                       -> VName
                       -> ImpGen.ImpM lore op ()
computeThreadChunkSize (SplitStrided stride) thread_index elements_per_thread num_elements chunk_var = do
  stride' <- ImpGen.compileSubExp stride
  chunk_var <--
    Imp.BinOpExp (SMin Int32)
    (Imp.innerExp elements_per_thread)
    ((Imp.innerExp num_elements - thread_index) `quotRoundingUp` stride')

computeThreadChunkSize SplitContiguous thread_index elements_per_thread num_elements chunk_var = do
  starting_point <- dPrimV "starting_point" $
    thread_index * Imp.innerExp elements_per_thread
  remaining_elements <- dPrimV "remaining_elements" $
    Imp.innerExp num_elements - Imp.var starting_point int32

  let no_remaining_elements = Imp.var remaining_elements int32 .<=. 0
      beyond_bounds = Imp.innerExp num_elements .<=. Imp.var starting_point int32

  sIf (no_remaining_elements .||. beyond_bounds)
    (chunk_var <-- 0)
    (sIf is_last_thread
       (chunk_var <-- Imp.innerExp last_thread_elements)
       (chunk_var <-- Imp.innerExp elements_per_thread))
  where last_thread_elements =
          num_elements - Imp.elements thread_index * elements_per_thread
        is_last_thread =
          Imp.innerExp num_elements .<.
          (thread_index + 1) * Imp.innerExp elements_per_thread

kernelInitialisationSetSpace :: KernelSpace -> InKernelGen ()
                             -> ImpGen.ImpM lore op (KernelConstants, ImpGen.ImpM InKernel Imp.KernelOp ())
kernelInitialisationSetSpace space set_space = do
  group_size' <- ImpGen.compileSubExp $ spaceGroupSize space
  num_threads' <- ImpGen.compileSubExp $ spaceNumThreads space
  num_groups <- ImpGen.compileSubExp $ spaceNumGroups space

  let global_tid = spaceGlobalId space
      local_tid = spaceLocalId space
      group_id = spaceGroupId space
  wave_size <- newVName "wave_size"
  inner_group_size <- newVName "group_size"

  let (space_is, space_dims) = unzip $ spaceDimensions space
  space_dims' <- mapM ImpGen.compileSubExp space_dims
  let constants =
        KernelConstants
        (Imp.var global_tid int32)
        (Imp.var local_tid int32)
        (Imp.var group_id int32)
        global_tid local_tid group_id
        group_size' num_groups num_threads'
        (Imp.var wave_size int32) (zip space_is space_dims')
        (if null (spaceDimensions space)
         then true else isActive (spaceDimensions space)) mempty

  let set_constants = do
        dPrim_ wave_size int32
        dPrim_ inner_group_size int32
        ImpGen.dScope Nothing (scopeOfKernelSpace space)

        sOp (Imp.GetGlobalId global_tid 0)
        sOp (Imp.GetLocalId local_tid 0)
        sOp (Imp.GetLocalSize inner_group_size 0)
        sOp (Imp.GetLockstepWidth wave_size)
        sOp (Imp.GetGroupId group_id 0)

        set_space

  return (constants, set_constants)

kernelInitialisation :: KernelSpace
                     -> ImpGen.ImpM lore op (KernelConstants, ImpGen.ImpM InKernel Imp.KernelOp ())
kernelInitialisation space =
  kernelInitialisationSetSpace space $
  setSpaceIndices (Imp.var (spaceGlobalId space) int32) space

setSpaceIndices :: Imp.Exp -> KernelSpace -> InKernelGen ()
setSpaceIndices gtid space =
  case spaceStructure space of
    FlatThreadSpace is_and_dims ->
      flatSpaceWith gtid is_and_dims
    NestedThreadSpace is_and_dims -> do
      let (gtids, gdims, ltids, ldims) = unzip4 is_and_dims
      gdims' <- mapM ImpGen.compileSubExp gdims
      ldims' <- mapM ImpGen.compileSubExp ldims
      let (gtid_es, ltid_es) = unzip $ unflattenNestedIndex gdims' ldims' gtid
      zipWithM_ (<--) gtids gtid_es
      zipWithM_ (<--) ltids ltid_es
  where flatSpaceWith base is_and_dims = do
          let (is, dims) = unzip is_and_dims
          dims' <- mapM ImpGen.compileSubExp dims
          let index_expressions = unflattenIndex dims' base
          zipWithM_ (<--) is index_expressions

isActive :: [(VName, SubExp)] -> Imp.Exp
isActive limit = case actives of
                    [] -> Imp.ValueExp $ BoolValue True
                    x:xs -> foldl (.&&.) x xs
  where (is, ws) = unzip limit
        actives = zipWith active is $ map (ImpGen.compileSubExpOfType Bool) ws
        active i = (Imp.var i int32 .<.)

unflattenNestedIndex :: IntegralExp num => [num] -> [num] -> num -> [(num,num)]
unflattenNestedIndex global_dims group_dims global_id =
  zip global_is local_is
  where num_groups_dims = zipWith quotRoundingUp global_dims group_dims
        group_size = product group_dims
        group_id = global_id `Futhark.Util.IntegralExp.quot` group_size
        local_id = global_id `Futhark.Util.IntegralExp.rem` group_size

        group_is = unflattenIndex num_groups_dims group_id
        local_is = unflattenIndex group_dims local_id
        global_is = zipWith (+) local_is $ zipWith (*) group_is group_dims


-- | Change every memory block to be in the global address space,
-- except those who are in the local memory space.  This only affects
-- generated code - we still need to make sure that the memory is
-- actually present on the device (and dared as variables in the
-- kernel).
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
  local (\env -> env { ImpGen.envDefaultSpace = Imp.Space "global" }) .
  ImpGen.localVTable (M.map globalMemory)
  where globalMemory (ImpGen.MemVar _ entry)
          | ImpGen.entryMemSpace entry /= Space "local" =
              ImpGen.MemVar Nothing entry { ImpGen.entryMemSpace = Imp.Space "global" }
        globalMemory entry =
          entry

allThreads :: KernelConstants -> InKernelGen () -> InKernelGen ()
allThreads constants = ImpGen.emit <=< ImpGen.subImpM_ (inKernelOperations constants')
  where constants' =
          constants { kernelThreadActive = Imp.ValueExp (BoolValue True) }



writeParamToLocalMemory :: Typed (MemBound u) =>
                           Imp.Exp -> (VName, t) -> Param (MemBound u)
                        -> ImpGen.ImpM lore op ()
writeParamToLocalMemory i (mem, _) param
  | Prim t <- paramType param =
      ImpGen.emit $
      Imp.Write mem (bytes i') bt (Space "local") Imp.Volatile $
      Imp.var (paramName param) t
  | otherwise =
      return ()
  where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
        bt = elemType $ paramType param

readParamFromLocalMemory :: Typed (MemBound u) =>
                            VName -> Imp.Exp -> Param (MemBound u) -> (VName, t)
                         -> ImpGen.ImpM lore op ()
readParamFromLocalMemory index i param (l_mem, _)
  | Prim _ <- paramType param =
      paramName param <--
      Imp.index l_mem (bytes i') bt (Space "local") Imp.Volatile
  | otherwise = index <-- i
  where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
        bt = elemType $ paramType param

groupReduce :: ExplicitMemorish lore =>
               KernelConstants
            -> Imp.Exp
            -> Lambda lore
            -> [VName]
            -> ImpGen.ImpM lore Imp.KernelOp ()
groupReduce constants w lam arrs = do
  offset <- dPrim "offset" int32
  groupReduceWithOffset constants offset w lam arrs

groupReduceWithOffset :: ExplicitMemorish lore =>
                         KernelConstants
                      -> VName
                      -> Imp.Exp
                      -> Lambda lore
                      -> [VName]
                      -> ImpGen.ImpM lore Imp.KernelOp ()
groupReduceWithOffset constants offset w lam arrs = do
  let (reduce_acc_params, reduce_arr_params) = splitAt (length arrs) $ lambdaParams lam

  skip_waves <- dPrim "skip_waves" int32
  ImpGen.dLParams $ lambdaParams lam

  offset <-- 0

  ImpGen.comment "participating threads read initial accumulator" $
    sWhen (local_tid .<. w) $
    zipWithM_ readReduceArgument reduce_acc_params arrs

  let do_reduce = do ImpGen.comment "read array element" $
                       zipWithM_ readReduceArgument reduce_arr_params arrs
                     ImpGen.comment "apply reduction operation" $
                       ImpGen.compileBody' reduce_acc_params $ lambdaBody lam
                     ImpGen.comment "write result of operation" $
                       zipWithM_ writeReduceOpResult reduce_acc_params arrs
      in_wave_reduce = ImpGen.everythingVolatile do_reduce

      wave_size = kernelWaveSize constants
      group_size = kernelGroupSize constants
      wave_id = local_tid `quot` wave_size
      in_wave_id = local_tid - wave_id * wave_size
      num_waves = (group_size + wave_size - 1) `quot` wave_size
      arg_in_bounds = local_tid + Imp.var offset int32 .<. w

      doing_in_wave_reductions =
        Imp.var offset int32 .<. wave_size
      apply_in_in_wave_iteration =
        (in_wave_id .&. (2 * Imp.var offset int32 - 1)) .==. 0
      in_wave_reductions = do
        offset <-- 1
        sWhile doing_in_wave_reductions $ do
          sWhen (arg_in_bounds .&&. apply_in_in_wave_iteration)
            in_wave_reduce
          offset <-- Imp.var offset int32 * 2

      doing_cross_wave_reductions =
        Imp.var skip_waves int32 .<. num_waves
      is_first_thread_in_wave =
        in_wave_id .==. 0
      wave_not_skipped =
        (wave_id .&. (2 * Imp.var skip_waves int32 - 1)) .==. 0
      apply_in_cross_wave_iteration =
        arg_in_bounds .&&. is_first_thread_in_wave .&&. wave_not_skipped
      cross_wave_reductions = do
        skip_waves <-- 1
        sWhile doing_cross_wave_reductions $ do
          barrier
          offset <-- Imp.var skip_waves int32 * wave_size
          sWhen apply_in_cross_wave_iteration
            do_reduce
          skip_waves <-- Imp.var skip_waves int32 * 2

  in_wave_reductions
  cross_wave_reductions
  where local_tid = kernelLocalThreadId constants
        global_tid = kernelGlobalThreadId constants

        barrier
          | all primType $ lambdaReturnType lam = sOp Imp.LocalBarrier
          | otherwise                           = sOp Imp.GlobalBarrier

        readReduceArgument param arr
          | Prim _ <- paramType param = do
              let i = local_tid + ImpGen.varIndex offset
              ImpGen.copyDWIM (paramName param) [] (Var arr) [i]
          | otherwise = do
              let i = global_tid + ImpGen.varIndex offset
              ImpGen.copyDWIM (paramName param) [] (Var arr) [i]

        writeReduceOpResult param arr
          | Prim _ <- paramType param =
              ImpGen.copyDWIM arr [local_tid] (Var $ paramName param) []
          | otherwise =
              return ()

groupScan :: KernelConstants
          -> Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
          -> Imp.Exp
          -> Lambda InKernel
          -> [VName]
          -> ImpGen.ImpM InKernel Imp.KernelOp ()
groupScan constants seg_flag w lam arrs = do
  when (any (not . primType . paramType) $ lambdaParams lam) $
    compilerLimitationS "Cannot compile parallel scans with array element type."

  renamed_lam <- renameLambda lam

  acc_local_mem <- flip zip (repeat ()) <$>
                   mapM (fmap (ImpGen.memLocationName . ImpGen.entryArrayLocation) .
                         ImpGen.lookupArray) arrs

  let ltid = kernelLocalThreadId constants
      (lam_i, other_index_param, actual_params) =
        partitionChunkedKernelLambdaParameters $ lambdaParams lam
      (x_params, y_params) = splitAt (length arrs) actual_params

  ImpGen.dLParams (lambdaParams lam++lambdaParams renamed_lam)
  lam_i <-- ltid

  -- The scan works by splitting the group into blocks, which are
  -- scanned separately.  Typically, these blocks are smaller than
  -- the lockstep width, which enables barrier-free execution inside
  -- them.
  --
  -- We hardcode the block size here.  The only requirement is that
  -- it should not be less than the square root of the group size.
  -- With 32, we will work on groups of size 1024 or smaller, which
  -- fits every device Troels has seen.  Still, it would be nicer if
  -- it were a runtime parameter.  Some day.
  let block_size = Imp.ValueExp $ IntValue $ Int32Value 32
      simd_width = kernelWaveSize constants
      block_id = ltid `quot` block_size
      in_block_id = ltid - block_id * block_size
      doInBlockScan seg_flag' active = inBlockScan seg_flag' simd_width block_size active ltid acc_local_mem
      ltid_in_bounds = ltid .<. w

  doInBlockScan seg_flag ltid_in_bounds lam
  sOp Imp.LocalBarrier

  let last_in_block = in_block_id .==. block_size - 1
  sComment "last thread of block 'i' writes its result to offset 'i'" $
    sWhen (last_in_block .&&. ltid_in_bounds) $
    zipWithM_ (writeParamToLocalMemory block_id) acc_local_mem y_params

  sOp Imp.LocalBarrier

  let is_first_block = block_id .==. 0
      first_block_seg_flag = do
        flag_true <- seg_flag
        Just $ \from to ->
          flag_true (from*block_size+block_size-1) (to*block_size+block_size-1)
  ImpGen.comment
    "scan the first block, after which offset 'i' contains carry-in for warp 'i+1'" $
    doInBlockScan first_block_seg_flag (is_first_block .&&. ltid_in_bounds) renamed_lam

  sOp Imp.LocalBarrier

  let read_carry_in =
        zipWithM_ (readParamFromLocalMemory
                   (paramName other_index_param) (block_id - 1))
        x_params acc_local_mem

  let op_to_y
        | Nothing <- seg_flag =
            ImpGen.compileBody' y_params $ lambdaBody lam
        | Just flag_true <- seg_flag =
            sUnless (flag_true (block_id*block_size-1) ltid) $
              ImpGen.compileBody' y_params $ lambdaBody lam
      write_final_result =
        zipWithM_ (writeParamToLocalMemory ltid) acc_local_mem y_params

  sComment "carry-in for every block except the first" $
    sUnless (is_first_block .||. Imp.UnOpExp Not ltid_in_bounds) $ do
    sComment "read operands" read_carry_in
    sComment "perform operation" op_to_y
    sComment "write final result" write_final_result

  sOp Imp.LocalBarrier

  sComment "restore correct values for first block" $
    sWhen is_first_block write_final_result

inBlockScan :: Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
            -> Imp.Exp
            -> Imp.Exp
            -> Imp.Exp
            -> Imp.Exp
            -> [(VName, t)]
            -> Lambda InKernel
            -> InKernelGen ()
inBlockScan seg_flag lockstep_width block_size active ltid acc_local_mem scan_lam = ImpGen.everythingVolatile $ do
  skip_threads <- dPrim "skip_threads" int32
  let in_block_thread_active =
        Imp.var skip_threads int32 .<=. in_block_id
      (scan_lam_i, other_index_param, actual_params) =
        partitionChunkedKernelLambdaParameters $ lambdaParams scan_lam
      (x_params, y_params) =
        splitAt (length actual_params `div` 2) actual_params
      read_operands =
        zipWithM_ (readParamFromLocalMemory (paramName other_index_param) $
                   ltid - Imp.var skip_threads int32)
        x_params acc_local_mem

  -- Set initial y values
  sWhen active $
    zipWithM_ (readParamFromLocalMemory scan_lam_i ltid)
    y_params acc_local_mem

  let op_to_y
        | Nothing <- seg_flag =
            ImpGen.compileBody' y_params $ lambdaBody scan_lam
        | Just flag_true <- seg_flag =
            sUnless (flag_true (ltid-Imp.var skip_threads int32) ltid) $
              ImpGen.compileBody' y_params $ lambdaBody scan_lam
      write_operation_result =
        zipWithM_ (writeParamToLocalMemory ltid) acc_local_mem y_params
      maybeLocalBarrier = sWhen (lockstep_width .<=. Imp.var skip_threads int32) $
                          sOp Imp.LocalBarrier

  sComment "in-block scan (hopefully no barriers needed)" $ do
    skip_threads <-- 1
    sWhile (Imp.var skip_threads int32 .<. block_size) $ do
      sWhen (in_block_thread_active .&&. active) $ do
        sComment "read operands" read_operands
        sComment "perform operation" op_to_y

      maybeLocalBarrier

      sWhen (in_block_thread_active .&&. active) $
        sComment "write result" write_operation_result

      maybeLocalBarrier

      skip_threads <-- Imp.var skip_threads int32 * 2

  where block_id = ltid `quot` block_size
        in_block_id = ltid - block_id * block_size

compileKernelStms :: KernelConstants -> [Stm InKernel]
                  -> InKernelGen a
                  -> InKernelGen a
compileKernelStms constants ungrouped_bnds m =
  compileGroupedKernelStms' $ groupStmsByGuard constants ungrouped_bnds
  where compileGroupedKernelStms' [] = m
        compileGroupedKernelStms' ((g, bnds):rest_bnds) = do
          ImpGen.dScopes (map ((Just . stmExp) &&& (castScope . scopeOf)) bnds)
          protect g $ mapM_ compileKernelStm bnds
          compileGroupedKernelStms' rest_bnds

        protect Nothing body_m =
          body_m
        protect (Just (Imp.ValueExp (BoolValue True))) body_m =
          body_m
        protect (Just g) body_m =
          sWhen g $ allThreads constants body_m

        compileKernelStm (Let pat _ e) = ImpGen.compileExp pat e

groupStmsByGuard :: KernelConstants
                     -> [Stm InKernel]
                     -> [(Maybe Imp.Exp, [Stm InKernel])]
groupStmsByGuard constants bnds =
  map collapse $ groupBy sameGuard $ zip (map bindingGuard bnds) bnds
  where bindingGuard (Let _ _ Op{}) = Nothing
        bindingGuard _ = Just $ kernelThreadActive constants

        sameGuard (g1, _) (g2, _) = g1 == g2

        collapse [] =
          (Nothing, [])
        collapse l@((g,_):_) =
          (g, map snd l)

sKernel :: KernelConstants -> String -> ImpGen.ImpM InKernel Imp.KernelOp a -> CallKernelGen ()
sKernel constants name m = do
  body <- makeAllMemoryGlobal $
          ImpGen.subImpM_ (inKernelOperations constants) m
  (uses, local_memory) <- computeKernelUses body mempty
  ImpGen.emit $ Imp.Op $ Imp.CallKernel Imp.Kernel
    { Imp.kernelBody = body
    , Imp.kernelLocalMemory = local_memory
    , Imp.kernelUses = uses
    , Imp.kernelNumGroups = [kernelNumGroups constants]
    , Imp.kernelGroupSize = [kernelGroupSize constants]
    , Imp.kernelName =
        nameFromString $ name ++ "_" ++ show (baseTag $ kernelGlobalThreadIdVar constants)
    }