{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.ImpGen.Kernels.SegScan
( compileSegScan )
where
import Control.Monad.Except
import Data.Maybe
import Data.List
import Prelude hiding (quot, rem)
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
makeLocalArrays :: SubExp -> SubExp -> [SubExp] -> Lambda InKernel
-> InKernelGen [VName]
makeLocalArrays group_size num_threads nes scan_op = do
let (scan_x_params, _scan_y_params) =
splitAt (length nes) $ lambdaParams scan_op
forM scan_x_params $ \p ->
case paramAttr p of
MemArray pt shape _ (ArrayIn mem _) -> do
let shape' = Shape [num_threads] <> shape
sArray "scan_arr" pt shape' $
ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape'
_ -> do
let pt = elemType $ paramType p
shape = Shape [group_size]
sAllocArray "scan_arr" pt shape $ Space "local"
type CrossesSegment = Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
scanStage1 :: Pattern ExplicitMemory
-> KernelSpace
-> Lambda InKernel -> [SubExp]
-> KernelBody InKernel
-> CallKernelGen (Imp.Exp, CrossesSegment)
scanStage1 (Pattern _ pes) space scan_op nes kbody = do
(base_constants, init_constants) <- kernelInitialisationSetSpace space $ return ()
let (gtids, dims) = unzip $ spaceDimensions space
dims' <- mapM toExp dims
let constants = base_constants { kernelThreadActive = true }
num_elements = product dims'
elems_per_thread = num_elements `quotRoundingUp` kernelNumThreads constants
elems_per_group = kernelGroupSize constants * elems_per_thread
scan_op_renamed <- renameLambda scan_op
let crossesSegment =
case reverse dims' of
segment_size : _ : _ -> Just $ \from to ->
(to-from) .>. (to `rem` segment_size)
_ -> Nothing
sKernel constants "scan_stage1" $ allThreads constants $ do
init_constants
local_arrs <-
makeLocalArrays (spaceGroupSize space) (spaceNumThreads space)
nes scan_op
dScope Nothing $ scopeOfLParams $ lambdaParams scan_op
let (scan_x_params, scan_y_params) =
splitAt (length nes) $ lambdaParams scan_op
forM_ (zip scan_x_params nes) $ \(p, ne) ->
copyDWIM (paramName p) [] ne []
j <- newVName "j"
sFor j Int32 elems_per_thread $ do
chunk_offset <- dPrimV "chunk_offset" $
kernelGroupSize constants * Imp.var j int32 +
kernelGroupId constants * elems_per_group
flat_idx <- dPrimV "flat_idx" $
Imp.var chunk_offset int32 + kernelLocalThreadId constants
zipWithM_ (<--) gtids $ unflattenIndex dims' $ Imp.var flat_idx int32
let in_bounds =
foldl1 (.&&.) $ zipWith (.<.) (map (`Imp.var` int32) gtids) dims'
when_in_bounds = compileStms mempty (kernelBodyStms kbody) $ do
let (scan_res, map_res) = splitAt (length nes) $ kernelBodyResult kbody
sComment "write to-scan values to parameters" $
forM_ (zip scan_y_params scan_res) $ \(p, se) ->
copyDWIM (paramName p) [] (kernelResultSubExp se) []
sComment "write mapped values results to global memory" $
forM_ (zip (drop (length nes) pes) map_res) $ \(pe, se) ->
copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids)
(kernelResultSubExp se) []
when_out_of_bounds = forM_ (zip scan_y_params nes) $ \(p, ne) ->
copyDWIM (paramName p) [] ne []
sComment "threads in bounds read input; others get neutral element" $
sIf in_bounds when_in_bounds when_out_of_bounds
sComment "combine with carry and write to local memory" $
compileStms mempty (bodyStms $ lambdaBody scan_op) $
forM_ (zip local_arrs $ bodyResult $ lambdaBody scan_op) $ \(arr, se) ->
copyDWIM arr [kernelLocalThreadId constants] se []
let crossesSegment' = do
f <- crossesSegment
Just $ \from to ->
let from' = from + Imp.var chunk_offset int32
to' = to + Imp.var chunk_offset int32
in f from' to'
groupScan constants crossesSegment'
(kernelGroupSize constants) scan_op_renamed local_arrs
sComment "threads in bounds write partial scan result" $
sWhen in_bounds $ forM_ (zip pes local_arrs) $ \(pe, arr) ->
copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids)
(Var arr) [kernelLocalThreadId constants]
sOp Imp.LocalBarrier
let load_carry =
forM_ (zip local_arrs scan_x_params) $ \(arr, p) ->
copyDWIM (paramName p) [] (Var arr) [kernelGroupSize constants - 1]
load_neutral =
forM_ (zip nes scan_x_params) $ \(ne, p) ->
copyDWIM (paramName p) [] ne []
sComment "first thread reads last element as carry-in for next iteration" $
sWhen (kernelLocalThreadId constants .==. 0) $
case crossesSegment of Nothing -> load_carry
Just f -> sIf (f (Imp.var chunk_offset int32 +
kernelGroupSize constants-1)
(Imp.var chunk_offset int32 +
kernelGroupSize constants))
load_neutral load_carry
sOp Imp.LocalBarrier
return (elems_per_group, crossesSegment)
scanStage2 :: Pattern ExplicitMemory
-> Imp.Exp -> CrossesSegment -> KernelSpace
-> Lambda InKernel -> [SubExp]
-> CallKernelGen ()
scanStage2 (Pattern _ pes) elems_per_group crossesSegment space scan_op nes = do
group_size <- toExp $ spaceNumGroups space
(constants, init_constants) <-
kernelInitialisationSimple 1 group_size Nothing
let (gtids, dims) = unzip $ spaceDimensions space
dims' <- mapM toExp dims
let crossesSegment' = do
f <- crossesSegment
Just $ \from to ->
f ((from + 1) * elems_per_group - 1) ((to + 1) * elems_per_group - 1)
sKernel constants "scan_stage2" $ do
init_constants
local_arrs <- makeLocalArrays (spaceNumGroups space) (spaceNumGroups space)
nes scan_op
flat_idx <- dPrimV "flat_idx" $
(kernelLocalThreadId constants + 1) * elems_per_group - 1
zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ Imp.var flat_idx int32
let in_bounds =
foldl1 (.&&.) $ zipWith (.<.) (map (`Imp.var` int32) gtids) dims'
when_in_bounds = forM_ (zip local_arrs pes) $ \(arr, pe) ->
copyDWIM arr [kernelLocalThreadId constants]
(Var $ patElemName pe) $ map (`Imp.var` int32) gtids
when_out_of_bounds = forM_ (zip local_arrs nes) $ \(arr, ne) ->
copyDWIM arr [kernelLocalThreadId constants] ne []
sComment "threads in bound read carries; others get neutral element" $
sIf in_bounds when_in_bounds when_out_of_bounds
groupScan constants crossesSegment'
(kernelGroupSize constants) scan_op local_arrs
sComment "threads in bounds write scanned carries" $
sWhen in_bounds $ forM_ (zip pes local_arrs) $ \(pe, arr) ->
copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids)
(Var arr) [kernelLocalThreadId constants]
scanStage3 :: Pattern ExplicitMemory
-> Imp.Exp -> CrossesSegment -> KernelSpace
-> Lambda InKernel -> [SubExp]
-> CallKernelGen ()
scanStage3 (Pattern _ pes) elems_per_group crossesSegment space scan_op nes = do
let (gtids, dims) = unzip $ spaceDimensions space
dims' <- mapM toExp dims
(constants, init_constants) <- simpleKernelConstants (product dims') "scan"
sKernel constants "scan_stage3" $ do
init_constants
zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ kernelGlobalThreadId constants
orig_group <- dPrimV "orig_group" $
kernelGlobalThreadId constants `quot` elems_per_group
carry_in_flat_idx <- dPrimV "carry_in_flat_idx" $
Imp.var orig_group int32 * elems_per_group - 1
let carry_in_idx = unflattenIndex dims' $ Imp.var carry_in_flat_idx int32
let crosses_segment = fromMaybe false $
crossesSegment <*>
pure (Imp.var carry_in_flat_idx int32) <*>
pure (kernelGlobalThreadId constants)
is_a_carry = kernelGlobalThreadId constants .==.
(Imp.var orig_group int32 + 1) * elems_per_group - 1
no_carry_in = Imp.var orig_group int32 .==. 0 .||. is_a_carry .||. crosses_segment
sWhen (kernelThreadActive constants) $ sUnless no_carry_in $ do
dScope Nothing $ scopeOfLParams $ lambdaParams scan_op
let (scan_x_params, scan_y_params) =
splitAt (length nes) $ lambdaParams scan_op
forM_ (zip scan_x_params pes) $ \(p, pe) ->
copyDWIM (paramName p) [] (Var $ patElemName pe) carry_in_idx
forM_ (zip scan_y_params pes) $ \(p, pe) ->
copyDWIM (paramName p) [] (Var $ patElemName pe) $ map (`Imp.var` int32) gtids
compileBody' scan_x_params $ lambdaBody scan_op
forM_ (zip scan_x_params pes) $ \(p, pe) ->
copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids) (Var $ paramName p) []
compileSegScan :: Pattern ExplicitMemory
-> KernelSpace
-> Lambda InKernel -> [SubExp]
-> KernelBody InKernel
-> CallKernelGen ()
compileSegScan pat space scan_op nes kbody = do
(elems_per_group, crossesSegment) <- scanStage1 pat space scan_op nes kbody
emit $ Imp.DebugPrint "\n# SegScan" Nothing
emit $ Imp.DebugPrint "elems_per_group" $ Just (int32, elems_per_group)
scan_op' <- renameLambda scan_op
scan_op'' <- renameLambda scan_op
scanStage2 pat elems_per_group crossesSegment space scan_op' nes
scanStage3 pat elems_per_group crossesSegment space scan_op'' nes