{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.StreamKernel
  ( segThreadCapped,
  )
where

import Control.Monad
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.GPU hiding
  ( BasicOp,
    Body,
    Exp,
    FParam,
    FunDef,
    LParam,
    Lambda,
    Pat,
    PatElem,
    Prog,
    RetType,
    Stm,
  )
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Prelude hiding (quot)

data KernelSize = KernelSize
  { -- | Int64
    KernelSize -> SubExp
kernelElementsPerThread :: SubExp,
    -- | Int32
    KernelSize -> SubExp
kernelNumThreads :: SubExp
  }
  deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
/= :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KernelSize -> KernelSize -> Ordering
compare :: KernelSize -> KernelSize -> Ordering
$c< :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
>= :: KernelSize -> KernelSize -> Bool
$cmax :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
min :: KernelSize -> KernelSize -> KernelSize
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
(Int -> KernelSize -> ShowS)
-> (KernelSize -> String)
-> ([KernelSize] -> ShowS)
-> Show KernelSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelSize -> ShowS
showsPrec :: Int -> KernelSize -> ShowS
$cshow :: KernelSize -> String
show :: KernelSize -> String
$cshowList :: [KernelSize] -> ShowS
showList :: [KernelSize] -> ShowS
Show)

numberOfBlocks ::
  (MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
  String ->
  SubExp ->
  SubExp ->
  m (SubExp, SubExp)
numberOfBlocks :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfBlocks String
desc SubExp
w SubExp
tblock_size = do
  Name
max_num_tblocks_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
prettyString (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_num_tblocks")
  SubExp
num_tblocks <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_tblocks" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
        SizeOp -> HostOp inner (Rep m)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp inner (Rep m)) -> SizeOp -> HostOp inner (Rep m)
forall a b. (a -> b) -> a -> b
$
          SubExp -> Name -> SubExp -> SizeOp
CalcNumBlocks SubExp
w Name
max_num_tblocks_key SubExp
tblock_size
  SubExp
num_threads <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_tblocks SubExp
tblock_size
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
num_tblocks, SubExp
num_threads)

-- | Like 'segThread', but cap the thread count to the input size.
-- This is more efficient for small kernels, e.g. summing a small
-- array.
segThreadCapped :: (MonadFreshNames m) => MkSegLevel GPU m
segThreadCapped :: forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
  SubExp
w <-
    String -> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"nest_size"
      (Exp GPU -> BuilderT GPU m SubExp)
-> BuilderT GPU m (Exp GPU) -> BuilderT GPU m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ws
  SubExp
tblock_size <- String -> SizeClass -> BuilderT GPU m SubExp
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_tblock_size") SizeClass
SizeThreadBlock

  case ThreadRecommendation
r of
    ThreadRecommendation
ManyThreads -> do
      SubExp
usable_groups <-
        String -> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"segmap_usable_groups"
          (Exp GPU -> BuilderT GPU m SubExp)
-> BuilderT GPU m (Exp GPU) -> BuilderT GPU m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
            (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe)
            (SubExp -> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)
            (SubExp -> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
SubExp -> BuilderT GPU m (Exp GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BuilderT GPU m (Exp GPU))
-> BuilderT GPU m SubExp -> BuilderT GPU m (Exp GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
tblock_size)
      let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
tblock_size)
      SegLevel -> BuilderT GPU m SegLevel
forall a. a -> BuilderT GPU m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegLevel -> BuilderT GPU m SegLevel)
-> SegLevel -> BuilderT GPU m SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
    NoRecommendation SegVirt
v -> do
      (SubExp
num_tblocks, SubExp
_) <- String -> SubExp -> SubExp -> BuilderT GPU m (SubExp, SubExp)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfBlocks String
desc SubExp
w SubExp
tblock_size
      let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_tblocks) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
tblock_size)
      SegLevel -> BuilderT GPU m SegLevel
forall a. a -> BuilderT GPU m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegLevel -> BuilderT GPU m SegLevel)
-> SegLevel -> BuilderT GPU m SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
v (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)