module Futhark.CodeGen.ImpGen.Kernels.SegScan (compileSegScan) where

import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.CodeGen.ImpGen.Kernels.SegScan.SinglePass as SinglePass
import qualified Futhark.CodeGen.ImpGen.Kernels.SegScan.TwoPass as TwoPass
import Futhark.IR.KernelsMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScans :: [SegBinOp KernelsMem] -> SegBinOp KernelsMem
combineScans :: [SegBinOp KernelsMem] -> SegBinOp KernelsMem
combineScans [SegBinOp KernelsMem]
ops =
  SegBinOp :: forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp KernelsMem -> Commutativity)
-> [SegBinOp KernelsMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm [SegBinOp KernelsMem]
ops),
      segBinOpLambda :: Lambda KernelsMem
segBinOpLambda = Lambda KernelsMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = (SegBinOp KernelsMem -> [SubExp])
-> [SegBinOp KernelsMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp KernelsMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = Shape
forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda KernelsMem]
lams = (SegBinOp KernelsMem -> Lambda KernelsMem)
-> [SegBinOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp KernelsMem]
ops
    xParams :: LambdaT lore -> [Param (LParamInfo lore)]
xParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT lore
lam)
    yParams :: LambdaT lore -> [Param (LParamInfo lore)]
yParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT lore
lam)
    lam' :: Lambda KernelsMem
lam' =
      Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
        { lambdaParams :: [LParam KernelsMem]
lambdaParams = (Lambda KernelsMem -> [Param LParamMem])
-> [Lambda KernelsMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda KernelsMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
xParams [Lambda KernelsMem]
lams [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ (Lambda KernelsMem -> [Param LParamMem])
-> [Lambda KernelsMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda KernelsMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
yParams [Lambda KernelsMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = (Lambda KernelsMem -> [Type]) -> [Lambda KernelsMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType [Lambda KernelsMem]
lams,
          lambdaBody :: BodyT KernelsMem
lambdaBody =
            BodyDec KernelsMem
-> Stms KernelsMem -> [SubExp] -> BodyT KernelsMem
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
              ()
              ([Stms KernelsMem] -> Stms KernelsMem
forall a. Monoid a => [a] -> a
mconcat ((Lambda KernelsMem -> Stms KernelsMem)
-> [Lambda KernelsMem] -> [Stms KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> (Lambda KernelsMem -> BodyT KernelsMem)
-> Lambda KernelsMem
-> Stms KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda KernelsMem]
lams))
              ((Lambda KernelsMem -> [SubExp]) -> [Lambda KernelsMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp])
-> (Lambda KernelsMem -> BodyT KernelsMem)
-> Lambda KernelsMem
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda KernelsMem]
lams)
        }

canBeSinglePass :: SegSpace -> [SegBinOp KernelsMem] -> Maybe (SegBinOp KernelsMem)
canBeSinglePass :: SegSpace -> [SegBinOp KernelsMem] -> Maybe (SegBinOp KernelsMem)
canBeSinglePass SegSpace
space [SegBinOp KernelsMem]
ops
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
    (SegBinOp KernelsMem -> Bool) -> [SegBinOp KernelsMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp KernelsMem -> Bool
forall lore. SegBinOp lore -> Bool
ok [SegBinOp KernelsMem]
ops =
    SegBinOp KernelsMem -> Maybe (SegBinOp KernelsMem)
forall a. a -> Maybe a
Just (SegBinOp KernelsMem -> Maybe (SegBinOp KernelsMem))
-> SegBinOp KernelsMem -> Maybe (SegBinOp KernelsMem)
forall a b. (a -> b) -> a -> b
$ [SegBinOp KernelsMem] -> SegBinOp KernelsMem
combineScans [SegBinOp KernelsMem]
ops
  | Bool
otherwise =
    Maybe (SegBinOp KernelsMem)
forall a. Maybe a
Nothing
  where
    ok :: SegBinOp lore -> Bool
ok SegBinOp lore
op =
      SegBinOp lore -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp lore
op Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op))

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pattern KernelsMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  KernelBody KernelsMem ->
  CallKernelGen ()
compileSegScan :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody = TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TPrimExp Int64 ExpLeaf
0 TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
  Target
target <- HostEnv -> Target
hostTarget (HostEnv -> Target)
-> ImpM KernelsMem HostEnv HostOp HostEnv
-> ImpM KernelsMem HostEnv HostOp Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
  case Target
target of
    Target
CUDA
      | Just SegBinOp KernelsMem
scan' <- SegSpace -> [SegBinOp KernelsMem] -> Maybe (SegBinOp KernelsMem)
canBeSinglePass SegSpace
space [SegBinOp KernelsMem]
scans ->
        Pattern KernelsMem
-> SegLevel
-> SegSpace
-> SegBinOp KernelsMem
-> KernelBody KernelsMem
-> CallKernelGen ()
SinglePass.compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space SegBinOp KernelsMem
scan' KernelBody KernelsMem
kbody
    Target
_ -> Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
TwoPass.compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody
  where
    n :: TPrimExp Int64 ExpLeaf
n = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp ([SubExp] -> [TPrimExp Int64 ExpLeaf])
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space