{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE RebindableSyntax    #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan
-- Copyright   : [2016..2017] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan (

  mkScanl, mkScanl1, mkScanl',
  mkScanr, mkScanr1, mkScanr',

) where

-- accelerate
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar

import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target

import LLVM.AST.Type.Representation

import qualified Foreign.CUDA.Analysis                              as CUDA

import Control.Applicative
import Control.Monad                                                ( (>=>), void )
import Data.String                                                  ( fromString )
import Data.Coerce                                                  as Safe
import Data.Bits                                                    as P
import Prelude                                                      as P hiding ( last )


data Direction = L | R

-- 'Data.List.scanl' style left-to-right exclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation.
--
-- > scanl (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 11) [10,10,11,13,16,20,25,31,38,46,55]
--
mkScanl
    :: forall aenv sh e. (Shape sh, Elt e)
    => PTX
    -> Gamma         aenv
    -> IRFun2    PTX aenv (e -> e -> e)
    -> IRExp     PTX aenv e
    -> IRDelayed PTX aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanl ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine (Just seed) arr
                              , mkScanAllP2 L dev aenv combine
                              , mkScanAllP3 L dev aenv combine (Just seed)
                              , mkScanFill ptx aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScanDim L dev aenv combine (Just seed) arr
          <*> mkScanFill ptx aenv seed


-- 'Data.List.scanl1' style left-to-right inclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation. The array must not be empty.
--
-- > scanl1 (+) (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 10) [0,1,3,6,10,15,21,28,36,45]
--
mkScanl1
    :: forall aenv sh e. (Shape sh, Elt e)
    => PTX
    -> Gamma         aenv
    -> IRFun2    PTX aenv (e -> e -> e)
    -> IRDelayed PTX aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanl1 (deviceProperties . ptxContext -> dev) aenv combine arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine Nothing arr
                              , mkScanAllP2 L dev aenv combine
                              , mkScanAllP3 L dev aenv combine Nothing
                              ]
  --
  | otherwise
  = mkScanDim L dev aenv combine Nothing arr


-- Variant of 'scanl' where the final result is returned in a separate array.
--
-- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> ( Array (Z :. 10) [10,10,11,13,16,20,25,31,38,46]
--       , Array Z [55]
--       )
--
mkScanl'
    :: forall aenv sh e. (Shape sh, Elt e)
    => PTX
    -> Gamma         aenv
    -> IRFun2    PTX aenv (e -> e -> e)
    -> IRExp     PTX aenv e
    -> IRDelayed PTX aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScanl' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScan'AllP1 L dev aenv combine seed arr
                              , mkScan'AllP2 L dev aenv combine
                              , mkScan'AllP3 L dev aenv combine
                              , mkScan'Fill ptx aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScan'Dim L dev aenv combine seed arr
          <*> mkScan'Fill ptx aenv seed


-- 'Data.List.scanr' style right-to-left exclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation.
--
-- > scanr (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 11) [55,55,54,52,49,45,40,34,27,19,10]
--
mkScanr
    :: forall aenv sh e. (Shape sh, Elt e)
    => PTX
    -> Gamma         aenv
    -> IRFun2    PTX aenv (e -> e -> e)
    -> IRExp     PTX aenv e
    -> IRDelayed PTX aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanr ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine (Just seed) arr
                              , mkScanAllP2 R dev aenv combine
                              , mkScanAllP3 R dev aenv combine (Just seed)
                              , mkScanFill ptx aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScanDim R dev aenv combine (Just seed) arr
          <*> mkScanFill ptx aenv seed


-- 'Data.List.scanr1' style right-to-left inclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation. The array must not be empty.
--
-- > scanr (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 10) [45,45,44,42,39,35,30,24,17,9]
--
mkScanr1
    :: forall aenv sh e. (Shape sh, Elt e)
    => PTX
    -> Gamma         aenv
    -> IRFun2    PTX aenv (e -> e -> e)
    -> IRDelayed PTX aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanr1 (deviceProperties . ptxContext -> dev) aenv combine arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine Nothing arr
                              , mkScanAllP2 R dev aenv combine
                              , mkScanAllP3 R dev aenv combine Nothing
                              ]
  --
  | otherwise
  = mkScanDim R dev aenv combine Nothing arr


-- Variant of 'scanr' where the final result is returned in a separate array.
--
-- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> ( Array (Z :. 10) [55,54,52,49,45,40,34,27,19,10]
--       , Array Z [55]
--       )
--
mkScanr'
    :: forall aenv sh e. (Shape sh, Elt e)
    => PTX
    -> Gamma         aenv
    -> IRFun2    PTX aenv (e -> e -> e)
    -> IRExp     PTX aenv e
    -> IRDelayed PTX aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScanr' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScan'AllP1 R dev aenv combine seed arr
                              , mkScan'AllP2 R dev aenv combine
                              , mkScan'AllP3 R dev aenv combine
                              , mkScan'Fill ptx aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScan'Dim R dev aenv combine seed arr
          <*> mkScan'Fill ptx aenv seed


