{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.SegMap (compileSegMap) where
import Control.Monad.Except
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.IR.KernelsMem
import Futhark.Util.IntegralExp (divUp)
import Prelude hiding (quot, rem)
compileSegMap ::
Pattern KernelsMem ->
SegLevel ->
SegSpace ->
KernelBody KernelsMem ->
CallKernelGen ()
compileSegMap :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegMap Pattern KernelsMem
pat SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody = do
let ([VName]
is, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
num_groups' :: Count NumGroups (TExp Int64)
num_groups' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
group_size' :: Count GroupSize (TExp Int64)
group_size' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
case SegLevel
lvl of
SegThread {} -> do
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegMap" Maybe Exp
forall a. Maybe a
Nothing
let virt_num_groups :: TPrimExp Int32 ExpLeaf
virt_num_groups =
TExp Int64 -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 ExpLeaf)
-> TExp Int64 -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segmap" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt
-> TPrimExp Int32 ExpLeaf
-> (TPrimExp Int32 ExpLeaf -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TPrimExp Int32 ExpLeaf
virt_num_groups ((TPrimExp Int32 ExpLeaf -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 ExpLeaf -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
group_id -> do
TPrimExp Int32 ExpLeaf
local_tid <- KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId (KernelConstants -> TPrimExp Int32 ExpLeaf)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> TPrimExp Int32 ExpLeaf
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int32 ExpLeaf)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let global_tid :: TExp Int64
global_tid =
TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size')
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
local_tid
(VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(TExp Int64 -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 ([TExp Int64] -> [TExp Int64]) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((TExp Int64 -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 [TExp Int64]
dims') TExp Int64
global_tid
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem KernelsMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
SegGroup {} ->
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelGroup String
"segmap_intragroup" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let virt_num_groups :: TPrimExp Int32 ExpLeaf
virt_num_groups = TExp Int64 -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 ExpLeaf)
-> TExp Int64 -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims'
Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall a. Stms KernelsMem -> InKernelGen a -> InKernelGen a
precomputeSegOpIDs (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt
-> TPrimExp Int32 ExpLeaf
-> (TPrimExp Int32 ExpLeaf -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TPrimExp Int32 ExpLeaf
virt_num_groups ((TPrimExp Int32 ExpLeaf -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 ExpLeaf -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
group_id -> do
(VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
group_id
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem KernelsMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
space) (PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody