{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.StreamKernel
( segThreadCapped,
)
where
import Control.Monad
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.GPU hiding
( BasicOp,
Body,
Exp,
FParam,
FunDef,
LParam,
Lambda,
Pat,
PatElem,
Prog,
RetType,
Stm,
)
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Prelude hiding (quot)
data KernelSize = KernelSize
{
KernelSize -> SubExp
kernelElementsPerThread :: SubExp,
KernelSize -> SubExp
kernelNumThreads :: SubExp
}
deriving (KernelSize -> KernelSize -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSize] -> ShowS
$cshowList :: [KernelSize] -> ShowS
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> ShowS
$cshowsPrec :: Int -> KernelSize -> ShowS
Show)
numberOfGroups ::
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String ->
SubExp ->
SubExp ->
m (SubExp, SubExp)
numberOfGroups :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size = do
Name
max_num_groups_key <- String -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> String
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
SubExp
num_groups <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_groups" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w Name
max_num_groups_key SubExp
group_size
SubExp
num_threads <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
num_groups, SubExp
num_threads)
segThreadCapped :: MonadFreshNames m => MkSegLevel GPU m
segThreadCapped :: forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
SubExp
w <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"nest_size"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ws
SubExp
group_size <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup
case ThreadRecommendation
r of
ThreadRecommendation
ManyThreads -> do
SubExp
usable_groups <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"segmap_usable_groups"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe)
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
usable_groups) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt (forall a. a -> Maybe a
Just KernelGrid
grid)
NoRecommendation SegVirt
v -> do
(SubExp
num_groups, SubExp
_) <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size
let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
v (forall a. a -> Maybe a
Just KernelGrid
grid)