{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.Base
  ( KernelConstants (..)
  , keyWithEntryPoint
  , CallKernelGen
  , InKernelGen
  , HostEnv (..)
  , KernelEnv (..)
  , computeThreadChunkSize
  , groupReduce
  , groupScan
  , isActive
  , sKernelThread
  , sKernelGroup
  , sReplicate
  , sIota
  , sCopy
  , compileThreadResult
  , compileGroupResult
  , virtualiseGroups
  , groupLoop
  , kernelLoop
  , groupCoverSpace

  , atomicUpdateLocking
  , AtomicBinOp
  , Locking(..)
  , AtomicUpdate(..)
  , DoAtomicUpdate
  )
  where

import Control.Monad.Except
import Data.Maybe
import qualified Data.Map.Strict as M
import Data.List (elemIndex, find, nub, zip4)

import Prelude hiding (quot, rem)

import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
import Futhark.Util (chunks, maybeNth, mapAccumLM, takeLast, dropLast)

newtype HostEnv = HostEnv
  { HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp }

data KernelEnv = KernelEnv
  { KernelEnv -> AtomicBinOp
kernelAtomics :: AtomicBinOp
  , KernelEnv -> KernelConstants
kernelConstants :: KernelConstants
  }

type CallKernelGen = ImpM ExplicitMemory HostEnv Imp.HostOp
type InKernelGen = ImpM ExplicitMemory KernelEnv Imp.KernelOp

data KernelConstants = KernelConstants
                       { KernelConstants -> Exp
kernelGlobalThreadId :: Imp.Exp
                       , KernelConstants -> Exp
kernelLocalThreadId :: Imp.Exp
                       , KernelConstants -> Exp
kernelGroupId :: Imp.Exp
                       , KernelConstants -> VName
kernelGlobalThreadIdVar :: VName
                       , KernelConstants -> VName
kernelLocalThreadIdVar :: VName
                       , KernelConstants -> VName
kernelGroupIdVar :: VName
                       , KernelConstants -> Exp
kernelNumGroups :: Imp.Exp
                       , KernelConstants -> Exp
kernelGroupSize :: Imp.Exp
                       , KernelConstants -> Exp
kernelNumThreads :: Imp.Exp
                       , KernelConstants -> Exp
kernelWaveSize :: Imp.Exp
                       , KernelConstants -> Exp
kernelThreadActive :: Imp.Exp
                       }

keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key =
  String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String -> (Name -> String) -> Maybe Name -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ((String -> String -> String
forall a. [a] -> [a] -> [a]
++String
".") (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString) Maybe Name
fname String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameToString Name
key

allocLocal :: AllocCompiler ExplicitMemory r Imp.KernelOp
allocLocal :: AllocCompiler ExplicitMemory r KernelOp
allocLocal VName
mem Count Bytes Exp
size =
  KernelOp -> ImpM ExplicitMemory r KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM ExplicitMemory r KernelOp ())
-> KernelOp -> ImpM ExplicitMemory r KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes Exp -> KernelOp
Imp.LocalAlloc VName
mem Count Bytes Exp
size

kernelAlloc :: Pattern ExplicitMemory
            -> SubExp -> Space
            -> InKernelGen ()
kernelAlloc :: Pattern ExplicitMemory -> SubExp -> Space -> InKernelGen ()
kernelAlloc (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
_]) SubExp
_ ScalarSpace{} =
  -- Handled by the declaration of the memory block, which is then
  -- translated to an actual scalar variable during C code generation.
  () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
kernelAlloc (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
mem]) SubExp
size (Space String
"local") = do
  Exp
size' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
size
  AllocCompiler ExplicitMemory KernelEnv KernelOp
forall r. AllocCompiler ExplicitMemory r KernelOp
allocLocal (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
mem) (Count Bytes Exp -> InKernelGen ())
-> Count Bytes Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> Count Bytes Exp
Imp.bytes Exp
size'
kernelAlloc (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
mem]) SubExp
_ Space
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate memory block " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
mem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in kernel."
kernelAlloc Pattern ExplicitMemory
dest SubExp
_ Space
_ =
  String -> InKernelGen ()
forall a. HasCallStack => String -> a
error (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for in-kernel allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Show a => a -> String
show Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
dest

splitSpace :: (ToExp w, ToExp i, ToExp elems_per_thread) =>
              Pattern ExplicitMemory -> SplitOrdering -> w -> i -> elems_per_thread
           -> ImpM lore r op ()
splitSpace :: Pattern ExplicitMemory
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace (Pattern [] [PatElemT (LetAttr ExplicitMemory)
size]) SplitOrdering
o w
w i
i elems_per_thread
elems_per_thread = do
  Count Elements Exp
num_elements <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> w -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp w
w
  Exp
i' <- i -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp i
i
  Count Elements Exp
elems_per_thread' <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> elems_per_thread -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp elems_per_thread
elems_per_thread
  SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
forall lore r op.
SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
o Exp
i' Count Elements Exp
elems_per_thread' Count Elements Exp
num_elements (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
size)
splitSpace Pattern ExplicitMemory
pat SplitOrdering
_ w
_ i
_ elems_per_thread
_ =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for splitSpace: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

compileThreadExp :: ExpCompiler ExplicitMemory KernelEnv Imp.KernelOp
compileThreadExp :: ExpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (ArrayLit [SubExp]
es Type
_)) =
  [(Int32, SubExp)]
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int32] -> [SubExp] -> [(Int32, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int32
0..] [SubExp]
es) (((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int32
i,SubExp
e) ->
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Int32 -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
i::Int32)] SubExp
e []
compileThreadExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e =
  ExpCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e


-- | Assign iterations of a for-loop to all threads in the kernel.  The
-- passed-in function is invoked with the (symbolic) iteration.  For
-- multidimensional loops, use 'groupCoverSpace'.
kernelLoop :: Imp.Exp -> Imp.Exp -> Imp.Exp
           -> (Imp.Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop :: Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
tid Exp
num_threads Exp
n Exp -> InKernelGen ()
f =
  if Exp
n Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
num_threads then
    Exp -> InKernelGen ()
f Exp
tid
  else do
    -- Compute how many elements this thread is responsible for.
    -- Formula: (n - tid) / num_threads (rounded up).
    let elems_for_this :: Exp
elems_for_this = (Exp
n Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
tid) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp
num_threads

    String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
elems_for_this ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> Exp -> InKernelGen ()
f (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_threads Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
tid

-- | Assign iterations of a for-loop to threads in the workgroup.  The
-- passed-in function is invoked with the (symbolic) iteration.  For
-- multidimensional loops, use 'groupCoverSpace'.
groupLoop :: Imp.Exp
          -> (Imp.Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop :: Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop Exp
n Exp -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop (KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants) (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Exp
n Exp -> InKernelGen ()
f

-- | Iterate collectively though a multidimensional space, such that
-- all threads in the group participate.  The passed-in function is
-- invoked with a (symbolic) point in the index space.
groupCoverSpace :: [Imp.Exp]
                -> ([Imp.Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace :: [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds [Exp] -> InKernelGen ()
f =
  Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
ds) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> InKernelGen ()
f ([Exp] -> InKernelGen ())
-> (Exp -> [Exp]) -> Exp -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
ds

groupCopy :: VName -> [Imp.Exp] -> SubExp -> [Imp.Exp] -> InKernelGen ()
groupCopy :: VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy VName
to [Exp]
to_is SubExp
from [Exp]
from_is = do
  [Exp]
ds <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> (Type -> [SubExp])
-> Type
-> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
from
  [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
to ([Exp]
to_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is) SubExp
from ([Exp]
from_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)

compileGroupExp :: ExpCompiler ExplicitMemory KernelEnv Imp.KernelOp
-- The static arrays stuff does not work inside kernels.
compileGroupExp :: ExpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (ArrayLit [SubExp]
es Type
_)) =
  [(Int32, SubExp)]
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int32] -> [SubExp] -> [(Int32, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int32
0..] [SubExp]
es) (((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int32
i,SubExp
e) ->
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Int32 -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
i::Int32)] SubExp
e []
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Copy VName
arr)) = do
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [] (VName -> SubExp
Var VName
arr) []
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Manifest [Int]
_ VName
arr)) = do
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [] (VName -> SubExp
Var VName
arr) []
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Replicate Shape
ds SubExp
se)) = do
  [Exp]
ds' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ds
  [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds' (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Exp]
is SubExp
se (Int -> [Exp] -> [Exp]
forall a. Int -> [a] -> [a]
drop (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
ds) [Exp]
is)
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
_)) = do
  Exp
n' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
n
  Exp
e' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  Exp
s' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s
  Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop Exp
n' ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i' -> do
    VName
x <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"x" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ Exp
e' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s'
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Exp
i'] (VName -> SubExp
Var VName
x) []
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

compileGroupExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e =
  ExpCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e

sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel SegThread{} = () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sanityCheckLevel SegGroup{} =
  String -> InKernelGen ()
forall a. HasCallStack => String -> a
error String
"compileGroupOp: unexpected group-level SegOp."

compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space = do
  SegLevel -> InKernelGen ()
sanityCheckLevel SegLevel
lvl

  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
  Exp
ltid <- KernelConstants -> Exp
kernelLocalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
ltids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' Exp
ltid

  VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) Exp
ltid

