{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Definition of /Second-Order Array Combinators/ (SOACs), which are
-- the main form of parallelism in the early stages of the compiler.
module Futhark.IR.SOACS.SOAC
  ( SOAC (..),
    ScremaForm (..),
    HistOp (..),
    Scan (..),
    scanResults,
    singleScan,
    Reduce (..),
    redResults,
    singleReduce,

    -- * Utility
    scremaType,
    soacType,
    typeCheckSOAC,
    mkIdentityLambda,
    isIdentityLambda,
    nilFn,
    scanomapSOAC,
    redomapSOAC,
    scanSOAC,
    reduceSOAC,
    mapSOAC,
    isScanomapSOAC,
    isRedomapSOAC,
    isScanSOAC,
    isReduceSOAC,
    isMapSOAC,
    ppScrema,
    ppHist,
    ppStream,
    ppScatter,
    groupScatterResults,
    groupScatterResults',
    splitScatterResults,

    -- * Generic traversal
    SOACMapper (..),
    identitySOACMapper,
    mapSOACM,
    traverseSOACStms,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, CanBeAliased (..))
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth, splitAt3)
import Futhark.Util.Pretty (Doc, align, comma, commasep, docText, parens, ppTuple', pretty, (<+>), (</>))
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

-- | A second-order array combinator (SOAC).
data SOAC rep
  = Stream SubExp [VName] [SubExp] (Lambda rep)
  | -- | @Scatter <length> <inputs> <lambda> <outputs>@
    --
    -- Scatter maps values from a set of input arrays to indices and values of a
    -- set of output arrays. It is able to write multiple values to multiple
    -- outputs each of which may have multiple dimensions.
    --
    -- <inputs> is a list of input arrays, all having size <length>, elements of
    -- which are applied to the <lambda> function. For instance, if there are
    -- two arrays, <lambda> will get two values as input, one from each array.
    --
    -- <outputs> specifies the result of the <lambda> and which arrays to write
    -- to. Each element of the list consists of a <VName> specifying which array
    -- to scatter to, a <Shape> describing the shape of that array, and an <Int>
    -- describing how many elements should be written to that array for each
    -- invocation of the <lambda>.
    --
    -- <lambda> is a function that takes inputs from <inputs> and returns values
    -- according to the output-specification in <outputs>. It returns values in
    -- the following manner:
    --
    --     [index_0, index_1, ..., index_n, value_0, value_1, ..., value_m]
    --
    -- For each output in <outputs>, <lambda> returns <i> * <j> index values and
    -- <j> output values, where <i> is the number of dimensions (rank) of the
    -- given output, and <j> is the number of output values written to the given
    -- output.
    --
    -- For example, given the following output specification:
    --
    --     [([x1, y1, z1], 2, arr1), ([x2, y2], 1, arr2)]
    --
    -- <lambda> will produce 6 (3 * 2) index values and 2 output values for
    -- <arr1>, and 2 (2 * 1) index values and 1 output value for
    -- arr2. Additionally, the results are grouped, so the first 6 index values
    -- will correspond to the first two output values, and so on. For this
    -- example, <lambda> should return a total of 11 values, 8 index values and
    -- 3 output values.  See also 'splitScatterResults'.
    Scatter SubExp [VName] (Lambda rep) [(Shape, Int, VName)]
  | -- | @Hist <length> <input arrays> <dest-arrays-and-ops> <bucket fun>@
    --
    -- The final lambda produces indexes and values for the 'HistOp's.
    Hist SubExp [VName] [HistOp rep] (Lambda rep)
  | -- FIXME: this should not be here
    JVP (Lambda rep) [SubExp] [SubExp]
  | -- FIXME: this should not be here
    VJP (Lambda rep) [SubExp] [SubExp]
  | -- | A combination of scan, reduction, and map.  The first
    -- t'SubExp' is the size of the input arrays.
    Screma SubExp [VName] (ScremaForm rep)
  deriving (SOAC rep -> SOAC rep -> Bool
(SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool) -> Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
/= :: SOAC rep -> SOAC rep -> Bool
Eq, Eq (SOAC rep)
Eq (SOAC rep)
-> (SOAC rep -> SOAC rep -> Ordering)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> Ord (SOAC rep)
SOAC rep -> SOAC rep -> Bool
SOAC rep -> SOAC rep -> Ordering
SOAC rep -> SOAC rep -> SOAC rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
$ccompare :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
compare :: SOAC rep -> SOAC rep -> Ordering
$c< :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
< :: SOAC rep -> SOAC rep -> Bool
$c<= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
<= :: SOAC rep -> SOAC rep -> Bool
$c> :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
> :: SOAC rep -> SOAC rep -> Bool
$c>= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
>= :: SOAC rep -> SOAC rep -> Bool
$cmax :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
max :: SOAC rep -> SOAC rep -> SOAC rep
$cmin :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
min :: SOAC rep -> SOAC rep -> SOAC rep
Ord, Int -> SOAC rep -> ShowS
[SOAC rep] -> ShowS
SOAC rep -> String
(Int -> SOAC rep -> ShowS)
-> (SOAC rep -> String) -> ([SOAC rep] -> ShowS) -> Show (SOAC rep)
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
showsPrec :: Int -> SOAC rep -> ShowS
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
show :: SOAC rep -> String
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
showList :: [SOAC rep] -> ShowS
Show)

-- | Information about computing a single histogram.
data HistOp rep = HistOp
  { forall rep. HistOp rep -> Shape
histShape :: Shape,
    -- | Race factor @RF@ means that only @1/RF@
    -- bins are used.
    forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
    forall rep. HistOp rep -> [VName]
histDest :: [VName],
    forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
/= :: HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep)
-> (HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
compare :: HistOp rep -> HistOp rep -> Ordering
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
>= :: HistOp rep -> HistOp rep -> Bool
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
showsPrec :: Int -> HistOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
show :: HistOp rep -> String
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
showList :: [HistOp rep] -> ShowS
Show)

-- | The essential parts of a 'Screma' factored out (everything
-- except the input arrays).
data ScremaForm rep = ScremaForm
  { forall rep. ScremaForm rep -> [Scan rep]
scremaScans :: [Scan rep],
    forall rep. ScremaForm rep -> [Reduce rep]
scremaReduces :: [Reduce rep],
    -- | The "main" lambda of the Screma. For a map, this is
    -- equivalent to 'isMapSOAC'. Note that the meaning of the return
    -- value of this lambda depends crucially on exactly which Screma
    -- this is. The parameters will correspond exactly to elements of
    -- the input arrays, however.
    forall rep. ScremaForm rep -> Lambda rep
scremaLambda :: Lambda rep
  }
  deriving (ScremaForm rep -> ScremaForm rep -> Bool
(ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
== :: ScremaForm rep -> ScremaForm rep -> Bool
$c/= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
/= :: ScremaForm rep -> ScremaForm rep -> Bool
Eq, Eq (ScremaForm rep)
Eq (ScremaForm rep)
-> (ScremaForm rep -> ScremaForm rep -> Ordering)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> Ord (ScremaForm rep)
ScremaForm rep -> ScremaForm rep -> Bool
ScremaForm rep -> ScremaForm rep -> Ordering
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$ccompare :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
compare :: ScremaForm rep -> ScremaForm rep -> Ordering
$c< :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
< :: ScremaForm rep -> ScremaForm rep -> Bool
$c<= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
<= :: ScremaForm rep -> ScremaForm rep -> Bool
$c> :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
> :: ScremaForm rep -> ScremaForm rep -> Bool
$c>= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
>= :: ScremaForm rep -> ScremaForm rep -> Bool
$cmax :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
max :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmin :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
min :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
Ord, Int -> ScremaForm rep -> ShowS
[ScremaForm rep] -> ShowS
ScremaForm rep -> String
(Int -> ScremaForm rep -> ShowS)
-> (ScremaForm rep -> String)
-> ([ScremaForm rep] -> ShowS)
-> Show (ScremaForm rep)
forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
forall rep. RepTypes rep => ScremaForm rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
showsPrec :: Int -> ScremaForm rep -> ShowS
$cshow :: forall rep. RepTypes rep => ScremaForm rep -> String
show :: ScremaForm rep -> String
$cshowList :: forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
showList :: [ScremaForm rep] -> ShowS
Show)

singleBinOp :: (Buildable rep) => [Lambda rep] -> Lambda rep
singleBinOp :: forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp [Lambda rep]
lams =
  Lambda
    { lambdaParams :: [LParam rep]
lambdaParams = (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
Lambda rep -> [LParam rep]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
xParams [Lambda rep]
lams [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
Lambda rep -> [LParam rep]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
yParams [Lambda rep]
lams,
      lambdaReturnType :: [Type]
lambdaReturnType = (Lambda rep -> [Type]) -> [Lambda rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType [Lambda rep]
lams,
      lambdaBody :: Body rep
lambdaBody =
        Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
          ([Stms rep] -> Stms rep
forall a. Monoid a => [a] -> a
mconcat ((Lambda rep -> Stms rep) -> [Lambda rep] -> [Stms rep]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep)
-> (Lambda rep -> Body rep) -> Lambda rep -> Stms rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams))
          ((Lambda rep -> Result) -> [Lambda rep] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Lambda rep -> Body rep) -> Lambda rep -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams)
    }
  where
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)

-- | How to compute a single scan result.
data Scan rep = Scan
  { forall rep. Scan rep -> Lambda rep
scanLambda :: Lambda rep,
    forall rep. Scan rep -> [SubExp]
scanNeutral :: [SubExp]
  }
  deriving (Scan rep -> Scan rep -> Bool
(Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool) -> Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
== :: Scan rep -> Scan rep -> Bool
$c/= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
/= :: Scan rep -> Scan rep -> Bool
Eq, Eq (Scan rep)
Eq (Scan rep)
-> (Scan rep -> Scan rep -> Ordering)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Scan rep)
-> (Scan rep -> Scan rep -> Scan rep)
-> Ord (Scan rep)
Scan rep -> Scan rep -> Bool
Scan rep -> Scan rep -> Ordering
Scan rep -> Scan rep -> Scan rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
$ccompare :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
compare :: Scan rep -> Scan rep -> Ordering
$c< :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
< :: Scan rep -> Scan rep -> Bool
$c<= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
<= :: Scan rep -> Scan rep -> Bool
$c> :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
> :: Scan rep -> Scan rep -> Bool
$c>= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
>= :: Scan rep -> Scan rep -> Bool
$cmax :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
max :: Scan rep -> Scan rep -> Scan rep
$cmin :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
min :: Scan rep -> Scan rep -> Scan rep
Ord, Int -> Scan rep -> ShowS
[Scan rep] -> ShowS
Scan rep -> String
(Int -> Scan rep -> ShowS)
-> (Scan rep -> String) -> ([Scan rep] -> ShowS) -> Show (Scan rep)
forall rep. RepTypes rep => Int -> Scan rep -> ShowS
forall rep. RepTypes rep => [Scan rep] -> ShowS
forall rep. RepTypes rep => Scan rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> Scan rep -> ShowS
showsPrec :: Int -> Scan rep -> ShowS
$cshow :: forall rep. RepTypes rep => Scan rep -> String
show :: Scan rep -> String
$cshowList :: forall rep. RepTypes rep => [Scan rep] -> ShowS
showList :: [Scan rep] -> ShowS
Show)

-- | What are the sizes of reduction results produced by these 'Scan's?
scanSizes :: [Scan rep] -> [Int]
scanSizes :: forall rep. [Scan rep] -> [Int]
scanSizes = (Scan rep -> Int) -> [Scan rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Scan rep -> [SubExp]) -> Scan rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral)