-- Device wide scans
-- -----------------
--
-- This is a classic two-pass algorithm which proceeds in two phases and
-- requires ~4n data movement to global memory. In future we would like to
-- replace this with a single pass algorithm.
--

-- Parallel scan, step 1.
--
-- Threads scan a stripe of the input into a temporary array, incorporating the
-- initial element and any fused functions on the way. The final reduction
-- result of this chunk is written to a separate array.
--
mkScanAllP1
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties                             -- ^ properties of the target GPU
    -> Gamma aenv                                   -- ^ array environment
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> Maybe (IRExp PTX aenv e)                     -- ^ seed element, if this is an exclusive scan
    -> IRDelayed PTX aenv (Vector e)                -- ^ input data
    -> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP1 dir dev aenv combine mseed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      config                    = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
      smem n                    = warps * (1 + per_warp) * bytes
        where
          ws        = CUDA.warpSize dev
          warps     = n `P.quot` ws
          per_warp  = ws + ws `P.quot` 2
          bytes     = sizeOf (eltType (undefined :: e))
  in
  makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do

    -- Size of the input array
    sz  <- A.fromIntegral integralType numType . indexHead =<< delayedExtent

    -- A thread block scans a non-empty stripe of the input, storing the final
    -- block-wide aggregate into a separate array
    --
    -- For exclusive scans, thread 0 of segment 0 must incorporate the initial
    -- element into the input and output. Threads shuffle their indices
    -- appropriately.
    --
    bid <- blockIdx
    gd  <- gridDim
    s0  <- A.add numType start bid

    -- iterating over thread-block-wide segments
    imapFromStepTo s0 gd end $ \chunk -> do

      bd  <- blockDim
      inf <- A.mul numType chunk bd

      -- index i* is the index that this thread will read data from. Recall that
      -- the supremum index is exclusive
      tid <- threadIdx
      i0  <- case dir of
               L -> A.add numType inf tid
               R -> do x <- A.sub numType sz inf
                       y <- A.sub numType x tid
                       z <- A.sub numType y (lift 1)
                       return z

      -- index j* is the index that we write to. Recall that for exclusive scans
      -- the output array is one larger than the input; the initial element will
      -- be written into this spot by thread 0 of the first thread block.
      j0  <- case mseed of
               Nothing -> return i0
               Just _  -> case dir of
                            L -> A.add numType i0 (lift 1)
                            R -> return i0

      -- If this thread has input, read data and participate in thread-block scan
      let valid i = case dir of
                      L -> A.lt  scalarType i sz
                      R -> A.gte scalarType i (lift 0)

      when (valid i0) $ do
        x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0
        x1 <- case mseed of
                Nothing   -> return x0
                Just seed ->
                  if A.eq scalarType tid (lift 0) `A.land` A.eq scalarType chunk (lift 0)
                    then do
                      z <- seed
                      case dir of
                        L -> writeArray arrOut (lift 0 :: IR Int32) z >> app2 combine z x0
                        R -> writeArray arrOut sz                   z >> app2 combine x0 z
                    else
                      return x0

        n  <- A.sub numType sz inf
        x2 <- if A.gte scalarType n bd
                then scanBlockSMem dir dev combine Nothing  x1
                else scanBlockSMem dir dev combine (Just n) x1

        -- Write this thread's scan result to memory
        writeArray arrOut j0 x2

        -- The last thread also writes its result---the aggregate for this
        -- thread block---to the temporary partial sums array. This is only
        -- necessary for full blocks in a multi-block scan; the final
        -- partially-full tile does not have a successor block.
        last <- A.sub numType bd (lift 1)
        when (A.gt scalarType gd (lift 1) `land` A.eq scalarType tid last) $
          case dir of
            L -> writeArray arrTmp chunk x2
            R -> do u <- A.sub numType end chunk
                    v <- A.sub numType u (lift 1)
                    writeArray arrTmp v x2

    return_