-- Construct the necessary lock arrays for an intra-group histogram.
prepareIntraGroupSegHist :: Count GroupSize SubExp
                         -> [HistOp ExplicitMemory]
                         -> InKernelGen [[Imp.Exp] -> InKernelGen ()]
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp ExplicitMemory] -> InKernelGen [[Exp] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size =
  ((Maybe Locking, [[Exp] -> InKernelGen ()])
 -> [[Exp] -> InKernelGen ()])
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
-> InKernelGen [[Exp] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()]
forall a b. (a, b) -> b
snd (ImpM
   ExplicitMemory
   KernelEnv
   KernelOp
   (Maybe Locking, [[Exp] -> InKernelGen ()])
 -> InKernelGen [[Exp] -> InKernelGen ()])
-> ([HistOp ExplicitMemory]
    -> ImpM
         ExplicitMemory
         KernelEnv
         KernelOp
         (Maybe Locking, [[Exp] -> InKernelGen ()]))
-> [HistOp ExplicitMemory]
-> InKernelGen [[Exp] -> InKernelGen ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Locking
 -> HistOp ExplicitMemory
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (Maybe Locking, [Exp] -> InKernelGen ()))
-> Maybe Locking
-> [HistOp ExplicitMemory]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM Maybe Locking
-> HistOp ExplicitMemory
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
onOp Maybe Locking
forall a. Maybe a
Nothing
  where
    onOp :: Maybe Locking
-> HistOp ExplicitMemory
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
onOp Maybe Locking
l HistOp ExplicitMemory
op = do

      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
      AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

      let local_subhistos :: [VName]
local_subhistos = HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp ExplicitMemory
op

      case (Maybe Locking
l, AtomicBinOp
-> Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp (Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv)
-> Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op) of
        (Maybe Locking
_, AtomicPrim DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate ExplicitMemory KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Maybe Locking
_, AtomicCAS DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate ExplicitMemory KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f) -> do
          VName
locks <- String -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"locks"
          Exp
num_locks <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size

          let dims :: [Exp]
dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$
                     Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
                     [HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp ExplicitMemory
op]
              l' :: Locking
l' = VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 (Exp -> [Exp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> [Exp]) -> ([Exp] -> Exp) -> [Exp] -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
num_locks) (Exp -> Exp) -> ([Exp] -> Exp) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> [Exp] -> Exp
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [Exp]
dims)
              locks_t :: Type
locks_t = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness

          VName
locks_mem <- String
-> Count Bytes Exp
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> Count Bytes Exp -> Space -> ImpM lore r op VName
sAlloc String
"locks_mem" (Type -> Count Bytes Exp
typeSize Type
locks_t) (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
          VName -> PrimType -> Shape -> MemBind -> InKernelGen ()
forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
locks PrimType
int32 (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
locks_t) (MemBind -> InKernelGen ()) -> MemBind -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName -> IxFun -> MemBind
ArrayIn VName
locks_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
            (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
locks_t

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> Exp
kernelGroupSize KernelConstants
constants] (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
locks [Exp]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

          (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)

compileGroupOp :: OpCompiler ExplicitMemory KernelEnv Imp.KernelOp

compileGroupOp :: OpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupOp Pattern ExplicitMemory
pat (Alloc size space) =
  Pattern ExplicitMemory -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern ExplicitMemory
pat SubExp
size Space
space

compileGroupOp Pattern ExplicitMemory
pat (Inner (SizeOp (SplitSpace o w i elems_per_thread))) =
  Pattern ExplicitMemory
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern ExplicitMemory
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern ExplicitMemory
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread

compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegMap lvl space _ body))) = do
  InKernelGen () -> InKernelGen ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space

  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> KernelResult -> InKernelGen ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space) (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegScan lvl space scan_op _ _ body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims

  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT (MemInfo SubExp NoUniqueness MemBind) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat) ([KernelResult] -> [(VName, KernelResult)])
-> [KernelResult] -> [(VName, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest
    ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
ltids)
    (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
      crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)
  Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') Lambda ExplicitMemory
scan_op ([VName] -> InKernelGen ()) -> [VName] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    PatternT (MemInfo SubExp NoUniqueness MemBind) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegRed lvl space ops _ body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space

  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes) =
        Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
    [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp ExplicitMemory] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp ExplicitMemory]
ops) ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)]))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
    [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims

  let mkTempArr :: Type -> ImpM ExplicitMemory KernelEnv KernelOp VName
mkTempArr Type
t =
        String
-> PrimType
-> Shape
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
  [VName]
tmp_arrs <- (Type -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> [Type] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ImpM ExplicitMemory KernelEnv KernelOp VName
mkTempArr ([Type] -> ImpM ExplicitMemory KernelEnv KernelOp [VName])
-> [Type] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (SegRedOp ExplicitMemory -> [Type])
-> [SegRedOp ExplicitMemory] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda ExplicitMemory -> [Type])
-> (SegRedOp ExplicitMemory -> Lambda ExplicitMemory)
-> SegRedOp ExplicitMemory
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda) [SegRedOp ExplicitMemory]
ops
  let tmps_for_ops :: [[VName]]
tmps_for_ops = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegRedOp ExplicitMemory -> Int)
-> [SegRedOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOp ExplicitMemory -> [SubExp])
-> SegRedOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral) [SegRedOp ExplicitMemory]
ops) [VName]
tmp_arrs

  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    let ([KernelResult]
red_res, [KernelResult]
map_res) =
          Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp ExplicitMemory] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp ExplicitMemory]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body
    [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
    (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> KernelResult -> InKernelGen ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes [KernelResult]
map_res

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  case [Exp]
dims' of
    -- Nonsegmented case (or rather, a single segment) - this we can
    -- handle directly with a group-level reduction.
    [Exp
dim'] -> do
      [(SegRedOp ExplicitMemory, [VName])]
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]] -> [(SegRedOp ExplicitMemory, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOp ExplicitMemory]
ops [[VName]]
tmps_for_ops) (((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOp ExplicitMemory
op, [VName]
tmps) ->
        Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduce Exp
dim' (SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp ExplicitMemory
op) [VName]
tmps

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes [VName]
tmp_arrs) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [] (VName -> SubExp
Var VName
arr) [Exp
0]

    [Exp]
_ -> do
      -- Segmented intra-group reductions are turned into (regular)
      -- segmented scans.  It is possible that this can be done
      -- better, but at least this approach is simple.
      let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
          crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)

      [(SegRedOp ExplicitMemory, [VName])]
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]] -> [(SegRedOp ExplicitMemory, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOp ExplicitMemory]
ops [[VName]]
tmps_for_ops) (((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOp ExplicitMemory
op, [VName]
tmps) ->
        Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims')
        (SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp ExplicitMemory
op) [VName]
tmps

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      let segment_is :: [Exp]
segment_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [Exp]) -> [VName] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
ltids
      [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes [VName]
tmp_arrs) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
segment_is (VName -> SubExp
Var VName
arr) ([Exp]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [[Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1])

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
  let ltids :: [VName]
ltids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  -- We don't need the red_pes, because it is guaranteed by our type
  -- rules that they occupy the same memory as the destinations for
  -- the ops.
  let num_red_res :: Int
num_red_res = [HistOp ExplicitMemory] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp ExplicitMemory]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp ExplicitMemory -> Int) -> [HistOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp ExplicitMemory -> [SubExp])
-> HistOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp ExplicitMemory]
ops)
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
_red_pes, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes) =
        Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
    [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)]))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
    [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

  [[Exp] -> InKernelGen ()]
ops' <- Count GroupSize SubExp
-> [HistOp ExplicitMemory] -> InKernelGen [[Exp] -> InKernelGen ()]
prepareIntraGroupSegHist (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) [HistOp ExplicitMemory]
ops

  -- Ensure that all locks have been initialised.
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
kbody
        ([SubExp]
red_is, [SubExp]
red_vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp ExplicitMemory] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp ExplicitMemory]
ops) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
    (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> KernelResult -> InKernelGen ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes [KernelResult]
map_res

    let vs_per_op :: [[SubExp]]
vs_per_op = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp ExplicitMemory -> Int) -> [HistOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp ExplicitMemory -> [VName])
-> HistOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp ExplicitMemory]
ops) [SubExp]
red_vs

    [(SubExp, [SubExp], [Exp] -> InKernelGen (),
  HistOp ExplicitMemory)]
-> ((SubExp, [SubExp], [Exp] -> InKernelGen (),
     HistOp ExplicitMemory)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [[SubExp]]
-> [[Exp] -> InKernelGen ()]
-> [HistOp ExplicitMemory]
-> [(SubExp, [SubExp], [Exp] -> InKernelGen (),
     HistOp ExplicitMemory)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[Exp] -> InKernelGen ()]
ops' [HistOp ExplicitMemory]
ops) (((SubExp, [SubExp], [Exp] -> InKernelGen (),
   HistOp ExplicitMemory)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SubExp, [SubExp], [Exp] -> InKernelGen (),
     HistOp ExplicitMemory)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      \(SubExp
bin, [SubExp]
op_vs, [Exp] -> InKernelGen ()
do_op, HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda ExplicitMemory
lam) -> do
        let bin' :: Exp
bin' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
bin
            dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
            bin_in_bounds :: Exp
bin_in_bounds = Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bin' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
bin' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w'
            bin_is :: [Exp]
bin_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
ltids) [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp
bin']
            vs_params :: [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bin_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          [LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam ExplicitMemory] -> InKernelGen ())
-> [LParam ExplicitMemory] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam
          Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
            [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params [SubExp]
op_vs) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
v) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
v [Exp]
is
            [Exp] -> InKernelGen ()