-- | How many reduction results are produced by these 'Scan's?
scanResults :: [Scan rep] -> Int
scanResults :: forall rep. [Scan rep] -> Int
scanResults = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Scan rep] -> [Int]) -> [Scan rep] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Scan rep] -> [Int]
forall rep. [Scan rep] -> [Int]
scanSizes

-- | Combine multiple scan operators to a single operator.
singleScan :: (Buildable rep) => [Scan rep] -> Scan rep
singleScan :: forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans =
  let scan_nes :: [SubExp]
scan_nes = (Scan rep -> [SubExp]) -> [Scan rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan rep]
scans
      scan_lam :: Lambda rep
scan_lam = [Lambda rep] -> Lambda rep
forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Lambda rep) -> [Scan rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda [Scan rep]
scans
   in Lambda rep -> [SubExp] -> Scan rep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
scan_lam [SubExp]
scan_nes

-- | How to compute a single reduction result.
data Reduce rep = Reduce
  { forall rep. Reduce rep -> Commutativity
redComm :: Commutativity,
    forall rep. Reduce rep -> Lambda rep
redLambda :: Lambda rep,
    forall rep. Reduce rep -> [SubExp]
redNeutral :: [SubExp]
  }
  deriving (Reduce rep -> Reduce rep -> Bool
(Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool) -> Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
== :: Reduce rep -> Reduce rep -> Bool
$c/= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
/= :: Reduce rep -> Reduce rep -> Bool
Eq, Eq (Reduce rep)
Eq (Reduce rep)
-> (Reduce rep -> Reduce rep -> Ordering)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> Ord (Reduce rep)
Reduce rep -> Reduce rep -> Bool
Reduce rep -> Reduce rep -> Ordering
Reduce rep -> Reduce rep -> Reduce rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
$ccompare :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
compare :: Reduce rep -> Reduce rep -> Ordering
$c< :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
< :: Reduce rep -> Reduce rep -> Bool
$c<= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
<= :: Reduce rep -> Reduce rep -> Bool
$c> :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
> :: Reduce rep -> Reduce rep -> Bool
$c>= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
>= :: Reduce rep -> Reduce rep -> Bool
$cmax :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
max :: Reduce rep -> Reduce rep -> Reduce rep
$cmin :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
min :: Reduce rep -> Reduce rep -> Reduce rep
Ord, Int -> Reduce rep -> ShowS
[Reduce rep] -> ShowS
Reduce rep -> String
(Int -> Reduce rep -> ShowS)
-> (Reduce rep -> String)
-> ([Reduce rep] -> ShowS)
-> Show (Reduce rep)
forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
forall rep. RepTypes rep => [Reduce rep] -> ShowS
forall rep. RepTypes rep => Reduce rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
showsPrec :: Int -> Reduce rep -> ShowS
$cshow :: forall rep. RepTypes rep => Reduce rep -> String
show :: Reduce rep -> String
$cshowList :: forall rep. RepTypes rep => [Reduce rep] -> ShowS
showList :: [Reduce rep] -> ShowS
Show)

-- | What are the sizes of reduction results produced by these 'Reduce's?
redSizes :: [Reduce rep] -> [Int]
redSizes :: forall rep. [Reduce rep] -> [Int]
redSizes = (Reduce rep -> Int) -> [Reduce rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Reduce rep -> [SubExp]) -> Reduce rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral)

-- | How many reduction results are produced by these 'Reduce's?
redResults :: [Reduce rep] -> Int
redResults :: forall rep. [Reduce rep] -> Int
redResults = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Reduce rep] -> [Int]) -> [Reduce rep] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Reduce rep] -> [Int]
forall rep. [Reduce rep] -> [Int]
redSizes

-- | Combine multiple reduction operators to a single operator.
singleReduce :: (Buildable rep) => [Reduce rep] -> Reduce rep
singleReduce :: forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds =
  let red_nes :: [SubExp]
