{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RebindableSyntax    #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.CodeGen.Scan
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.CodeGen.Scan
  where

import Data.Array.Accelerate.AST                                    ( Direction(..) )
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache

import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Generate
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )

import Control.Applicative
import Control.Monad
import Data.String                                                  ( fromString )
import Data.Coerce                                                  as Safe
import Prelude                                                      as P


-- 'Data.List.scanl' or 'Data.List.scanl1' style exclusive scan,
-- but with the restriction that the combination function must be associative
-- to enable efficient parallel implementation.
--
-- > scanl (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 11) [10,10,11,13,16,20,25,31,38,46,55]
--
mkScan
    :: UID
    -> Gamma             aenv
    -> ArrayR                   (Array (sh, Int) e)
    -> Direction
    -> IRFun2       Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> MIRDelayed   Native aenv (Array (sh, Int) e)
    -> CodeGen      Native      (IROpenAcc Native aenv (Array (sh, Int) e))
mkScan :: UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> Direction
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
mkScan UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR Direction
dir IRFun2 Native aenv (e -> e -> e)
combine Maybe (IRExp Native aenv e)
seed MIRDelayed Native aenv (Array (sh, Int) e)
arr
  = (IROpenAcc Native aenv (Array (sh, Int) e)
 -> IROpenAcc Native aenv (Array (sh, Int) e)
 -> IROpenAcc Native aenv (Array (sh, Int) e))
-> [IROpenAcc Native aenv (Array (sh, Int) e)]
-> IROpenAcc Native aenv (Array (sh, Int) e)
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IROpenAcc Native aenv (Array (sh, Int) e)
-> IROpenAcc Native aenv (Array (sh, Int) e)
-> IROpenAcc Native aenv (Array (sh, Int) e)
forall aenv a.
IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
(+++) ([IROpenAcc Native aenv (Array (sh, Int) e)]
 -> IROpenAcc Native aenv (Array (sh, Int) e))