do_op ([Exp]
bin_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

compileGroupOp Pattern ExplicitMemory
pat Op ExplicitMemory
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileGroupOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

compileThreadOp :: OpCompiler ExplicitMemory KernelEnv Imp.KernelOp
compileThreadOp :: OpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadOp Pattern ExplicitMemory
pat (Alloc size space) =
  Pattern ExplicitMemory -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern ExplicitMemory
pat SubExp
size Space
space
compileThreadOp Pattern ExplicitMemory
pat (Inner (SizeOp (SplitSpace o w i elems_per_thread))) =
  Pattern ExplicitMemory
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern ExplicitMemory
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern ExplicitMemory
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileThreadOp Pattern ExplicitMemory
pat Op ExplicitMemory
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileThreadOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

-- | Locking strategy used for an atomic update.
data Locking =
  Locking { Locking -> VName
lockingArray :: VName
            -- ^ Array containing the lock.
          , Locking -> Exp
lockingIsUnlocked :: Imp.Exp
            -- ^ Value for us to consider the lock free.
          , Locking -> Exp
lockingToLock :: Imp.Exp
            -- ^ What to write when we lock it.
          , Locking -> Exp
lockingToUnlock :: Imp.Exp
            -- ^ What to write when we unlock it.
          , Locking -> [Exp] -> [Exp]
lockingMapping :: [Imp.Exp] -> [Imp.Exp]
            -- ^ A transformation from the logical lock index to the
            -- physical position in the array.  This can also be used
            -- to make the lock array smaller.
          }

-- | A function for generating code for an atomic update.  Assumes
-- that the bucket is in-bounds.
type DoAtomicUpdate lore r =
  Space -> [VName] -> [Imp.Exp] -> ImpM lore r Imp.KernelOp ()

-- | The mechanism that will be used for performing the atomic update.
-- Approximates how efficient it will be.  Ordered from most to least
-- efficient.
data AtomicUpdate lore r
  = AtomicPrim (DoAtomicUpdate lore r)
    -- ^ Supported directly by primitive.
  | AtomicCAS (DoAtomicUpdate lore r)
    -- ^ Can be done by efficient swaps.
  | AtomicLocking (Locking -> DoAtomicUpdate lore r)
    -- ^ Requires explicit locking.

-- | Is there an atomic 'BinOp' corresponding to this 'BinOp'?
type AtomicBinOp =
  BinOp ->
  Maybe (VName -> VName -> Count Imp.Elements Imp.Exp -> Imp.Exp -> Imp.AtomicOp)

-- | 'atomicUpdate', but where it is explicitly visible whether a
-- locking strategy is necessary.
atomicUpdateLocking :: AtomicBinOp -> Lambda ExplicitMemory
                    -> AtomicUpdate ExplicitMemory KernelEnv

atomicUpdateLocking :: AtomicBinOp
-> Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda ExplicitMemory
lam
  | Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda ExplicitMemory -> Maybe [(BinOp, PrimType, VName, VName)]
forall lore.
Attributes lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda ExplicitMemory
lam,
    ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
    [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate ExplicitMemory KernelEnv
 -> AtomicUpdate ExplicitMemory KernelEnv)
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName]
arrs [Exp]
bucket ->
  -- If the operator is a vectorised binary operator on 32-bit values,
  -- we can use a particularly efficient implementation. If the
  -- operator has an atomic implementation we use that, otherwise it
  -- is still a binary operator which can be implemented by atomic
  -- compare-and-swap if 32 bits.
  [(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do

  -- Common variables.
  VName
old <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
t

  (VName
arr', Space
_a_space, Count Elements Exp
bucket_offset) <- VName
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
a [Exp]
bucket

  case Space
-> VName
-> VName
-> Count Elements Exp
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements Exp
bucket_offset BinOp
op of
    Just Exp -> KernelOp
f -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> KernelOp
f (Exp -> KernelOp) -> Exp -> KernelOp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
y PrimType
t
    Maybe (Exp -> KernelOp)
Nothing -> Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
a VName
old [Exp]
bucket VName
x (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
x VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t) (VName -> PrimType -> Exp
Imp.var VName
y PrimType
t)

  where opHasAtomicSupport :: Space
-> VName
-> VName
-> Count Elements Exp
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements Exp
bucket' BinOp
bop = do
          let atomic :: (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Exp -> KernelOp
atomic VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
f = Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> (Exp -> AtomicOp) -> Exp -> KernelOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
f VName
old VName
arr' Count Elements Exp
bucket'
          (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Exp -> KernelOp
atomic ((VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
 -> Exp -> KernelOp)
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Maybe (Exp -> KernelOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop

        primOrCas :: [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops
          | ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, PrimType, VName, VName) -> Bool
isPrim [(BinOp, PrimType, VName, VName)]
ops = DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicPrim
          | Bool
otherwise      = DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS

        isPrim :: (BinOp, PrimType, VName, VName) -> Bool
isPrim (BinOp
op, PrimType
_, VName
_, VName
_) = Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
 -> Bool)
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op

-- If the operator functions purely on single 32-bit values, we can
-- use an implementation based on CAS, no matter what the operator
-- does.
atomicUpdateLocking AtomicBinOp
_ Lambda ExplicitMemory
op
  | [Prim PrimType
t] <- Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
op,
    [LParam ExplicitMemory
xp, LParam ExplicitMemory
_] <- Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
op,
    PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS (DoAtomicUpdate ExplicitMemory KernelEnv
 -> AtomicUpdate ExplicitMemory KernelEnv)
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName
arr] [Exp]
bucket -> do
      VName
old <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
t
      Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [Exp]
bucket (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName LParam ExplicitMemory
Param (MemInfo SubExp NoUniqueness MemBind)
xp) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [LParam ExplicitMemory
Param (MemInfo SubExp NoUniqueness MemBind)
xp] (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
op

atomicUpdateLocking AtomicBinOp
_ Lambda ExplicitMemory
op = (Locking -> DoAtomicUpdate ExplicitMemory KernelEnv)
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r.
(Locking -> DoAtomicUpdate lore r) -> AtomicUpdate lore r
AtomicLocking ((Locking -> DoAtomicUpdate ExplicitMemory KernelEnv)
 -> AtomicUpdate ExplicitMemory KernelEnv)
-> (Locking -> DoAtomicUpdate ExplicitMemory KernelEnv)
-> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ \Locking
locking Space
space [VName]
arrs [Exp]
bucket -> do
  VName
old <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
int32
  VName
continue <- String -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"continue"
  VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrimVol_ VName
continue PrimType
Bool
  VName
continue VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
forall v. PrimExp v
true

  -- Correctly index into locks.
  (VName
locks', Space
_locks_space, Count Elements Exp
locks_offset) <-
    VName
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([Exp]
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (VName, Space, Count Elements Exp))
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall a b. (a -> b) -> a -> b
$ Locking -> [Exp] -> [Exp]
lockingMapping Locking
locking [Exp]
bucket

  -- Critical section
  let try_acquire_lock :: InKernelGen ()
try_acquire_lock =
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
        PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old VName
locks' Count Elements Exp
locks_offset
        (Locking -> Exp
lockingIsUnlocked Locking
locking) (Locking -> Exp
lockingToLock Locking
locking)
      lock_acquired :: Exp
lock_acquired = VName -> PrimType -> Exp
Imp.var VName
old PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Locking -> Exp
lockingIsUnlocked Locking
locking
      -- Even the releasing is done with an atomic rather than a
      -- simple write, for memory coherency reasons.
      release_lock :: InKernelGen ()
release_lock =
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
        PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old VName
locks' Count Elements Exp
locks_offset
        (Locking -> Exp
lockingToLock Locking
locking) (Locking -> Exp
lockingToUnlock Locking
locking)
      break_loop :: InKernelGen ()
break_loop = VName
continue VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
forall v. PrimExp v
false

  -- Preparing parameters. It is assumed that the caller has already
  -- filled the arr_params. We copy the current value to the
  -- accumulator parameters.
  --
  -- Note the use of 'everythingVolatile' when reading and writing the
  -- buckets.  This was necessary to ensure correct execution on a
  -- newer NVIDIA GPU (RTX 2080).  The 'volatile' modifiers likely
  -- make the writes pass through the (SM-local) L1 cache, which is
  -- necessary here, because we are really doing device-wide
  -- synchronisation without atomics (naughty!).
  let ([Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
_arr_params) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
op
      bind_acc_params :: InKernelGen ()
bind_acc_params =
        InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"bind lhs" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
acc_p, VName
arr) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
acc_p) [] (VName -> SubExp
Var VName
arr) [Exp]
bucket

  let op_body :: InKernelGen ()
op_body = String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"execute operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
op

      do_hist :: InKernelGen ()
do_hist =
        InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"update global result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (VName -> SubExp -> InKernelGen ())
-> [VName] -> [SubExp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([Exp] -> VName -> SubExp -> InKernelGen ()
forall lore r op. [Exp] -> VName -> SubExp -> ImpM lore r op ()
writeArray [Exp]
bucket) [VName]
arrs ([SubExp] -> InKernelGen ()) -> [SubExp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp NoUniqueness MemBind) -> SubExp)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName) [Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params

      fence :: InKernelGen ()
fence = case Space
space of Space String
"local" -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
                            Space
_             -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal


  -- While-loop: Try to insert your value
  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
continue PrimType
Bool) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
try_acquire_lock
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
lock_acquired (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      [LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params
      InKernelGen ()
bind_acc_params
      InKernelGen ()
op_body
      InKernelGen ()
do_hist
      InKernelGen ()
fence
      InKernelGen ()
release_lock
      InKernelGen ()
break_loop
    InKernelGen ()
fence
  where writeArray :: [Exp] -> VName -> SubExp -> ImpM lore r op ()
writeArray [Exp]
bucket VName
arr SubExp
val = VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp]
bucket SubExp
val []

atomicUpdateCAS :: Space -> PrimType
                -> VName -> VName
                -> [Imp.Exp] -> VName
                -> InKernelGen ()
                -> InKernelGen ()
atomicUpdateCAS :: Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [Exp]
bucket VName
x InKernelGen ()
do_op = do
  -- Code generation target:
  --
  -- old = d_his[idx];
  -- do {
  --   assumed = old;
  --   x = do_op(assumed, y);
  --   old = atomicCAS(&d_his[idx], assumed, tmp);
  -- } while(assumed != old);
  VName
assumed <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"assumed" PrimType
t
  VName
run_loop <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"run_loop" Exp
1

  -- XXX: CUDA may generate really bad code if this is not a volatile
  -- read.  Unclear why.  The later reads are volatile, so maybe
  -- that's it.
  InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [Exp]
bucket

  (VName
arr', Space
_a_space, Count Elements Exp
bucket_offset) <- VName
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
arr [Exp]
bucket

  -- While-loop: Try to insert your value
  let (Exp -> Exp
toBits, Exp -> Exp
fromBits) =
        case PrimType
t of FloatType FloatType
Float32 -> (\Exp
v -> String -> [Exp] -> PrimType -> Exp
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits32" [Exp
v] PrimType
int32,
                                        \Exp
v -> String -> [Exp] -> PrimType -> Exp
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits32" [Exp
v] PrimType
t)
                  PrimType
_                 -> (Exp -> Exp
forall a. a -> a
id, Exp -> Exp
forall a. a -> a
id)
  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
run_loop PrimType
int32) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
assumed VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
old PrimType
t
    VName
x VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t
    InKernelGen ()
do_op
    VName
old_bits <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old_bits" PrimType
int32
    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
      PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old_bits VName
arr' Count Elements Exp
bucket_offset
      (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t)) (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t))
    VName
old VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp -> Exp
fromBits (VName -> PrimType -> Exp
Imp.var VName
old_bits PrimType
int32)
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. VName -> PrimType -> Exp
Imp.var VName
old_bits PrimType
int32)
      (VName
run_loop VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0)

-- | Horizontally fission a lambda that models a binary operator.
splitOp :: Attributes lore => Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp :: Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda lore
lam = (SubExp -> Maybe (BinOp, PrimType, VName, VName))
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm ([SubExp] -> Maybe [(BinOp, PrimType, VName, VName)])
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  where n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
        splitStm :: SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm (Var VName
res) = do
          Let (Pattern [] [PatElemT (LetAttr lore)
pe]) StmAux (ExpAttr lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
            (Stm lore -> Bool) -> [Stm lore] -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res][VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
==) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> (Stm lore -> PatternT (LetAttr lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm lore] -> Maybe (Stm lore)) -> [Stm lore] -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$
            Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
          Int
i <- VName -> SubExp
Var VName
res SubExp -> [SubExp] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
          Param (LParamAttr lore)
xp <- Int -> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore)))
-> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
          Param (LParamAttr lore)
yp <- Int -> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
i) ([Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore)))
-> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
xp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
yp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
          Prim PrimType
t <- Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PatElemT (LetAttr lore) -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT (LetAttr lore)
pe
          (BinOp, PrimType, VName, VName)
-> Maybe (BinOp, PrimType, VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinOp
op, PrimType
t, Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
xp, Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
yp)
        splitStm SubExp
_ = Maybe (BinOp, PrimType, VName, VName)
forall a. Maybe a
Nothing

computeKernelUses :: FreeIn a =>
                     a -> [VName]
                  -> CallKernelGen [Imp.KernelUse]
computeKernelUses :: a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses a
kernel_body [VName]
bound_in_kernel = do
  let actually_free :: Names
actually_free = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
kernel_body Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
bound_in_kernel
  -- Compute the variables that we need to pass to the kernel.
  [KernelUse] -> [KernelUse]
forall a. Eq a => [a] -> [a]
nub ([KernelUse] -> [KernelUse])
-> CallKernelGen [KernelUse] -> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> CallKernelGen [KernelUse]
readsFromSet Names
actually_free

readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet :: Names -> CallKernelGen [KernelUse]
readsFromSet Names
free =
  ([Maybe KernelUse] -> [KernelUse])
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe KernelUse] -> [KernelUse]
forall a. [Maybe a] -> [a]
catMaybes (ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
 -> CallKernelGen [KernelUse])
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall a b. (a -> b) -> a -> b
$
  [VName]
-> (VName -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
free) ((VName -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
 -> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse])
-> (VName -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
forall a b. (a -> b) -> a -> b
$ \VName
var -> do
    Type
t <- VName -> ImpM ExplicitMemory HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
var
    VTable ExplicitMemory
vtable <- ImpM ExplicitMemory HostEnv HostOp (VTable ExplicitMemory)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
    case Type
t of
      Array {} -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
      Mem (Space String
"local") -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
      Mem {} -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
 -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelUse
Imp.MemoryUse VName
var
      Prim PrimType
bt ->
        VTable ExplicitMemory
-> Exp -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelConstExp)
forall lore r op.
VTable ExplicitMemory
-> Exp -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable ExplicitMemory
vtable (VName -> PrimType -> Exp
Imp.var VName
var PrimType
bt) ImpM ExplicitMemory HostEnv HostOp (Maybe KernelConstExp)
-> (Maybe KernelConstExp
    -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Just KernelConstExp
ce -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
 -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
var KernelConstExp
ce
          Maybe KernelConstExp
Nothing | PrimType
bt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
Cert -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
                  | Bool
otherwise  -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
 -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> KernelUse
Imp.ScalarUse VName
var PrimType
bt

isConstExp :: VTable ExplicitMemory -> Imp.Exp
           -> ImpM lore r op (Maybe Imp.KernelConstExp)
isConstExp :: VTable ExplicitMemory
-> Exp -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable ExplicitMemory
vtable Exp
size = do
  Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let onLeaf :: ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf (Imp.ScalarVar VName
name) PrimType
_ = VName -> Maybe KernelConstExp
lookupConstExp VName
name
      onLeaf (Imp.SizeOf PrimType
pt) PrimType
_ = KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ PrimType -> KernelConstExp
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
      onLeaf Imp.Index{} PrimType
_ = Maybe KernelConstExp
forall a. Maybe a
Nothing
      lookupConstExp :: VName -> Maybe KernelConstExp
lookupConstExp VName
name =
        ExpT ExplicitMemory -> Maybe KernelConstExp
constExp (ExpT ExplicitMemory -> Maybe KernelConstExp)
-> Maybe (ExpT ExplicitMemory) -> Maybe KernelConstExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarEntry ExplicitMemory -> Maybe (ExpT ExplicitMemory)
forall lore. VarEntry lore -> Maybe (Exp lore)
hasExp (VarEntry ExplicitMemory -> Maybe (ExpT ExplicitMemory))
-> Maybe (VarEntry ExplicitMemory) -> Maybe (ExpT ExplicitMemory)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VTable ExplicitMemory -> Maybe (VarEntry ExplicitMemory)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name VTable ExplicitMemory
vtable
      constExp :: ExpT ExplicitMemory -> Maybe KernelConstExp
constExp (Op (Inner (SizeOp (GetSize key _)))) =
        KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PrimType -> KernelConstExp
forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> KernelConst
Imp.SizeConst (Name -> KernelConst) -> Name -> KernelConst
forall a b. (a -> b) -> a -> b
$ Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) PrimType
int32
      constExp ExpT ExplicitMemory
e = (VName -> Maybe KernelConstExp)
-> ExpT ExplicitMemory -> Maybe KernelConstExp
forall (m :: * -> *) lore v.
(MonadFail m, Annotations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp VName -> Maybe KernelConstExp
lookupConstExp ExpT ExplicitMemory
e
  Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp))
-> Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall a b. (a -> b) -> a -> b
$ (ExpLeaf -> PrimType -> Maybe KernelConstExp)
-> Exp -> Maybe KernelConstExp
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf Exp
size
  where hasExp :: VarEntry lore -> Maybe (Exp lore)
hasExp (ArrayVar Maybe (Exp lore)
e ArrayEntry
_) = Maybe (Exp lore)
e
        hasExp (ScalarVar Maybe (Exp lore)
e ScalarEntry
_) = Maybe (Exp lore)
e
        hasExp (MemVar Maybe (Exp lore)
e MemEntry
_) = Maybe (Exp lore)
e

computeThreadChunkSize :: SplitOrdering
                       -> Imp.Exp
                       -> Imp.Count Imp.Elements Imp.Exp
                       -> Imp.Count Imp.Elements Imp.Exp
                       -> VName
                       -> ImpM lore r op ()
computeThreadChunkSize :: SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize (SplitStrided SubExp
stride) Exp
thread_index Count Elements Exp
elements_per_thread Count Elements Exp
num_elements VName
chunk_var = do
  Exp
stride' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
stride
  VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32)
    (Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread)
    ((Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
thread_index) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp
stride')

computeThreadChunkSize SplitOrdering
SplitContiguous Exp
thread_index Count Elements Exp
elements_per_thread Count Elements Exp
num_elements VName
chunk_var = do
  VName
starting_point <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"starting_point" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    Exp
thread_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread
  VName
remaining_elements <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"remaining_elements" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- VName -> PrimType -> Exp
Imp.var VName
starting_point PrimType
int32

  let no_remaining_elements :: Exp
no_remaining_elements = VName -> PrimType -> Exp
Imp.var VName
remaining_elements PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
0
      beyond_bounds :: Exp
beyond_bounds = Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> PrimType -> Exp
Imp.var VName
starting_point PrimType
int32

  Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
no_remaining_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
beyond_bounds)
    (VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0)
    (Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
is_last_thread
       (VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
last_thread_elements)
       (VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread))
  where last_thread_elements :: Count Elements Exp
last_thread_elements =
          Count Elements Exp
num_elements Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall a. Num a => a -> a -> a
- Exp -> Count Elements Exp
Imp.elements Exp
thread_index Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall a. Num a => a -> a -> a
* Count Elements Exp
elements_per_thread
        is_last_thread :: Exp