red_nes = (Reduce rep -> [SubExp]) -> [Reduce rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
      red_lam :: Lambda rep
red_lam = [Lambda rep] -> Lambda rep
forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Lambda rep) -> [Reduce rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda [Reduce rep]
reds
   in Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce ([Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((Reduce rep -> Commutativity) -> [Reduce rep] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Commutativity
forall rep. Reduce rep -> Commutativity
redComm [Reduce rep]
reds)) Lambda rep
red_lam [SubExp]
red_nes

-- | The types produced by a single 'Screma', given the size of the
-- input array.
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType :: forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  [Type]
scan_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
red_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
  where
    scan_tps :: [Type]
scan_tps =
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
        (Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
    red_tps :: [Type]
red_tps = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
    map_tps :: [Type]
map_tps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam

-- | Construct a lambda that takes parameters of the given types and
-- simply returns them unchanged.
mkIdentityLambda ::
  (Buildable rep, MonadFreshNames m) =>
  [Type] ->
  m (Lambda rep)
mkIdentityLambda :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts = do
  [Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
  Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda
      { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
[LParam rep]
params,
        lambdaBody :: Body rep
lambdaBody = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params,
        lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
      }

-- | Is the given lambda an identity lambda?
isIdentityLambda :: Lambda rep -> Bool
isIdentityLambda :: forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
lam =
  (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
    [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param (LParamInfo rep) -> SubExp)
-> [Param (LParamInfo rep)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (LParamInfo rep) -> VName)
-> Param (LParamInfo rep)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)

-- | A lambda with no parameters that returns no values.
nilFn :: (Buildable rep) => Lambda rep
nilFn :: forall rep. Buildable rep => Lambda rep
nilFn = [LParam rep] -> [Type] -> Body rep -> Lambda rep
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Param Type]
[LParam rep]
forall a. Monoid a => a
mempty [Type]
forall a. Monoid a => a
mempty (Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty Result
forall a. Monoid a => a
mempty)

-- | Construct a Screma with possibly multiple scans, and
-- the given map function.
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC :: forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scans []

-- | Construct a Screma with possibly multiple reductions, and
-- the given map function.
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC :: forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm []

-- | Construct a Screma with possibly multiple scans, and identity map
-- function.
scanSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Scan rep] ->
  m (ScremaForm rep)
scanSOAC :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan rep]
scans = [Scan rep] -> Lambda rep -> ScremaForm rep
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = (Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans

-- | Construct a Screma with possibly multiple reductions, and
-- identity map function.
reduceSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Reduce rep] ->
  m (ScremaForm rep)
reduceSOAC :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds = [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Reduce rep]
reds (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds

-- | Construct a Screma corresponding to a map.
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC :: forall rep. Lambda rep -> ScremaForm rep
mapSOAC = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] []

-- | Does this Screma correspond to a scan-map composition?
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  ([Scan rep], Lambda rep) -> Maybe ([Scan rep], Lambda rep)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Scan rep]
scans, Lambda rep
map_lam)

-- | Does this Screma correspond to pure scan?
isScanSOAC :: ScremaForm rep -> Maybe [Scan rep]
isScanSOAC :: forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm rep
form = do
  ([Scan rep]
scans, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm rep
form
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  [Scan rep] -> Maybe [Scan rep]
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Scan rep]
scans

-- | Does this Screma correspond to a reduce-map composition?
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  ([Reduce rep], Lambda rep) -> Maybe ([Reduce rep], Lambda rep)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Reduce rep]
reds, Lambda rep
map_lam)

-- | Does this Screma correspond to a pure reduce?
isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC :: forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm rep
form = do
  ([Reduce rep]
reds, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  [Reduce rep] -> Maybe [Reduce rep]
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Reduce rep]
reds

-- | Does this Screma correspond to a simple map, without any
-- reduction or scan results?
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC :: forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  Lambda rep -> Maybe (Lambda rep)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
map_lam