-- Parallel scan, step 2
--
-- A single thread block performs a scan of the per-block aggregates computed in
-- step 1. This gives the per-block prefix which must be added to each element
-- in step 3.
--
mkScanAllP2
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties                             -- ^ properties of the target GPU
    -> Gamma aenv                                   -- ^ array environment
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP2 dir dev aenv combine =
  let
      (start, end, paramGang)   = gangParam
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      config                    = launchConfig dev (CUDA.incWarp dev) smem grid gridQ
      grid _ _                  = 1
      gridQ                     = [|| \_ _ -> 1 ||]
      smem n                    = warps * (1 + per_warp) * bytes
        where
          ws        = CUDA.warpSize dev
          warps     = n `P.quot` ws
          per_warp  = ws + ws `P.quot` 2
          bytes     = sizeOf (eltType (undefined :: e))
  in
  makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramEnv) $ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    --
    -- TODO: We could optimise this a bit if we can get access to the shared
    -- memory area used by 'scanBlockSMem', and from there directly read the
    -- value computed by the last thread.
    carry <- staticSharedMem 1

    bd    <- blockDim
    imapFromStepTo start bd end $ \offset -> do

      -- Index of the partial sums array that this thread will process.
      tid <- threadIdx
      i0  <- case dir of
               L -> A.add numType offset tid
               R -> do x <- A.sub numType end offset
                       y <- A.sub numType x tid
                       z <- A.sub numType y (lift 1)
                       return z

      let valid i = case dir of
                      L -> A.lt  scalarType i end
                      R -> A.gte scalarType i start

      when (valid i0) $ do

        __syncthreads

        x0 <- readArray arrTmp i0
        x1 <- if A.gt scalarType offset (lift 0) `land` A.eq scalarType tid (lift 0)
                then do
                  c <- readArray carry (lift 0 :: IR Int32)
                  case dir of
                    L -> app2 combine c x0
                    R -> app2 combine x0 c
                else do
                  return x0

        n  <- A.sub numType end offset
        x2 <- if A.gte scalarType n bd
                then scanBlockSMem dir dev combine Nothing  x1
                else scanBlockSMem dir dev combine (Just n) x1

        -- Update the temporary array with this thread's result
        writeArray arrTmp i0 x2

        -- The last thread writes the carry-out value. If the last thread is not
        -- active, then this must be the last stripe anyway.
        last <- A.sub numType bd (lift 1)
        when (A.eq scalarType tid last) $
          writeArray carry (lift 0 :: IR Int32) x2

    return_


-- Parallel scan, step 3.
--
-- Threads combine every element of the partial block results with the carry-in
-- value computed in step 2.
--
mkScanAllP3
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties                             -- ^ properties of the target GPU
    -> Gamma aenv                                   -- ^ array environment
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> Maybe (IRExp PTX aenv e)                     -- ^ seed element, if this is an exclusive scan
    -> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP3 dir dev aenv combine mseed =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      stride                    = local           scalarType ("ix.stride" :: Name Int32)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int32)
      --
      config                    = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||]
  in
  makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do

    sz  <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut))
    tid <- threadIdx

    -- Threads that will never contribute can just exit immediately. The size of
    -- each chunk is set by the block dimension of the step 1 kernel, which may
    -- be different from the block size of this kernel.
    when (A.lt scalarType tid stride) $ do

      -- Iterate over the segments computed in phase 1. Note that we have one
      -- fewer chunk to process because the first has no carry-in.
      bid <- blockIdx
      gd  <- gridDim
      c0  <- A.add numType start bid
      imapFromStepTo c0 gd end $ \chunk -> do

        -- Determine the start and end indicies of this chunk to which we will
        -- carry-in the value. Returned for left-to-right traversal.
        (inf,sup) <- case dir of
                       L -> do
                         a <- A.add numType chunk (lift 1)
                         b <- A.mul numType stride a
                         case mseed of
                           Just{}  -> do
                             c <- A.add numType b (lift 1)
                             d <- A.add numType c stride
                             e <- A.min scalarType d sz
                             return (c,e)
                           Nothing -> do
                             c <- A.add numType b stride
                             d <- A.min scalarType c sz
                             return (b,d)
                       R -> do
                         a <- A.sub numType end chunk
                         b <- A.mul numType stride a
                         c <- A.sub numType sz b
                         case mseed of
                           Just{}  -> do
                             d <- A.sub numType c (lift 1)
                             e <- A.sub numType d stride
                             f <- A.max scalarType e (lift 0)
                             return (f,d)
                           Nothing -> do
                             d <- A.sub numType c stride
                             e <- A.max scalarType d (lift 0)
                             return (e,c)

        -- Read the carry-in value
        carry     <- case dir of
                       L -> readArray arrTmp chunk
                       R -> do
                         a <- A.add numType chunk (lift 1)
                         b <- readArray arrTmp a
                         return b

        -- Apply the carry-in value to each element in the chunk
        bd        <- blockDim
        i0        <- A.add numType inf tid
        imapFromStepTo i0 bd sup $ \i -> do
          v <- readArray arrOut i
          u <- case dir of
                 L -> app2 combine carry v
                 R -> app2 combine v carry
          writeArray arrOut i u

    return_


