{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Definition of /Second-Order Array Combinators/ (SOACs), which are
-- the main form of parallelism in the early stages of the compiler.
module Futhark.IR.SOACS.SOAC
  ( SOAC (..),
    StreamOrd (..),
    StreamForm (..),
    ScremaForm (..),
    HistOp (..),
    Scan (..),
    scanResults,
    singleScan,
    Reduce (..),
    redResults,
    singleReduce,

    -- * Utility
    scremaType,
    soacType,
    typeCheckSOAC,
    mkIdentityLambda,
    isIdentityLambda,
    nilFn,
    scanomapSOAC,
    redomapSOAC,
    scanSOAC,
    reduceSOAC,
    mapSOAC,
    isScanomapSOAC,
    isRedomapSOAC,
    isScanSOAC,
    isReduceSOAC,
    isMapSOAC,
    ppScrema,
    ppHist,
    groupScatterResults,
    groupScatterResults',
    splitScatterResults,

    -- * Generic traversal
    SOACMapper (..),
    identitySOACMapper,
    mapSOACM,
  )
where

import Control.Category
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, removeLambdaAliases)
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Lore
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty (Doc, Pretty, comma, commasep, parens, ppr, text, (<+>), (</>))
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))

-- | A second-order array combinator (SOAC).
data SOAC lore
  = Stream SubExp [VName] (StreamForm lore) [SubExp] (Lambda lore)
  | -- | @Scatter <length> <lambda> <inputs> <outputs>@
    --
    -- Scatter maps values from a set of input arrays to indices and values of a
    -- set of output arrays. It is able to write multiple values to multiple
    -- outputs each of which may have multiple dimensions.
    --
    -- <inputs> is a list of input arrays, all having size <length>, elements of
    -- which are applied to the <lambda> function. For instance, if there are
    -- two arrays, <lambda> will get two values as input, one from each array.
    --
    -- <outputs> specifies the result of the <lambda> and which arrays to write
    -- to. Each element of the list consists of a <VName> specifying which array
    -- to scatter to, a <Shape> describing the shape of that array, and an <Int>
    -- describing how many elements should be written to that array for each
    -- invocation of the <lambda>.
    --
    -- <lambda> is a function that takes inputs from <inputs> and returns values
    -- according to the output-specification in <outputs>. It returns values in
    -- the following manner:
    --
    --     [index_0, index_1, ..., index_n, value_0, value_1, ..., value_m]
    --
    -- For each output in <outputs>, <lambda> returns <i> * <j> index values and
    -- <j> output values, where <i> is the number of dimensions (rank) of the
    -- given output, and <j> is the number of output values written to the given
    -- output.
    --
    -- For example, given the following output specification:
    --
    --     [([x1, y1, z1], 2, arr1), ([x2, y2], 1, arr2)]
    --
    -- <lambda> will produce 6 (3 * 2) index values and 2 output values for
    -- <arr1>, and 2 (2 * 1) index values and 1 output value for
    -- arr2. Additionally, the results are grouped, so the first 6 index values
    -- will correspond to the first two output values, and so on. For this
    -- example, <lambda> should return a total of 11 values, 8 index values and
    -- 3 output values.
    Scatter SubExp (Lambda lore) [VName] [(Shape, Int, VName)]
  | -- | @Hist <length> <dest-arrays-and-ops> <bucket fun> <input arrays>@
    --
    -- The first SubExp is the length of the input arrays. The first
    -- list describes the operations to perform.  The t'Lambda' is the
    -- bucket function.  Finally comes the input images.
    Hist SubExp [HistOp lore] (Lambda lore) [VName]
  | -- | A combination of scan, reduction, and map.  The first
    -- t'SubExp' is the size of the input arrays.
    Screma SubExp [VName] (ScremaForm lore)
  deriving (SOAC lore -> SOAC lore -> Bool
(SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool) -> Eq (SOAC lore)
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC lore -> SOAC lore -> Bool
$c/= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
== :: SOAC lore -> SOAC lore -> Bool
$c== :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
Eq, Eq (SOAC lore)
Eq (SOAC lore)
-> (SOAC lore -> SOAC lore -> Ordering)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> SOAC lore)
-> (SOAC lore -> SOAC lore -> SOAC lore)
-> Ord (SOAC lore)
SOAC lore -> SOAC lore -> Bool
SOAC lore -> SOAC lore -> Ordering
SOAC lore -> SOAC lore -> SOAC lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (SOAC lore)
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Ordering
forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
min :: SOAC lore -> SOAC lore -> SOAC lore
$cmin :: forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
max :: SOAC lore -> SOAC lore -> SOAC lore
$cmax :: forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
>= :: SOAC lore -> SOAC lore -> Bool
$c>= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
> :: SOAC lore -> SOAC lore -> Bool
$c> :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
<= :: SOAC lore -> SOAC lore -> Bool
$c<= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
< :: SOAC lore -> SOAC lore -> Bool
$c< :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
compare :: SOAC lore -> SOAC lore -> Ordering
$ccompare :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Ordering
Ord, Int -> SOAC lore -> ShowS
[SOAC lore] -> ShowS
SOAC lore -> String
(Int -> SOAC lore -> ShowS)
-> (SOAC lore -> String)
-> ([SOAC lore] -> ShowS)
-> Show (SOAC lore)
forall lore. Decorations lore => Int -> SOAC lore -> ShowS
forall lore. Decorations lore => [SOAC lore] -> ShowS
forall lore. Decorations lore => SOAC lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [SOAC lore] -> ShowS
show :: SOAC lore -> String
$cshow :: forall lore. Decorations lore => SOAC lore -> String
showsPrec :: Int -> SOAC lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> SOAC lore -> ShowS
Show)

-- | Information about computing a single histogram.
data HistOp lore = HistOp
  { forall lore. HistOp lore -> SubExp
histWidth :: SubExp,
    -- | Race factor @RF@ means that only @1/RF@
    -- bins are used.
    forall lore. HistOp lore -> SubExp
histRaceFactor :: SubExp,
    forall lore. HistOp lore -> [VName]
histDest :: [VName],
    forall lore. HistOp lore -> [SubExp]
histNeutral :: [SubExp],
    forall lore. HistOp lore -> Lambda lore
histOp :: Lambda lore
  }
  deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Decorations lore => Int -> HistOp lore -> ShowS
forall lore. Decorations lore => [HistOp lore] -> ShowS
forall lore. Decorations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Decorations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> HistOp lore -> ShowS
Show)

-- | Is the stream chunk required to correspond to a contiguous
-- subsequence of the original input ('InOrder') or not?  'Disorder'
-- streams can be more efficient, but not all algorithms work with
-- this.
data StreamOrd = InOrder | Disorder
  deriving (StreamOrd -> StreamOrd -> Bool
(StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool) -> Eq StreamOrd
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamOrd -> StreamOrd -> Bool
$c/= :: StreamOrd -> StreamOrd -> Bool
== :: StreamOrd -> StreamOrd -> Bool
$c== :: StreamOrd -> StreamOrd -> Bool
Eq, Eq StreamOrd
Eq StreamOrd
-> (StreamOrd -> StreamOrd -> Ordering)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> Ord StreamOrd
StreamOrd -> StreamOrd -> Bool
StreamOrd -> StreamOrd -> Ordering
StreamOrd -> StreamOrd -> StreamOrd
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: StreamOrd -> StreamOrd -> StreamOrd
$cmin :: StreamOrd -> StreamOrd -> StreamOrd
max :: StreamOrd -> StreamOrd -> StreamOrd
$cmax :: StreamOrd -> StreamOrd -> StreamOrd
>= :: StreamOrd -> StreamOrd -> Bool
$c>= :: StreamOrd -> StreamOrd -> Bool
> :: StreamOrd -> StreamOrd -> Bool
$c> :: StreamOrd -> StreamOrd -> Bool
<= :: StreamOrd -> StreamOrd -> Bool
$c<= :: StreamOrd -> StreamOrd -> Bool
< :: StreamOrd -> StreamOrd -> Bool
$c< :: StreamOrd -> StreamOrd -> Bool
compare :: StreamOrd -> StreamOrd -> Ordering
$ccompare :: StreamOrd -> StreamOrd -> Ordering
Ord, Int -> StreamOrd -> ShowS
[StreamOrd] -> ShowS
StreamOrd -> String
(Int -> StreamOrd -> ShowS)
-> (StreamOrd -> String)
-> ([StreamOrd] -> ShowS)
-> Show StreamOrd
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamOrd] -> ShowS
$cshowList :: [StreamOrd] -> ShowS
show :: StreamOrd -> String
$cshow :: StreamOrd -> String
showsPrec :: Int -> StreamOrd -> ShowS
$cshowsPrec :: Int -> StreamOrd -> ShowS
Show)