-- | @groupScatterResults <output specification> <results>@
--
-- Blocks the index values and result values of <results> according to the
-- <output specification>.
--
-- This function is used for extracting and grouping the results of a
-- scatter. In the SOAC representation, the lambda inside a 'Scatter' returns
-- all indices and values as one big list. This function groups each value with
-- its corresponding indices (as determined by the t'Shape' of the output array).
--
-- The elements of the resulting list correspond to the shape and name of the
-- output parameters, in addition to a list of values written to that output
-- parameter, along with the array indices marking where to write them to.
--
-- See 'Scatter' for more information.
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
arrays) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
   in [(Shape, Int, array)] -> [a] -> [([a], a)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results
        [([a], a)] -> ([([a], a)] -> [[([a], a)]]) -> [[([a], a)]]
forall a b. a -> (a -> b) -> b
& [Int] -> [([a], a)] -> [[([a], a)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
        [[([a], a)]]
-> ([[([a], a)]] -> [(Shape, array, [([a], a)])])
-> [(Shape, array, [([a], a)])]
forall a b. a -> (a -> b) -> b
& [Shape] -> [array] -> [[([a], a)]] -> [(Shape, array, [([a], a)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays

-- | @groupScatterResults' <output specification> <results>@
--
-- Blocks the index values and result values of <results> according to the
-- output specification. This is the simpler version of @groupScatterResults@,
-- which doesn't return any information about shapes or output arrays.
--
-- See 'groupScatterResults' for more information,
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
  let ([a]
indices, [a]
values) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
      ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      chunk_sizes :: [Int]
chunk_sizes =
        [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int -> [Int]) -> [Shape] -> [Int] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
n (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
   in [[a]] -> [a] -> [([a], a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values

-- | @splitScatterResults <output specification> <results>@
--
-- Splits the results array into indices and values according to the output
-- specification.
--
-- See 'groupScatterResults' for more information.
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      num_indices :: Int
num_indices = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
   in Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results

-- | Like 'Mapper', but just for 'SOAC's.
data SOACMapper frep trep m = SOACMapper
  { forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
    forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda :: Lambda frep -> m (Lambda trep),
    forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
  }

-- | A mapper that simply returns the SOAC verbatim.
identitySOACMapper :: forall rep m. (Monad m) => SOACMapper rep rep m
identitySOACMapper :: forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper =
  SOACMapper
    { mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSOACVName :: VName -> m VName
mapOnSOACVName = VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }

-- | Map a monadic action across the immediate children of a
-- SOAC.  The mapping does not descend recursively into subexpressions
-- and is done left-to-right.
mapSOACM ::
  (Monad m) =>
  SOACMapper frep trep m ->
  SOAC frep ->
  m (SOAC trep)
mapSOACM :: forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper frep trep m
tv (JVP Lambda frep
lam [SubExp]
args [SubExp]
vec) =
  Lambda trep -> [SubExp] -> [SubExp] -> SOAC trep
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP
    (Lambda trep -> [SubExp] -> [SubExp] -> SOAC trep)
-> m (Lambda trep) -> m ([SubExp] -> [SubExp] -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    m ([SubExp] -> [SubExp] -> SOAC trep)
-> m [SubExp] -> m ([SubExp] -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
    m ([SubExp] -> SOAC trep) -> m [SubExp] -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
mapSOACM SOACMapper frep trep m
tv (VJP Lambda frep
lam [SubExp]
args [SubExp]
vec) =
  Lambda trep -> [SubExp] -> [SubExp] -> SOAC trep
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP
    (Lambda trep -> [SubExp] -> [SubExp] -> SOAC trep)
-> m (Lambda trep) -> m ([SubExp] -> [SubExp] -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    m ([SubExp] -> [SubExp] -> SOAC trep)
-> m [SubExp] -> m ([SubExp] -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
    m ([SubExp] -> SOAC trep) -> m [SubExp] -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
mapSOACM SOACMapper frep trep m
tv (Stream SubExp
size [VName]
arrs [SubExp]
accs Lambda frep
lam) =
  SubExp -> [VName] -> [SubExp] -> Lambda trep -> SOAC trep
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream
    (SubExp -> [VName] -> [SubExp] -> Lambda trep -> SOAC trep)
-> m SubExp -> m ([VName] -> [SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
size
    m ([VName] -> [SubExp] -> Lambda trep -> SOAC trep)
-> m [VName] -> m ([SubExp] -> Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    m ([SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m (Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
accs
    m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
mapSOACM SOACMapper frep trep m
tv (Scatter SubExp
w [VName]
ivs Lambda frep
lam [(Shape, Int, VName)]
as) =
  SubExp
-> [VName] -> Lambda trep -> [(Shape, Int, VName)] -> SOAC trep
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter
    (SubExp
 -> [VName] -> Lambda trep -> [(Shape, Int, VName)] -> SOAC trep)
-> m SubExp
-> m ([VName] -> Lambda trep -> [(Shape, Int, VName)] -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    m ([VName] -> Lambda trep -> [(Shape, Int, VName)] -> SOAC trep)
-> m [VName]
-> m (Lambda trep -> [(Shape, Int, VName)] -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
ivs
    m (Lambda trep -> [(Shape, Int, VName)] -> SOAC trep)
-> m (Lambda trep) -> m ([(Shape, Int, VName)] -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    m ([(Shape, Int, VName)] -> SOAC trep)
-> m [(Shape, Int, VName)] -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Shape, Int, VName) -> m (Shape, Int, VName))
-> [(Shape, Int, VName)] -> m [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
      ( \(Shape
aw, Int
an, VName
a) ->
          (,,)
            (Shape -> Int -> VName -> (Shape, Int, VName))
-> m Shape -> m (Int -> VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ShapeBase a -> m (ShapeBase b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
aw
            m (Int -> VName -> (Shape, Int, VName))
-> m Int -> m (VName -> (Shape, Int, VName))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
            m (VName -> (Shape, Int, VName))
-> m VName -> m (Shape, Int, VName)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv VName
a
      )
      [(Shape, Int, VName)]
as
mapSOACM SOACMapper frep trep m
tv (Hist SubExp
w [VName]
arrs [HistOp frep]
ops Lambda frep
bucket_fun) =
  SubExp -> [VName] -> [HistOp trep] -> Lambda trep -> SOAC trep
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
    (SubExp -> [VName] -> [HistOp trep] -> Lambda trep -> SOAC trep)
-> m SubExp
-> m ([VName] -> [HistOp trep] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    m ([VName] -> [HistOp trep] -> Lambda trep -> SOAC trep)
-> m [VName] -> m ([HistOp trep] -> Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    m ([HistOp trep] -> Lambda trep -> SOAC trep)
-> m [HistOp trep] -> m (Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
      ( \(HistOp Shape
shape SubExp
rf [VName]
op_arrs [SubExp]
nes Lambda frep
op) ->
          Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp
            (Shape
 -> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m Shape
-> m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ShapeBase a -> m (ShapeBase b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
shape
            m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m SubExp
-> m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
rf
            m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m [VName] -> m ([SubExp] -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
op_arrs
            m ([SubExp] -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
nes
            m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
op
      )
      [HistOp frep]
ops
    m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
bucket_fun
mapSOACM SOACMapper frep trep m
tv (Screma SubExp
w [VName]
arrs (ScremaForm [Scan frep]
scans [Reduce frep]
reds Lambda frep
map_lam)) =
  SubExp -> [VName] -> ScremaForm trep -> SOAC trep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
    (SubExp -> [VName] -> ScremaForm trep -> SOAC trep)
-> m SubExp -> m ([VName] -> ScremaForm trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    m ([VName] -> ScremaForm trep -> SOAC trep)
-> m [VName] -> m (ScremaForm trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    m (ScremaForm trep -> SOAC trep)
-> m (ScremaForm trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( [Scan trep] -> [Reduce trep] -> Lambda trep -> ScremaForm trep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
            ([Scan trep] -> [Reduce trep] -> Lambda trep -> ScremaForm trep)
-> m [Scan trep]
-> m ([Reduce trep] -> Lambda trep -> ScremaForm trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Scan frep] -> (Scan frep -> m (Scan trep)) -> m [Scan trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Scan frep]
scans
              ( \(Scan Lambda frep
red_lam [SubExp]
red_nes) ->
                  Lambda trep -> [SubExp] -> Scan trep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan
                    (Lambda trep -> [SubExp] -> Scan trep)
-> m (Lambda trep) -> m ([SubExp] -> Scan trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    m ([SubExp] -> Scan trep) -> m [SubExp] -> m (Scan trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            m ([Reduce trep] -> Lambda trep -> ScremaForm trep)
-> m [Reduce trep] -> m (Lambda trep -> ScremaForm trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Reduce frep]
-> (Reduce frep -> m (Reduce trep)) -> m [Reduce trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Reduce frep]
reds
              ( \(Reduce Commutativity
comm Lambda frep
red_lam [SubExp]
red_nes) ->
                  Commutativity -> Lambda trep -> [SubExp] -> Reduce trep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm
                    (Lambda trep -> [SubExp] -> Reduce trep)
-> m (Lambda trep) -> m ([SubExp] -> Reduce trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    m ([SubExp] -> Reduce trep) -> m [SubExp] -> m (Reduce trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            m (Lambda trep -> ScremaForm trep)
-> m (Lambda trep) -> m (ScremaForm trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
map_lam
        )

-- | A helper for defining 'TraverseOpStms'.
traverseSOACStms :: (Monad m) => OpStmsTraverser m (SOAC rep) rep
traverseSOACStms :: forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms Scope rep -> Stms rep -> m (Stms rep)
f = SOACMapper rep rep m -> SOAC rep -> m (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep m
mapper
  where
    mapper :: SOACMapper rep rep m
mapper = SOACMapper Any Any m
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = OpStmsTraverser m (Lambda rep) rep
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f}

instance (ASTRep rep) => FreeIn (Scan rep) where
  freeIn' :: Scan rep -> FV
freeIn' (Scan Lambda rep
lam [SubExp]
ne) = Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne

instance (ASTRep rep) => FreeIn (Reduce rep) where
  freeIn' :: Reduce rep -> FV
freeIn' (Reduce Commutativity
_ Lambda rep
lam [SubExp]
ne) = Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne

instance (ASTRep rep) => FreeIn (ScremaForm rep) where
  freeIn' :: ScremaForm rep -> FV
freeIn' (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
lam) =
    [Scan rep] -> FV
forall a. FreeIn a => a -> FV
freeIn' [Scan rep]
scans FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [Reduce rep] -> FV
forall a. FreeIn a => a -> FV
freeIn' [Reduce rep]
reds FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam

instance (ASTRep rep) => FreeIn (HistOp rep) where
  freeIn' :: HistOp rep -> FV
freeIn' (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
lam) =
    Shape -> FV
forall a. FreeIn a => a -> FV
freeIn' Shape
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
rf FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [VName] -> FV
forall a. FreeIn a => a -> FV
freeIn' [VName]
dests FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
nes FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam

instance (ASTRep rep) => FreeIn (SOAC rep) where
  freeIn' :: SOAC rep -> FV
freeIn' = (State FV (SOAC rep) -> FV -> FV)
-> FV -> State FV (SOAC rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SOAC rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SOAC rep) -> FV)
-> (SOAC rep -> State FV (SOAC rep)) -> SOAC rep -> FV
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep (StateT FV Identity)
-> SOAC rep -> State FV (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep (StateT FV Identity)
free
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
      free :: SOACMapper rep rep (StateT FV Identity)
free =
        SOACMapper
          { mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSOACLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
          }

instance (ASTRep rep) => Substitute (SOAC rep) where
  substituteNames :: Map VName VName -> SOAC rep -> SOAC rep
substituteNames Map VName VName
subst =
    Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC rep -> Identity (SOAC rep)) -> SOAC rep -> SOAC rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep Identity -> SOAC rep -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep Identity
substitute
    where
      substitute :: SOACMapper rep rep Identity
substitute =
        SOACMapper
          { mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = Lambda rep -> Identity (Lambda rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance (ASTRep rep) => Rename (SOAC rep) where
  rename :: SOAC rep -> RenameM (SOAC rep)
rename = SOACMapper rep rep RenameM -> SOAC rep -> RenameM (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep RenameM
renamer
    where
      renamer :: SOACMapper rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (VName -> RenameM VName)
-> SOACMapper rep rep RenameM
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename

-- | The type of a SOAC.
soacType :: (Typed (LParamInfo rep)) => SOAC rep -> [Type]
soacType :: forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType (JVP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
  Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
soacType (VJP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
  Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Param (LParamInfo rep) -> Type)
-> [Param (LParamInfo rep)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
soacType (Stream SubExp
outersize [VName]
_ [SubExp]
accs Lambda rep
lam) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
  where
    nms :: [VName]
nms = (Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [Param (LParamInfo rep)]
params
    substs :: Map VName SubExp
substs = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
accs)
    Lambda [Param (LParamInfo rep)]
params [Type]
rtp Body rep
_ = Lambda rep
lam
soacType (Scatter SubExp
_w [VName]
_ivs Lambda rep
lam [(Shape, Int, VName)]
dests) =
  (Type -> Shape -> Type) -> [Type] -> [Shape] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape (([([Type], Type)] -> Type) -> [[([Type], Type)]] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (([Type], Type) -> Type
forall a b. (a, b) -> b
snd (([Type], Type) -> Type)
-> ([([Type], Type)] -> ([Type], Type)) -> [([Type], Type)] -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [([Type], Type)] -> ([Type], Type)
forall a. HasCallStack => [a] -> a
head) [[([Type], Type)]]
rets) [Shape]
shapes
  where
    ([Shape]
shapes, [VName]
_, [[([Type], Type)]]
rets) =
      [(Shape, VName, [([Type], Type)])]
-> ([Shape], [VName], [[([Type], Type)]])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Shape, VName, [([Type], Type)])]
 -> ([Shape], [VName], [[([Type], Type)]]))
-> [(Shape, VName, [([Type], Type)])]
-> ([Shape], [VName], [[([Type], Type)]])
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)]
-> [Type] -> [(Shape, VName, [([Type], Type)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests ([Type] -> [(Shape, VName, [([Type], Type)])])
-> [Type] -> [(Shape, VName, [([Type], Type)])]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
soacType (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_bucket_fun) = do
  HistOp rep
op <- [HistOp rep]
ops
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp rep
op) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm rep
form) =
  SubExp -> ScremaForm rep -> [Type]
forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form

instance TypedOp SOAC where
  opType :: forall rep (m :: * -> *). HasScope rep m => SOAC rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SOAC rep -> [ExtType]) -> SOAC rep -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SOAC rep -> [Type]) -> SOAC rep -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType

instance AliasedOp SOAC where
  opAliases :: forall rep. Aliased rep => SOAC rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names]) -> (SOAC rep -> [Type]) -> SOAC rep -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType

  consumedInOp :: forall rep. Aliased rep => SOAC rep -> Names
consumedInOp JVP {} = Names
forall a. Monoid a => a
mempty
  consumedInOp VJP {} = Names
forall a. Monoid a => a
mempty
  -- Only map functions can consume anything.  The operands to scan
  -- and reduce functions are always considered "fresh".
  consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
map_lam)) =
    (VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
map_lam
    where
      consumedArray :: VName -> VName
consumedArray VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
      params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
map_lam) [VName]
arrs
  consumedInOp (Stream SubExp
_ [VName]
arrs [SubExp]
accs Lambda rep
lam) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
    where
      consumedArray :: VName -> SubExp
consumedArray VName
v = SubExp -> Maybe SubExp -> SubExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, SubExp)] -> Maybe SubExp
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
      -- Drop the chunk parameter, which cannot alias anything.
      paramsToInput :: [(VName, SubExp)]
paramsToInput =
        [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
1 ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam) ([SubExp]
accs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
  consumedInOp (Scatter SubExp
_ [VName]
_ Lambda rep
_ [(Shape, Int, VName)]
as) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) [(Shape, Int, VName)]
as
  consumedInOp (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops

mapHistOp ::
  (Lambda frep -> Lambda trep) ->
  HistOp frep ->
  HistOp trep
mapHistOp :: forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp Lambda frep -> Lambda trep
f (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda frep
lam) =
  Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes (Lambda trep -> HistOp trep) -> Lambda trep -> HistOp trep
forall a b. (a -> b) -> a -> b
$ Lambda frep -> Lambda trep
f Lambda frep
lam

instance CanBeAliased SOAC where
  addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SOAC rep -> SOAC (Aliases rep)
addOpAliases AliasTable
aliases (JVP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Lambda (Aliases rep) -> [SubExp] -> [SubExp] -> SOAC (Aliases rep)
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [SubExp]
args [SubExp]
vec
  addOpAliases AliasTable
aliases (VJP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Lambda (Aliases rep) -> [SubExp] -> [SubExp] -> SOAC (Aliases rep)
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [SubExp]
args [SubExp]
vec
  addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr [SubExp]
accs Lambda rep
lam) =
    SubExp
-> [VName]
-> [SubExp]
-> Lambda (Aliases rep)
-> SOAC (Aliases rep)
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
arr [SubExp]
accs (Lambda (Aliases rep) -> SOAC (Aliases rep))
-> Lambda (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$ AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam
  addOpAliases AliasTable
aliases (Scatter SubExp
len [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests) =
    SubExp
-> [VName]
-> Lambda (Aliases rep)
-> [(Shape, Int, VName)]
-> SOAC (Aliases rep)
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len [VName]
arrs (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [(Shape, Int, VName)]
dests
  addOpAliases AliasTable
aliases (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
    SubExp
-> [VName]
-> [HistOp (Aliases rep)]
-> Lambda (Aliases rep)
-> SOAC (Aliases rep)
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
      SubExp
w
      [VName]
arrs
      ((HistOp rep -> HistOp (Aliases rep))
-> [HistOp rep] -> [HistOp (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map ((Lambda rep -> Lambda (Aliases rep))
-> HistOp rep -> HistOp (Aliases rep)
forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)) [HistOp rep]
ops)
      (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
bucket_fun)
  addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    SubExp -> [VName] -> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Aliases rep) -> SOAC (Aliases rep))
-> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$
      [Scan (Aliases rep)]
-> [Reduce (Aliases rep)]
-> Lambda (Aliases rep)
-> ScremaForm (Aliases rep)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
        ((Scan rep -> Scan (Aliases rep))
-> [Scan rep] -> [Scan (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Scan (Aliases rep)
onScan [Scan rep]
scans)
        ((Reduce rep -> Reduce (Aliases rep))
-> [Reduce rep] -> [Reduce (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Reduce (Aliases rep)
onRed [Reduce rep]
reds)
        (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
map_lam)
    where
      onRed :: Reduce rep -> Reduce (Aliases rep)
onRed Reduce rep
red = Reduce rep
red {redLambda :: Lambda (Aliases rep)
redLambda = AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep -> Lambda (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda Reduce rep
red}
      onScan :: Scan rep -> Scan (Aliases rep)
onScan Scan rep
scan = Scan rep
scan {scanLambda :: Lambda (Aliases rep)
scanLambda = AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep -> Lambda (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda Scan rep
scan}

instance IsOp SOAC where
  safeOp :: forall rep. ASTRep rep => SOAC rep -> Bool
safeOp SOAC rep
_ = Bool
False
  cheapOp :: forall rep. ASTRep rep => SOAC rep -> Bool
cheapOp SOAC rep
_ = Bool
False
  opDependencies :: forall rep. ASTRep rep => SOAC rep -> [Names]
opDependencies (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda rep
lam) =
    let accs_deps :: [Names]
accs_deps = (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
accs
        arrs_deps :: [Names]
arrs_deps = SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs
     in AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam ([Names]
arrs_deps [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> [Names]
accs_deps)
  opDependencies (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
lam) =
    let bucket_fun_deps' :: [Names]
bucket_fun_deps' = AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam (SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs)
        -- Bucket function results are indices followed by values.
        -- Reshape this to align with list of histogram operations.
        ranks :: [Int]
ranks = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp rep -> Shape) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp rep]
ops
        value_lengths :: [Int]
value_lengths = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (HistOp rep -> [SubExp]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp rep]
ops
        ([Names]
indices, [Names]
values) = Int -> [Names] -> ([Names], [Names])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [Names]
bucket_fun_deps'
        bucket_fun_deps :: [[Names]]
bucket_fun_deps =
          ([Names] -> [Names] -> [Names])
-> [[Names]] -> [[Names]] -> [[Names]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            [Names] -> [Names] -> [Names]
forall {b}. Monoid b => [b] -> [b] -> [b]
concatIndicesToEachValue
            ([Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [Names]
indices)
            ([Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
value_lengths [Names]
values)
     in [[Names]] -> [Names]
forall a. Monoid a => [a] -> a
mconcat ([[Names]] -> [Names]) -> [[Names]] -> [Names]
forall a b. (a -> b) -> a -> b
$ ([Names] -> [Names] -> [Names])
-> [[Names]] -> [[Names]] -> [[Names]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>)) [[Names]]
bucket_fun_deps ((HistOp rep -> [Names]) -> [HistOp rep] -> [[Names]]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> [Names]
forall {rep}. ASTRep rep => HistOp rep -> [Names]
depsOfHistOp [HistOp rep]
ops)
    where
      depsOfHistOp :: HistOp rep -> [Names]
depsOfHistOp (HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
        let shape_deps :: Names
shape_deps = Shape -> Names
depsOfShape Shape
dest_shape
            in_deps :: [Names]
in_deps = (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
vn -> VName -> Names
oneName VName
vn Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
shape_deps Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> SubExp -> Names
depsOf' SubExp
rf) [VName]
dests
         in AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
op [SubExp]
nes [Names]
in_deps
      -- A histogram operation may use the same index for multiple values.
      concatIndicesToEachValue :: [b] -> [b] -> [b]
concatIndicesToEachValue [b]
is [b]
vs =
        let is_flat :: b
is_flat = [b] -> b
forall a. Monoid a => [a] -> a
mconcat [b]
is
         in (b -> b) -> [b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (b
is_flat <>) [b]
vs
  opDependencies (Scatter SubExp
w [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
outputs) =
    let deps :: [Names]
deps = AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam (SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs)
     in ((Shape, VName, [([Names], Names)]) -> Names)
-> [(Shape, VName, [([Names], Names)])] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, VName, [([Names], Names)]) -> Names
forall {a}. (a, VName, [([Names], Names)]) -> Names
flattenBlocks ([(Shape, Int, VName)]
-> [Names] -> [(Shape, VName, [([Names], Names)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
outputs [Names]
deps)
    where
      flattenBlocks :: (a, VName, [([Names], Names)]) -> Names
flattenBlocks (a
_, VName
arr, [([Names], Names)]
ivs) =
        VName -> Names
oneName VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((([Names], Names) -> Names) -> [([Names], Names)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names)
-> (([Names], Names) -> [Names]) -> ([Names], Names) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ([Names], Names) -> [Names]
forall a b. (a, b) -> a
fst) [([Names], Names)]
ivs) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((([Names], Names) -> Names) -> [([Names], Names)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([Names], Names) -> Names
forall a b. (a, b) -> b
snd [([Names], Names)]
ivs)
  opDependencies (JVP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    [[Names]] -> [Names]
forall a. Monoid a => [a] -> a
mconcat ([[Names]] -> [Names]) -> [[Names]] -> [Names]
forall a b. (a -> b) -> a -> b
$
      Int -> [Names] -> [[Names]]
forall a. Int -> a -> [a]
replicate Int
2 ([Names] -> [[Names]]) -> [Names] -> [[Names]]
forall a b. (a -> b) -> a -> b
$
        AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$
          (Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
args) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
vec)
  opDependencies (VJP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies
      AliasTable
forall a. Monoid a => a
mempty
      Lambda rep
lam
      ((Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
args) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
vec))
      [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> (Param (LParamInfo rep) -> Names)
-> [Param (LParamInfo rep)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Param (LParamInfo rep) -> Names
forall a b. a -> b -> a
const (Names -> Param (LParamInfo rep) -> Names)
-> Names -> Param (LParamInfo rep) -> Names
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
args Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
  opDependencies (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    let ([Names]
scans_in, [Names]
reds_in, [Names]
map_deps) =
          Int -> Int -> [Names] -> ([Names], [Names], [Names])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([Scan rep] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan rep]
scans) ([Reduce rep] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce rep]
reds) ([Names] -> ([Names], [Names], [Names]))
-> [Names] -> ([Names], [Names], [Names])
forall a b. (a -> b) -> a -> b
$
            AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
map_lam (SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs)
        scans_deps :: [Names]
scans_deps =
          ((Scan rep, [Names]) -> [Names])
-> [(Scan rep, [Names])] -> [Names]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Scan rep, [Names]) -> [Names]
forall {rep}. ASTRep rep => (Scan rep, [Names]) -> [Names]
depsOfScan ([Scan rep] -> [[Names]] -> [(Scan rep, [Names])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Scan rep]
scans ([[Names]] -> [(Scan rep, [Names])])
-> [[Names]] -> [(Scan rep, [Names])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Scan rep] -> [Int]
forall rep. [Scan rep] -> [Int]
scanSizes [Scan rep]
scans) [Names]
scans_in)
        reds_deps :: [Names]
reds_deps =
          ((Reduce rep, [Names]) -> [Names])
-> [(Reduce rep, [Names])] -> [Names]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Reduce rep, [Names]) -> [Names]
forall {rep}. ASTRep rep => (Reduce rep, [Names]) -> [Names]
depsOfRed ([Reduce rep] -> [[Names]] -> [(Reduce rep, [Names])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Reduce rep]
reds ([[Names]] -> [(Reduce rep, [Names])])
-> [[Names]] -> [(Reduce rep, [Names])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Reduce rep] -> [Int]
forall rep. [Reduce rep] -> [Int]
redSizes [Reduce rep]
reds) [Names]
reds_in)
     in [Names]
scans_deps [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> [Names]
reds_deps [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> [Names]
map_deps
    where
      depsOfScan :: (Scan rep, [Names]) -> [Names]
depsOfScan (Scan Lambda rep
lam [SubExp]
nes, [Names]
deps_in) =
        AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam [SubExp]
nes [Names]
deps_in
      depsOfRed :: (Reduce rep, [Names]) -> [Names]
depsOfRed (Reduce Commutativity
_ Lambda rep
lam [SubExp]
nes, [Names]
deps_in) =
        AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam [SubExp]
nes [Names]
deps_in

substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
  let shp' :: Shape
shp' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
   in PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u

substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
  SubExp -> VName -> Map VName SubExp -> SubExp
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs

instance CanBeWise SOAC where
  addOpWisdom :: forall rep. Informing rep => SOAC rep -> SOAC (Wise rep)
addOpWisdom = Identity (SOAC (Wise rep)) -> SOAC (Wise rep)
forall a. Identity a -> a
runIdentity (Identity (SOAC (Wise rep)) -> SOAC (Wise rep))
-> (SOAC rep -> Identity (SOAC (Wise rep)))
-> SOAC rep
-> SOAC (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep (Wise rep) Identity
-> SOAC rep -> Identity (SOAC (Wise rep))
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM ((SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Wise rep)))
-> (VName -> Identity VName)
-> SOACMapper rep (Wise rep) Identity
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep)))
-> (Lambda rep -> Lambda (Wise rep))
-> Lambda rep
-> Identity (Lambda (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Lambda (Wise rep)
forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda) VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)

instance (RepTypes rep) => ST.IndexOp (SOAC rep) where
  indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SOAC rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k SOAC rep
soac [TPrimExp Int64 VName
i] = do
    (Lambda rep
lam, SubExpRes
se, [Param (LParamInfo rep)]
arr_params, [VName]
arrs) <- SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp SOAC rep
soac
    let arr_indexes :: Map VName (PrimExp VName, Certs)
arr_indexes = [(VName, (PrimExp VName, Certs))]
-> Map VName (PrimExp VName, Certs)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (PrimExp VName, Certs))]
 -> Map VName (PrimExp VName, Certs))
-> [(VName, (PrimExp VName, Certs))]
-> Map VName (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, (PrimExp VName, Certs))]
-> [(VName, (PrimExp VName, Certs))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, (PrimExp VName, Certs))]
 -> [(VName, (PrimExp VName, Certs))])
-> [Maybe (VName, (PrimExp VName, Certs))]
-> [(VName, (PrimExp VName, Certs))]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep)
 -> VName -> Maybe (VName, (PrimExp VName, Certs)))
-> [Param (LParamInfo rep)]
-> [VName]
-> [Maybe (VName, (PrimExp VName, Certs))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex [Param (LParamInfo rep)]
arr_params [VName]
arrs
        arr_indexes' :: Map VName (PrimExp VName, Certs)
arr_indexes' = (Map VName (PrimExp VName, Certs)
 -> Stm rep -> Map VName (PrimExp VName, Certs))
-> Map VName (PrimExp VName, Certs)
-> Seq (Stm rep)
-> Map VName (PrimExp VName, Certs)
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
arr_indexes (Seq (Stm rep) -> Map VName (PrimExp VName, Certs))
-> Seq (Stm rep) -> Map VName (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Seq (Stm rep)) -> Body rep -> Seq (Stm rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    case SubExpRes
se of
      SubExpRes Certs
_ (Var VName
v) -> (PrimExp VName -> Certs -> Indexed)
-> (PrimExp VName, Certs) -> Indexed
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Certs -> PrimExp VName -> Indexed)
-> PrimExp VName -> Certs -> Indexed
forall a b c. (a -> b -> c) -> b -> a -> c
flip Certs -> PrimExp VName -> Indexed
ST.Indexed) ((PrimExp VName, Certs) -> Indexed)
-> Maybe (PrimExp VName, Certs) -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (PrimExp VName, Certs) -> Maybe (PrimExp VName, Certs)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
arr_indexes'
      SubExpRes
_ -> Maybe Indexed
forall a. Maybe a
Nothing
    where
      lambdaAndSubExp :: SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
        Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut ([Scan rep] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan rep]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce rep] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce rep]
reds) Lambda rep
map_lam [VName]
arrs
      lambdaAndSubExp SOAC rep
_ =
        Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
forall a. Maybe a
Nothing

      nthMapOut :: Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut Int
num_accs Lambda rep
lam [VName]
arrs = do
        SubExpRes
se <- Int -> Result -> Maybe SubExpRes
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) (Result -> Maybe SubExpRes) -> Result -> Maybe SubExpRes
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
        (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam, SubExpRes
se, Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam, [VName]
arrs)

      arrIndex :: Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex Param (LParamInfo rep)
p VName
arr = do
        ST.Indexed Certs
cs PrimExp VName
pe <- VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
forall rep.
VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable rep
vtable
        (VName, (PrimExp VName, Certs))
-> Maybe (VName, (PrimExp VName, Certs))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p, (PrimExp VName
pe, Certs
cs))

      expandPrimExpTable :: Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
table Stm rep
stm
        | [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable rep
vtable) (Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm) =
            VName
-> (PrimExp VName, Certs)
-> Map VName (PrimExp VName, Certs)
-> Map VName (PrimExp VName, Certs)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) Map VName (PrimExp VName, Certs)
table
        | Bool
otherwise =
            Map VName (PrimExp VName, Certs)
table

      asPrimExp :: Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table VName
v
        | Just (PrimExp VName
e, Certs
cs) <- VName
-> Map VName (PrimExp VName, Certs) -> Maybe (PrimExp VName, Certs)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall a b.
WriterT Certs Maybe a
-> WriterT Certs Maybe b -> WriterT Certs Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
        | Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
            PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => m a -> WriterT Certs m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SOAC rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

-- | Type-check a SOAC.
typeCheckSOAC :: (TC.Checkable rep) => SOAC (Aliases rep) -> TC.TypeM rep ()
typeCheckSOAC :: forall rep. Checkable rep => SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC (VJP Lambda (Aliases rep)
lam [SubExp]
args [SubExp]
vec) = do
  [Arg]
args' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
  [Type]
vec_ts <- (SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Type
forall rep. Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Doc Any -> ErrorCase rep) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      Doc Any
"Return type"
        Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam))
        Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
        Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (JVP Lambda (Aliases rep)
lam [SubExp]
args [SubExp]
vec) = do
  [Arg]
args' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
  [Type]
vec_ts <- (SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Type
forall rep. Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args') (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Doc Any -> ErrorCase rep) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      Doc Any
"Parameter type"
        Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty ([Type] -> Doc Any) -> [Type] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args')
        Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
        Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (Stream SubExp
size [VName]
arrexps [SubExp]
accexps Lambda (Aliases rep)
lam) = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  [Arg]
accargs <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
accexps
  [Type]
arrargs <- (VName -> TypeM rep Type) -> [VName] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrexps
  [Arg]
_ <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
  Param (LParamInfo rep)
chunk <- case Lambda (Aliases rep) -> [LParam (Aliases rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda (Aliases rep)
lam of
    LParam (Aliases rep)
chunk : [LParam (Aliases rep)]
_ -> Param (LParamInfo rep) -> TypeM rep (Param (LParamInfo rep))
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (LParamInfo rep)
LParam (Aliases rep)
chunk
    [] -> ErrorCase rep -> TypeM rep (Param (LParamInfo rep))
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep (Param (LParamInfo rep)))
-> ErrorCase rep -> TypeM rep (Param (LParamInfo rep))
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Stream lambda without parameters."
  let asArg :: a -> (a, b)
asArg a
t = (a
t, b
forall a. Monoid a => a
mempty)
      inttp :: TypeBase shape u
inttp = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      lamarrs' :: [Type]
lamarrs' = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
chunk)) [Type]
arrargs
      acc_len :: Int
acc_len = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
      lamrtp :: [Type]
lamrtp = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
acc_len ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      Text
"Stream with inconsistent accumulator type in lambda."
  -- just get the dflow of lambda on the fakearg, which does not alias
  -- arr, so we can later check that aliases of arr are not used inside lam.
  let fake_lamarrs' :: [Arg]
fake_lamarrs' = (Type -> Arg) -> [Type] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg Type
forall {shape} {u}. TypeBase shape u
inttp Arg -> [Arg] -> [Arg]
forall a. a -> [a] -> [a]
: [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w [VName]
arrs Lambda (Aliases rep)
lam [(Shape, Int, VName)]
as) = do
  -- Requirements:
  --
  --   0. @lambdaReturnType@ of @lam@ must be a list
  --      [index types..., value types, ...].
  --
  --   1. The number of index types and value types must be equal to the number
  --      of return values from @lam@.
  --
  --   2. Each index type must have the type i64.
  --
  --   3. Each array in @as@ and the value types must have the same type
  --
  --   4. Each array in @as@ is consumed.  This is not really a check, but more
  --      of a requirement, so that e.g. the source is not hoisted out of a
  --      loop, which will mean it cannot be consumed.
  --
  --   5. Each of arrs must be an array matching a corresponding lambda
  --      parameters.
  --
  -- Code:

  -- First check the input size.
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- 0.
  let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
      indexes :: Int
indexes = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
      rts :: [Type]
rts = Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
      rtsI :: [Type]
rtsI = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
      rtsV :: [Type]
rtsV = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts

  -- 1.
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Scatter: number of index types, value types and array outputs do not match."

  -- 2.
  [Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
rtI) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Scatter: Index return type must be i64."

  [([Type], (Shape, Int, VName))]
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[Type]]
-> [(Shape, Int, VName)] -> [([Type], (Shape, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(Shape, Int, VName)]
as) ((([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ())
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
    -- All lengths must have type i64.
    (SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw

    -- 3.
    [Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtV -> [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a

    -- 4.
    Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
a

  -- 5.
  [Arg]
arrargs <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
w [VName]
arrs [HistOp (Aliases rep)]
ops Lambda (Aliases rep)
bucket_fun) = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- Check the operators.
  [HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ())
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases rep)
op) -> do
    [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
    (SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
dest_shape
    [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf

    -- Operator type must match the type of neutral elements.
    Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
    let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text
"Operator has return type "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t

    -- Arrays must have proper type.
    [(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
      [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape] VName
dest
      Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

  -- Types of input arrays must equal parameter types for bucket function.
  [Arg]
img' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
bucket_fun [Arg]
img'

  -- Return type of bucket function must be an index for each
  -- operation followed by the values to write.
  [Type]
nes_ts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> TypeM rep [[Type]] -> TypeM rep [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp (Aliases rep) -> TypeM rep [Type])
-> [HistOp (Aliases rep)] -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType ([SubExp] -> TypeM rep [Type])
-> (HistOp (Aliases rep) -> [SubExp])
-> HistOp (Aliases rep)
-> TypeM rep [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp (Aliases rep)]
ops
  let bucket_ret_t :: [Type]
bucket_ret_t =
        (HistOp (Aliases rep) -> [Type])
-> [HistOp (Aliases rep)] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> Type -> [Type]
forall a. Int -> a -> [a]
`replicate` PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) (Int -> [Type])
-> (HistOp (Aliases rep) -> Int) -> HistOp (Aliases rep) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int)
-> (HistOp (Aliases rep) -> Shape) -> HistOp (Aliases rep) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp (Aliases rep)]
ops
          [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      Text
"Bucket function has return type "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Aliases rep)]
scans [Reduce (Aliases rep)]
reds Lambda (Aliases rep)
map_lam)) = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  [Arg]
arrs' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
map_lam [Arg]
arrs'

  [Arg]
scan_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> TypeM rep a -> TypeM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
    [Scan (Aliases rep)]
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases rep)]
scans ((Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases rep)
scan_lam [SubExp]
scan_nes) -> do
      [Arg]
scan_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
scan_nes
      let scan_t :: [Type]
scan_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
scan_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text
"Scan function returns type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam)
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
scan_t
      [Arg] -> TypeM rep [Arg]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
scan_nes'

  [Arg]
red_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> TypeM rep a -> TypeM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
    [Reduce (Aliases rep)]
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases rep)]
reds ((Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases rep)
red_lam [SubExp]
red_nes) -> do
      [Arg]
red_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
red_nes
      let red_t :: [Type]
red_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
red_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text
"Reduce function returns type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam)
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
red_t
      [Arg] -> TypeM rep [Arg]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
red_nes'

  let map_lam_ts :: [Type]
map_lam_ts = Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
map_lam

  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    ( Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Arg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Arg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
        [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
    )
    (TypeM rep () -> TypeM rep ())
-> (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad
    (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError
    (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text
"Map function return type "
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
map_lam_ts
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" wrong for given scan and reduction functions."

instance RephraseOp SOAC where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SOAC from -> m (SOAC to)
rephraseInOp Rephraser m from to
r (VJP Lambda from
lam [SubExp]
args [SubExp]
vec) =
    Lambda to -> [SubExp] -> [SubExp] -> SOAC to
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP (Lambda to -> [SubExp] -> [SubExp] -> SOAC to)
-> m (Lambda to) -> m ([SubExp] -> [SubExp] -> SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam m ([SubExp] -> [SubExp] -> SOAC to)
-> m [SubExp] -> m ([SubExp] -> SOAC to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
args m ([SubExp] -> SOAC to) -> m [SubExp] -> m (SOAC to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
vec
  rephraseInOp Rephraser m from to
r (JVP Lambda from
lam [SubExp]
args [SubExp]
vec) =
    Lambda to -> [SubExp] -> [SubExp] -> SOAC to
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP (Lambda to -> [SubExp] -> [SubExp] -> SOAC to)
-> m (Lambda to) -> m ([SubExp] -> [SubExp] -> SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam m ([SubExp] -> [SubExp] -> SOAC to)
-> m [SubExp] -> m ([SubExp] -> SOAC to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
args m ([SubExp] -> SOAC to) -> m [SubExp] -> m (SOAC to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
vec
  rephraseInOp Rephraser m from to
r (Stream SubExp
w [VName]
arrs [SubExp]
acc Lambda from
lam) =
    SubExp -> [VName] -> [SubExp] -> Lambda to -> SOAC to
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs [SubExp]
acc (Lambda to -> SOAC to) -> m (Lambda to) -> m (SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
  rephraseInOp Rephraser m from to
r (Scatter SubExp
w [VName]
arrs Lambda from
lam [(Shape, Int, VName)]
dests) =
    SubExp -> [VName] -> Lambda to -> [(Shape, Int, VName)] -> SOAC to
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
arrs (Lambda to -> [(Shape, Int, VName)] -> SOAC to)
-> m (Lambda to) -> m ([(Shape, Int, VName)] -> SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam m ([(Shape, Int, VName)] -> SOAC to)
-> m [(Shape, Int, VName)] -> m (SOAC to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Shape, Int, VName)] -> m [(Shape, Int, VName)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Shape, Int, VName)]
dests
  rephraseInOp Rephraser m from to
r (Hist SubExp
w [VName]
arrs [HistOp from]
ops Lambda from
lam) =
    SubExp -> [VName] -> [HistOp to] -> Lambda to -> SOAC to
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w [VName]
arrs ([HistOp to] -> Lambda to -> SOAC to)
-> m [HistOp to] -> m (Lambda to -> SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp from -> m (HistOp to)) -> [HistOp from] -> m [HistOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
ops m (Lambda to -> SOAC to) -> m (Lambda to) -> m (SOAC to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
    where
      onOp :: HistOp from -> m (HistOp to)
onOp (HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda from
op) =
        Shape -> SubExp -> [VName] -> [SubExp] -> Lambda to -> HistOp to
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes (Lambda to -> HistOp to) -> m (Lambda to) -> m (HistOp to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op
  rephraseInOp Rephraser m from to
r (Screma SubExp
w [VName]
arrs (ScremaForm [Scan from]
scans [Reduce from]
red Lambda from
lam)) =
    SubExp -> [VName] -> ScremaForm to -> SOAC to
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs
      (ScremaForm to -> SOAC to) -> m (ScremaForm to) -> m (SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( [Scan to] -> [Reduce to] -> Lambda to -> ScremaForm to
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
              ([Scan to] -> [Reduce to] -> Lambda to -> ScremaForm to)
-> m [Scan to] -> m ([Reduce to] -> Lambda to -> ScremaForm to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scan from -> m (Scan to)) -> [Scan from] -> m [Scan to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Scan from -> m (Scan to)
onScan [Scan from]
scans
              m ([Reduce to] -> Lambda to -> ScremaForm to)
-> m [Reduce to] -> m (Lambda to -> ScremaForm to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Reduce from -> m (Reduce to)) -> [Reduce from] -> m [Reduce to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Reduce from -> m (Reduce to)
onRed [Reduce from]
red
              m (Lambda to -> ScremaForm to)
-> m (Lambda to) -> m (ScremaForm to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
          )
    where
      onScan :: Scan from -> m (Scan to)
onScan (Scan Lambda from
op [SubExp]
nes) = Lambda to -> [SubExp] -> Scan to
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan (Lambda to -> [SubExp] -> Scan to)
-> m (Lambda to) -> m ([SubExp] -> Scan to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op m ([SubExp] -> Scan to) -> m [SubExp] -> m (Scan to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
      onRed :: Reduce from -> m (Reduce to)
onRed (Reduce Commutativity
comm Lambda from
op [SubExp]
nes) = Commutativity -> Lambda to -> [SubExp] -> Reduce to
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm (Lambda to -> [SubExp] -> Reduce to)
-> m (Lambda to) -> m ([SubExp] -> Reduce to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op m ([SubExp] -> Reduce to) -> m [SubExp] -> m (Reduce to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes

instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where
  opMetrics :: SOAC rep -> MetricsM ()
opMetrics (VJP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"VJP" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (JVP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"JVP" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Stream SubExp
_ [VName]
_ [SubExp]
_ Lambda rep
lam) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Scatter SubExp
_len [VName]
_ Lambda rep
lam [(Shape, Int, VName)]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
bucket_fun) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops MetricsM () -> MetricsM () -> MetricsM ()
forall a b. MetricsM a -> MetricsM b -> MetricsM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
bucket_fun
  opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (Scan rep -> MetricsM ()) -> [Scan rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Scan rep -> Lambda rep) -> Scan rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
      (Reduce rep -> MetricsM ()) -> [Reduce rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Reduce rep -> Lambda rep) -> Reduce rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
      Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
map_lam

instance (PrettyRep rep) => PP.Pretty (SOAC rep) where
  pretty :: forall ann. SOAC rep -> Doc ann
pretty (VJP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Doc ann
"vjp"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens
        ( Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
            Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
              Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
args)
              Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
vec)
        )
  pretty (JVP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Doc ann
"jvp"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens
        ( Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
            Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
              Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
args)
              Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
vec)
        )
  pretty (Stream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam) =
    SubExp -> [VName] -> [SubExp] -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam
  pretty (Scatter SubExp
w [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests) =
    SubExp -> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter SubExp
w [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests
  pretty (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
    SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun
  pretty (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam))
    | [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans,
      [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
        Doc ann
"map"
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
            ( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
arrs)
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
            )
    | [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans =
        Doc ann
"redomap"
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
            ( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
arrs)
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc ann) -> [Reduce rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Reduce rep -> Doc ann
pretty [Reduce rep]
reds)
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
            )
    | [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
        Doc ann
"scanomap"
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
            ( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
arrs)
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc ann) -> [Scan rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Scan rep -> Doc ann
pretty [Scan rep]
scans)
                Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                  Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
            )
  pretty (Screma SubExp
w [VName]
arrs ScremaForm rep
form) = SubExp -> [VName] -> ScremaForm rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [VName]
arrs ScremaForm rep
form

-- | Prettyprint the given Screma.
ppScrema ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [inp]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  Doc ann
"screma"
    Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
      ( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc ann) -> [Scan rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Scan rep -> Doc ann
pretty [Scan rep]
scans)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc ann) -> [Reduce rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Reduce rep -> Doc ann
pretty [Reduce rep]
reds)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
      )

-- | Prettyprint the given Stream.
ppStream ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [inp]
arrs [SubExp]
acc Lambda rep
lam =
  Doc ann
"streamSeq"
    Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
      ( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
size
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
acc)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
      )

-- | Prettyprint the given Scatter.
ppScatter ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter SubExp
w [inp]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests =
  Doc ann
"scatter"
    Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
      ( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep (((Shape, Int, VName) -> Doc ann)
-> [(Shape, Int, VName)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, Int, VName) -> Doc ann
forall ann. (Shape, Int, VName) -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [(Shape, Int, VName)]
dests)
      )

instance (PrettyRep rep) => Pretty (Scan rep) where
  pretty :: forall ann. Scan rep -> Doc ann
pretty (Scan Lambda rep
scan_lam [SubExp]
scan_nes) =
    Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
scan_lam Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
scan_nes)

ppComm :: Commutativity -> Doc ann
ppComm :: forall ann. Commutativity -> Doc ann
ppComm Commutativity
Noncommutative = Doc ann
forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = Doc ann
"commutative "

instance (PrettyRep rep) => Pretty (Reduce rep) where
  pretty :: forall ann. Reduce rep -> Doc ann
pretty (Reduce Commutativity
comm Lambda rep
red_lam [SubExp]
red_nes) =
    Commutativity -> Doc ann
forall ann. Commutativity -> Doc ann
ppComm Commutativity
comm
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
red_lam
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
red_nes)

-- | Prettyprint the given histogram operation.
ppHist ::
  (PrettyRep rep, Pretty inp) =>
  SubExp ->
  [inp] ->
  [HistOp rep] ->
  Lambda rep ->
  Doc ann
ppHist :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [inp]
arrs [HistOp rep]
ops Lambda rep
bucket_fun =
  Doc ann
"hist"
    Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens
      ( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc ann) -> [HistOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc ann
forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
bucket_fun
      )
  where
    ppOp :: HistOp rep -> Doc ann
ppOp (HistOp Shape
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
      Shape -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
dest_w
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
rf
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
dests)
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
op