-- Parallel scan', step 1.
--
-- Similar to mkScanAllP1. Threads scan a stripe of the input into a temporary
-- array, incorporating the initial element and any fused functions on the way.
-- The final reduction result of this chunk is written to a separate array.
--
mkScan'AllP1
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties
    -> Gamma aenv
    -> IRFun2 PTX aenv (e -> e -> e)
    -> IRExp PTX aenv e
    -> IRDelayed PTX aenv (Vector e)
    -> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP1 dir dev aenv combine seed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      config                    = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
      smem n                    = warps * (1 + per_warp) * bytes
        where
          ws        = CUDA.warpSize dev
          warps     = n `P.quot` ws
          per_warp  = ws + ws `P.quot` 2
          bytes     = sizeOf (eltType (undefined :: e))
  in
  makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do

    -- Size of the input array
    sz  <- A.fromIntegral integralType numType . indexHead =<< delayedExtent

    -- A thread block scans a non-empty stripe of the input, storing the partial
    -- result and the final block-wide aggregate
    bid <- blockIdx
    gd  <- gridDim
    s0  <- A.add numType start bid

    -- iterate over thread-block wide segments
    imapFromStepTo s0 gd end $ \seg -> do

      bd  <- blockDim
      inf <- A.mul numType seg bd

      -- i* is the index that this thread will read data from
      tid <- threadIdx
      i0  <- case dir of
               L -> A.add numType inf tid
               R -> do x <- A.sub numType sz inf
                       y <- A.sub numType x tid
                       z <- A.sub numType y (lift 1)
                       return z

      -- j* is the index this thread will write to. This is just shifted by one
      -- to make room for the initial element
      j0  <- case dir of
               L -> A.add numType i0 (lift 1)
               R -> A.sub numType i0 (lift 1)

      -- If this thread has input it participates in the scan
      let valid i = case dir of
                      L -> A.lt  scalarType i sz
                      R -> A.gte scalarType i (lift 0)

      when (valid i0) $ do
        x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0

        -- Thread 0 of the first segment must also evaluate and store the
        -- initial element
        x1 <- if A.eq scalarType tid (lift 0) `A.land` A.eq scalarType seg (lift 0)
                then do
                  z <- seed
                  writeArray arrOut i0 z
                  case dir of
                    L -> app2 combine z x0
                    R -> app2 combine x0 z
                else
                  return x0

        -- Block-wide scan
        n  <- A.sub numType sz inf
        x2 <- if A.gte scalarType n bd
                then scanBlockSMem dir dev combine Nothing  x1
                else scanBlockSMem dir dev combine (Just n) x1

        -- Write this thread's scan result to memory. Recall that we had to make
        -- space for the initial element, so the very last thread does not store
        -- its result here.
        case dir of
          L -> when (A.lt  scalarType j0 sz)       $ writeArray arrOut j0 x2
          R -> when (A.gte scalarType j0 (lift 0)) $ writeArray arrOut j0 x2

        -- Last active thread writes its result to the partial sums array. These
        -- will be used to compute the carry-in value in step 2.
        m  <- do x <- A.min scalarType n bd
                 y <- A.sub numType x (lift 1)
                 return y
        when (A.eq scalarType tid m) $
          case dir of
            L -> writeArray arrTmp seg x2
            R -> do x <- A.sub numType end seg
                    y <- A.sub numType x (lift 1)
                    writeArray arrTmp y x2

    return_


-- Parallel scan', step 2
--
-- A single thread block performs an inclusive scan of the partial sums array to
-- compute the per-block carry-in values, as well as the final reduction result.
--
mkScan'AllP2
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties
    -> Gamma aenv
    -> IRFun2 PTX aenv (e -> e -> e)
    -> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP2 dir dev aenv combine =
  let
      (start, end, paramGang)   = gangParam
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      (arrSum, paramSum)        = mutableArray ("sum" :: Name (Scalar e))
      paramEnv                  = envParam aenv
      --
      config                    = launchConfig dev (CUDA.incWarp dev) smem grid gridQ
      grid _ _                  = 1
      gridQ                     = [|| \_ _ -> 1 ||]
      smem n                    = warps * (1 + per_warp) * bytes
        where
          ws        = CUDA.warpSize dev
          warps     = n `P.quot` ws
          per_warp  = ws + ws `P.quot` 2
          bytes     = sizeOf (eltType (undefined :: e))
  in
  makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramSum ++ paramEnv) $ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    carry <- staticSharedMem 1

    -- A single thread block iterates over the per-block partial results from
    -- step 1
    tid <- threadIdx
    bd  <- blockDim
    imapFromStepTo start bd end $ \offset -> do

      i0  <- case dir of
               L -> A.add numType offset tid
               R -> do x <- A.sub numType end offset
                       y <- A.sub numType x tid
                       z <- A.sub numType y (lift 1)
                       return z

      let valid i = case dir of
                      L -> A.lt  scalarType i end
                      R -> A.gte scalarType i start

      when (valid i0) $ do

        -- wait for the carry-in value to be updated
        __syncthreads

        x0 <- readArray arrTmp i0
        x1 <- if A.gt scalarType offset (lift 0) `A.land` A.eq scalarType tid (lift 0)
                then do
                  c <- readArray carry (lift 0 :: IR Int32)
                  case dir of
                    L -> app2 combine c x0
                    R -> app2 combine x0 c
                else
                  return x0

        n  <- A.sub numType end offset
        x2 <- if A.gte scalarType n bd
                then scanBlockSMem dir dev combine Nothing  x1
                else scanBlockSMem dir dev combine (Just n) x1

        -- Update the partial results array
        writeArray arrTmp i0 x2

        -- The last active thread saves its result as the carry-out value.
        m  <- do x <- A.min scalarType bd n
                 y <- A.sub numType x (lift 1)
                 return y
        when (A.eq scalarType tid m) $
          writeArray carry (lift 0 :: IR Int32) x2

    -- First thread stores the final carry-out values at the final reduction
    -- result for the entire array
    __syncthreads

    when (A.eq scalarType tid (lift 0)) $
      writeArray arrSum (lift 0 :: IR Int32) =<< readArray carry (lift 0 :: IR Int32)

    return_


