{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} module Futhark.Pass.ExtractKernels.StreamKernel ( segThreadCapped, streamRed, streamMap, ) where import Control.Monad import Control.Monad.Writer import Data.List () import Futhark.Analysis.PrimExp import Futhark.IR import Futhark.IR.Kernels hiding ( BasicOp, Body, Exp, FParam, FunDef, LParam, Lambda, PatElem, Pattern, Prog, RetType, Stm, ) import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.BlockedKernel import Futhark.Pass.ExtractKernels.ToKernels import Futhark.Tools import Prelude hiding (quot) data KernelSize = KernelSize { -- | Int64 kernelElementsPerThread :: SubExp, -- | Int32 kernelNumThreads :: SubExp } deriving (Eq, Ord, Show) numberOfGroups :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) => String -> SubExp -> SubExp -> m (SubExp, SubExp) numberOfGroups desc w group_size = do max_num_groups_key <- nameFromString . pretty <$> newVName (desc ++ "_num_groups") num_groups <- letSubExp "num_groups" $ Op $ SizeOp $ CalcNumGroups w max_num_groups_key group_size num_threads <- letSubExp "num_threads" $ BasicOp $ BinOp (Mul Int64 OverflowUndef) num_groups group_size return (num_groups, num_threads) blockedKernelSize :: (MonadBinder m, Lore m ~ Kernels) => String -> SubExp -> m KernelSize blockedKernelSize desc w = do group_size <- getSize (desc ++ "_group_size") SizeGroup (_, num_threads) <- numberOfGroups desc w group_size per_thread_elements <- letSubExp "per_thread_elements" =<< eBinOp (SDivUp Int64 Unsafe) (eSubExp w) (eSubExp num_threads) return $ KernelSize per_thread_elements num_threads splitArrays :: (MonadBinder m, Lore m ~ Kernels) => VName -> [VName] -> SplitOrdering -> SubExp -> SubExp -> SubExp -> [VName] -> m () splitArrays chunk_size split_bound ordering w i elems_per_i arrs = do letBindNames [chunk_size] $ Op $ SizeOp $ SplitSpace ordering w i elems_per_i case ordering of SplitContiguous -> do offset <- letSubExp "slice_offset" $ BasicOp $ BinOp (Mul Int64 OverflowUndef) i elems_per_i zipWithM_ (contiguousSlice offset) split_bound arrs SplitStrided stride -> zipWithM_ (stridedSlice stride) split_bound arrs where contiguousSlice offset slice_name arr = do arr_t <- lookupType arr let slice = fullSlice arr_t [DimSlice offset (Var chunk_size) (constant (1 :: Int64))] letBindNames [slice_name] $ BasicOp $ Index arr slice stridedSlice stride slice_name arr = do arr_t <- lookupType arr let slice = fullSlice arr_t [DimSlice i (Var chunk_size) stride] letBindNames [slice_name] $ BasicOp $ Index arr slice partitionChunkedKernelFoldParameters :: Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec]) partitionChunkedKernelFoldParameters num_accs (i_param : chunk_param : params) = let (acc_params, arr_params) = splitAt num_accs params in (paramName i_param, chunk_param, acc_params, arr_params) partitionChunkedKernelFoldParameters _ _ = error "partitionChunkedKernelFoldParameters: lambda takes too few parameters" blockedPerThread :: (MonadBinder m, Lore m ~ Kernels) => VName -> SubExp -> KernelSize -> StreamOrd -> Lambda (Lore m) -> Int -> [VName] -> m ([PatElemT Type], [PatElemT Type]) blockedPerThread thread_gtid w kernel_size ordering lam num_nonconcat arrs = do let (_, chunk_size, [], arr_params) = partitionChunkedKernelFoldParameters 0 $ lambdaParams lam ordering' = case ordering of InOrder -> SplitContiguous Disorder -> SplitStrided $ kernelNumThreads kernel_size red_ts = take num_nonconcat $ lambdaReturnType lam map_ts = map rowType $ drop num_nonconcat $ lambdaReturnType lam per_thread <- asIntS Int64 $ kernelElementsPerThread kernel_size splitArrays (paramName chunk_size) (map paramName arr_params) ordering' w (Var thread_gtid) per_thread arrs chunk_red_pes <- forM red_ts $ \red_t -> do pe_name <- newVName "chunk_fold_red" return $ PatElem pe_name red_t chunk_map_pes <- forM map_ts $ \map_t -> do pe_name <- newVName "chunk_fold_map" return $ PatElem pe_name $ map_t `arrayOfRow` Var (paramName chunk_size) let (chunk_red_ses, chunk_map_ses) = splitAt num_nonconcat $ bodyResult $ lambdaBody lam addStms $ bodyStms (lambdaBody lam) <> stmsFromList [ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se | (pe, se) <- zip chunk_red_pes chunk_red_ses ] <> stmsFromList [ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se | (pe, se) <- zip chunk_map_pes chunk_map_ses ] return (chunk_red_pes, chunk_map_pes) -- | Given a chunked fold lambda that takes its initial accumulator -- value as parameters, bind those parameters to the neutral element -- instead. kerneliseLambda :: MonadFreshNames m => [SubExp] -> Lambda Kernels -> m (Lambda Kernels) kerneliseLambda nes lam = do thread_index <- newVName "thread_index" let thread_index_param = Param thread_index $ Prim int64 (fold_chunk_param, fold_acc_params, fold_inp_params) = partitionChunkedFoldParameters (length nes) $ lambdaParams lam mkAccInit p (Var v) | not $ primType $ paramType p = mkLet [] [paramIdent p] $ BasicOp $ Copy v mkAccInit p x = mkLet [] [paramIdent p] $ BasicOp $ SubExp x acc_init_bnds = stmsFromList $ zipWith mkAccInit fold_acc_params nes return lam { lambdaBody = insertStms acc_init_bnds $ lambdaBody lam, lambdaParams = thread_index_param : fold_chunk_param : fold_inp_params } prepareStream :: (MonadBinder m, Lore m ~ Kernels) => KernelSize -> [(VName, SubExp)] -> SubExp -> Commutativity -> Lambda Kernels -> [SubExp] -> [VName] -> m (SubExp, SegSpace, [Type], KernelBody Kernels) prepareStream size ispace w comm fold_lam nes arrs = do let (KernelSize elems_per_thread num_threads) = size let (ordering, split_ordering) = case comm of Commutative -> (Disorder, SplitStrided num_threads) Noncommutative -> (InOrder, SplitContiguous) fold_lam' <- kerneliseLambda nes fold_lam gtid <- newVName "gtid" space <- mkSegSpace $ ispace ++ [(gtid, num_threads)] kbody <- fmap (uncurry (flip (KernelBody ()))) $ runBinder $ localScope (scopeOfSegSpace space) $ do (chunk_red_pes, chunk_map_pes) <- blockedPerThread gtid w size ordering fold_lam' (length nes) arrs let concatReturns pe = ConcatReturns split_ordering w elems_per_thread $ patElemName pe return ( map (Returns ResultMaySimplify . Var . patElemName) chunk_red_pes ++ map concatReturns chunk_map_pes ) let (redout_ts, mapout_ts) = splitAt (length nes) $ lambdaReturnType fold_lam ts = redout_ts ++ map rowType mapout_ts return (num_threads, space, ts, kbody) streamRed :: (MonadFreshNames m, HasScope Kernels m) => MkSegLevel Kernels m -> Pattern Kernels -> SubExp -> Commutativity -> Lambda Kernels -> Lambda Kernels -> [SubExp] -> [VName] -> m (Stms Kernels) streamRed mk_lvl pat w comm red_lam fold_lam nes arrs = runBinderT'_ $ do -- The strategy here is to rephrase the stream reduction as a -- non-segmented SegRed that does explicit chunking within its body. -- First, figure out how many threads to use for this. size <- blockedKernelSize "stream_red" w let (redout_pes, mapout_pes) = splitAt (length nes) $ patternElements pat (redout_pat, ispace, read_dummy) <- dummyDim $ Pattern [] redout_pes let pat' = Pattern [] $ patternElements redout_pat ++ mapout_pes (_, kspace, ts, kbody) <- prepareStream size ispace w comm fold_lam nes arrs lvl <- mk_lvl [w] "stream_red" $ NoRecommendation SegNoVirt letBind pat' $ Op $ SegOp $ SegRed lvl kspace [SegBinOp comm red_lam nes mempty] ts kbody read_dummy -- Similar to streamRed, but without the last reduction. streamMap :: (MonadFreshNames m, HasScope Kernels m) => MkSegLevel Kernels m -> [String] -> [PatElem Kernels] -> SubExp -> Commutativity -> Lambda Kernels -> [SubExp] -> [VName] -> m ((SubExp, [VName]), Stms Kernels) streamMap mk_lvl out_desc mapout_pes w comm fold_lam nes arrs = runBinderT' $ do size <- blockedKernelSize "stream_map" w (threads, kspace, ts, kbody) <- prepareStream size [] w comm fold_lam nes arrs let redout_ts = take (length nes) ts redout_pes <- forM (zip out_desc redout_ts) $ \(desc, t) -> PatElem <$> newVName desc <*> pure (t `arrayOfRow` threads) let pat = Pattern [] $ redout_pes ++ mapout_pes lvl <- mk_lvl [w] "stream_map" $ NoRecommendation SegNoVirt letBind pat $ Op $ SegOp $ SegMap lvl kspace ts kbody return (threads, map patElemName redout_pes) -- | Like 'segThread', but cap the thread count to the input size. -- This is more efficient for small kernels, e.g. summing a small -- array. segThreadCapped :: MonadFreshNames m => MkSegLevel Kernels m segThreadCapped ws desc r = do w <- letSubExp "nest_size" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) ws group_size <- getSize (desc ++ "_group_size") SizeGroup case r of ManyThreads -> do usable_groups <- letSubExp "segmap_usable_groups" =<< eBinOp (SDivUp Int64 Unsafe) (eSubExp w) (eSubExp =<< asIntS Int64 group_size) return $ SegThread (Count usable_groups) (Count group_size) SegNoVirt NoRecommendation v -> do (num_groups, _) <- numberOfGroups desc w group_size return $ SegThread (Count num_groups) (Count group_size) v