-- | What kind of stream is this?
data StreamForm lore
  = Parallel StreamOrd Commutativity (Lambda lore)
  | Sequential
  deriving (StreamForm lore -> StreamForm lore -> Bool
(StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> Eq (StreamForm lore)
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamForm lore -> StreamForm lore -> Bool
$c/= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
== :: StreamForm lore -> StreamForm lore -> Bool
$c== :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
Eq, Eq (StreamForm lore)
Eq (StreamForm lore)
-> (StreamForm lore -> StreamForm lore -> Ordering)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> StreamForm lore)
-> (StreamForm lore -> StreamForm lore -> StreamForm lore)
-> Ord (StreamForm lore)
StreamForm lore -> StreamForm lore -> Bool
StreamForm lore -> StreamForm lore -> Ordering
StreamForm lore -> StreamForm lore -> StreamForm lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (StreamForm lore)
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Ordering
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
min :: StreamForm lore -> StreamForm lore -> StreamForm lore
$cmin :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
max :: StreamForm lore -> StreamForm lore -> StreamForm lore
$cmax :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
>= :: StreamForm lore -> StreamForm lore -> Bool
$c>= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
> :: StreamForm lore -> StreamForm lore -> Bool
$c> :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
<= :: StreamForm lore -> StreamForm lore -> Bool
$c<= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
< :: StreamForm lore -> StreamForm lore -> Bool
$c< :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
compare :: StreamForm lore -> StreamForm lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Ordering
Ord, Int -> StreamForm lore -> ShowS
[StreamForm lore] -> ShowS
StreamForm lore -> String
(Int -> StreamForm lore -> ShowS)
-> (StreamForm lore -> String)
-> ([StreamForm lore] -> ShowS)
-> Show (StreamForm lore)
forall lore. Decorations lore => Int -> StreamForm lore -> ShowS
forall lore. Decorations lore => [StreamForm lore] -> ShowS
forall lore. Decorations lore => StreamForm lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamForm lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [StreamForm lore] -> ShowS
show :: StreamForm lore -> String
$cshow :: forall lore. Decorations lore => StreamForm lore -> String
showsPrec :: Int -> StreamForm lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> StreamForm lore -> ShowS
Show)

-- | The essential parts of a 'Screma' factored out (everything
-- except the input arrays).
data ScremaForm lore
  = ScremaForm
      [Scan lore]
      [Reduce lore]
      (Lambda lore)
  deriving (ScremaForm lore -> ScremaForm lore -> Bool
(ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> Eq (ScremaForm lore)
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScremaForm lore -> ScremaForm lore -> Bool
$c/= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
== :: ScremaForm lore -> ScremaForm lore -> Bool
$c== :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
Eq, Eq (ScremaForm lore)
Eq (ScremaForm lore)
-> (ScremaForm lore -> ScremaForm lore -> Ordering)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> ScremaForm lore)
-> (ScremaForm lore -> ScremaForm lore -> ScremaForm lore)
-> Ord (ScremaForm lore)
ScremaForm lore -> ScremaForm lore -> Bool
ScremaForm lore -> ScremaForm lore -> Ordering
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (ScremaForm lore)
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Ordering
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
min :: ScremaForm lore -> ScremaForm lore -> ScremaForm lore
$cmin :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
max :: ScremaForm lore -> ScremaForm lore -> ScremaForm lore
$cmax :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
>= :: ScremaForm lore -> ScremaForm lore -> Bool
$c>= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
> :: ScremaForm lore -> ScremaForm lore -> Bool
$c> :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
<= :: ScremaForm lore -> ScremaForm lore -> Bool
$c<= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
< :: ScremaForm lore -> ScremaForm lore -> Bool
$c< :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
compare :: ScremaForm lore -> ScremaForm lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Ordering
Ord, Int -> ScremaForm lore -> ShowS
[ScremaForm lore] -> ShowS
ScremaForm lore -> String
(Int -> ScremaForm lore -> ShowS)
-> (ScremaForm lore -> String)
-> ([ScremaForm lore] -> ShowS)
-> Show (ScremaForm lore)
forall lore. Decorations lore => Int -> ScremaForm lore -> ShowS
forall lore. Decorations lore => [ScremaForm lore] -> ShowS
forall lore. Decorations lore => ScremaForm lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScremaForm lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [ScremaForm lore] -> ShowS
show :: ScremaForm lore -> String
$cshow :: forall lore. Decorations lore => ScremaForm lore -> String
showsPrec :: Int -> ScremaForm lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> ScremaForm lore -> ShowS
Show)

singleBinOp :: Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp :: forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp [Lambda lore]
lams =
  Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
    { lambdaParams :: [LParam lore]
lambdaParams = (Lambda lore -> [Param Type]) -> [Lambda lore] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Param Type]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
xParams [Lambda lore]
lams [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Lambda lore -> [Param Type]) -> [Lambda lore] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Param Type]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
yParams [Lambda lore]
lams,
      lambdaReturnType :: [Type]
lambdaReturnType = (Lambda lore -> [Type]) -> [Lambda lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType [Lambda lore]
lams,
      lambdaBody :: BodyT lore
lambdaBody =
        Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody
          ([Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat ((Lambda lore -> Stms lore) -> [Lambda lore] -> [Stms lore]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore)
-> (Lambda lore -> BodyT lore) -> Lambda lore -> Stms lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda lore]
lams))
          ((Lambda lore -> [SubExp]) -> [Lambda lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp])
-> (Lambda lore -> BodyT lore) -> Lambda lore -> [SubExp]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda lore]
lams)
    }
  where
    xParams :: LambdaT lore -> [Param (LParamInfo lore)]
xParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam)
    yParams :: LambdaT lore -> [Param (LParamInfo lore)]
yParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam)