-- Parallel scan', step 3.
--
-- Threads combine every element of the partial block results with the carry-in
-- value computed in step 2.
--
mkScan'AllP3
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties                             -- ^ properties of the target GPU
    -> Gamma aenv                                   -- ^ array environment
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP3 dir dev aenv combine =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      stride                    = local           scalarType ("ix.stride" :: Name Int32)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int32)
      --
      config                    = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||]
  in
  makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do

    sz  <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut))
    tid <- threadIdx

    when (A.lt scalarType tid stride) $ do

      bid <- blockIdx
      gd  <- gridDim
      c0  <- A.add numType start bid
      imapFromStepTo c0 gd end $ \chunk -> do

        (inf,sup) <- case dir of
                       L -> do
                         a <- A.add numType chunk (lift 1)
                         b <- A.mul numType stride a
                         c <- A.add numType b (lift 1)
                         d <- A.add numType c stride
                         e <- A.min scalarType d sz
                         return (c,e)
                       R -> do
                         a <- A.sub numType end chunk
                         b <- A.mul numType stride a
                         c <- A.sub numType sz b
                         d <- A.sub numType c (lift 1)
                         e <- A.sub numType d stride
                         f <- A.max scalarType e (lift 0)
                         return (f,d)

        carry     <- case dir of
                       L -> readArray arrTmp chunk
                       R -> do
                         a <- A.add numType chunk (lift 1)
                         b <- readArray arrTmp a
                         return b

        -- Apply the carry-in value to each element in the chunk
        bd        <- blockDim
        i0        <- A.add numType inf tid
        imapFromStepTo i0 bd sup $ \i -> do
          v <- readArray arrOut i
          u <- case dir of
                 L -> app2 combine carry v
                 R -> app2 combine v carry
          writeArray arrOut i u

    return_


-- Multidimensional scans
-- ----------------------

-- Multidimensional scan along the innermost dimension
--
-- A thread block individually computes along each innermost dimension. This is
-- a single-pass operation.
--
--  * We can assume that the array is non-empty; exclusive scans with empty
--    innermost dimension will be instead filled with the seed element via
--    'mkScanFill'.
--
--  * Small but non-empty innermost dimension arrays (size << thread
--    block size) will have many threads which do no work.
--
mkScanDim
    :: forall aenv sh e. (Shape sh, Elt e)
    => Direction
    -> DeviceProperties                             -- ^ properties of the target GPU
    -> Gamma aenv                                   -- ^ array environment
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> Maybe (IRExp PTX aenv e)                     -- ^ seed element, if this is an exclusive scan
    -> IRDelayed PTX aenv (Array (sh:.Int) e)       -- ^ input data
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanDim dir dev aenv combine mseed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh:.Int) e))
      paramEnv                  = envParam aenv
      --
      config                    = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
      smem n                    = warps * (1 + per_warp) * bytes
        where
          ws        = CUDA.warpSize dev
          warps     = n `P.quot` ws
          per_warp  = ws + ws `P.quot` 2
          bytes     = sizeOf (eltType (undefined :: e))
  in
  makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramEnv) $ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    --
    -- TODO: we could optimise this a bit if we can get access to the shared
    -- memory area used by 'scanBlockSMem', and from there directly read the
    -- value computed by the last thread.
    carry <- staticSharedMem 1

    -- Size of the input array
    sz  <- A.fromIntegral integralType numType . indexHead =<< delayedExtent

    -- Thread blocks iterate over the outer dimensions. Threads in a block
    -- cooperatively scan along one dimension, but thread blocks do not
    -- communicate with each other.
    --
    bid <- blockIdx
    gd  <- gridDim
    s0  <- A.add numType start bid
    imapFromStepTo s0 gd end $ \seg -> do

      -- Index this thread reads from
      tid <- threadIdx
      i0  <- case dir of
               L -> do x <- A.mul numType seg sz
                       y <- A.add numType x tid
                       return y

               R -> do x <- A.add numType seg (lift 1)
                       y <- A.mul numType x sz
                       z <- A.sub numType y tid
                       w <- A.sub numType z (lift 1)
                       return w

      -- Index this thread writes to
      j0  <- case mseed of
               Nothing -> return i0
               Just{}  -> do szp1 <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut))
                             case dir of
                               L -> do x <- A.mul numType seg szp1
                                       y <- A.add numType x tid
                                       return y

                               R -> do x <- A.add numType seg (lift 1)
                                       y <- A.mul numType x szp1
                                       z <- A.sub numType y tid
                                       w <- A.sub numType z (lift 1)
                                       return w

      -- Stride indices by block dimension
      bd <- blockDim
      let next ix = case dir of
                      L -> A.add numType ix bd
                      R -> A.sub numType ix bd

      -- Initialise this scan segment
      --
      -- If this is an exclusive scan then the first thread just evaluates the
      -- seed element and stores this value into the carry-in slot. All threads
      -- shift their write-to index (j) by one, to make space for this element.
      --
      -- If this is an inclusive scan then do a block-wide scan. The last thread
      -- in the block writes the carry-in value.
      --
      r <-
        case mseed of
          Just seed -> do
            when (A.eq scalarType tid (lift 0)) $ do
              z <- seed
              writeArray arrOut j0 z
              writeArray carry (lift 0 :: IR Int32) z
            j1 <- case dir of
                   L -> A.add numType j0 (lift 1)
                   R -> A.sub numType j0 (lift 1)
            return $ A.trip sz i0 j1

          Nothing -> do
            when (A.lt scalarType tid sz) $ do
              x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0
              r0 <- if A.gte scalarType sz bd
                      then scanBlockSMem dir dev combine Nothing   x0
                      else scanBlockSMem dir dev combine (Just sz) x0
              writeArray arrOut j0 r0

              ll <- A.sub numType bd (lift 1)
              when (A.eq scalarType tid ll) $
                writeArray carry (lift 0 :: IR Int32) r0

            n1 <- A.sub numType sz bd
            i1 <- next i0
            j1 <- next j0
            return $ A.trip n1 i1 j1

      -- Iterate over the remaining elements in this segment
      void $ while
        (\(A.fst3   -> n)       -> A.gt scalarType n (lift 0))
        (\(A.untrip -> (n,i,j)) -> do

          -- Wait for the carry-in value from the previous iteration to be updated
          __syncthreads

          -- Compute and store the next element of the scan
          --
          -- NOTE: As with 'foldSeg' we require all threads to participate in
          -- every iteration of the loop otherwise they will die prematurely.
          -- Out-of-bounds threads return 'undef' at this point, which is really
          -- unfortunate ):
          --
          x <- if A.lt scalarType tid n
                 then app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
                 else let
                          go :: TupleType a -> Operands a
                          go UnitTuple       = OP_Unit
                          go (PairTuple a b) = OP_Pair (go a) (go b)
                          go (SingleTuple t) = ir' t (undef t)
                      in
                      return . IR $ go (eltType (undefined::e))

          -- Thread zero incorporates the carry-in element
          y <- if A.eq scalarType tid (lift 0)
                 then do
                   c <- readArray carry (lift 0 :: IR Int32)
                   case dir of
                     L -> app2 combine c x
                     R -> app2 combine x c
                  else
                    return x

          -- Perform the scan and write the result to memory
          z <- if A.gte scalarType n bd
                 then scanBlockSMem dir dev combine Nothing  y
                 else scanBlockSMem dir dev combine (Just n) y

          when (A.lt scalarType tid n) $ do
            writeArray arrOut j z

            -- The last thread of the block writes its result as the carry-out
            -- value. If this thread is not active then we are on the last
            -- iteration of the loop and it will not be needed.
            w <- A.sub numType bd (lift 1)
            when (A.eq scalarType tid w) $
              writeArray carry (lift 0 :: IR Int32) z

          -- Update indices for the next iteration
          n' <- A.sub numType n bd
          i' <- next i
          j' <- next j
          return $ A.trip n' i' j')
        r

    return_