is_last_thread =
          Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<.
          (Exp
thread_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread

kernelInitialisationSimple :: Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
                           -> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple :: Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple (Count Exp
num_groups) (Count Exp
group_size) = do
  VName
global_tid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"global_tid"
  VName
local_tid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"local_tid"
  VName
group_id <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_tid"
  VName
wave_size <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"wave_size"
  VName
inner_group_size <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_size"
  let constants :: KernelConstants
constants =
        Exp
-> Exp
-> Exp
-> VName
-> VName
-> VName
-> Exp
-> Exp
-> Exp
-> Exp
-> Exp
-> KernelConstants
KernelConstants
        (VName -> PrimType -> Exp
Imp.var VName
global_tid PrimType
int32)
        (VName -> PrimType -> Exp
Imp.var VName
local_tid PrimType
int32)
        (VName -> PrimType -> Exp
Imp.var VName
group_id PrimType
int32)
        VName
global_tid VName
local_tid VName
group_id
        Exp
num_groups Exp
group_size (Exp
group_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
num_groups)
        (VName -> PrimType -> Exp
Imp.var VName
wave_size PrimType
int32)
        Exp
forall v. PrimExp v
true

  let set_constants :: InKernelGen ()
set_constants = do
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
global_tid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
local_tid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
inner_group_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
wave_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32

        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
global_tid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
local_tid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> KernelOp
Imp.GetLockstepWidth VName
wave_size)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)

  (KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelConstants
constants, InKernelGen ()
set_constants)

isActive :: [(VName, SubExp)] -> Imp.Exp
isActive :: [(VName, SubExp)] -> Exp
isActive [(VName, SubExp)]
limit = case [Exp]
actives of
                    [] -> PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
                    Exp
x:[Exp]
xs -> (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) Exp
x [Exp]
xs
  where ([VName]
is, [SubExp]
ws) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
limit
        actives :: [Exp]
actives = (VName -> Exp -> Exp) -> [VName] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Exp
active [VName]
is ([Exp] -> [Exp]) -> [Exp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool) [SubExp]
ws
        active :: VName -> Exp -> Exp
active VName
i = (VName -> PrimType -> Exp
Imp.var VName
i PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<.)

-- | Change every memory block to be in the global address space,
-- except those who are in the local memory space.  This only affects
-- generated code - we still need to make sure that the memory is
-- actually present on the device (and dared as variables in the
-- kernel).
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
  Space -> CallKernelGen a -> CallKernelGen a
forall lore r op a. Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace (String -> Space
Imp.Space String
"global") (CallKernelGen a -> CallKernelGen a)
-> (CallKernelGen a -> CallKernelGen a)
-> CallKernelGen a
-> CallKernelGen a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VTable ExplicitMemory -> VTable ExplicitMemory)
-> CallKernelGen a -> CallKernelGen a
forall lore r op a.
(VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable ((VarEntry ExplicitMemory -> VarEntry ExplicitMemory)
-> VTable ExplicitMemory -> VTable ExplicitMemory
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry ExplicitMemory -> VarEntry ExplicitMemory
forall lore. VarEntry lore -> VarEntry lore
globalMemory)
  where globalMemory :: VarEntry lore -> VarEntry lore
globalMemory (MemVar Maybe (Exp lore)
_ MemEntry
entry)
          | MemEntry -> Space
entryMemSpace MemEntry
entry Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= String -> Space
Space String
"local" =
              Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing MemEntry
entry { entryMemSpace :: Space
entryMemSpace = String -> Space
Imp.Space String
"global" }
        globalMemory VarEntry lore
entry =
          VarEntry lore
entry

groupReduce :: Imp.Exp
            -> Lambda ExplicitMemory
            -> [VName]
            -> InKernelGen ()
groupReduce :: Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduce Exp
w Lambda ExplicitMemory
lam [VName]
arrs = do
  VName
offset <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"offset" PrimType
int32
  VName -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduceWithOffset VName
offset Exp
w Lambda ExplicitMemory
lam [VName]
arrs

groupReduceWithOffset :: VName
                      -> Imp.Exp
                      -> Lambda ExplicitMemory
                      -> [VName]
                      -> InKernelGen ()
groupReduceWithOffset :: VName -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduceWithOffset VName
offset Exp
w Lambda ExplicitMemory
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

  let local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      global_tid :: Exp
global_tid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants

      barrier :: InKernelGen ()
barrier
        | (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
lam = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        | Bool
otherwise                           = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

      readReduceArgument :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readReduceArgument Param (MemInfo SubExp NoUniqueness MemBind)
param VName
arr
        | Prim PrimType
_ <- Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
param = do
            let i :: Exp
i = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
offset
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
param) [] (VName -> SubExp
Var VName
arr) [Exp
i]
        | Bool
otherwise = do
            let i :: Exp
i = Exp
global_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
offset
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
param) [] (VName -> SubExp
Var VName
arr) [Exp
i]

      writeReduceOpResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeReduceOpResult Param (MemInfo SubExp NoUniqueness MemBind)
param VName
arr
        | Prim PrimType
_ <- Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
param =
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
param) []
        | Bool
otherwise =
            () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  let ([Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_arr_params) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam

  VName
skip_waves <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"skip_waves" PrimType
int32
  [LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam ExplicitMemory] -> InKernelGen ())
-> [LParam ExplicitMemory] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam

  VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"participating threads read initial accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readReduceArgument [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params [VName]
arrs

  let do_reduce :: InKernelGen ()