-- | How to compute a single scan result.
data Scan lore = Scan
  { forall lore. Scan lore -> Lambda lore
scanLambda :: Lambda lore,
    forall lore. Scan lore -> [SubExp]
scanNeutral :: [SubExp]
  }
  deriving (Scan lore -> Scan lore -> Bool
(Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool) -> Eq (Scan lore)
forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scan lore -> Scan lore -> Bool
$c/= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
== :: Scan lore -> Scan lore -> Bool
$c== :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
Eq, Eq (Scan lore)
Eq (Scan lore)
-> (Scan lore -> Scan lore -> Ordering)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Scan lore)
-> (Scan lore -> Scan lore -> Scan lore)
-> Ord (Scan lore)
Scan lore -> Scan lore -> Bool
Scan lore -> Scan lore -> Ordering
Scan lore -> Scan lore -> Scan lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (Scan lore)
forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
forall lore. Decorations lore => Scan lore -> Scan lore -> Ordering
forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
min :: Scan lore -> Scan lore -> Scan lore
$cmin :: forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
max :: Scan lore -> Scan lore -> Scan lore
$cmax :: forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
>= :: Scan lore -> Scan lore -> Bool
$c>= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
> :: Scan lore -> Scan lore -> Bool
$c> :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
<= :: Scan lore -> Scan lore -> Bool
$c<= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
< :: Scan lore -> Scan lore -> Bool
$c< :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
compare :: Scan lore -> Scan lore -> Ordering
$ccompare :: forall lore. Decorations lore => Scan lore -> Scan lore -> Ordering
Ord, Int -> Scan lore -> ShowS
[Scan lore] -> ShowS
Scan lore -> String
(Int -> Scan lore -> ShowS)
-> (Scan lore -> String)
-> ([Scan lore] -> ShowS)
-> Show (Scan lore)
forall lore. Decorations lore => Int -> Scan lore -> ShowS
forall lore. Decorations lore => [Scan lore] -> ShowS
forall lore. Decorations lore => Scan lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scan lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [Scan lore] -> ShowS
show :: Scan lore -> String
$cshow :: forall lore. Decorations lore => Scan lore -> String
showsPrec :: Int -> Scan lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> Scan lore -> ShowS
Show)

-- | How many reduction results are produced by these 'Scan's?
scanResults :: [Scan lore] -> Int
scanResults :: forall lore. [Scan lore] -> Int
scanResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Scan lore] -> [Int]) -> [Scan lore] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Scan lore -> Int) -> [Scan lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Scan lore -> [SubExp]) -> Scan lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral)

-- | Combine multiple scan operators to a single operator.
singleScan :: Bindable lore => [Scan lore] -> Scan lore
singleScan :: forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan lore]
scans =
  let scan_nes :: [SubExp]
scan_nes = (Scan lore -> [SubExp]) -> [Scan lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan lore -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan lore]
scans
      scan_lam :: Lambda lore
scan_lam = [Lambda lore] -> Lambda lore
forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp ([Lambda lore] -> Lambda lore) -> [Lambda lore] -> Lambda lore
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Lambda lore) -> [Scan lore] -> [Lambda lore]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda [Scan lore]
scans
   in Lambda lore -> [SubExp] -> Scan lore
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan Lambda lore
scan_lam [SubExp]
scan_nes

-- | How to compute a single reduction result.
data Reduce lore = Reduce
  { forall lore. Reduce lore -> Commutativity
redComm :: Commutativity,
    forall lore. Reduce lore -> Lambda lore
redLambda :: Lambda lore,
    forall lore. Reduce lore -> [SubExp]
redNeutral :: [SubExp]
  }
  deriving (Reduce lore -> Reduce lore -> Bool
(Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool) -> Eq (Reduce lore)
forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reduce lore -> Reduce lore -> Bool
$c/= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
== :: Reduce lore -> Reduce lore -> Bool
$c== :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
Eq, Eq (Reduce lore)
Eq (Reduce lore)
-> (Reduce lore -> Reduce lore -> Ordering)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Reduce lore)
-> (Reduce lore -> Reduce lore -> Reduce lore)
-> Ord (Reduce lore)
Reduce lore -> Reduce lore -> Bool
Reduce lore -> Reduce lore -> Ordering
Reduce lore -> Reduce lore -> Reduce lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (Reduce lore)
forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Ordering
forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
min :: Reduce lore -> Reduce lore -> Reduce lore
$cmin :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
max :: Reduce lore -> Reduce lore -> Reduce lore
$cmax :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
>= :: Reduce lore -> Reduce lore -> Bool
$c>= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
> :: Reduce lore -> Reduce lore -> Bool
$c> :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
<= :: Reduce lore -> Reduce lore -> Bool
$c<= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
< :: Reduce lore -> Reduce lore -> Bool
$c< :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
compare :: Reduce lore -> Reduce lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Ordering
Ord, Int -> Reduce lore -> ShowS
[Reduce lore] -> ShowS
Reduce lore -> String
(Int -> Reduce lore -> ShowS)
-> (Reduce lore -> String)
-> ([Reduce lore] -> ShowS)
-> Show (Reduce lore)
forall lore. Decorations lore => Int -> Reduce lore -> ShowS
forall lore. Decorations lore => [Reduce lore] -> ShowS
forall lore. Decorations lore => Reduce lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reduce lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [Reduce lore] -> ShowS
show :: Reduce lore -> String
$cshow :: forall lore. Decorations lore => Reduce lore -> String
showsPrec :: Int -> Reduce lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> Reduce lore -> ShowS
Show)

-- | How many reduction results are produced by these 'Reduce's?
redResults :: [Reduce lore] -> Int
redResults :: forall lore. [Reduce lore] -> Int
redResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Reduce lore] -> [Int]) -> [Reduce lore] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Reduce lore -> Int) -> [Reduce lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (Reduce lore -> [SubExp]) -> Reduce lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral)

-- | Combine multiple reduction operators to a single operator.
singleReduce :: Bindable lore => [Reduce lore] -> Reduce lore
singleReduce :: forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce lore]
reds =
  let red_nes :: [SubExp]
red_nes = (Reduce lore -> [SubExp]) -> [Reduce lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce lore -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce lore]
reds
      red_lam :: Lambda lore
red_lam = [Lambda lore] -> Lambda lore
forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp ([Lambda lore] -> Lambda lore) -> [Lambda lore] -> Lambda lore
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Lambda lore) -> [Reduce lore] -> [Lambda lore]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda [Reduce lore]
reds
   in Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce ([Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((Reduce lore -> Commutativity) -> [Reduce lore] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Commutativity
forall lore. Reduce lore -> Commutativity
redComm [Reduce lore]
reds)) Lambda lore
red_lam [SubExp]
red_nes

-- | The types produced by a single 'Screma', given the size of the
-- input array.
scremaType :: SubExp -> ScremaForm lore -> [Type]
scremaType :: forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) =
  [Type]
scan_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
red_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
  where
    scan_tps :: [Type]
scan_tps =
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
        (Scan lore -> [Type]) -> [Scan lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Scan lore -> Lambda lore) -> Scan lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans
    red_tps :: [Type]
red_tps = (Reduce lore -> [Type]) -> [Reduce lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Reduce lore -> Lambda lore) -> Reduce lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds
    map_tps :: [Type]
map_tps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
map_lam

-- | Construct a lambda that takes parameters of the given types and
-- simply returns them unchanged.
mkIdentityLambda ::
  (Bindable lore, MonadFreshNames m) =>
  [Type] ->
  m (Lambda lore)
mkIdentityLambda :: forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts = do
  [Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
  Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
    Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
      { lambdaParams :: [LParam lore]
lambdaParams = [Param Type]
[LParam lore]
params,
        lambdaBody :: BodyT lore
lambdaBody = Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty ([SubExp] -> BodyT lore) -> [SubExp] -> BodyT lore
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
params,
        lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
      }

-- | Is the given lambda an identity lambda?
isIdentityLambda :: Lambda lore -> Bool
isIdentityLambda :: forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
lam =
  BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
    [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param (LParamInfo lore) -> SubExp)
-> [Param (LParamInfo lore)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (LParamInfo lore) -> VName)
-> Param (LParamInfo lore)
-> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName) (Lambda lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
lam)