-- Multidimensional scan' along the innermost dimension
--
-- A thread block individually computes along each innermost dimension. This is
-- a single-pass operation.
--
--  * We can assume that the array is non-empty; exclusive scans with empty
--    innermost dimension will be instead filled with the seed element via
--    'mkScan'Fill'.
--
--  * Small but non-empty innermost dimension arrays (size << thread
--    block size) will have many threads which do no work.
--
mkScan'Dim
    :: forall aenv sh e. (Shape sh, Elt e)
    => Direction
    -> DeviceProperties                             -- ^ properties of the target GPU
    -> Gamma aenv                                   -- ^ array environment
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> IRExp PTX aenv e                             -- ^ seed element
    -> IRDelayed PTX aenv (Array (sh:.Int) e)       -- ^ input data
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScan'Dim dir dev aenv combine seed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh:.Int) e))
      (arrSum, paramSum)        = mutableArray ("sum" :: Name (Array sh e))
      paramEnv                  = envParam aenv
      --
      config                    = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
      smem n                    = warps * (1 + per_warp) * bytes
        where
          ws        = CUDA.warpSize dev
          warps     = n `P.quot` ws
          per_warp  = ws + ws `P.quot` 2
          bytes     = sizeOf (eltType (undefined :: e))
  in
  makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramSum ++ paramEnv) $ do

    -- The first and last threads of the block need to communicate the
    -- block-wide aggregate as a carry-in value across iterations.
    --
    -- TODO: we could optimise this a bit if we can get access to the shared
    -- memory area used by 'scanBlockSMem', and from there directly read the
    -- value computed by the last thread.
    carry <- staticSharedMem 1

    -- Size of the input array
    sz    <- A.fromIntegral integralType numType . indexHead =<< delayedExtent

    -- If the innermost dimension is smaller than the number of threads in the
    -- block, those threads will never contribute to the output.
    tid   <- threadIdx
    when (A.lte scalarType tid sz) $ do

      -- Thread blocks iterate over the outer dimensions, each thread block
      -- cooperatively scanning along each outermost index.
      bid <- blockIdx
      gd  <- gridDim
      s0  <- A.add numType start bid
      imapFromStepTo s0 gd end $ \seg -> do

        -- Not necessary to wait for threads to catch up before starting this segment
        -- __syncthreads

        -- Linear index bounds for this segment
        inf <- A.mul numType seg sz
        sup <- A.add numType inf sz

        -- Index that this thread will read from. Recall that the supremum index
        -- is exclusive.
        i0  <- case dir of
                 L -> A.add numType inf tid
                 R -> do x <- A.sub numType sup tid
                         y <- A.sub numType x (lift 1)
                         return y

        -- The index that this thread will write to. This is just shifted along
        -- by one to make room for the initial element.
        j0  <- case dir of
                 L -> A.add numType i0 (lift 1)
                 R -> A.sub numType i0 (lift 1)

        -- Evaluate the initial element. Store it into the carry-in slot as well
        -- as to the array as the first element. This is always valid because if
        -- the input array is empty then we will be evaluating via mkScan'Fill.
        when (A.eq scalarType tid (lift 0)) $ do
          z <- seed
          writeArray arrOut i0                   z
          writeArray carry  (lift 0 :: IR Int32) z

        bd  <- blockDim
        let next ix = case dir of
                        L -> A.add numType ix bd
                        R -> A.sub numType ix bd

        -- Now, threads iterate over the elements along the innermost dimension.
        -- At each iteration the first thread incorporates the carry-in value
        -- from the previous step.
        --
        -- The index tracks how many elements remain for the thread block, since
        -- indices i* and j* are local to each thread
        n0  <- A.sub numType sup inf
        void $ while
          (\(A.fst3   -> n)       -> A.gt scalarType n (lift 0))
          (\(A.untrip -> (n,i,j)) -> do

            -- Wait for threads to catch up to ensure the carry-in value from
            -- the last iteration has been updated
            __syncthreads

            -- If all threads in the block will participate this round we can
            -- avoid (almost) all bounds checks.
            _ <- if A.gte scalarType n bd
                    -- All threads participate. No bounds checks required but
                    -- the last thread needs to update the carry-in value.
                    then do
                      x <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
                      y <- if A.eq scalarType tid (lift 0)
                              then do
                                c <- readArray carry (lift 0 :: IR Int32)
                                case dir of
                                  L -> app2 combine c x
                                  R -> app2 combine x c
                              else
                                return x
                      z <- scanBlockSMem dir dev combine Nothing y

                      -- Write results to the output array. Note that if we
                      -- align directly on the boundary of the array this is not
                      -- valid for the last thread.
                      case dir of
                        L -> when (A.lt  scalarType j sup) $ writeArray arrOut j z
                        R -> when (A.gte scalarType j inf) $ writeArray arrOut j z

                      -- Last thread of the block also saves its result as the
                      -- carry-in value
                      bd1 <- A.sub numType bd (lift 1)
                      when (A.eq scalarType tid bd1) $
                        writeArray carry (lift 0 :: IR Int32) z

                      return (IR OP_Unit :: IR ())

                    -- Only threads that are in bounds can participate. This is
                    -- the last iteration of the loop. The last active thread
                    -- still needs to store its value into the carry-in slot.
                    else do
                      when (A.lt scalarType tid n) $ do
                        x <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
                        y <- if A.eq scalarType tid (lift 0)
                                then do
                                  c <- readArray carry (lift 0 :: IR Int32)
                                  case dir of
                                    L -> app2 combine c x
                                    R -> app2 combine x c
                                else
                                  return x
                        z <- scanBlockSMem dir dev combine (Just n) y

                        m <- A.sub numType n (lift 1)
                        _ <- if A.lt scalarType tid m
                               then writeArray arrOut j                   z >> return (IR OP_Unit :: IR ())
                               else writeArray carry (lift 0 :: IR Int32) z >> return (IR OP_Unit :: IR ())

                        return ()
                      return (IR OP_Unit :: IR ())

            A.trip <$> A.sub numType n bd <*> next i <*> next j)
          (A.trip n0 i0 j0)

        -- Wait for the carry-in value to be updated
        __syncthreads

        -- Store the carry-in value to the separate final results array
        when (A.eq scalarType tid (lift 0)) $
          writeArray arrSum seg =<< readArray carry (lift 0 :: IR Int32)

    return_