-> CodeGen Native [IROpenAcc Native aenv (Array (sh, Int) e)]
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
-> CodeGen Native [IROpenAcc Native aenv (Array (sh, Int) e)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
codeScanS [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
-> [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
-> [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
forall a. [a] -> [a] -> [a]
++ [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
codeScanP [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
-> [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
-> [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
forall a. [a] -> [a] -> [a]
++ [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
codeScanFill)
  where
    codeScanS :: [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
codeScanS = [ Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall aenv sh e.
Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
mkScanS Direction
dir UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRFun2 Native aenv (e -> e -> e)
combine Maybe (IRExp Native aenv e)
seed MIRDelayed Native aenv (Array (sh, Int) e)
arr ]
    codeScanP :: [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
codeScanP = case ArrayR (Array (sh, Int) e)
aR of
      ArrayR (ShapeRsnoc ShapeR sh1
ShapeRz) TypeR e
eR -> [ Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
IRFun2 Native aenv (e -> e -> e)
combine Maybe (IRExp Native aenv e)
MIRExp Native aenv e
seed MIRDelayed Native aenv (Array (sh, Int) e)
MIRDelayed Native aenv (Vector e)
arr ]
      ArrayR (Array (sh, Int) e)
_                              -> []
    -- Input can be empty iff a seed is given. We then need to compile a fill kernel
    codeScanFill :: [CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))]
codeScanFill = case Maybe (IRExp Native aenv e)
seed of
      Just IRExp Native aenv e
s  -> [ UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkScanFill UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRExp Native aenv e
s ]
      Maybe (IRExp Native aenv e)
Nothing -> []

-- Variant of 'scanl' where the final result is returned in a separate array.
--
-- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> ( Array (Z :. 10) [10,10,11,13,16,20,25,31,38,46]
--       , Array Z [55]
--       )
--
mkScan'
    :: UID
    -> Gamma             aenv
    -> ArrayR                 (Array (sh, Int) e)
    -> Direction
    -> IRFun2     Native aenv (e -> e -> e)
    -> IRExp      Native aenv e
    -> MIRDelayed Native aenv (Array (sh, Int) e)
    -> CodeGen    Native      (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan' :: UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> Direction
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan' UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR Direction
dir IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
seed MIRDelayed Native aenv (Array (sh, Int) e)
arr
  | ArrayR (ShapeRsnoc ShapeR sh1
ShapeRz) TypeR e
eR <- ArrayR (Array (sh, Int) e)
aR
  = (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
 -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
 -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
-> [IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)]
-> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
-> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
-> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
forall aenv a.
IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
(+++) ([IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)]
 -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
-> CodeGen
     Native [IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)]
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CodeGen
   Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))]
-> CodeGen
     Native [IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall aenv sh e.
Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'S Direction
dir UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
seed MIRDelayed Native aenv (Array (sh, Int) e)
arr
                              , Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
IRExp Native aenv e
seed MIRDelayed Native aenv (Array (sh, Int) e)
MIRDelayed Native aenv (Vector e)
arr
                              , UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp Native aenv e
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp Native aenv e
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRExp Native aenv e
seed
                              ]
  --
  | Bool
otherwise
  = IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
-> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
-> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
forall aenv a.
IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
(+++) (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
 -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
 -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
-> CodeGen
     Native
     (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
      -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall aenv sh e.
Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'S Direction
dir UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
seed MIRDelayed Native aenv (Array (sh, Int) e)
arr
          CodeGen
  Native
  (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
   -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp Native aenv e
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp Native aenv e
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRExp Native aenv e
seed

-- If the innermost dimension of an exclusive scan is empty, then we just fill
-- the result with the seed element.
--
mkScanFill
    :: UID
    -> Gamma          aenv
    -> ArrayR (Array sh e)
    -> IRExp   Native aenv e
    -> CodeGen Native      (IROpenAcc Native aenv (Array sh e))
mkScanFill :: UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkScanFill UID
uid Gamma aenv
aenv ArrayR (Array sh e)
aR IRExp Native aenv e
seed =
  UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkGenerate UID
uid Gamma aenv
aenv ArrayR (Array sh e)
aR ((Operands sh -> IRExp Native aenv e)
-> IRFun1 Native aenv (sh -> e)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 (IRExp Native aenv e -> Operands sh -> IRExp Native aenv e
forall a b. a -> b -> a
const IRExp Native aenv e
seed))

mkScan'Fill
    :: UID
    -> Gamma          aenv
    -> ArrayR (Array (sh, Int) e)
    -> IRExp   Native aenv e
    -> CodeGen Native     (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill :: UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRExp Native aenv e
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'Fill UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRExp Native aenv e
seed =
  IROpenAcc Native aenv (Array sh e)
-> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)
Safe.coerce (IROpenAcc Native aenv (Array sh e)
 -> IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkScanFill UID
uid Gamma aenv
aenv (ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
forall sh e. ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank ArrayR (Array (sh, Int) e)
aR) IRExp Native aenv e
seed


-- A single thread sequentially scans along an entire innermost dimension. For
-- inclusive scans we can assume that the innermost-dimension is at least one
-- element.
--
-- Note that we can use this both when there is a single thread, or in parallel
-- where threads are scheduled over the outer dimensions (segments).
--
mkScanS
    :: Direction
    -> UID
    -> Gamma             aenv
    -> ArrayR (Array (sh, Int) e)
    -> IRFun2     Native aenv (e -> e -> e)
    -> MIRExp     Native aenv e
    -> MIRDelayed Native aenv (Array (sh, Int) e)
    -> CodeGen    Native      (IROpenAcc Native aenv (Array (sh, Int) e))
mkScanS :: Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
mkScanS Direction
dir UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed MIRDelayed Native aenv (Array (sh, Int) e)
marr =
  let
      (Operands sh
start, Operands sh
end, [Parameter]
paramGang) = ShapeR sh -> (Operands sh, Operands sh, [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam ShapeR sh
shR
      (IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut)      = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
aR Name (Array (sh, Int) e)
"out"
      (IRDelayed Native aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)       = Name (Array (sh, Int) e)
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> (IRDelayed Native aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray    Name (Array (sh, Int) e)
"in"  MIRDelayed Native aenv (Array (sh, Int) e)
marr
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      ShapeRsnoc ShapeR sh1
shR          = ArrayR (Array (sh, Int) e) -> ShapeR (sh, Int)
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array (sh, Int) e)
aR
      --
      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanS" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do

    -- The dimensions of the input and output arrays are (almost) the same
    -- but LLVM can't know that so make it explicit so that we reuse loop
    -- variables and index calculations
    Operands (sh, Int)
shIn <- IRDelayed Native aenv (Array (sh, Int) e)
-> IRExp Native aenv (sh, Int)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed Native aenv (Array (sh, Int) e)
arrIn
    let sz :: Operands Int
sz    = Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands (sh, Int)
shIn
        shOut :: Operands (sh, Int)
shOut = case MIRExp Native aenv e
mseed of
                  MIRExp Native aenv e
Nothing -> Operands (sh, Int)
shIn
                  Just{}  -> Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons (Operands (sh, Int) -> Operands sh
forall sh sz. Operands (sh, sz) -> Operands sh
indexTail Operands (sh, Int)
shIn) (Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut))

    -- Loop over the outer dimensions
    ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall sh.
ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapNestFromTo ShapeR sh
shR Operands sh
start Operands sh
end (Operands (sh, Int) -> Operands sh
forall sh sz. Operands (sh, sz) -> Operands sh
indexTail Operands (sh, Int)
shIn) ((Operands sh -> Operands Int -> CodeGen Native ())
 -> CodeGen Native ())
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix Operands Int
_ -> do

      -- index i* is the index that we will read data from. Recall that the
      -- supremum index is exclusive
      Operands Int
i0 <- case Direction
dir of
              Direction
LeftToRight -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Operands Int
liftInt Int
0)
              Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz (Int -> Operands Int
liftInt Int
1)

      -- index j* is the index that we write to. Recall that for exclusive scans
      -- the output array inner dimension is one larger than the input.
      Operands Int
j0 <- case MIRExp Native aenv e
mseed of
              MIRExp Native aenv e
Nothing -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0        -- merge 'i' and 'j' indices whenever we can
              Just{}  -> case Direction
dir of
                           Direction
LeftToRight -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0
                           Direction
RightToLeft -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
sz

      -- Evaluate or read the initial element. Update the read-from index
      -- appropriately.
      (Operands e
v0,Operands Int
i1) <- case MIRExp Native aenv e
mseed of
                   Just IRExp Native aenv e
seed -> (,) (Operands e -> Operands Int -> (Operands e, Operands Int))
-> IRExp Native aenv e
-> CodeGen Native (Operands Int -> (Operands e, Operands Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRExp Native aenv e
seed                                        CodeGen Native (Operands Int -> (Operands e, Operands Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands e, Operands Int)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands Int
i0
                   MIRExp Native aenv e
Nothing   -> (,) (Operands e -> Operands Int -> (Operands e, Operands Int))
-> IRExp Native aenv e
-> CodeGen Native (Operands Int -> (Operands e, Operands Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IROpenFun1 Native () aenv ((sh, Int) -> e)
-> Operands (sh, Int) -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Array (sh, Int) e)
-> IROpenFun1 Native () aenv ((sh, Int) -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (sh -> e)
delayedIndex IRDelayed Native aenv (Array (sh, Int) e)
arrIn) (Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
ix Operands Int
i0) CodeGen Native (Operands Int -> (Operands e, Operands Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands e, Operands Int)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i0

      -- Write first element, then continue looping through the rest of
      -- this innermost dimension
      Operands Int
k0 <- ShapeR (sh, Int)
-> Operands (sh, Int)
-> Operands (sh, Int)
-> CodeGen Native (Operands Int)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands sh -> CodeGen arch (Operands Int)
intOfIndex (ArrayR (Array (sh, Int) e) -> ShapeR (sh, Int)
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array (sh, Int) e)
aR) Operands (sh, Int)
shOut (Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
ix Operands Int
j0)
      Operands Int
j1 <- Operands Int -> CodeGen Native (Operands Int)
next Operands Int
j0
      IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
k0 Operands e
v0

      CodeGen Native (Operands ((((), Int), Int), e))
-> CodeGen Native ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen Native (Operands ((((), Int), Int), e))
 -> CodeGen Native ())
-> CodeGen Native (Operands ((((), Int), Int), e))
-> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ TypeR ((((), Int), Int), e)
-> (Operands ((((), Int), Int), e)
    -> CodeGen Native (Operands Bool))
-> (Operands ((((), Int), Int), e)
    -> CodeGen Native (Operands ((((), Int), Int), e)))
-> Operands ((((), Int), Int), e)
-> CodeGen Native (Operands ((((), Int), Int), e))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit TupR ScalarType ()
-> TupR ScalarType Int -> TupR ScalarType ((), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType ((), Int)
-> TupR ScalarType Int -> TupR ScalarType (((), Int), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType (((), Int), Int)
-> TupR ScalarType e -> TypeR ((((), Int), Int), e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ArrayR (Array (sh, Int) e) -> TupR ScalarType e
forall sh e. ArrayR (Array sh e) -> TypeR e
arrayRtype ArrayR (Array (sh, Int) e)
aR)
                   (\(Operands ((((), Int), Int), e)
-> (Operands Int, Operands Int, Operands e)
forall a b c.
Operands (Tup3 a b c) -> (Operands a, Operands b, Operands c)
A.untrip -> (Operands Int
i,Operands Int
_,Operands e
_)) -> do
                       case Direction
dir of
                         Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
sz
                         Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Int -> Operands Int
liftInt Int
0))
                   (\(Operands ((((), Int), Int), e)
-> (Operands Int, Operands Int, Operands e)
forall a b c.
Operands (Tup3 a b c) -> (Operands a, Operands b, Operands c)
A.untrip -> (Operands Int
i,Operands Int
j,Operands e
u)) -> do
                       Operands e
v <- IROpenFun1 Native () aenv ((sh, Int) -> e)
-> Operands (sh, Int) -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Array (sh, Int) e)
-> IROpenFun1 Native () aenv ((sh, Int) -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (sh -> e)
delayedIndex IRDelayed Native aenv (Array (sh, Int) e)
arrIn) (Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
ix Operands Int
i)
                       Operands e
w <- case Direction
dir of
                              Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
u Operands e
v
                              Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
v Operands e
u
                       Operands Int
k <- ShapeR (sh, Int)
-> Operands (sh, Int)
-> Operands (sh, Int)
-> CodeGen Native (Operands Int)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands sh -> CodeGen arch (Operands Int)
intOfIndex (ArrayR (Array (sh, Int) e) -> ShapeR (sh, Int)
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array (sh, Int) e)
aR) Operands (sh, Int)
shOut (Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
ix Operands Int
j)
                       IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
k Operands e
w
                       Operands Int
-> Operands Int -> Operands e -> Operands ((((), Int), Int), e)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip (Operands Int
 -> Operands Int -> Operands e -> Operands ((((), Int), Int), e))
-> CodeGen Native (Operands Int)
-> CodeGen
     Native
     (Operands Int -> Operands e -> Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i CodeGen
  Native
  (Operands Int -> Operands e -> Operands ((((), Int), Int), e))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands e -> Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
j CodeGen Native (Operands e -> Operands ((((), Int), Int), e))
-> IRExp Native aenv e
-> CodeGen Native (Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands e -> IRExp Native aenv e
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands e
w)
                   (Operands Int
-> Operands Int -> Operands e -> Operands ((((), Int), Int), e)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip Operands Int
i1 Operands Int
j1 Operands e
v0)

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


mkScan'S
    :: Direction
    -> UID
    -> Gamma             aenv
    -> ArrayR (Array (sh, Int) e)
    -> IRFun2     Native aenv (e -> e -> e)
    -> IRExp      Native aenv e
    -> MIRDelayed Native aenv (Array (sh, Int) e)
    -> CodeGen    Native      (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'S :: Direction
-> UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
mkScan'S Direction
dir UID
uid Gamma aenv
aenv ArrayR (Array (sh, Int) e)
aR IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
seed MIRDelayed Native aenv (Array (sh, Int) e)
marr =
  let
      (Operands sh
start, Operands sh
end, [Parameter]
paramGang) = ShapeR sh -> (Operands sh, Operands sh, [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    ShapeR sh
shR
      (IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut)      = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
aR              Name (Array (sh, Int) e)
"out"
      (IRArray (Array sh e)
arrSum, [Parameter]
paramSum)      = ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
forall sh e. ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank ArrayR (Array (sh, Int) e)
aR) Name (Array sh e)
"sum"
      (IRDelayed Native aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)       = Name (Array (sh, Int) e)
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> (IRDelayed Native aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in" MIRDelayed Native aenv (Array (sh, Int) e)
marr
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      ShapeRsnoc ShapeR sh1
shR          = ArrayR (Array (sh, Int) e) -> ShapeR (sh, Int)
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array (sh, Int) e)
aR
      --
      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanS" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSum [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen
      Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e)))
-> CodeGen Native ()
-> CodeGen
     Native (IROpenAcc Native aenv (Array (sh, Int) e, Array sh e))
forall a b. (a -> b) -> a -> b
$ do

    Operands (sh, Int)
shIn  <- IRDelayed Native aenv (Array (sh, Int) e)
-> IRExp Native aenv (sh, Int)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed Native aenv (Array (sh, Int) e)
arrIn
    let sz :: Operands Int
sz    = Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands (sh, Int)
shIn
        shOut :: Operands (sh, Int)
shOut = Operands (sh, Int)
shIn

    ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall sh.
ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapNestFromTo ShapeR sh
shR Operands sh
start Operands sh
end (Operands (sh, Int) -> Operands sh
forall sh sz. Operands (sh, sz) -> Operands sh
indexTail Operands (sh, Int)
shIn) ((Operands sh -> Operands Int -> CodeGen Native ())
 -> CodeGen Native ())
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix Operands Int
ii -> do

      -- index to read data from
      Operands Int
i0 <- case Direction
dir of
              Direction
LeftToRight -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Operands Int
liftInt Int
0)
              Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
sz (Int -> Operands Int
liftInt Int
1)

      -- initial element
      Operands e
v0 <- IRExp Native aenv e
seed

      -- Loop through the input. Only at the top of the loop to we write the
      -- carry-in value (i.e. value from the last loop iteration) to the output
      -- array. This ensures correct behaviour if the input array was empty.
      Operands (Int, e)
r  <- TypeR (Int, e)
-> (Operands (Int, e) -> CodeGen Native (Operands Bool))
-> (Operands (Int, e) -> CodeGen Native (Operands (Int, e)))
-> Operands (Int, e)
-> CodeGen Native (Operands (Int, e))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType Int -> TupR ScalarType e -> TypeR (Int, e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ArrayR (Array (sh, Int) e) -> TupR ScalarType e
forall sh e. ArrayR (Array sh e) -> TypeR e
arrayRtype ArrayR (Array (sh, Int) e)
aR)
                  (\(Operands (Int, e) -> (Operands Int, Operands e)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair -> (Operands Int
i,Operands e
_)) -> do
                      case Direction
dir of
                        Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i Operands Int
sz
                        Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Int -> Operands Int
liftInt Int
0))
                  (\(Operands (Int, e) -> (Operands Int, Operands e)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair -> (Operands Int
i,Operands e
u)) -> do
                      Operands Int
k <- ShapeR (sh, Int)
-> Operands (sh, Int)
-> Operands (sh, Int)
-> CodeGen Native (Operands Int)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands sh -> CodeGen arch (Operands Int)
intOfIndex (ArrayR (Array (sh, Int) e) -> ShapeR (sh, Int)
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array (sh, Int) e)
aR) Operands (sh, Int)
shOut (Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
ix Operands Int
i)
                      IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
k Operands e
u

                      Operands e
v <- IROpenFun1 Native () aenv ((sh, Int) -> e)
-> Operands (sh, Int) -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Array (sh, Int) e)
-> IROpenFun1 Native () aenv ((sh, Int) -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (sh -> e)
delayedIndex IRDelayed Native aenv (Array (sh, Int) e)
arrIn) (Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
ix Operands Int
i)
                      Operands e
w <- case Direction
dir of
                             Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
u Operands e
v
                             Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
v Operands e
u
                      Operands Int -> Operands e -> Operands (Int, e)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair (Operands Int -> Operands e -> Operands (Int, e))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands e -> Operands (Int, e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i CodeGen Native (Operands e -> Operands (Int, e))
-> IRExp Native aenv e -> CodeGen Native (Operands (Int, e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands e -> IRExp Native aenv e
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands e
w)
                  (Operands Int -> Operands e -> Operands (Int, e)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair Operands Int
i0 Operands e
v0)

      IntegralType Int
-> IRArray (Array sh e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh e)
arrSum Operands Int
ii (Operands (Int, e) -> Operands e
forall sh sz. Operands (sh, sz) -> Operands sz
A.snd Operands (Int, e)
r)

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


mkScanP
    :: Direction
    -> UID
    -> Gamma             aenv
    -> TypeR e
    -> IRFun2     Native aenv (e -> e -> e)
    -> MIRExp     Native aenv e
    -> MIRDelayed Native aenv (Vector e)
    -> CodeGen    Native      (IROpenAcc Native aenv (Vector e))
mkScanP :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed MIRDelayed Native aenv (Vector e)
marr =
  (IROpenAcc Native aenv (Vector e)
 -> IROpenAcc Native aenv (Vector e)
 -> IROpenAcc Native aenv (Vector e))
-> [IROpenAcc Native aenv (Vector e)]
-> IROpenAcc Native aenv (Vector e)
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IROpenAcc Native aenv (Vector e)
-> IROpenAcc Native aenv (Vector e)
-> IROpenAcc Native aenv (Vector e)
forall aenv a.
IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
(+++) ([IROpenAcc Native aenv (Vector e)]
 -> IROpenAcc Native aenv (Vector e))
-> CodeGen Native [IROpenAcc Native aenv (Vector e)]
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CodeGen Native (IROpenAcc Native aenv (Vector e))]
-> CodeGen Native [IROpenAcc Native aenv (Vector e)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP1 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed MIRDelayed Native aenv (Vector e)
marr
                            , Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP2 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine
                            , Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP3 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed
                            ]

-- Parallel scan, step 1.
--
-- Threads scan a stripe of the input into a temporary array, incorporating the
-- initial element and any fused functions on the way. The final reduction
-- result of this chunk is written to a separate array.
--
mkScanP1
    :: Direction
    -> UID
    -> Gamma             aenv
    -> TypeR e
    -> IRFun2     Native aenv (e -> e -> e)
    -> MIRExp     Native aenv e
    -> MIRDelayed Native aenv (Vector e)
    -> CodeGen    Native      (IROpenAcc Native aenv (Vector e))
mkScanP1 :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP1 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed MIRDelayed Native aenv (Vector e)
marr =
  let
      (Operands ((), Int)
start, Operands ((), Int)
end, [Parameter]
paramGang) = ShapeR ((), Int)
-> (Operands ((), Int), Operands ((), Int), [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    ShapeR ((), Int)
dim1
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"tmp"
      (IRDelayed Native aenv (Vector e)
arrIn,  [Parameter]
paramIn)       = Name (Vector e)
-> MIRDelayed Native aenv (Vector e)
-> (IRDelayed Native aenv (Vector e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Vector e)
"in" MIRDelayed Native aenv (Vector e)
marr
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      steps :: Operands Int
steps                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      paramSteps :: [Parameter]
paramSteps              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      piece :: Operands Int
piece                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      paramPiece :: [Parameter]
paramPiece              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      --
      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
      firstPiece :: Operands Int
firstPiece              = case Direction
dir of
                                  Direction
LeftToRight -> Int -> Operands Int
liftInt Int
0
                                  Direction
RightToLeft -> Operands Int
steps
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanP1" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramPiece [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSteps [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Vector e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall a b. (a -> b) -> a -> b
$ do

    -- A thread scans a non-empty stripe of the input, storing the final
    -- reduction result into a separate array.
    --
    -- For exclusive scans the first chunk must incorporate the initial element
    -- into the input and output, while all other chunks increment their output
    -- index by one.
    --
    -- index i* is the index that we read data from. Recall that the supremum
    -- index is exclusive
    Operands Int
i0  <- case Direction
dir of
             Direction
LeftToRight -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)
             Direction
RightToLeft -> Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)

    -- index j* is the index that we write to. Recall that for exclusive scan
    -- the output array is one larger than the input; the first piece uses
    -- this spot to write the initial element, all other chunks shift by one.
    Operands Int
j0  <- case MIRExp Native aenv e
mseed of
             MIRExp Native aenv e
Nothing -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0
             Just IRExp Native aenv e
_  -> case Direction
dir of
                          Direction
LeftToRight -> if (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt, SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
piece Operands Int
firstPiece)
                                         then Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0
                                         else Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i0
                          Direction
RightToLeft -> if (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt, SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
piece Operands Int
firstPiece)
                                         then Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)
                                         else Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int
i0

    -- Evaluate/read the initial element for this piece. Update the read-from
    -- index appropriately
    (Operands e
v0,Operands Int
i1) <- Operands (e, Int) -> (Operands e, Operands Int)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair (Operands (e, Int) -> (Operands e, Operands Int))
-> CodeGen Native (Operands (e, Int))
-> CodeGen Native (Operands e, Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case MIRExp Native aenv e
mseed of
                 Just IRExp Native aenv e
seed -> if (TypeR e
eR TypeR e -> TupR ScalarType Int -> TupR ScalarType (e, Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt, SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
piece Operands Int
firstPiece)
                                then Operands e -> Operands Int -> Operands (e, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair (Operands e -> Operands Int -> Operands (e, Int))
-> IRExp Native aenv e
-> CodeGen Native (Operands Int -> Operands (e, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRExp Native aenv e
seed                               CodeGen Native (Operands Int -> Operands (e, Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands (e, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands Int
i0
                                else Operands e -> Operands Int -> Operands (e, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair (Operands e -> Operands Int -> Operands (e, Int))
-> IRExp Native aenv e
-> CodeGen Native (Operands Int -> Operands (e, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IROpenFun1 Native () aenv (Int -> e)
-> Operands Int -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Vector e)
-> IROpenFun1 Native () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Vector e)
arrIn) Operands Int
i0 CodeGen Native (Operands Int -> Operands (e, Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands (e, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i0
                 MIRExp Native aenv e
Nothing   ->        Operands e -> Operands Int -> Operands (e, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair (Operands e -> Operands Int -> Operands (e, Int))
-> IRExp Native aenv e
-> CodeGen Native (Operands Int -> Operands (e, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IROpenFun1 Native () aenv (Int -> e)
-> Operands Int -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Vector e)
-> IROpenFun1 Native () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Vector e)
arrIn) Operands Int
i0 CodeGen Native (Operands Int -> Operands (e, Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands (e, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i0

    -- Write first element
    IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
j0 Operands e
v0
    Operands Int
j1  <- Operands Int -> CodeGen Native (Operands Int)
next Operands Int
j0

    -- Continue looping through the rest of the input
    let cont :: Operands Int -> CodeGen Native (Operands Bool)
cont Operands Int
i =
           case Direction
dir of
             Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)
             Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)

    Operands ((((), Int), Int), e)
r   <- TypeR ((((), Int), Int), e)
-> (Operands ((((), Int), Int), e)
    -> CodeGen Native (Operands Bool))
-> (Operands ((((), Int), Int), e)
    -> CodeGen Native (Operands ((((), Int), Int), e)))
-> Operands ((((), Int), Int), e)
-> CodeGen Native (Operands ((((), Int), Int), e))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit TupR ScalarType ()
-> TupR ScalarType Int -> TupR ScalarType ((), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType ((), Int)
-> TupR ScalarType Int -> TupR ScalarType (((), Int), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType (((), Int), Int)
-> TypeR e -> TypeR ((((), Int), Int), e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TypeR e
eR)
                 (Operands Int -> CodeGen Native (Operands Bool)
cont (Operands Int -> CodeGen Native (Operands Bool))
-> (Operands ((((), Int), Int), e) -> Operands Int)
-> Operands ((((), Int), Int), e)
-> CodeGen Native (Operands Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands ((((), Int), Int), e) -> Operands Int
forall a b c. Operands (Tup3 a b c) -> Operands a
A.fst3)
                 (\(Operands ((((), Int), Int), e)
-> (Operands Int, Operands Int, Operands e)
forall a b c.
Operands (Tup3 a b c) -> (Operands a, Operands b, Operands c)
A.untrip -> (Operands Int
i,Operands Int
j,Operands e
v)) -> do
                     Operands e
u  <- IROpenFun1 Native () aenv (Int -> e)
-> Operands Int -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Vector e)
-> IROpenFun1 Native () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Vector e)
arrIn) Operands Int
i
                     Operands e
v' <- case Direction
dir of
                             Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
v Operands e
u
                             Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
u Operands e
v
                     IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
j Operands e
v'
                     Operands Int
-> Operands Int -> Operands e -> Operands ((((), Int), Int), e)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip (Operands Int
 -> Operands Int -> Operands e -> Operands ((((), Int), Int), e))
-> CodeGen Native (Operands Int)
-> CodeGen
     Native
     (Operands Int -> Operands e -> Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i CodeGen
  Native
  (Operands Int -> Operands e -> Operands ((((), Int), Int), e))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands e -> Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
j CodeGen Native (Operands e -> Operands ((((), Int), Int), e))
-> IRExp Native aenv e
-> CodeGen Native (Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands e -> IRExp Native aenv e
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands e
v')
                 (Operands Int
-> Operands Int -> Operands e -> Operands ((((), Int), Int), e)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip Operands Int
i1 Operands Int
j1 Operands e
v0)

    -- Final reduction result of this piece
    IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
piece (Operands ((((), Int), Int), e) -> Operands e
forall a b c. Operands (Tup3 a b c) -> Operands c
A.thd3 Operands ((((), Int), Int), e)
r)

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan, step 2.
--
-- A single thread performs an in-place inclusive scan of the partial block
-- sums. This forms the carry-in value which are added to the stripe partial
-- results in the final step.
--
mkScanP2
    :: Direction
    -> UID
    -> Gamma          aenv
    -> TypeR e
    -> IRFun2  Native aenv (e -> e -> e)
    -> CodeGen Native      (IROpenAcc Native aenv (Vector e))
mkScanP2 :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP2 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine =
  let
      (Operands ((), Int)
start, Operands ((), Int)
end, [Parameter]
paramGang) = ShapeR ((), Int)
-> (Operands ((), Int), Operands ((), Int), [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    ShapeR ((), Int)
dim1
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"tmp"
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      cont :: Operands Int -> CodeGen Native (Operands Bool)
cont Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)
                                  Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)

      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanP2" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Vector e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall a b. (a -> b) -> a -> b
$ do

    Operands Int
i0 <- case Direction
dir of
            Direction
LeftToRight -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)
            Direction
RightToLeft -> Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)

    Operands e
v0 <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i0
    Operands Int
i1 <- Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i0

    CodeGen Native (Operands (Int, e)) -> CodeGen Native ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen Native (Operands (Int, e)) -> CodeGen Native ())
-> CodeGen Native (Operands (Int, e)) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ TypeR (Int, e)
-> (Operands (Int, e) -> CodeGen Native (Operands Bool))
-> (Operands (Int, e) -> CodeGen Native (Operands (Int, e)))
-> Operands (Int, e)
-> CodeGen Native (Operands (Int, e))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType Int -> TypeR e -> TypeR (Int, e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TypeR e
eR)
                 (Operands Int -> CodeGen Native (Operands Bool)
cont (Operands Int -> CodeGen Native (Operands Bool))
-> (Operands (Int, e) -> Operands Int)
-> Operands (Int, e)
-> CodeGen Native (Operands Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands (Int, e) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sh
A.fst)
                 (\(Operands (Int, e) -> (Operands Int, Operands e)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair -> (Operands Int
i,Operands e
v)) -> do
                    Operands e
u  <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i
                    Operands Int
i' <- Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i
                    Operands e
v' <- case Direction
dir of
                            Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
v Operands e
u
                            Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
u Operands e
v
                    IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i Operands e
v'
                    Operands (Int, e) -> CodeGen Native (Operands (Int, e))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands (Int, e) -> CodeGen Native (Operands (Int, e)))
-> Operands (Int, e) -> CodeGen Native (Operands (Int, e))
forall a b. (a -> b) -> a -> b
$ Operands Int -> Operands e -> Operands (Int, e)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair Operands Int
i' Operands e
v')
                 (Operands Int -> Operands e -> Operands (Int, e)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair Operands Int
i1 Operands e
v0)

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan, step 3.
--
-- Threads combine every element of the partial block results with the carry-in
-- value computed from step 2.
--
-- Note that first chunk does not need extra processing (has no carry-in value).
--
mkScanP3
    :: Direction
    -> UID
    -> Gamma aenv
    -> TypeR e
    -> IRFun2  Native aenv (e -> e -> e)
    -> MIRExp  Native aenv e
    -> CodeGen Native      (IROpenAcc Native aenv (Vector e))
mkScanP3 :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
mkScanP3 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed =
  let
      (Operands ((), Int)
start, Operands ((), Int)
end, [Parameter]
paramGang) = ShapeR ((), Int)
-> (Operands ((), Int), Operands ((), Int), [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    ShapeR ((), Int)
dim1
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"tmp"
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      steps :: Operands Int
steps                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      paramSteps :: [Parameter]
paramSteps              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      piece :: Operands Int
piece                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      paramPiece :: [Parameter]
paramPiece              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      --
      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
      prev :: Operands Int -> CodeGen Native (Operands Int)
prev Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
      firstPiece :: Operands Int
firstPiece              = case Direction
dir of
                                  Direction
LeftToRight -> Int -> Operands Int
liftInt Int
0
                                  Direction
RightToLeft -> Operands Int
steps
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanP3" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramPiece [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSteps [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Vector e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e))
forall a b. (a -> b) -> a -> b
$ do

    -- TODO: Don't schedule the "first" piece. In the scheduler this corresponds
    -- to the split range with the smallest/largest linear index for left/right
    -- scans respectively. For right scans this is not necessarily the last piece(?).
    --
    CodeGen Native (Operands Bool)
-> CodeGen Native () -> CodeGen Native ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
A.when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
neq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
piece Operands Int
firstPiece) (CodeGen Native () -> CodeGen Native ())
-> CodeGen Native () -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ do

      -- Compute start and end indices, leaving space for the initial element
      (Operands Int
inf,Operands Int
sup) <- case (Direction
dir, MIRExp Native aenv e
mseed) of
                     (Direction
LeftToRight, Just{}) -> (,) (Operands Int -> Operands Int -> (Operands Int, Operands Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands Int -> (Operands Int, Operands Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start) CodeGen Native (Operands Int -> (Operands Int, Operands Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands Int, Operands Int)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)
                     (Direction, MIRExp Native aenv e)
_                     -> (,) (Operands Int -> Operands Int -> (Operands Int, Operands Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands Int -> (Operands Int, Operands Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Operands Int -> CodeGen Native (Operands Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start) CodeGen Native (Operands Int -> (Operands Int, Operands Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands Int, Operands Int)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)

      -- Read in the carry in value for this piece
      Operands e
c <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands Int) -> CodeGen Native (Operands e)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int -> CodeGen Native (Operands Int)
prev Operands Int
piece

      Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapFromTo Operands Int
inf Operands Int
sup ((Operands Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do
        Operands e
x <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i
        Operands e
y <- case Direction
dir of
               Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
c Operands e
x
               Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
x Operands e
c
        IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i Operands e
y

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


mkScan'P
    :: Direction
    -> UID
    -> Gamma             aenv
    -> TypeR e
    -> IRFun2     Native aenv (e -> e -> e)
    -> IRExp      Native aenv e
    -> MIRDelayed Native aenv (Vector e)
    -> CodeGen    Native      (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
seed MIRDelayed Native aenv (Vector e)
arr =
  (IROpenAcc Native aenv (Vector e, Scalar e)
 -> IROpenAcc Native aenv (Vector e, Scalar e)
 -> IROpenAcc Native aenv (Vector e, Scalar e))
-> [IROpenAcc Native aenv (Vector e, Scalar e)]
-> IROpenAcc Native aenv (Vector e, Scalar e)
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IROpenAcc Native aenv (Vector e, Scalar e)
-> IROpenAcc Native aenv (Vector e, Scalar e)
-> IROpenAcc Native aenv (Vector e, Scalar e)
forall aenv a.
IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
(+++) ([IROpenAcc Native aenv (Vector e, Scalar e)]
 -> IROpenAcc Native aenv (Vector e, Scalar e))
-> CodeGen Native [IROpenAcc Native aenv (Vector e, Scalar e)]
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))]
-> CodeGen Native [IROpenAcc Native aenv (Vector e, Scalar e)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P1 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
seed MIRDelayed Native aenv (Vector e)
arr
                            , Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P2 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine
                            , Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall aenv e.
Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P3 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine
                            ]

-- Parallel scan', step 1
--
-- Threads scan a stripe of the input into a temporary array. Similar to
-- exclusive scan, the output indices are shifted by one relative to the input
-- indices to make space for the initial element.
--
mkScan'P1
    :: Direction
    -> UID
    -> Gamma             aenv
    -> TypeR e
    -> IRFun2     Native aenv (e -> e -> e)
    -> IRExp      Native aenv e
    -> MIRDelayed Native aenv (Vector e)
    -> CodeGen    Native      (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P1 :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> MIRDelayed Native aenv (Vector e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P1 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine IRExp Native aenv e
seed MIRDelayed Native aenv (Vector e)
marr =
  let
      (Operands ((), Int)
start, Operands ((), Int)
end, [Parameter]
paramGang) = ShapeR ((), Int)
-> (Operands ((), Int), Operands ((), Int), [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    ShapeR ((), Int)
dim1
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"tmp"
      (IRDelayed Native aenv (Vector e)
arrIn,  [Parameter]
paramIn)       = Name (Vector e)
-> MIRDelayed Native aenv (Vector e)
-> (IRDelayed Native aenv (Vector e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Vector e)
"in" MIRDelayed Native aenv (Vector e)
marr
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      steps :: Operands Int
steps                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      paramSteps :: [Parameter]
paramSteps              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      piece :: Operands Int
piece                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      paramPiece :: [Parameter]
paramPiece              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      --
      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)

      firstPiece :: Operands Int
firstPiece              = case Direction
dir of
                                  Direction
LeftToRight -> Int -> Operands Int
liftInt Int
0
                                  Direction
RightToLeft -> Operands Int
steps
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanP1" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramPiece [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSteps [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    -- index i* is the index that we pull data from.
    Operands Int
i0 <- case Direction
dir of
            Direction
LeftToRight -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)
            Direction
RightToLeft -> Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)

    -- index j* is the index that we write results to. The first piece needs to
    -- include the initial element, and all other chunks shift their results
    -- across by one to make space.
    Operands Int
j0      <- if (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt, SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
piece Operands Int
firstPiece)
                 then Operands Int -> CodeGen Native (Operands Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands Int
i0
                 else Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i0

    -- Evaluate/read the initial element. Update the read-from index
    -- appropriately.
    (Operands e
v0,Operands Int
i1) <- Operands (e, Int) -> (Operands e, Operands Int)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair (Operands (e, Int) -> (Operands e, Operands Int))
-> CodeGen Native (Operands (e, Int))
-> CodeGen Native (Operands e, Operands Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if (TypeR e
eR TypeR e -> TupR ScalarType Int -> TupR ScalarType (e, Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt, SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
piece Operands Int
firstPiece)
                              then Operands e -> Operands Int -> Operands (e, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair (Operands e -> Operands Int -> Operands (e, Int))
-> IRExp Native aenv e
-> CodeGen Native (Operands Int -> Operands (e, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IRExp Native aenv e
seed                               CodeGen Native (Operands Int -> Operands (e, Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands (e, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands Int
i0
                              else Operands e -> Operands Int -> Operands (e, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair (Operands e -> Operands Int -> Operands (e, Int))
-> IRExp Native aenv e
-> CodeGen Native (Operands Int -> Operands (e, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IROpenFun1 Native () aenv (Int -> e)
-> Operands Int -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Vector e)
-> IROpenFun1 Native () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Vector e)
arrIn) Operands Int
i0 CodeGen Native (Operands Int -> Operands (e, Int))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands (e, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands Int
j0

    -- Write the first element
    IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
j0 Operands e
v0
    Operands Int
j1 <- Operands Int -> CodeGen Native (Operands Int)
next Operands Int
j0

    -- Continue looping through the rest of the input
    let cont :: Operands Int -> CodeGen Native (Operands Bool)
cont Operands Int
i =
           case Direction
dir of
             Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)
             Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)

    Operands ((((), Int), Int), e)
r  <- TypeR ((((), Int), Int), e)
-> (Operands ((((), Int), Int), e)
    -> CodeGen Native (Operands Bool))
-> (Operands ((((), Int), Int), e)
    -> CodeGen Native (Operands ((((), Int), Int), e)))
-> Operands ((((), Int), Int), e)
-> CodeGen Native (Operands ((((), Int), Int), e))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit TupR ScalarType ()
-> TupR ScalarType Int -> TupR ScalarType ((), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType ((), Int)
-> TupR ScalarType Int -> TupR ScalarType (((), Int), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt TupR ScalarType (((), Int), Int)
-> TypeR e -> TypeR ((((), Int), Int), e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TypeR e
eR)
                (Operands Int -> CodeGen Native (Operands Bool)
cont (Operands Int -> CodeGen Native (Operands Bool))
-> (Operands ((((), Int), Int), e) -> Operands Int)
-> Operands ((((), Int), Int), e)
-> CodeGen Native (Operands Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands ((((), Int), Int), e) -> Operands Int
forall a b c. Operands (Tup3 a b c) -> Operands a
A.fst3)
                (\(Operands ((((), Int), Int), e)
-> (Operands Int, Operands Int, Operands e)
forall a b c.
Operands (Tup3 a b c) -> (Operands a, Operands b, Operands c)
A.untrip-> (Operands Int
i,Operands Int
j,Operands e
v)) -> do
                    Operands e
u  <- IROpenFun1 Native () aenv (Int -> e)
-> Operands Int -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Vector e)
-> IROpenFun1 Native () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Vector e)
arrIn) Operands Int
i
                    Operands e
v' <- case Direction
dir of
                            Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
v Operands e
u
                            Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
u Operands e
v
                    IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
j Operands e
v'
                    Operands Int
-> Operands Int -> Operands e -> Operands ((((), Int), Int), e)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip (Operands Int
 -> Operands Int -> Operands e -> Operands ((((), Int), Int), e))
-> CodeGen Native (Operands Int)
-> CodeGen
     Native
     (Operands Int -> Operands e -> Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i CodeGen
  Native
  (Operands Int -> Operands e -> Operands ((((), Int), Int), e))
-> CodeGen Native (Operands Int)
-> CodeGen Native (Operands e -> Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands Int -> CodeGen Native (Operands Int)
next Operands Int
j CodeGen Native (Operands e -> Operands ((((), Int), Int), e))
-> IRExp Native aenv e
-> CodeGen Native (Operands ((((), Int), Int), e))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Operands e -> IRExp Native aenv e
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands e
v')
                (Operands Int
-> Operands Int -> Operands e -> Operands ((((), Int), Int), e)
forall a b c.
Operands a -> Operands b -> Operands c -> Operands (Tup3 a b c)
A.trip Operands Int
i1 Operands Int
j1 Operands e
v0)

    -- Write the final reduction result of this piece
    IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
piece (Operands ((((), Int), Int), e) -> Operands e
forall a b c. Operands (Tup3 a b c) -> Operands c
A.thd3 Operands ((((), Int), Int), e)
r)

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan', step 2
--
-- Identical to mkScanP2, except we store the total scan result into a separate
-- array (rather than discard it).
--
mkScan'P2
    :: Direction
    -> UID
    -> Gamma          aenv
    -> TypeR e
    -> IRFun2  Native aenv (e -> e -> e)
    -> CodeGen Native      (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P2 :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P2 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine =
  let
      (Operands ((), Int)
start, Operands ((), Int)
end, [Parameter]
paramGang) = ShapeR ((), Int)
-> (Operands ((), Int), Operands ((), Int), [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    ShapeR ((), Int)
dim1
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"tmp"
      (IRArray (Scalar e)
arrSum, [Parameter]
paramSum)      = ArrayR (Scalar e)
-> Name (Scalar e) -> (IRArray (Scalar e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
eR) Name (Scalar e)
"sum"
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      cont :: Operands Int -> CodeGen Native (Operands Bool)
cont Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt  SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)
                                  Direction
RightToLeft -> SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gte SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)

      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanP2" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSum [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    Operands Int
i0 <- case Direction
dir of
            Direction
LeftToRight -> Operands Int -> CodeGen Native (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)
            Direction
RightToLeft -> Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)

    Operands e
v0 <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i0
    Operands Int
i1 <- Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i0

    Operands (Int, e)
r  <- TypeR (Int, e)
-> (Operands (Int, e) -> CodeGen Native (Operands Bool))
-> (Operands (Int, e) -> CodeGen Native (Operands (Int, e)))
-> Operands (Int, e)
-> CodeGen Native (Operands (Int, e))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (TupR ScalarType Int -> TypeR e -> TypeR (Int, e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
TupRpair (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) TypeR e
eR)
                (Operands Int -> CodeGen Native (Operands Bool)
cont (Operands Int -> CodeGen Native (Operands Bool))
-> (Operands (Int, e) -> Operands Int)
-> Operands (Int, e)
-> CodeGen Native (Operands Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands (Int, e) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sh
A.fst)
                (\(Operands (Int, e) -> (Operands Int, Operands e)
forall a b. Operands (a, b) -> (Operands a, Operands b)
A.unpair -> (Operands Int
i,Operands e
v)) -> do
                   Operands e
u  <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i
                   Operands Int
i' <- Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i
                   Operands e
v' <- case Direction
dir of
                           Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
v Operands e
u
                           Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
u Operands e
v
                   IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp Operands Int
i Operands e
v'
                   Operands (Int, e) -> CodeGen Native (Operands (Int, e))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands (Int, e) -> CodeGen Native (Operands (Int, e)))
-> Operands (Int, e) -> CodeGen Native (Operands (Int, e))
forall a b. (a -> b) -> a -> b
$ Operands Int -> Operands e -> Operands (Int, e)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair Operands Int
i' Operands e
v')
                (Operands Int -> Operands e -> Operands (Int, e)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
A.pair Operands Int
i1 Operands e
v0)

    IntegralType Int
-> IRArray (Scalar e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Scalar e)
arrSum (Int -> Operands Int
liftInt Int
0) (Operands (Int, e) -> Operands e
forall sh sz. Operands (sh, sz) -> Operands sz
A.snd Operands (Int, e)
r)

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel scan', step 3
--
-- Similar to mkScanP3, except that indices are shifted by one since the output
-- array is the same size as the input (despite being an exclusive scan).
--
-- Note that the first chunk does not need to do any extra processing (has no
-- carry-in value).
--
mkScan'P3
    :: Direction
    -> UID
    -> Gamma          aenv
    -> TypeR e
    -> IRFun2  Native aenv (e -> e -> e)
    -> CodeGen Native      (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P3 :: Direction
-> UID
-> Gamma aenv
-> TypeR e
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P3 Direction
dir UID
uid Gamma aenv
aenv TypeR e
eR IRFun2 Native aenv (e -> e -> e)
combine =
  let
      (Operands ((), Int)
start, Operands ((), Int)
end, [Parameter]
paramGang) = ShapeR ((), Int)
-> (Operands ((), Int), Operands ((), Int), [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    ShapeR ((), Int)
dim1
      (IRArray (Vector e)
arrOut, [Parameter]
paramOut)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"out"
      (IRArray (Vector e)
arrTmp, [Parameter]
paramTmp)      = ArrayR (Vector e)
-> Name (Vector e) -> (IRArray (Vector e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
eR) Name (Vector e)
"tmp"
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      --
      steps :: Operands Int
steps                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      paramSteps :: [Parameter]
paramSteps              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.steps"
      piece :: Operands Int
piece                   = TupR ScalarType Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local     (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      paramPiece :: [Parameter]
paramPiece              = TupR ScalarType Int -> Name Int -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter (ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
"ix.piece"
      --
      next :: Operands Int -> CodeGen Native (Operands Int)
next Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
      prev :: Operands Int -> CodeGen Native (Operands Int)
prev Operands Int
i                  = case Direction
dir of
                                  Direction
LeftToRight -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
                                  Direction
RightToLeft -> NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
      firstPiece :: Operands Int
firstPiece              = case Direction
dir of
                                  Direction
LeftToRight -> Int -> Operands Int
liftInt Int
0
                                  Direction
RightToLeft -> Operands Int
steps
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"scanP3" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramPiece [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSteps [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramTmp [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Vector e, Scalar e))
forall a b. (a -> b) -> a -> b
$ do

    -- TODO: don't schedule the "first" piece.
    --
    CodeGen Native (Operands Bool)
-> CodeGen Native () -> CodeGen Native ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
A.when (SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
neq SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
piece Operands Int
firstPiece) (CodeGen Native () -> CodeGen Native ())
-> CodeGen Native () -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ do

      -- Compute start and end indices, leaving space for the initial element
      Operands Int
inf <- Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
start)
      Operands Int
sup <- Operands Int -> CodeGen Native (Operands Int)
next (Operands ((), Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands ((), Int)
end)

      -- Read the carry-in value for this piece
      Operands e
c   <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrTmp (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands Int) -> CodeGen Native (Operands e)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands Int -> CodeGen Native (Operands Int)
prev Operands Int
piece

      -- Apply the carry-in value to all elements of the output
      Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapFromTo Operands Int
inf Operands Int
sup ((Operands Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do
        Operands e
x <- IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i
        Operands e
y <- case Direction
dir of
               Direction
LeftToRight -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
c Operands e
x
               Direction
RightToLeft -> IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> CodeGen Native (Operands e)
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine Operands e
x Operands e
c
        IntegralType Int
-> IRArray (Vector e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Vector e)
arrOut Operands Int
i Operands e
y

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_