-- | A lambda with no parameters that returns no values.
nilFn :: Bindable lore => Lambda lore
nilFn :: forall lore. Bindable lore => Lambda lore
nilFn = [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
forall a. Monoid a => a
mempty (Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty [SubExp]
forall a. Monoid a => a
mempty) [Type]
forall a. Monoid a => a
mempty

-- | Construct a Screma with possibly multiple scans, and
-- the given map function.
scanomapSOAC :: [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC :: forall lore. [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC [Scan lore]
scans = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan lore]
scans []

-- | Construct a Screma with possibly multiple reductions, and
-- the given map function.
redomapSOAC :: [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC :: forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm []

-- | Construct a Screma with possibly multiple scans, and identity map
-- function.
scanSOAC ::
  (Bindable lore, MonadFreshNames m) =>
  [Scan lore] ->
  m (ScremaForm lore)
scanSOAC :: forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Scan lore] -> m (ScremaForm lore)
scanSOAC [Scan lore]
scans = [Scan lore] -> Lambda lore -> ScremaForm lore
forall lore. [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC [Scan lore]
scans (Lambda lore -> ScremaForm lore)
-> m (Lambda lore) -> m (ScremaForm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = (Scan lore -> [Type]) -> [Scan lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Scan lore -> Lambda lore) -> Scan lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans

-- | Construct a Screma with possibly multiple reductions, and
-- identity map function.
reduceSOAC ::
  (Bindable lore, MonadFreshNames m) =>
  [Reduce lore] ->
  m (ScremaForm lore)
reduceSOAC :: forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Reduce lore]
reds = [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Reduce lore]
reds (Lambda lore -> ScremaForm lore)
-> m (Lambda lore) -> m (ScremaForm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = (Reduce lore -> [Type]) -> [Reduce lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Reduce lore -> Lambda lore) -> Reduce lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds

-- | Construct a Screma corresponding to a map.
mapSOAC :: Lambda lore -> ScremaForm lore
mapSOAC :: forall lore. Lambda lore -> ScremaForm lore
mapSOAC = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [] []

-- | Does this Screma correspond to a scan-map composition?
isScanomapSOAC :: ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC :: forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
  ([Scan lore], Lambda lore) -> Maybe ([Scan lore], Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Scan lore]
scans, Lambda lore
map_lam)

-- | Does this Screma correspond to pure scan?
isScanSOAC :: ScremaForm lore -> Maybe [Scan lore]
isScanSOAC :: forall lore. ScremaForm lore -> Maybe [Scan lore]
isScanSOAC ScremaForm lore
form = do
  ([Scan lore]
scans, Lambda lore
map_lam) <- ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm lore
form
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
map_lam
  [Scan lore] -> Maybe [Scan lore]
forall (m :: * -> *) a. Monad m => a -> m a
return [Scan lore]
scans

-- | Does this Screma correspond to a reduce-map composition?
isRedomapSOAC :: ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC :: forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
  ([Reduce lore], Lambda lore) -> Maybe ([Reduce lore], Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Reduce lore]
reds, Lambda lore
map_lam)

-- | Does this Screma correspond to a pure reduce?
isReduceSOAC :: ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC :: forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm lore
form = do
  ([Reduce lore]
reds, Lambda lore
map_lam) <- ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm lore
form
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
map_lam
  [Reduce lore] -> Maybe [Reduce lore]
forall (m :: * -> *) a. Monad m => a -> m a
return [Reduce lore]
reds

-- | Does this Screma correspond to a simple map, without any
-- reduction or scan results?
isMapSOAC :: ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC :: forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
  Lambda lore -> Maybe (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
map_lam

-- | @groupScatterResults <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- <output specification>.
--
-- This function is used for extracting and grouping the results of a
-- scatter. In the SOAC representation, the lambda inside a 'Scatter' returns
-- all indices and values as one big list. This function groups each value with
-- its corresponding indices (as determined by the 'Shape' of the output array).
--
-- The elements of the resulting list correspond to the shape and name of the
-- output parameters, in addition to a list of values written to that output
-- parameter, along with the array indices marking where to write them to.
--
-- See 'Scatter' for more information.
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
arrays) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
   in [(Shape, Int, array)] -> [a] -> [([a], a)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results
        [([a], a)] -> ([([a], a)] -> [[([a], a)]]) -> [[([a], a)]]
forall a b. a -> (a -> b) -> b
& [Int] -> [([a], a)] -> [[([a], a)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
        [[([a], a)]]
-> ([[([a], a)]] -> [(Shape, array, [([a], a)])])
-> [(Shape, array, [([a], a)])]
forall a b. a -> (a -> b) -> b
& [Shape] -> [array] -> [[([a], a)]] -> [(Shape, array, [([a], a)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays

-- | @groupScatterResults' <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- output specification. This is the simpler version of @groupScatterResults@,
-- which doesn't return any information about shapes or output arrays.
--
-- See 'groupScatterResults' for more information,
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
  let ([a]
indices, [a]
values) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
      ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      chunk_sizes :: [Int]
chunk_sizes =
        [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int -> [Int]) -> [Shape] -> [Int] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
n (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
   in [[a]] -> [a] -> [([a], a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values

-- | @splitScatterResults <output specification> <results>@
--
-- Splits the results array into indices and values according to the output
-- specification.
--
-- See 'groupScatterResults' for more information.
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      num_indices :: Int
num_indices = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
   in Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results

-- | Like 'Mapper', but just for 'SOAC's.
data SOACMapper flore tlore m = SOACMapper
  { forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
    forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda :: Lambda flore -> m (Lambda tlore),
    forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
  }

-- | A mapper that simply returns the SOAC verbatim.
identitySOACMapper :: Monad m => SOACMapper lore lore m
identitySOACMapper :: forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper =
  SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper
    { mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSOACLambda :: Lambda lore -> m (Lambda lore)
mapOnSOACLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSOACVName :: VName -> m VName
mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
    }

-- | Map a monadic action across the immediate children of a
-- SOAC.  The mapping does not descend recursively into subexpressions
-- and is done left-to-right.
mapSOACM ::
  (Applicative m, Monad m) =>
  SOACMapper flore tlore m ->
  SOAC flore ->
  m (SOAC tlore)
mapSOACM :: forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper flore tlore m
tv (Stream SubExp
size [VName]
arrs StreamForm flore
form [SubExp]
accs Lambda flore
lam) =
  SubExp
-> [VName]
-> StreamForm tlore
-> [SubExp]
-> Lambda tlore
-> SOAC tlore
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Stream (SubExp
 -> [VName]
 -> StreamForm tlore
 -> [SubExp]
 -> Lambda tlore
 -> SOAC tlore)
-> m SubExp
-> m ([VName]
      -> StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
size
    m ([VName]
   -> StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
-> m [VName]
-> m (StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
    m (StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
-> m (StreamForm tlore)
-> m ([SubExp] -> Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StreamForm flore -> m (StreamForm tlore)
mapOnStreamForm StreamForm flore
form
    m ([SubExp] -> Lambda tlore -> SOAC tlore)
-> m [SubExp] -> m (Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
accs
    m (Lambda tlore -> SOAC tlore)
-> m (Lambda tlore) -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam
  where
    mapOnStreamForm :: StreamForm flore -> m (StreamForm tlore)
mapOnStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda flore
lam0) =
      StreamOrd -> Commutativity -> Lambda tlore -> StreamForm tlore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o Commutativity
comm (Lambda tlore -> StreamForm tlore)
-> m (Lambda tlore) -> m (StreamForm tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam0
    mapOnStreamForm StreamForm flore
Sequential =
      StreamForm tlore -> m (StreamForm tlore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure StreamForm tlore
forall lore. StreamForm lore
Sequential
mapSOACM SOACMapper flore tlore m
tv (Scatter SubExp
len Lambda flore
lam [VName]
ivs [(Shape, Int, VName)]
as) =
  SubExp
-> Lambda tlore -> [VName] -> [(Shape, Int, VName)] -> SOAC tlore
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter
    (SubExp
 -> Lambda tlore -> [VName] -> [(Shape, Int, VName)] -> SOAC tlore)
-> m SubExp
-> m (Lambda tlore
      -> [VName] -> [(Shape, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
len
    m (Lambda tlore -> [VName] -> [(Shape, Int, VName)] -> SOAC tlore)
-> m (Lambda tlore)
-> m ([VName] -> [(Shape, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam
    m ([VName] -> [(Shape, Int, VName)] -> SOAC tlore)
-> m [VName] -> m ([(Shape, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
ivs
    m ([(Shape, Int, VName)] -> SOAC tlore)
-> m [(Shape, Int, VName)] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Shape, Int, VName) -> m (Shape, Int, VName))
-> [(Shape, Int, VName)] -> m [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(Shape
aw, Int
an, VName
a) ->
          (,,) (Shape -> Int -> VName -> (Shape, Int, VName))
-> m Shape -> m (Int -> VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) Shape
aw
            m (Int -> VName -> (Shape, Int, VName))
-> m Int -> m (VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
            m (VName -> (Shape, Int, VName))
-> m VName -> m (Shape, Int, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv VName
a
      )
      [(Shape, Int, VName)]
as
mapSOACM SOACMapper flore tlore m
tv (Hist SubExp
len [HistOp flore]
ops Lambda flore
bucket_fun [VName]
imgs) =
  SubExp -> [HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Hist
    (SubExp -> [HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
-> m SubExp
-> m ([HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
len
    m ([HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
-> m [HistOp tlore] -> m (Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(HistOp SubExp
e SubExp
rf [VName]
arrs [SubExp]
nes Lambda flore
op) ->
          SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore
forall lore.
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda lore -> HistOp lore
HistOp (SubExp
 -> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m (SubExp
      -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
e
            m (SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m ([VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
rf
            m ([VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m [VName] -> m ([SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
            m ([SubExp] -> Lambda tlore -> HistOp tlore)
-> m [SubExp] -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
nes
            m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
op
      )
      [HistOp flore]
ops
    m (Lambda tlore -> [VName] -> SOAC tlore)
-> m (Lambda tlore) -> m ([VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
bucket_fun
    m ([VName] -> SOAC tlore) -> m [VName] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
imgs
mapSOACM SOACMapper flore tlore m
tv (Screma SubExp
w [VName]
arrs (ScremaForm [Scan flore]
scans [Reduce flore]
reds Lambda flore
map_lam)) =
  SubExp -> [VName] -> ScremaForm tlore -> SOAC tlore
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma (SubExp -> [VName] -> ScremaForm tlore -> SOAC tlore)
-> m SubExp -> m ([VName] -> ScremaForm tlore -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
w
    m ([VName] -> ScremaForm tlore -> SOAC tlore)
-> m [VName] -> m (ScremaForm tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
    m (ScremaForm tlore -> SOAC tlore)
-> m (ScremaForm tlore) -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( [Scan tlore] -> [Reduce tlore] -> Lambda tlore -> ScremaForm tlore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm
            ([Scan tlore]
 -> [Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
-> m [Scan tlore]
-> m ([Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Scan flore] -> (Scan flore -> m (Scan tlore)) -> m [Scan tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Scan flore]
scans
              ( \(Scan Lambda flore
red_lam [SubExp]
red_nes) ->
                  Lambda tlore -> [SubExp] -> Scan tlore
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan (Lambda tlore -> [SubExp] -> Scan tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Scan tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
red_lam
                    m ([SubExp] -> Scan tlore) -> m [SubExp] -> m (Scan tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
red_nes
              )
            m ([Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
-> m [Reduce tlore] -> m (Lambda tlore -> ScremaForm tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Reduce flore]
-> (Reduce flore -> m (Reduce tlore)) -> m [Reduce tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Reduce flore]
reds
              ( \(Reduce Commutativity
comm Lambda flore
red_lam [SubExp]
red_nes) ->
                  Commutativity -> Lambda tlore -> [SubExp] -> Reduce tlore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm (Lambda tlore -> [SubExp] -> Reduce tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Reduce tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
red_lam
                    m ([SubExp] -> Reduce tlore) -> m [SubExp] -> m (Reduce tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
red_nes
              )
            m (Lambda tlore -> ScremaForm tlore)
-> m (Lambda tlore) -> m (ScremaForm tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
map_lam
        )

instance ASTLore lore => FreeIn (SOAC lore) where
  freeIn' :: SOAC lore -> FV
freeIn' = (State FV (SOAC lore) -> FV -> FV)
-> FV -> State FV (SOAC lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SOAC lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SOAC lore) -> FV)
-> (SOAC lore -> State FV (SOAC lore)) -> SOAC lore -> FV
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper lore lore (StateT FV Identity)
-> SOAC lore -> State FV (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore (StateT FV Identity)
free
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
      free :: SOACMapper lore lore (StateT FV Identity)
free =
        SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper
          { mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSOACLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
          }

instance ASTLore lore => Substitute (SOAC lore) where
  substituteNames :: Map VName VName -> SOAC lore -> SOAC lore
substituteNames Map VName VName
subst =
    Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC lore -> Identity (SOAC lore)) -> SOAC lore -> SOAC lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper lore lore Identity -> SOAC lore -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore Identity
substitute
    where
      substitute :: SOACMapper lore lore Identity
substitute =
        SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper
          { mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSOACLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance ASTLore lore => Rename (SOAC lore) where
  rename :: SOAC lore -> RenameM (SOAC lore)
rename = SOACMapper lore lore RenameM -> SOAC lore -> RenameM (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore RenameM
renamer
    where
      renamer :: SOACMapper lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (VName -> RenameM VName)
-> SOACMapper lore lore RenameM
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename

-- | The type of a SOAC.
soacType :: SOAC lore -> [Type]
soacType :: forall lore. SOAC lore -> [Type]
soacType (Stream SubExp
outersize [VName]
_ StreamForm lore
_ [SubExp]
accs Lambda lore
lam) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
  where
    nms :: [VName]
nms = (Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [Param (LParamInfo lore)]
params
    substs :: Map VName SubExp
substs = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
accs)
    Lambda [Param (LParamInfo lore)]
params BodyT lore
_ [Type]
rtp = Lambda lore
lam
soacType (Scatter SubExp
_w Lambda lore
lam [VName]
_ivs [(Shape, Int, VName)]
as) =
  (Type -> Shape -> Type) -> [Type] -> [Shape] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
  where
    indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
    val_ts :: [Type]
val_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
    ([Shape]
ws, [Int]
ns, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
soacType (Hist SubExp
_len [HistOp lore]
ops Lambda lore
_bucket_fun [VName]
_imgs) = do
  HistOp lore
op <- [HistOp lore]
ops
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op) (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type]) -> Lambda lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm lore
form) =
  SubExp -> ScremaForm lore -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm lore
form

instance TypedOp (SOAC lore) where
  opType :: forall t (m :: * -> *). HasScope t m => SOAC lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SOAC lore -> [ExtType]) -> SOAC lore -> m [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SOAC lore -> [Type]) -> SOAC lore -> [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC lore -> [Type]
forall lore. SOAC lore -> [Type]
soacType

instance (ASTLore lore, Aliased lore) => AliasedOp (SOAC lore) where
  opAliases :: SOAC lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SOAC lore -> [Type]) -> SOAC lore -> [Names]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC lore -> [Type]
forall lore. SOAC lore -> [Type]
soacType

  -- Only map functions can consume anything.  The operands to scan
  -- and reduce functions are always considered "fresh".
  consumedInOp :: SOAC lore -> Names
consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan lore]
_ [Reduce lore]
_ Lambda lore
map_lam)) =
    (VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
map_lam
    where
      consumedArray :: VName -> VName
consumedArray VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
      params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
map_lam) [VName]
arrs
  consumedInOp (Stream SubExp
_ [VName]
arrs StreamForm lore
form [SubExp]
accs Lambda lore
lam) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
      [SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$
        case StreamForm lore
form of
          StreamForm lore
Sequential ->
            (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
lam
          Parallel {} ->
            (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
lam
    where
      consumedArray :: VName -> SubExp
consumedArray VName
v = SubExp -> Maybe SubExp -> SubExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, SubExp)] -> Maybe SubExp
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
      -- Drop the chunk parameter, which cannot alias anything.
      paramsToInput :: [(VName, SubExp)]
paramsToInput =
        [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop Int
1 ([Param (LParamInfo lore)] -> [Param (LParamInfo lore)])
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
lam) ([SubExp]
accs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
  consumedInOp (Scatter SubExp
_ Lambda lore
_ [VName]
_ [(Shape, Int, VName)]
as) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) [(Shape, Int, VName)]
as
  consumedInOp (Hist SubExp
_ [HistOp lore]
ops Lambda lore
_ [VName]
_) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops

mapHistOp ::
  (Lambda flore -> Lambda tlore) ->
  HistOp flore ->
  HistOp tlore
mapHistOp :: forall flore tlore.
(Lambda flore -> Lambda tlore) -> HistOp flore -> HistOp tlore
mapHistOp Lambda flore -> Lambda tlore
f (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda flore
lam) =
  SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore
forall lore.
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda lore -> HistOp lore
HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes (Lambda tlore -> HistOp tlore) -> Lambda tlore -> HistOp tlore
forall a b. (a -> b) -> a -> b
$ Lambda flore -> Lambda tlore
f Lambda flore
lam

instance
  ( ASTLore lore,
    ASTLore (Aliases lore),
    CanBeAliased (Op lore)
  ) =>
  CanBeAliased (SOAC lore)
  where
  type OpWithAliases (SOAC lore) = SOAC (Aliases lore)

  addOpAliases :: AliasTable -> SOAC lore -> OpWithAliases (SOAC lore)
addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr StreamForm lore
form [SubExp]
accs Lambda lore
lam) =
    SubExp
-> [VName]
-> StreamForm (Aliases lore)
-> [SubExp]
-> Lambda (Aliases lore)
-> SOAC (Aliases lore)
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Stream SubExp
size [VName]
arr (StreamForm lore -> StreamForm (Aliases lore)
analyseStreamForm StreamForm lore
form) [SubExp]
accs (Lambda (Aliases lore) -> SOAC (Aliases lore))
-> Lambda (Aliases lore) -> SOAC (Aliases lore)
forall a b. (a -> b) -> a -> b
$
      AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
lam
    where
      analyseStreamForm :: StreamForm lore -> StreamForm (Aliases lore)
analyseStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda lore
lam0) =
        StreamOrd
-> Commutativity
-> Lambda (Aliases lore)
-> StreamForm (Aliases lore)
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o Commutativity
comm (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
lam0)
      analyseStreamForm StreamForm lore
Sequential = StreamForm (Aliases lore)
forall lore. StreamForm lore
Sequential
  addOpAliases AliasTable
aliases (Scatter SubExp
len Lambda lore
lam [VName]
ivs [(Shape, Int, VName)]
as) =
    SubExp
-> Lambda (Aliases lore)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Aliases lore)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
len (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
lam) [VName]
ivs [(Shape, Int, VName)]
as
  addOpAliases AliasTable
aliases (Hist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs) =
    SubExp
-> [HistOp (Aliases lore)]
-> Lambda (Aliases lore)
-> [VName]
-> SOAC (Aliases lore)
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Hist
      SubExp
len
      ((HistOp lore -> HistOp (Aliases lore))
-> [HistOp lore] -> [HistOp (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((Lambda lore -> Lambda (Aliases lore))
-> HistOp lore -> HistOp (Aliases lore)
forall flore tlore.
(Lambda flore -> Lambda tlore) -> HistOp flore -> HistOp tlore
mapHistOp (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases)) [HistOp lore]
ops)
      (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
bucket_fun)
      [VName]
imgs
  addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam)) =
    SubExp
-> [VName] -> ScremaForm (Aliases lore) -> SOAC (Aliases lore)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (ScremaForm (Aliases lore) -> SOAC (Aliases lore))
-> ScremaForm (Aliases lore) -> SOAC (Aliases lore)
forall a b. (a -> b) -> a -> b
$
      [Scan (Aliases lore)]
-> [Reduce (Aliases lore)]
-> Lambda (Aliases lore)
-> ScremaForm (Aliases lore)
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm
        ((Scan lore -> Scan (Aliases lore))
-> [Scan lore] -> [Scan (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Scan (Aliases lore)
onScan [Scan lore]
scans)
        ((Reduce lore -> Reduce (Aliases lore))
-> [Reduce lore] -> [Reduce (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Reduce (Aliases lore)
onRed [Reduce lore]
reds)
        (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
map_lam)
    where
      onRed :: Reduce lore -> Reduce (Aliases lore)
onRed Reduce lore
red = Reduce lore
red {redLambda :: Lambda (Aliases lore)
redLambda = AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore -> Lambda (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda Reduce lore
red}
      onScan :: Scan lore -> Scan (Aliases lore)
onScan Scan lore
scan = Scan lore
scan {scanLambda :: Lambda (Aliases lore)
scanLambda = AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore -> Lambda (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda Scan lore
scan}

  removeOpAliases :: OpWithAliases (SOAC lore) -> SOAC lore
removeOpAliases = Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC (Aliases lore) -> Identity (SOAC lore))
-> SOAC (Aliases lore)
-> SOAC lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Aliases lore) lore Identity
-> SOAC (Aliases lore) -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper (Aliases lore) lore Identity
remove
    where
      remove :: SOACMapper (Aliases lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (VName -> Identity VName)
-> SOACMapper (Aliases lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance ASTLore lore => IsOp (SOAC lore) where
  safeOp :: SOAC lore -> Bool
safeOp SOAC lore
_ = Bool
False
  cheapOp :: SOAC lore -> Bool
cheapOp SOAC lore
_ = Bool
True

substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
  let shp' :: Shape
shp' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
   in PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u

substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
  SubExp -> VName -> Map VName SubExp -> SubExp
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs

instance (ASTLore lore, CanBeWise (Op lore)) => CanBeWise (SOAC lore) where
  type OpWithWisdom (SOAC lore) = SOAC (Wise lore)

  removeOpWisdom :: OpWithWisdom (SOAC lore) -> SOAC lore
removeOpWisdom = Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC (Wise lore) -> Identity (SOAC lore))
-> SOAC (Wise lore)
-> SOAC lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Wise lore) lore Identity
-> SOAC (Wise lore) -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper (Wise lore) lore Identity
remove
    where
      remove :: SOACMapper (Wise lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (VName -> Identity VName)
-> SOACMapper (Wise lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance Decorations lore => ST.IndexOp (SOAC lore) where
  indexOp :: forall lore.
(ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> SOAC lore -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k SOAC lore
soac [TPrimExp Int64 VName
i] = do
    (LambdaT lore
lam, SubExp
se, [Param (LParamInfo lore)]
arr_params, [VName]
arrs) <- SOAC lore
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
lambdaAndSubExp SOAC lore
soac
    let arr_indexes :: Map VName (PrimExp VName, Certificates)
arr_indexes = [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (PrimExp VName, Certificates))]
 -> Map VName (PrimExp VName, Certificates))
-> [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, (PrimExp VName, Certificates))]
 -> [(VName, (PrimExp VName, Certificates))])
-> [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo lore)
 -> VName -> Maybe (VName, (PrimExp VName, Certificates)))
-> [Param (LParamInfo lore)]
-> [VName]
-> [Maybe (VName, (PrimExp VName, Certificates))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo lore)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex [Param (LParamInfo lore)]
arr_params [VName]
arrs
        arr_indexes' :: Map VName (PrimExp VName, Certificates)
arr_indexes' = (Map VName (PrimExp VName, Certificates)
 -> Stm lore -> Map VName (PrimExp VName, Certificates))
-> Map VName (PrimExp VName, Certificates)
-> Seq (Stm lore)
-> Map VName (PrimExp VName, Certificates)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certificates)
-> Stm lore -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
arr_indexes (Seq (Stm lore) -> Map VName (PrimExp VName, Certificates))
-> Seq (Stm lore) -> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Seq (Stm lore)) -> BodyT lore -> Seq (Stm lore)
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT lore
lam
    case SubExp
se of
      Var VName
v -> (PrimExp VName -> Certificates -> Indexed)
-> (PrimExp VName, Certificates) -> Indexed
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Certificates -> PrimExp VName -> Indexed)
-> PrimExp VName -> Certificates -> Indexed
forall a b c. (a -> b -> c) -> b -> a -> c
flip Certificates -> PrimExp VName -> Indexed
ST.Indexed) ((PrimExp VName, Certificates) -> Indexed)
-> Maybe (PrimExp VName, Certificates) -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
arr_indexes'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
    where
      lambdaAndSubExp :: SOAC lore
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds LambdaT lore
map_lam)) =
        Int
-> LambdaT lore
-> [VName]
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
nthMapOut ([Scan lore] -> Int
forall lore. [Scan lore] -> Int
scanResults [Scan lore]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce lore] -> Int
forall lore. [Reduce lore] -> Int
redResults [Reduce lore]
reds) LambdaT lore
map_lam [VName]
arrs
      lambdaAndSubExp SOAC lore
_ =
        Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
forall a. Maybe a
Nothing

      nthMapOut :: Int
-> LambdaT lore
-> [VName]
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
nthMapOut Int
num_accs LambdaT lore
lam [VName]
arrs = do
        SubExp
se <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT lore
lam
        (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaT lore
lam, SubExp
se, Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (LParamInfo lore)] -> [Param (LParamInfo lore)])
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam, [VName]
arrs)

      arrIndex :: Param (LParamInfo lore)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex Param (LParamInfo lore)
p VName
arr = do
        ST.Indexed Certificates
cs PrimExp VName
pe <- VName
-> [TPrimExp Int64 VName] -> SymbolTable lore -> Maybe Indexed
forall lore.
VName
-> [TPrimExp Int64 VName] -> SymbolTable lore -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable lore
vtable
        (VName, (PrimExp VName, Certificates))
-> Maybe (VName, (PrimExp VName, Certificates))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
p, (PrimExp VName
pe, Certificates
cs))

      expandPrimExpTable :: Map VName (PrimExp VName, Certificates)
-> Stm lore -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
table Stm lore
stm
        | [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
          Just (PrimExp VName
pe, Certificates
cs) <-
            WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable) (Certificates -> [VName]
unCertificates (Certificates -> [VName]) -> Certificates -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm) =
          VName
-> (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) Map VName (PrimExp VName, Certificates)
table
        | Bool
otherwise =
          Map VName (PrimExp VName, Certificates)
table

      asPrimExp :: Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table VName
v
        | Just (PrimExp VName
e, Certificates
cs) <- VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
        | Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
          PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable lore
_ Int
_ SOAC lore
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

-- | Type-check a SOAC.
typeCheckSOAC :: TC.Checkable lore => SOAC (Aliases lore) -> TC.TypeM lore ()
typeCheckSOAC :: forall lore. Checkable lore => SOAC (Aliases lore) -> TypeM lore ()
typeCheckSOAC (Stream SubExp
size [VName]
arrexps StreamForm (Aliases lore)
form [SubExp]
accexps Lambda (Aliases lore)
lam) = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  [Arg]
accargs <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
accexps
  [Type]
arrargs <- (VName -> TypeM lore Type) -> [VName] -> TypeM lore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrexps
  [Arg]
_ <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
  let chunk :: Param (LParamInfo lore)
chunk = [Param (LParamInfo lore)] -> Param (LParamInfo lore)
forall a. [a] -> a
head ([Param (LParamInfo lore)] -> Param (LParamInfo lore))
-> [Param (LParamInfo lore)] -> Param (LParamInfo lore)
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases lore) -> [LParam (Aliases lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda (Aliases lore)
lam
  let asArg :: a -> (a, b)
asArg a
t = (a
t, b
forall a. Monoid a => a
mempty)
      inttp :: TypeBase shape u
inttp = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      lamarrs' :: [Type]
lamarrs' = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
chunk)) [Type]
arrargs
  let acc_len :: Int
acc_len = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
  let lamrtp :: [Type]
lamrtp = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
acc_len ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam
  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Stream with inconsistent accumulator type in lambda."
  -- check reduce's lambda, if any
  ()
_ <- case StreamForm (Aliases lore)
form of
    Parallel StreamOrd
_ Commutativity
_ Lambda (Aliases lore)
lam0 -> do
      let acct :: [Type]
acct = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs
          outerRetType :: [Type]
outerRetType = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam0
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam0 ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
accargs
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
acct [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
outerRetType) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
            String
"Initial value is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
acct
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but stream's reduce lambda returns type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
outerRetType
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."
    StreamForm (Aliases lore)
Sequential -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  -- just get the dflow of lambda on the fakearg, which does not alias
  -- arr, so we can later check that aliases of arr are not used inside lam.
  let fake_lamarrs' :: [Arg]
fake_lamarrs' = (Type -> Arg) -> [Type] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg Type
forall {shape} {u}. TypeBase shape u
inttp Arg -> [Arg] -> [Arg]
forall a. a -> [a] -> [a]
: [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w Lambda (Aliases lore)
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
  -- Requirements:
  --
  --   0. @lambdaReturnType@ of @lam@ must be a list
  --      [index types..., value types, ...].
  --
  --   1. The number of index types and value types must be equal to the number
  --      of return values from @lam@.
  --
  --   2. Each index type must have the type i64.
  --
  --   3. Each array in @as@ and the value types must have the same type
  --
  --   4. Each array in @as@ is consumed.  This is not really a check, but more
  --      of a requirement, so that e.g. the source is not hoisted out of a
  --      loop, which will mean it cannot be consumed.
  --
  --   5. Each of ivs must be an array matching a corresponding lambda
  --      parameters.
  --
  -- Code:

  -- First check the input size.
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- 0.
  let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
      indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
      rts :: [Type]
rts = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam
      rtsI :: [Type]
rtsI = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
      rtsV :: [Type]
rtsV = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts

  -- 1.
  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Scatter: number of index types, value types and array outputs do not match."

  -- 2.
  [Type] -> (Type -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI ((Type -> TypeM lore ()) -> TypeM lore ())
-> (Type -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
rtI) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Scatter: Index return type must be i64."

  [([Type], (Shape, Int, VName))]
-> (([Type], (Shape, Int, VName)) -> TypeM lore ())
-> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[Type]]
-> [(Shape, Int, VName)] -> [([Type], (Shape, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(Shape, Int, VName)]
as) ((([Type], (Shape, Int, VName)) -> TypeM lore ()) -> TypeM lore ())
-> (([Type], (Shape, Int, VName)) -> TypeM lore ())
-> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
    -- All lengths must have type i64.
    (SubExp -> TypeM lore ()) -> Shape -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw

    -- 3.
    [Type] -> (Type -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs ((Type -> TypeM lore ()) -> TypeM lore ())
-> (Type -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \Type
rtV -> [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a

    -- 4.
    Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
a

  -- 5.
  [Arg]
arrargs <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
ivs
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
len [HistOp (Aliases lore)]
ops Lambda (Aliases lore)
bucket_fun [VName]
imgs) = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
len

  -- Check the operators.
  [HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ())
-> (HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases lore)
op) -> do
    [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
    [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dest_w
    [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf

    -- Operator type must match the type of neutral elements.
    Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
    let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
          String
"Operator has return type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op)
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t

    -- Arrays must have proper type.
    [(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
      [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
dest_w] VName
dest
      Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest

  -- Types of input arrays must equal parameter types for bucket function.
  [Arg]
img' <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
len [VName]
imgs
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
bucket_fun [Arg]
img'

  -- Return type of bucket function must be an index for each
  -- operation followed by the values to write.
  [Type]
nes_ts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> TypeM lore [[Type]] -> TypeM lore [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp (Aliases lore) -> TypeM lore [Type])
-> [HistOp (Aliases lore)] -> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> TypeM lore Type) -> [SubExp] -> TypeM lore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType ([SubExp] -> TypeM lore [Type])
-> (HistOp (Aliases lore) -> [SubExp])
-> HistOp (Aliases lore)
-> TypeM lore [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases lore) -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp (Aliases lore)]
ops
  let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
bucket_fun) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
        String
"Bucket function has return type "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
bucket_fun)
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Aliases lore)]
scans [Reduce (Aliases lore)]
reds Lambda (Aliases lore)
map_lam)) = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  [Arg]
arrs' <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
map_lam [Arg]
arrs'

  [Arg]
scan_nes' <- ([[Arg]] -> [Arg]) -> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM lore [[Arg]] -> TypeM lore [Arg])
-> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall a b. (a -> b) -> a -> b
$
    [Scan (Aliases lore)]
-> (Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases lore)]
scans ((Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]])
-> (Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases lore)
scan_lam [SubExp]
scan_nes) -> do
      [Arg]
scan_nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
scan_nes
      let scan_t :: [Type]
scan_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
scan_lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
scan_lam) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
            String
"Scan function returns type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
scan_lam)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
scan_t
      [Arg] -> TypeM lore [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
scan_nes'

  [Arg]
red_nes' <- ([[Arg]] -> [Arg]) -> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM lore [[Arg]] -> TypeM lore [Arg])
-> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall a b. (a -> b) -> a -> b
$
    [Reduce (Aliases lore)]
-> (Reduce (Aliases lore) -> TypeM lore [Arg])
-> TypeM lore [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases lore)]
reds ((Reduce (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]])
-> (Reduce (Aliases lore) -> TypeM lore [Arg])
-> TypeM lore [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases lore)
red_lam [SubExp]
red_nes) -> do
      [Arg]
red_nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
red_nes
      let red_t :: [Type]
red_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
red_lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
red_lam) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
            String
"Reduce function returns type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
red_lam)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
red_t
      [Arg] -> TypeM lore [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
red_nes'

  let map_lam_ts :: [Type]
map_lam_ts = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
map_lam

  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    ( Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
        [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
    )
    (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
        String
"Map function return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
map_lam_ts
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" wrong for given scan and reduction functions."

instance OpMetrics (Op lore) => OpMetrics (SOAC lore) where
  opMetrics :: SOAC lore -> MetricsM ()
opMetrics (Stream SubExp
_ [VName]
_ StreamForm lore
_ [SubExp]
_ Lambda lore
lam) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
lam
  opMetrics (Scatter SubExp
_len Lambda lore
lam [VName]
_ivs [(Shape, Int, VName)]
_as) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
lam
  opMetrics (Hist SubExp
_len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
_imgs) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops MetricsM () -> MetricsM () -> MetricsM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
bucket_fun
  opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam)) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (Scan lore -> MetricsM ()) -> [Scan lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (Scan lore -> Lambda lore) -> Scan lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans
      (Reduce lore -> MetricsM ()) -> [Reduce lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (Reduce lore -> Lambda lore) -> Reduce lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds
      Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
map_lam

instance PrettyLore lore => PP.Pretty (SOAC lore) where
  ppr :: SOAC lore -> Doc
ppr (Stream SubExp
size [VName]
arrs StreamForm lore
form [SubExp]
acc Lambda lore
lam) =
    case StreamForm lore
form of
      Parallel StreamOrd
o Commutativity
comm Lambda lore
lam0 ->
        let ord_str :: String
ord_str = if StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
== StreamOrd
Disorder then String
"Per" else String
""
            comm_str :: String
comm_str = case Commutativity
comm of
              Commutativity
Commutative -> String
"Comm"
              Commutativity
Noncommutative -> String
""
         in String -> Doc
text (String
"streamPar" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
ord_str String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
comm_str)
              Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
                ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam0 Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
                )
      StreamForm lore
Sequential ->
        String -> Doc
text String
"streamSeq"
          Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
            ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
            )
  ppr (Scatter SubExp
w Lambda lore
lam [VName]
ivs [(Shape, Int, VName)]
as) =
    Doc
"scatter"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
        ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
            Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
            Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ([VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
ivs Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: ((Shape, Int, VName) -> Doc) -> [(Shape, Int, VName)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, Int, VName) -> Doc
forall a. Pretty a => a -> Doc
ppr [(Shape, Int, VName)]
as)
        )
  ppr (Hist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs) =
    SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
ppHist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs
  ppr (Screma SubExp
w [VName]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam))
    | [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans,
      [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds =
      String -> Doc
text String
"map"
        Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
          ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
          )
    | [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans =
      String -> Doc
text String
"redomap"
        Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
          ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Doc) -> [Reduce lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce lore]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
          )
    | [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds =
      String -> Doc
text String
"scanomap"
        Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
          ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Doc) -> [Scan lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan lore]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
          )
  ppr (Screma SubExp
w [VName]
arrs ScremaForm lore
form) = SubExp -> [VName] -> ScremaForm lore -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [inp] -> ScremaForm lore -> Doc
ppScrema SubExp
w [VName]
arrs ScremaForm lore
form

-- | Prettyprint the given Screma.
ppScrema ::
  (PrettyLore lore, Pretty inp) => SubExp -> [inp] -> ScremaForm lore -> Doc
ppScrema :: forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [inp] -> ScremaForm lore -> Doc
ppScrema SubExp
w [inp]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) =
  String -> Doc
text String
"screma"
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
      ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> [inp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [inp]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Doc) -> [Scan lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan lore]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Doc) -> [Reduce lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce lore]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
      )

instance PrettyLore lore => Pretty (Scan lore) where
  ppr :: Scan lore -> Doc
ppr (Scan Lambda lore
scan_lam [SubExp]
scan_nes) =
    Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
scan_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
scan_nes)

ppComm :: Commutativity -> Doc
ppComm :: Commutativity -> Doc
ppComm Commutativity
Noncommutative = Doc
forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = String -> Doc
text String
"commutative "

instance PrettyLore lore => Pretty (Reduce lore) where
  ppr :: Reduce lore -> Doc
ppr (Reduce Commutativity
comm Lambda lore
red_lam [SubExp]
red_nes) =
    Commutativity -> Doc
ppComm Commutativity
comm Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
red_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
red_nes)

-- | Prettyprint the given histogram operation.
ppHist ::
  (PrettyLore lore, Pretty inp) =>
  SubExp ->
  [HistOp lore] ->
  Lambda lore ->
  [inp] ->
  Doc
ppHist :: forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
ppHist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [inp]
imgs =
  String -> Doc
text String
"hist"
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
      ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
len Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall {lore}. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
bucket_fun Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ((inp -> Doc) -> [inp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc
forall a. Pretty a => a -> Doc
ppr [inp]
imgs)
      )
  where
    ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda lore
op) =
      SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
        Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
        Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op