-- | Code generation for 'SegScan'.  Dispatches to either a
-- single-pass or two-pass implementation, depending on the nature of
-- the scan and the chosen abckend.
module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where

import Control.Monad
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass qualified as SinglePass
import Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass qualified as TwoPass
import Futhark.IR.GPUMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScanOps :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScanOps :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScanOps [SegBinOp GPUMem]
ops =
  SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp GPUMem -> Commutativity)
-> [SegBinOp GPUMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
      segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = (SegBinOp GPUMem -> [SubExp]) -> [SegBinOp GPUMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = Shape
forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda GPUMem]
lams = (SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    lam' :: Lambda GPUMem
lam' =
      Lambda
        { lambdaParams :: [LParam GPUMem]
lambdaParams = (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [LParam GPUMem]
Lambda GPUMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
xParams [Lambda GPUMem]
lams [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [LParam GPUMem]
Lambda GPUMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = (Lambda GPUMem -> [Type]) -> [Lambda GPUMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
          lambdaBody :: Body GPUMem
lambdaBody =
            BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
              ()
              ([Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ((Lambda GPUMem -> Stms GPUMem) -> [Lambda GPUMem] -> [Stms GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem)
-> (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams))
              ((Lambda GPUMem -> Result) -> [Lambda GPUMem] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Body GPUMem -> Result
forall rep. Body rep -> Result
bodyResult (Body GPUMem -> Result)
-> (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams)
        }

bodyHas :: (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas :: (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
f = (Stm GPUMem -> Bool) -> Stms GPUMem -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp GPUMem -> Bool
f' (Exp GPUMem -> Bool)
-> (Stm GPUMem -> Exp GPUMem) -> Stm GPUMem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp) (Stms GPUMem -> Bool)
-> (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms
  where
    f' :: Exp GPUMem -> Bool
f' Exp GPUMem
e
      | Exp GPUMem -> Bool
f Exp GPUMem
e = Bool
True
      | Bool
otherwise = Maybe () -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe () -> Bool) -> Maybe () -> Bool
forall a b. (a -> b) -> a -> b
$ Walker GPUMem Maybe -> Exp GPUMem -> Maybe ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM Walker GPUMem Maybe
walker Exp GPUMem
e
    walker :: Walker GPUMem Maybe
walker =
      Walker GPUMem Maybe
forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker
        { walkOnBody = const $ guard . not . bodyHas f
        }

canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scan_ops =
  if (SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
ok [SegBinOp GPUMem]
scan_ops
    then SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a. a -> Maybe a
Just (SegBinOp GPUMem -> Maybe (SegBinOp GPUMem))
-> SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScanOps [SegBinOp GPUMem]
scan_ops
    else Maybe (SegBinOp GPUMem)
forall a. Maybe a
Nothing
  where
    ok :: SegBinOp GPUMem -> Bool
ok SegBinOp GPUMem
op =
      SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op))
        Bool -> Bool -> Bool
&& Bool -> Bool
not ((Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
forall {rep}. Exp rep -> Bool
isAssert (Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)))
    isAssert :: Exp rep -> Bool
isAssert (BasicOp Assert {}) = Bool
True
    isAssert Exp rep
_ = Bool
False

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scan_ops KernelBody GPUMem
map_kbody =
  TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
    Target
target <- HostEnv -> Target
hostTarget (HostEnv -> Target)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv

    case (Target -> Bool
targetSupportsSinglePass Target
target, [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scan_ops) of
      (Bool
True, Just SegBinOp GPUMem
scan_ops') ->
        Pat LParamMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan_ops' KernelBody GPUMem
map_kbody
      (Bool, Maybe (SegBinOp GPUMem))
_ ->
        Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scan_ops KernelBody GPUMem
map_kbody
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing
  where
    n :: TPrimExp Int64 VName
n = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    targetSupportsSinglePass :: Target -> Bool
targetSupportsSinglePass Target
CUDA = Bool
True
    targetSupportsSinglePass Target
HIP = Bool
True
    targetSupportsSinglePass Target
_ = Bool
False