-- | Multicore code generation for 'SegMap'.
module Futhark.CodeGen.ImpGen.Multicore.SegMap
  ( compileSegMap,
  )
where

import Control.Monad
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Transform.Rename

writeResult ::
  [VName] ->
  PatElemT dec ->
  KernelResult ->
  MulticoreGen ()
writeResult :: forall dec.
[VName] -> PatElemT dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is PatElemT dec
pe (Returns ResultManifest
_ SubExp
se) =
  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is) SubExp
se []
writeResult [VName]
_ PatElemT dec
pe (WriteReturns (Shape [SubExp]
rws) VName
_ [(Slice SubExp, SubExp)]
idx_vals) = do
  let ([Slice SubExp]
iss, [SubExp]
vs) = [(Slice SubExp, SubExp)] -> ([Slice SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Slice SubExp, SubExp)]
idx_vals
      rws' :: [TExp Int64]
rws' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
rws
  [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Slice SubExp] -> [SubExp] -> [(Slice SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Slice SubExp]
iss [SubExp]
vs) (((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
v) -> do
    let slice' :: [DimIndex (TExp Int64)]
slice' = (DimIndex SubExp -> DimIndex (TExp Int64))
-> Slice SubExp -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TExp Int64) -> DimIndex SubExp -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice SubExp
slice
        condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
rw =
          TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
rw
        condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
rw =
          TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
rw
        in_bounds :: TPrimExp Bool ExpLeaf
in_bounds = (TPrimExp Bool ExpLeaf
 -> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf)
-> [TPrimExp Bool ExpLeaf] -> TPrimExp Bool ExpLeaf
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TPrimExp Bool ExpLeaf] -> TPrimExp Bool ExpLeaf)
-> [TPrimExp Bool ExpLeaf] -> TPrimExp Bool ExpLeaf
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TPrimExp Bool ExpLeaf)
-> [DimIndex (TExp Int64)]
-> [TExp Int64]
-> [TPrimExp Bool ExpLeaf]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall {t} {v}.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice' [TExp Int64]
rws'
        when_in_bounds :: ImpM lore r op ()
when_in_bounds = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) [DimIndex (TExp Int64)]
slice' SubExp
v []
    TPrimExp Bool ExpLeaf -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen TPrimExp Bool ExpLeaf
in_bounds MulticoreGen ()
forall {lore} {r} {op}. ImpM lore r op ()
when_in_bounds
writeResult [VName]
_ PatElemT dec
_ KernelResult
res =
  [Char] -> MulticoreGen ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> MulticoreGen ()) -> [Char] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"writeResult: cannot handle " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ KernelResult -> [Char]
forall a. Pretty a => a -> [Char]
pretty KernelResult
res

compileSegMapBody ::
  TV Int64 ->
  Pattern MCMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
compileSegMapBody :: TV Int64
-> Pattern MCMem
-> SegSpace
-> KernelBody MCMem
-> MulticoreGen Code
compileSegMapBody TV Int64
flat_idx Pattern MCMem
pat SegSpace
space (KernelBody BodyDec MCMem
_ Stms MCMem
kstms [KernelResult]
kres) = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  Stms MCMem
kstms' <- (Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem))
-> Stms MCMem -> ImpM MCMem HostEnv Multicore (Stms MCMem)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm Stms MCMem
kstms
  MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    Code -> MulticoreGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"SegMap fbody" Maybe Exp
forall a. Maybe a
Nothing
    (VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 ([TExp Int64] -> [TExp Int64]) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx
    Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
kres) Stms MCMem
kstms' (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      (PatElemT LetDecMem -> KernelResult -> MulticoreGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([VName] -> PatElemT LetDecMem -> KernelResult -> MulticoreGen ()
forall dec.
[VName] -> PatElemT dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is) (PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LetDecMem
pat) [KernelResult]
kres

compileSegMap ::
  Pattern MCMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
compileSegMap :: Pattern MCMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Code
compileSegMap Pattern MCMem
pat SegSpace
space KernelBody MCMem
kbody =
  MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    TV Int64
flat_par_idx <- [Char] -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
"iter" PrimType
int64
    Code
body <- TV Int64
-> Pattern MCMem
-> SegSpace
-> KernelBody MCMem
-> MulticoreGen Code
compileSegMapBody TV Int64
flat_par_idx Pattern MCMem
pat SegSpace
space KernelBody MCMem
kbody
    [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
body [SegSpace -> VName
segFlat SegSpace
space, TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_par_idx]
    let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
    Code -> MulticoreGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ [Char]
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop [Char]
"segmap" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_par_idx) Code
body_allocs Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space