do_reduce = do String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read array element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                       (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readReduceArgument [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_arr_params [VName]
arrs
                     String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"apply reduction operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                       [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
                     String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"write result of operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                       (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeReduceOpResult [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params [VName]
arrs
      in_wave_reduce :: InKernelGen ()
in_wave_reduce = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile InKernelGen ()
do_reduce

      wave_size :: Exp
wave_size = KernelConstants -> Exp
kernelWaveSize KernelConstants
constants
      group_size :: Exp
group_size = KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
      wave_id :: Exp
wave_id = Exp
local_tid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
wave_size
      in_wave_id :: Exp
in_wave_id = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
wave_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
wave_size
      num_waves :: Exp
num_waves = (Exp
group_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
wave_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
wave_size
      arg_in_bounds :: Exp
arg_in_bounds = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w

      doing_in_wave_reductions :: Exp
doing_in_wave_reductions =
        VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
wave_size
      apply_in_in_wave_iteration :: Exp
apply_in_in_wave_iteration =
        (Exp
in_wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&. (Exp
2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
      in_wave_reductions :: InKernelGen ()
in_wave_reductions = do
        VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
doing_in_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
arg_in_bounds Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
apply_in_in_wave_iteration)
            InKernelGen ()
in_wave_reduce
          VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2

      doing_cross_wave_reductions :: Exp
doing_cross_wave_reductions =
        VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
num_waves
      is_first_thread_in_wave :: Exp
is_first_thread_in_wave =
        Exp
in_wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
      wave_not_skipped :: Exp
wave_not_skipped =
        (Exp
wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&. (Exp
2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
      apply_in_cross_wave_iteration :: Exp
apply_in_cross_wave_iteration =
        Exp
arg_in_bounds Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
is_first_thread_in_wave Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
wave_not_skipped
      cross_wave_reductions :: InKernelGen ()
cross_wave_reductions = do
        VName
skip_waves VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
doing_cross_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          InKernelGen ()
barrier
          VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
wave_size
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
apply_in_cross_wave_iteration
            InKernelGen ()
do_reduce
          VName
skip_waves VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2

  InKernelGen ()
in_wave_reductions
  InKernelGen ()
cross_wave_reductions

groupScan :: Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
          -> Imp.Exp
          -> Imp.Exp
          -> Lambda ExplicitMemory
          -> [VName]
          -> InKernelGen ()
groupScan :: Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan Maybe (Exp -> Exp -> Exp)
seg_flag Exp
arrs_full_size Exp
w Lambda ExplicitMemory
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  Lambda ExplicitMemory
renamed_lam <- Lambda ExplicitMemory
-> ImpM ExplicitMemory KernelEnv KernelOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
lam

  let ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      ([Param (MemInfo SubExp NoUniqueness MemBind)]
x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam

  [LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams (Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam[Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
renamed_lam)

  -- The scan works by splitting the group into blocks, which are
  -- scanned separately.  Typically, these blocks are smaller than
  -- the lockstep width, which enables barrier-free execution inside
  -- them.
  --
  -- We hardcode the block size here.  The only requirement is that
  -- it should not be less than the square root of the group size.
  -- With 32, we will work on groups of size 1024 or smaller, which
  -- fits every device Troels has seen.  Still, it would be nicer if
  -- it were a runtime parameter.  Some day.
  let block_size :: PrimExp v
block_size = PrimValue -> PrimExp v
forall v. PrimValue -> PrimExp v
Imp.ValueExp (PrimValue -> PrimExp v) -> PrimValue -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
32
      simd_width :: Exp
simd_width = KernelConstants -> Exp
kernelWaveSize KernelConstants
constants
      block_id :: Exp
block_id = Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
forall v. PrimExp v
block_size
      in_block_id :: Exp
in_block_id = Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
forall v. PrimExp v
block_size
      doInBlockScan :: Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda ExplicitMemory -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
seg_flag' Exp
active =
        KernelConstants
-> Maybe (Exp -> Exp -> Exp)
-> Exp
-> Exp
-> Exp
-> Exp
-> [VName]
-> InKernelGen ()
-> Lambda ExplicitMemory
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (Exp -> Exp -> Exp)
seg_flag' Exp
arrs_full_size
        Exp
simd_width Exp
forall v. PrimExp v
block_size Exp
active [VName]
arrs InKernelGen ()
barrier
      ltid_in_bounds :: Exp
ltid_in_bounds = Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w
      array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
lam
      barrier :: InKernelGen ()
barrier | Bool
array_scan =
                  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
              | Bool
otherwise =
                  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

      group_offset :: Exp
group_offset = KernelConstants -> Exp
kernelGroupId KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

      writeBlockResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeBlockResult Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
        | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
        | Bool
otherwise =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []

      readPrevBlockResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readPrevBlockResult Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
        | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]
        | Bool
otherwise =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]

  Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda ExplicitMemory -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
seg_flag Exp
ltid_in_bounds Lambda ExplicitMemory
lam
  InKernelGen ()
barrier

  let is_first_block :: Exp
is_first_block = Exp
block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, VName
arr) ->
      Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []

    InKernelGen ()
barrier

  let last_in_block :: Exp
last_in_block = Exp
in_block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"last thread of block 'i' writes its result to offset 'i'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
last_in_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeBlockResult [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs

  InKernelGen ()
barrier

  let first_block_seg_flag :: Maybe (Exp -> Exp -> Exp)
first_block_seg_flag = do
        Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag
        (Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp))
-> (Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
          Exp -> Exp -> Exp
flag_true (Exp
fromExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1) (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1)
  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment
    String
"scan the first block, after which offset 'i' contains carry-in for block 'i+1'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda ExplicitMemory -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
first_block_seg_flag (Exp
is_first_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
ltid_in_bounds) Lambda ExplicitMemory
renamed_lam

  InKernelGen ()
barrier

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"move correct values for first block back a block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, VName
arr) ->
      Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM
      VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]
      (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]

    InKernelGen ()
barrier

  let read_carry_in :: InKernelGen ()
read_carry_in = do
        [(Param (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x,Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
          VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x)) []
        (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readPrevBlockResult [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs

      y_to_x :: InKernelGen ()
y_to_x = [(Param (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x,Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []

      op_to_x :: InKernelGen ()
op_to_x
        | Maybe (Exp -> Exp -> Exp)
Nothing <- Maybe (Exp -> Exp -> Exp)
seg_flag =
            [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
        | Just Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag = do
            Exp
inactive <-
              String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"inactive" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
flag_true (Exp
block_idExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1) Exp
ltid
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
inactive InKernelGen ()
y_to_x
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam

      write_final_result :: InKernelGen ()
write_final_result =
        [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, VName
arr) ->
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"carry-in for every block except the first" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
is_first_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
Not Exp
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" InKernelGen ()
read_carry_in
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write final result" InKernelGen ()
write_final_result

  InKernelGen ()
barrier

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"restore correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$[(Param (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, Param (MemInfo SubExp NoUniqueness MemBind)
y, VName
arr) ->
      if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
y)
      then VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) []
      else VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]

  InKernelGen ()
barrier

inBlockScan :: KernelConstants
            -> Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
            -> Imp.Exp
            -> Imp.Exp
            -> Imp.Exp
            -> Imp.Exp
            -> [VName]
            -> InKernelGen ()
            -> Lambda ExplicitMemory
            -> InKernelGen ()
inBlockScan :: KernelConstants
-> Maybe (Exp -> Exp -> Exp)
-> Exp
-> Exp
-> Exp
-> Exp
-> [VName]
-> InKernelGen ()
-> Lambda ExplicitMemory
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (Exp -> Exp -> Exp)
seg_flag Exp
arrs_full_size Exp
lockstep_width Exp
block_size Exp
active [VName]
arrs InKernelGen ()
barrier Lambda ExplicitMemory
scan_lam = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  VName
skip_threads <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"skip_threads" PrimType
int32
  let in_block_thread_active :: Exp
in_block_thread_active =
        VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
in_block_id
      actual_params :: [LParam ExplicitMemory]
actual_params = Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_lam
      ([Param (MemInfo SubExp NoUniqueness MemBind)]
x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) =
        Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
actual_params Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
actual_params
      y_to_x :: InKernelGen ()
y_to_x =
        [(Param (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x,Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []

  -- Set initial y values
  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
active (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readInitial [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params [VName]
arrs
    -- Since the final result is expected to be in x_params, we may
    -- need to copy it there for the first thread in the block.
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) InKernelGen ()
y_to_x

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier

  let op_to_x :: InKernelGen ()
op_to_x
        | Maybe (Exp -> Exp -> Exp)
Nothing <- Maybe (Exp -> Exp -> Exp)
seg_flag =
            [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_lam
        | Just Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag = do
            Exp
inactive <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"inactive" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                        Exp -> Exp -> Exp
flag_true (Exp
ltidExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32) Exp
ltid
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
inactive InKernelGen ()
y_to_x
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_lam

      maybeBarrier :: InKernelGen ()
maybeBarrier = Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
lockstep_width Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32)
                     InKernelGen ()
barrier

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"in-block scan (hopefully no barriers needed)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
skip_threads VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
block_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_thread_active Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Exp
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
readParam (VName -> Exp
Imp.vi32 VName
skip_threads)) [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x

      InKernelGen ()
maybeBarrier

      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_thread_active Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [InKernelGen ()] -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([InKernelGen ()] -> InKernelGen ())
-> [InKernelGen ()] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp NoUniqueness MemBind)
 -> Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName
 -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [InKernelGen ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
writeResult [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params [VName]
arrs

      InKernelGen ()
maybeBarrier

      VName
skip_threads VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2

  where block_id :: Exp
block_id = Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
block_size
        in_block_id :: Exp
in_block_id = Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_size
        ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
        gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
        array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
scan_lam

        readInitial :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readInitial Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
          | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid]
          | Bool
otherwise =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
gtid]

        readParam :: Exp
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
readParam Exp
behind Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
          | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
behind]
          | Bool
otherwise =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
gtid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
behind Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
arrs_full_size]

        writeResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
writeResult Param (MemInfo SubExp NoUniqueness MemBind)
x Param (MemInfo SubExp NoUniqueness MemBind)
y VName
arr
          | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x = do
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
          | Bool
otherwise =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []

computeMapKernelGroups :: Imp.Exp -> CallKernelGen (Imp.Exp, Imp.Exp)
computeMapKernelGroups :: Exp -> CallKernelGen (Exp, Exp)
computeMapKernelGroups Exp
kernel_size = do
  VName
group_size <- String -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"group_size" PrimType
int32
  Maybe Name
fname <- ImpM ExplicitMemory HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let group_size_var :: Exp
group_size_var = VName -> PrimType -> Exp
Imp.var VName
group_size PrimType
int32
      group_size_key :: Name
group_size_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
group_size
  HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
group_size Name
group_size_key SizeClass
Imp.SizeGroup
  VName
num_groups <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_groups" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Exp
kernel_size Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int32) Exp
group_size_var
  (Exp, Exp) -> CallKernelGen (Exp, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> PrimType -> Exp
Imp.var VName
num_groups PrimType
int32, VName -> PrimType -> Exp
Imp.var VName
group_size PrimType
int32)

simpleKernelConstants :: Imp.Exp -> String
                      -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants :: Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
kernel_size String
desc = do
  VName
thread_gtid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM ExplicitMemory HostEnv HostOp VName)
-> String -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gtid"
  VName
thread_ltid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM ExplicitMemory HostEnv HostOp VName)
-> String -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_ltid"
  VName
group_id <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM ExplicitMemory HostEnv HostOp VName)
-> String -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gid"
  (Exp
num_groups, Exp
group_size) <- Exp -> CallKernelGen (Exp, Exp)
computeMapKernelGroups Exp
kernel_size
  let set_constants :: InKernelGen ()
set_constants = do
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_gtid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_ltid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
thread_gtid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
thread_ltid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)


  (KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> Exp
-> Exp
-> VName
-> VName
-> VName
-> Exp
-> Exp
-> Exp
-> Exp
-> Exp
-> KernelConstants
KernelConstants
          (VName -> PrimType -> Exp
Imp.var VName
thread_gtid PrimType
int32) (VName -> PrimType -> Exp
Imp.var VName
thread_ltid PrimType
int32) (VName -> PrimType -> Exp
Imp.var VName
group_id PrimType
int32)
          VName
thread_gtid VName
thread_ltid VName
group_id
          Exp
num_groups Exp
group_size (Exp
group_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
num_groups) Exp
0
          (VName -> PrimType -> Exp
Imp.var VName
thread_gtid PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
kernel_size),

          InKernelGen ()
set_constants)

-- | For many kernels, we may not have enough physical groups to cover
-- the logical iteration space.  Some groups thus have to perform
-- double duty; we put an outer loop to accomplish this.  The
-- advantage over just launching a bazillion threads is that the cost
-- of memory expansion should be proportional to the number of
-- *physical* threads (hardware parallelism), not the amount of
-- application parallelism.
virtualiseGroups :: SegVirt
                 -> Imp.Exp
                 -> (VName -> InKernelGen ())
                 -> InKernelGen ()
virtualiseGroups :: SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegNoVirt Exp
_ VName -> InKernelGen ()
m = do
  VName
gid <- KernelConstants -> VName
kernelGroupIdVar (KernelConstants -> VName)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> VName)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  VName -> InKernelGen ()
m VName
gid
virtualiseGroups SegVirt
SegVirt Exp
required_groups VName -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  VName
phys_group_id <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"phys_group_id" PrimType
int32
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetGroupId VName
phys_group_id Int
0
  let iterations :: Exp
iterations = (Exp
required_groups Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- VName -> Exp
Imp.vi32 VName
phys_group_id) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp`
                   KernelConstants -> Exp
kernelNumGroups KernelConstants
constants

  String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
iterations ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
    VName -> InKernelGen ()
m (VName -> InKernelGen ())
-> ImpM ExplicitMemory KernelEnv KernelOp VName -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"virt_group_id" (VName -> Exp
Imp.vi32 VName
phys_group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelNumGroups KernelConstants
constants)
    -- Make sure the virtual group is actually done before we let
    -- another virtual group have its way with it.
    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

sKernelThread :: String
              -> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
              -> VName
              -> InKernelGen ()
              -> CallKernelGen ()
sKernelThread :: String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelThread = Operations ExplicitMemory KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernel Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants -> Exp
kernelGlobalThreadId

sKernelGroup :: String
             -> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
             -> VName
             -> InKernelGen ()
             -> CallKernelGen ()
sKernelGroup :: String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelGroup = Operations ExplicitMemory KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernel Operations ExplicitMemory KernelEnv KernelOp
groupOperations KernelConstants -> Exp
kernelGroupId

sKernelFailureTolerant :: Bool
                       -> Operations ExplicitMemory KernelEnv Imp.KernelOp
                       -> KernelConstants
                       -> Name
                       -> InKernelGen ()
                       -> CallKernelGen ()
sKernelFailureTolerant :: Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
tol Operations ExplicitMemory KernelEnv KernelOp
ops KernelConstants
constants Name
name InKernelGen ()
m = do
  HostEnv AtomicBinOp
atomics <- ImpM ExplicitMemory HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
  Code KernelOp
body <- CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal (CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp))
-> CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a b. (a -> b) -> a -> b
$ KernelEnv
-> Operations ExplicitMemory KernelEnv KernelOp
-> InKernelGen ()
-> CallKernelGen (Code KernelOp)
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ (AtomicBinOp -> KernelConstants -> KernelEnv
KernelEnv AtomicBinOp
atomics KernelConstants
constants) Operations ExplicitMemory KernelEnv KernelOp
ops InKernelGen ()
m
  [KernelUse]
uses <- Code KernelOp -> [VName] -> CallKernelGen [KernelUse]
forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses Code KernelOp
body [VName]
forall a. Monoid a => a
mempty
  Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp) -> HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Kernel -> HostOp
Imp.CallKernel Kernel :: Code KernelOp
-> [KernelUse] -> [Exp] -> [Exp] -> Name -> Bool -> Kernel
Imp.Kernel
    { kernelBody :: Code KernelOp
Imp.kernelBody = Code KernelOp
body
    , kernelUses :: [KernelUse]
Imp.kernelUses = [KernelUse]
uses
    , kernelNumGroups :: [Exp]
Imp.kernelNumGroups = [KernelConstants -> Exp
kernelNumGroups KernelConstants
constants]
    , kernelGroupSize :: [Exp]
Imp.kernelGroupSize = [KernelConstants -> Exp
kernelGroupSize KernelConstants
constants]
    , kernelName :: Name
Imp.kernelName = Name
name
    , kernelFailureTolerant :: Bool
Imp.kernelFailureTolerant = Bool
tol
    }

sKernel :: Operations ExplicitMemory KernelEnv Imp.KernelOp
        -> (KernelConstants -> Imp.Exp)
        -> String
        -> Count NumGroups Imp.Exp
        -> Count GroupSize Imp.Exp
        -> VName
        -> InKernelGen ()
        -> CallKernelGen ()
sKernel :: Operations ExplicitMemory KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernel Operations ExplicitMemory KernelEnv KernelOp
ops KernelConstants -> Exp
flatf String
name Count NumGroups Exp
num_groups Count GroupSize Exp
group_size VName
v InKernelGen ()
f = do
  (KernelConstants
constants, InKernelGen ()
set_constants) <- Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple Count NumGroups Exp
num_groups Count GroupSize Exp
group_size
  let name' :: Name
name' = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag VName
v)
  Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
False Operations ExplicitMemory KernelEnv KernelOp
ops KernelConstants
constants Name
name' (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
v (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
flatf KernelConstants
constants
    InKernelGen ()
f

copyInGroup :: CopyCompiler ExplicitMemory KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler ExplicitMemory KernelEnv KernelOp
copyInGroup PrimType
pt MemLocation
destloc MemLocation
srcloc = do
  Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
-> ImpM ExplicitMemory KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
destloc)
  Space
src_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
-> ImpM ExplicitMemory KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
srcloc)

  if Space -> Bool
isScalarMem Space
dest_space Bool -> Bool -> Bool
&& Space -> Bool
isScalarMem Space
src_space
    then MemLocation -> VName
memLocationName MemLocation
destloc VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var (MemLocation -> VName
memLocationName MemLocation
srcloc) PrimType
pt
    else CopyCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
destloc MemLocation
srcloc

  where isScalarMem :: Space -> Bool
isScalarMem ScalarSpace{} = Bool
True
        isScalarMem Space
_ = Bool
False

threadOperations, groupOperations :: Operations ExplicitMemory KernelEnv Imp.KernelOp
threadOperations :: Operations ExplicitMemory KernelEnv KernelOp
threadOperations =
  (OpCompiler ExplicitMemory KernelEnv KernelOp
-> Operations ExplicitMemory KernelEnv KernelOp
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadOp)
  { opsCopyCompiler :: CopyCompiler ExplicitMemory KernelEnv KernelOp
opsCopyCompiler = CopyCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise
  , opsExpCompiler :: ExpCompiler ExplicitMemory KernelEnv KernelOp
opsExpCompiler = ExpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadExp
  , opsStmsCompiler :: Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty
  , opsAllocCompilers :: Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
opsAllocCompilers =
      [(Space, AllocCompiler ExplicitMemory KernelEnv KernelOp)]
-> Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (String -> Space
Space String
"local", AllocCompiler ExplicitMemory KernelEnv KernelOp
forall r. AllocCompiler ExplicitMemory r KernelOp
allocLocal) ]
  }
groupOperations :: Operations ExplicitMemory KernelEnv KernelOp
groupOperations =
  (OpCompiler ExplicitMemory KernelEnv KernelOp
-> Operations ExplicitMemory KernelEnv KernelOp
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupOp)
  { opsCopyCompiler :: CopyCompiler ExplicitMemory KernelEnv KernelOp
opsCopyCompiler = CopyCompiler ExplicitMemory KernelEnv KernelOp
copyInGroup
  , opsExpCompiler :: ExpCompiler ExplicitMemory KernelEnv KernelOp
opsExpCompiler = ExpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupExp
  , opsStmsCompiler :: Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty
  , opsAllocCompilers :: Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
opsAllocCompilers =
      [(Space, AllocCompiler ExplicitMemory KernelEnv KernelOp)]
-> Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (String -> Space
Space String
"local", AllocCompiler ExplicitMemory KernelEnv KernelOp
forall r. AllocCompiler ExplicitMemory r KernelOp
allocLocal) ]
  }

-- | Perform a Replicate with a kernel.
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel :: VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se = do
  Type
t <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
  [SubExp]
ds <- Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
dropLast (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) ([SubExp] -> [SubExp]) -> (Type -> [SubExp]) -> Type -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> ImpM ExplicitMemory HostEnv HostOp Type
-> ImpM ExplicitMemory HostEnv HostOp [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr

  [Exp]
dims <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp])
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ds [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
  (KernelConstants
constants, InKernelGen ()
set_constants) <-
    Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims) String
"replicate"

  let is' :: [Exp]
is' = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
      name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++
             Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp]
is' SubExp
se ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Int -> [Exp] -> [Exp]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [Exp]
is'

replicateFunction :: PrimType -> CallKernelGen Imp.Function
replicateFunction :: PrimType -> CallKernelGen Function
replicateFunction PrimType
bt = do
  VName
mem <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
  VName
num_elems <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"num_elems"
  VName
val <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"val"

  let params :: [Param]
params = [VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
                VName -> PrimType -> Param
Imp.ScalarParam VName
num_elems PrimType
int32,
                VName -> PrimType -> Param
Imp.ScalarParam VName
val PrimType
bt]
      shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_elems]
  [Param]
-> [Param]
-> ImpM ExplicitMemory HostEnv HostOp ()
-> CallKernelGen Function
forall lore r op.
[Param]
-> [Param] -> ImpM lore r op () -> ImpM lore r op (Function op)
function [] [Param]
params (ImpM ExplicitMemory HostEnv HostOp () -> CallKernelGen Function)
-> ImpM ExplicitMemory HostEnv HostOp () -> CallKernelGen Function
forall a b. (a -> b) -> a -> b
$ do
    VName
arr <- String
-> PrimType
-> Shape
-> MemBind
-> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"arr" PrimType
bt Shape
shape (MemBind -> ImpM ExplicitMemory HostEnv HostOp VName)
-> MemBind -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
           (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
    VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicateKernel VName
arr (SubExp -> ImpM ExplicitMemory HostEnv HostOp ())
-> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val

replicateName :: PrimType -> String
replicateName :: PrimType -> String
replicateName PrimType
bt = String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt

replicateForType :: PrimType -> CallKernelGen Name
replicateForType :: PrimType -> CallKernelGen Name
replicateForType PrimType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
replicateName PrimType
bt

  Bool
exists <- Name -> ImpM ExplicitMemory HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool
-> ImpM ExplicitMemory HostEnv HostOp ()
-> ImpM ExplicitMemory HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM ExplicitMemory HostEnv HostOp ()
 -> ImpM ExplicitMemory HostEnv HostOp ())
-> ImpM ExplicitMemory HostEnv HostOp ()
-> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ Name -> Function -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function -> ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen Function -> ImpM ExplicitMemory HostEnv HostOp ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType -> CallKernelGen Function
replicateFunction PrimType
bt

  Name -> CallKernelGen Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill :: VName
-> SubExp
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
v = do
  ArrayEntry (MemLocation VName
arr_mem [SubExp]
arr_shape IxFun Exp
arr_ixfun) PrimType
_ <- VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
  Type
v_t <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v
  case Type
v_t of
    Prim PrimType
v_t'
      | IxFun Exp -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun Exp
arr_ixfun -> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ImpM ExplicitMemory HostEnv HostOp ())
 -> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ())))
-> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
forall a b. (a -> b) -> a -> b
$ ImpM ExplicitMemory HostEnv HostOp ()
-> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
forall a. a -> Maybe a
Just (ImpM ExplicitMemory HostEnv HostOp ()
 -> Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
-> ImpM ExplicitMemory HostEnv HostOp ()
-> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
forall a b. (a -> b) -> a -> b
$ do
          Name
fname <- PrimType -> CallKernelGen Name
replicateForType PrimType
v_t'
          Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [] Name
fname
            [VName -> Arg
Imp.MemArg VName
arr_mem,
             Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
arr_shape,
             Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
v_t' SubExp
v]
    Type
_ -> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ImpM ExplicitMemory HostEnv HostOp ())
forall a. Maybe a
Nothing

-- | Perform a Replicate with a kernel.
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate :: VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicate VName
arr SubExp
se = do
  -- If the replicate is of a particularly common and simple form
  -- (morally a memset()/fill), then we use a common function.
  Maybe (ImpM ExplicitMemory HostEnv HostOp ())
is_fill <- VName
-> SubExp
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
se

  case Maybe (ImpM ExplicitMemory HostEnv HostOp ())
is_fill of
    Just ImpM ExplicitMemory HostEnv HostOp ()
m -> ImpM ExplicitMemory HostEnv HostOp ()
m
    Maybe (ImpM ExplicitMemory HostEnv HostOp ())
Nothing -> VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se

-- | Perform an Iota with a kernel.
sIota :: VName -> Imp.Exp -> Imp.Exp -> Imp.Exp -> IntType
      -> CallKernelGen ()
sIota :: VName
-> Exp
-> Exp
-> Exp
-> IntType
-> ImpM ExplicitMemory HostEnv HostOp ()
sIota VName
arr Exp
n Exp
x Exp
s IntType
et = do
  MemLocation
destloc <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
-> ImpM ExplicitMemory HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
  (KernelConstants
constants, InKernelGen ()
set_constants) <- Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
n String
"iota"

  let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"iota_" String -> String -> String
forall a. [a] -> [a] -> [a]
++
             Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    let gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      (VName
destmem, Space
destspace, Count Elements Exp
destidx) <- MemLocation
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
destloc [Exp
gtid]

      Code KernelOp -> InKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code KernelOp
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements Exp
destidx (IntType -> PrimType
IntType IntType
et) Space
destspace Volatility
Imp.Nonvolatile (Exp -> Code KernelOp) -> Exp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
        ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
et) Exp
gtid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
x

sCopy :: PrimType
      -> MemLocation
      -> MemLocation
      -> CallKernelGen ()
sCopy :: PrimType
-> MemLocation
-> MemLocation
-> ImpM ExplicitMemory HostEnv HostOp ()
sCopy PrimType
bt
  destloc :: MemLocation
destloc@(MemLocation VName
destmem [SubExp]
_ IxFun Exp
_)
  srcloc :: MemLocation
srcloc@(MemLocation VName
srcmem [SubExp]
srcshape IxFun Exp
_)
  = do
  -- Note that the shape of the destination and the source are
  -- necessarily the same.
  let shape :: [Exp]
shape = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
srcshape
      kernel_size :: Exp
kernel_size = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
shape

  (KernelConstants
constants, InKernelGen ()
set_constants) <- Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
kernel_size String
"copy"

  let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"copy_" String -> String -> String
forall a. [a] -> [a] -> [a]
++
             Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants

    let gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
        dest_is :: [Exp]
dest_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
shape Exp
gtid
        src_is :: [Exp]
src_is = [Exp]
dest_is

    (VName
_, Space
destspace, Count Elements Exp
destidx) <- MemLocation
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
destloc [Exp]
dest_is
    (VName
_, Space
srcspace, Count Elements Exp
srcidx) <- MemLocation
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
srcloc [Exp]
src_is

    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
gtid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
kernel_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Code KernelOp -> InKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code KernelOp
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements Exp
destidx PrimType
bt Space
destspace Volatility
Imp.Nonvolatile (Exp -> Code KernelOp) -> Exp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
srcmem Count Elements Exp
srcidx PrimType
bt Space
srcspace Volatility
Imp.Nonvolatile

compileGroupResult :: SegSpace
                   -> PatElem ExplicitMemory -> KernelResult
                   -> InKernelGen ()

compileGroupResult :: SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileGroupResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (TileReturns [(SubExp
w,SubExp
per_group_elems)] VName
what) = do
  Exp
n <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> (Type -> SubExp)
-> Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
what

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  let ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      offset :: Exp
offset = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
per_group_elems Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupId KernelConstants
constants

  -- Avoid loop for the common case where each thread is statically
  -- known to write at most one element.
  if PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
per_group_elems Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
    then Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
         VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
offset] (VName -> SubExp
Var VName
what) [Exp
ltid]
    else
    String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (Exp
n Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
      Exp
j <- (VName -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp VName
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Exp
Imp.vi32 (ImpM ExplicitMemory KernelEnv KernelOp VName
 -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp VName
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"j" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
           KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp
j Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
offset] (VName -> SubExp
Var VName
what) [Exp
j]

compileGroupResult SegSpace
space PatElemT (LetAttr ExplicitMemory)
pe (TileReturns [(SubExp, SubExp)]
dims VName
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      out_tile_sizes :: [Exp]
out_tile_sizes = ((SubExp, SubExp) -> Exp) -> [(SubExp, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((SubExp, SubExp) -> SubExp) -> (SubExp, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
      local_is :: [Exp]
local_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
out_tile_sizes (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      group_is :: [Exp]
group_is = (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(*) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
gids) [Exp]
out_tile_sizes
  [VName]
is_for_thread <- (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> [Exp] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"thread_out_index") ([Exp] -> ImpM ExplicitMemory KernelEnv KernelOp [VName])
-> [Exp] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(+) [Exp]
group_is [Exp]
local_is

  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is_for_thread ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
is_for_thread) (VName -> SubExp
Var VName
what) [Exp]
local_is

compileGroupResult SegSpace
space PatElemT (LetAttr ExplicitMemory)
pe (Returns ResultManifest
_ SubExp
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
  let gids :: [Exp]
gids = ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Exp
Imp.vi32 (VName -> Exp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  if Bool -> Bool
not Bool
in_local_memory then
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
gids SubExp
what []
    else
      -- If the result of the group is an array in local memory, we
      -- store it by collective copying among all the threads of the
      -- group.  TODO: also do this if the array is in global memory
      -- (but this is a bit more tricky, synchronisation-wise).
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
gids SubExp
what []

compileGroupResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
_ WriteReturns{} =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: WriteReturns not handled yet."

compileGroupResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
_ ConcatReturns{} =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: ConcatReturns not handled yet."

compileThreadResult :: SegSpace
                    -> PatElem ExplicitMemory -> KernelResult
                    -> InKernelGen ()

compileThreadResult :: SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space PatElemT (LetAttr ExplicitMemory)
pe (Returns ResultManifest
_ SubExp
what) = do
  let is :: [Exp]
is = ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Exp
Imp.vi32 (VName -> Exp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
is SubExp
what []

compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (ConcatReturns SplitOrdering
SplitContiguous SubExp
_ SubExp
per_thread_elems VName
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  let offset :: Exp
offset = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
per_thread_elems Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
  Exp
n <- PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> (Type -> SubExp) -> Type -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
what
  VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp -> Exp -> Exp -> DimIndex Exp
forall d. d -> d -> d -> DimIndex d
DimSlice Exp
offset Exp
n Exp
1] (VName -> SubExp
Var VName
what) []

compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (ConcatReturns (SplitStrided SubExp
stride) SubExp
_ SubExp
_ VName
what) = do
  Exp
offset <- KernelConstants -> Exp
kernelGlobalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  Exp
n <- PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> (Type -> SubExp) -> Type -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
what
  VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp -> Exp -> Exp -> DimIndex Exp
forall d. d -> d -> d -> DimIndex d
DimSlice Exp
offset Exp
n (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
stride] (VName -> SubExp
Var VName
what) []

compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (WriteReturns [SubExp]
rws VName
_arr [([SubExp], SubExp)]
dests) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  [Exp]
rws' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
rws
  [([SubExp], SubExp)]
-> (([SubExp], SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([SubExp], SubExp)]
dests ((([SubExp], SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (([SubExp], SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \([SubExp]
is, SubExp
e) -> do
    [Exp]
is' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
is
    let condInBounds :: PrimExp v -> PrimExp v -> PrimExp v
condInBounds PrimExp v
i PrimExp v
rw = PrimExp v
0 PrimExp v -> PrimExp v -> PrimExp v
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. PrimExp v
i PrimExp v -> PrimExp v -> PrimExp v
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. PrimExp v
i PrimExp v -> PrimExp v -> PrimExp v
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimExp v
rw
        write :: Exp
write = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
                (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. Pretty v => PrimExp v -> PrimExp v -> PrimExp v
condInBounds [Exp]
is' [Exp]
rws'
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
write (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
is) SubExp
e []

compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
_ TileReturns{} =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: TileReturns unhandled."

arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
  VarEntry ExplicitMemory
res <- VName
-> ImpM ExplicitMemory KernelEnv KernelOp (VarEntry ExplicitMemory)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry ExplicitMemory
res of
    ArrayVar Maybe (ExpT ExplicitMemory)
_ ArrayEntry
entry ->
      (String -> Space
Space String
"local"Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
==) (Space -> Bool) -> (MemEntry -> Space) -> MemEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace (MemEntry -> Bool)
-> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
-> InKernelGen Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      VName -> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
entry))
    VarEntry ExplicitMemory
_ -> Bool -> InKernelGen Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
arrayInLocalMemory Constant{} = Bool -> InKernelGen Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False