-- Parallel scan, auxiliary
--
-- If this is an exclusive scan of an empty array, we just  fill the result with
-- the seed element.
--
mkScanFill
    :: (Shape sh, Elt e)
    => PTX
    -> Gamma aenv
    -> IRExp PTX aenv e
    -> CodeGen (IROpenAcc PTX aenv (Array sh e))
mkScanFill ptx aenv seed =
  mkGenerate ptx aenv (IRFun1 (const seed))

mkScan'Fill
    :: forall aenv sh e. (Shape sh, Elt e)
    => PTX
    -> Gamma aenv
    -> IRExp PTX aenv e
    -> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScan'Fill ptx aenv seed =
  Safe.coerce <$> (mkGenerate ptx aenv (IRFun1 (const seed)) :: CodeGen (IROpenAcc PTX aenv (Array sh e)))


-- Block wide scan
-- ---------------

-- Efficient block-wide (inclusive) scan using the specified operator.
--
-- Each block requires (#warps * (1 + 1.5*warp size)) elements of dynamically
-- allocated shared memory.
--
-- Example: https://github.com/NVlabs/cub/blob/1.5.4/cub/block/specializations/block_scan_warp_scans.cuh
--
scanBlockSMem
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties                             -- ^ properties of the target device
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> Maybe (IR Int32)                             -- ^ number of valid elements (may be less than block size)
    -> IR e                                         -- ^ calling thread's input element
    -> CodeGen (IR e)
scanBlockSMem dir dev combine nelem = warpScan >=> warpPrefix
  where
    int32 :: Integral a => a -> IR Int32
    int32 = lift . P.fromIntegral

    -- Temporary storage required for each warp
    warp_smem_elems = CUDA.warpSize dev + (CUDA.warpSize dev `P.quot` 2)
    warp_smem_bytes = warp_smem_elems  * sizeOf (eltType (undefined::e))

    -- Step 1: Scan in every warp
    warpScan :: IR e -> CodeGen (IR e)
    warpScan input = do
      -- Allocate (1.5 * warpSize) elements of shared memory for each warp
      -- (individually addressable by each warp)
      wid   <- warpId
      skip  <- A.mul numType wid (int32 warp_smem_bytes)
      smem  <- dynamicSharedMem (int32 warp_smem_elems) skip
      scanWarpSMem dir dev combine smem input

    -- Step 2: Collect the aggregate results of each warp to compute the prefix
    -- values for each warp and combine with the partial result to compute each
    -- thread's final value.
    warpPrefix :: IR e -> CodeGen (IR e)
    warpPrefix input = do
      -- Allocate #warps elements of shared memory
      bd    <- blockDim
      warps <- A.quot integralType bd (int32 (CUDA.warpSize dev))
      skip  <- A.mul numType warps (int32 warp_smem_bytes)
      smem  <- dynamicSharedMem warps skip

      -- Share warp aggregates
      wid   <- warpId
      lane  <- laneId
      when (A.eq scalarType lane (int32 (CUDA.warpSize dev - 1))) $ do
        writeArray smem wid input

      -- Wait for each warp to finish its local scan and share the aggregate
      __syncthreads

      -- Compute the prefix value for this warp and add to the partial result.
      -- This step is not required for the first warp, which has no carry-in.
      if A.eq scalarType wid (lift 0)
        then return input
        else do
          -- Every thread sequentially scans the warp aggregates to compute
          -- their prefix value. We do this sequentially, but could also have
          -- warp 0 do it cooperatively if we limit thread block sizes to
          -- (warp size ^ 2).
          steps  <- case nelem of
                      Nothing -> return wid
                      Just n  -> A.min scalarType wid =<< A.quot integralType n (int32 (CUDA.warpSize dev))

          p0     <- readArray smem (lift 0 :: IR Int32)
          prefix <- iterFromStepTo (lift 1) (lift 1) steps p0 $ \step x -> do
                      y <- readArray smem step
                      case dir of
                        L -> app2 combine x y
                        R -> app2 combine y x

          case dir of
            L -> app2 combine prefix input
            R -> app2 combine input prefix


-- Warp-wide scan
-- --------------

-- Efficient warp-wide (inclusive) scan using the specified operator.
--
-- Each warp requires 48 (1.5 x warp size) elements of shared memory. The
-- routine assumes that it is allocated individually per-warp (i.e. can be
-- indexed in the range [0, warp size)).
--
-- Example: https://github.com/NVlabs/cub/blob/1.5.4/cub/warp/specializations/warp_scan_smem.cuh
--
scanWarpSMem
    :: forall aenv e. Elt e
    => Direction
    -> DeviceProperties                             -- ^ properties of the target device
    -> IRFun2 PTX aenv (e -> e -> e)                -- ^ combination function
    -> IRArray (Vector e)                           -- ^ temporary storage array in shared memory (1.5 x warp size elements)
    -> IR e                                         -- ^ calling thread's input element
    -> CodeGen (IR e)
scanWarpSMem dir dev combine smem = scan 0
  where
    log2 :: Double -> Double
    log2 = P.logBase 2

    -- Number of steps required to scan warp
    steps     = P.floor (log2 (P.fromIntegral (CUDA.warpSize dev)))
    halfWarp  = P.fromIntegral (CUDA.warpSize dev `P.quot` 2)

    -- Unfold the scan as a recursive code generation function
    scan :: Int -> IR e -> CodeGen (IR e)
    scan step x
      | step >= steps               = return x
      | offset <- 1 `P.shiftL` step = do
          -- share partial result through shared memory buffer
          lane <- laneId
          i    <- A.add numType lane (lift halfWarp)
          writeArray smem i x

          -- update partial result if in range
          x'   <- if A.gte scalarType lane (lift offset)
                    then do
                      i' <- A.sub numType i (lift offset)     -- lane + HALF_WARP - offset
                      x' <- readArray smem i'
                      case dir of
                        L -> app2 combine x' x
                        R -> app2 combine x x'

                    else
                      return x

          scan (step+1) x'