{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Segmented operations.  These correspond to perfect @map@ nests on
-- top of /something/, except that the @map@s are conceptually only
-- over @iota@s (so there will be explicit indexing inside them).
module Futhark.IR.SegOp
  ( SegOp (..),
    segLevel,
    segBody,
    segSpace,
    typeCheckSegOp,
    SegSpace (..),
    scopeOfSegSpace,
    segSpaceDims,

    -- * Details
    HistOp (..),
    histType,
    splitHistResults,
    SegBinOp (..),
    segBinOpResults,
    segBinOpChunks,
    KernelBody (..),
    aliasAnalyseKernelBody,
    consumedInKernelBody,
    ResultManifest (..),
    KernelResult (..),
    kernelResultCerts,
    kernelResultSubExp,

    -- ** Generic traversal
    SegOpMapper (..),
    identitySegOpMapper,
    mapSegOpM,
    traverseSegOpStms,

    -- * Simplification
    simplifySegOp,
    HasSegOp (..),
    segOpRules,

    -- * Memory
    segOpReturns,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.Foldable (traverse_)
import Data.List
  ( elemIndex,
    foldl',
    groupBy,
    intersperse,
    isPrefixOf,
    partition,
  )
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR
import Futhark.IR.Aliases
  ( Aliases,
    CanBeAliased (..),
  )
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
  ( Doc,
    apply,
    hsep,
    parens,
    ppTuple',
    pretty,
    (<+>),
    (</>),
  )
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

-- | An operator for 'SegHist'.
data HistOp rep = HistOp
  { forall rep. HistOp rep -> ShapeBase SubExp
histShape :: Shape,
    forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
    forall rep. HistOp rep -> [VName]
histDest :: [VName],
    forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall rep. HistOp rep -> ShapeBase SubExp
histOpShape :: Shape,
    forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
/= :: HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep)
-> (HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
compare :: HistOp rep -> HistOp rep -> Ordering
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
>= :: HistOp rep -> HistOp rep -> Bool
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
showsPrec :: Int -> HistOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
show :: HistOp rep -> String
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
showList :: [HistOp rep] -> ShowS
Show)

-- | The type of a histogram produced by a 'HistOp'.  This can be
-- different from the type of the 'histDest's in case we are
-- dealing with a segmented histogram.
histType :: HistOp rep -> [Type]
histType :: forall rep. HistOp rep -> [Type]
histType HistOp rep
op =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` (HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
    Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$
      HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op

-- | Split reduction results returned by a 'KernelBody' into those
-- that correspond to indexes for the 'HistOp's, and those that
-- correspond to value.
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
  let ranks :: [Int]
ranks = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp rep -> ShapeBase SubExp) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp rep]
ops
      ([SubExp]
idxs, [SubExp]
vals) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
   in [[SubExp]] -> [[SubExp]] -> [([SubExp], [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip
        ([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
        ([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)

-- | An operator for 'SegScan' and 'SegRed'.
data SegBinOp rep = SegBinOp
  { forall rep. SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
    forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
    forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape :: Shape
  }
  deriving (SegBinOp rep -> SegBinOp rep -> Bool
(SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool) -> Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
/= :: SegBinOp rep -> SegBinOp rep -> Bool
Eq, Eq (SegBinOp rep)
Eq (SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> Ordering)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> Ord (SegBinOp rep)
SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$ccompare :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$c< :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$cmax :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
Ord, Int -> SegBinOp rep -> ShowS
[SegBinOp rep] -> ShowS
SegBinOp rep -> String
(Int -> SegBinOp rep -> ShowS)
-> (SegBinOp rep -> String)
-> ([SegBinOp rep] -> ShowS)
-> Show (SegBinOp rep)
forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
forall rep. RepTypes rep => SegBinOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => SegBinOp rep -> String
show :: SegBinOp rep -> String
$cshowList :: forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
showList :: [SegBinOp rep] -> ShowS
Show)

-- | How many reduction results are produced by these 'SegBinOp's?
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: forall rep. [SegBinOp rep] -> Int
segBinOpResults = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | The body of a 'SegOp'.
data KernelBody rep = KernelBody
  { forall rep. KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
    forall rep. KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
    forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
  }

deriving instance (RepTypes rep) => Ord (KernelBody rep)

deriving instance (RepTypes rep) => Show (KernelBody rep)

deriving instance (RepTypes rep) => Eq (KernelBody rep)

-- | Metadata about whether there is a subtle point to this
-- 'KernelResult'.  This is used to protect things like tiling, which
-- might otherwise be removed by the simplifier because they're
-- semantically redundant.  This has no semantic effect and can be
-- ignored at code generation.
data ResultManifest
  = -- | Don't simplify this one!
    ResultNoSimplify
  | -- | Go nuts.
    ResultMaySimplify
  | -- | The results produced are only used within the
    -- same physical thread later on, and can thus be
    -- kept in registers.
    ResultPrivate
  deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
/= :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ResultManifest -> ShowS
showsPrec :: Int -> ResultManifest -> ShowS
$cshow :: ResultManifest -> String
show :: ResultManifest -> String
$cshowList :: [ResultManifest] -> ShowS
showList :: [ResultManifest] -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ResultManifest -> ResultManifest -> Ordering
compare :: ResultManifest -> ResultManifest -> Ordering
$c< :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
>= :: ResultManifest -> ResultManifest -> Bool
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
min :: ResultManifest -> ResultManifest -> ResultManifest
Ord)

-- | A 'KernelBody' does not return an ordinary 'Result'.  Instead, it
-- returns a list of these.
data KernelResult
  = -- | Each "worker" in the kernel returns this.
    -- Whether this is a result-per-thread or a
    -- result-per-group depends on where the 'SegOp' occurs.
    Returns ResultManifest Certs SubExp
  | WriteReturns
      Certs
      Shape -- Size of array.  Must match number of dims.
      VName -- Which array
      [(Slice SubExp, SubExp)]
  | TileReturns
      Certs
      [(SubExp, SubExp)] -- Total/tile for each dimension
      VName -- Tile written by this worker.
      -- The TileReturns must not expect more than one
      -- result to be written per physical thread.
  | RegTileReturns
      Certs
      -- For each dim of result:
      [ ( SubExp, -- size of this dim.
          SubExp, -- block tile size for this dim.
          SubExp -- reg tile size for this dim.
        )
      ]
      VName -- Tile returned by this worker/group.
  deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
/= :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelResult -> ShowS
showsPrec :: Int -> KernelResult -> ShowS
$cshow :: KernelResult -> String
show :: KernelResult -> String
$cshowList :: [KernelResult] -> ShowS
showList :: [KernelResult] -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KernelResult -> KernelResult -> Ordering
compare :: KernelResult -> KernelResult -> Ordering
$c< :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
>= :: KernelResult -> KernelResult -> Bool
$cmax :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
min :: KernelResult -> KernelResult -> KernelResult
Ord)

-- | Get the certs for this 'KernelResult'.
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs ShapeBase SubExp
_ VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs

-- | Get the root t'SubExp' corresponding values for a 'KernelResult'.
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v

instance FreeIn KernelResult where
  freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
  freeIn' (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ShapeBase SubExp
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
  freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v

instance (ASTRep rep) => FreeIn (KernelBody rep) where
  freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
    where
      bound_in_stms :: Names
bound_in_stms = (Stm rep -> Names) -> Stms rep -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall rep. Stm rep -> Names
boundByStm Stms rep
stms

instance (ASTRep rep) => Substitute (KernelBody rep) where
  substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
      (Map VName VName -> BodyDec rep -> BodyDec rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
      (Map VName VName -> Stms rep -> Stms rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
      (Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)

instance Substitute KernelResult where
  substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
  substituteNames Map VName VName
subst (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) =
    Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase SubExp
rws)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
      (Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
  substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)

instance (ASTRep rep) => Rename (KernelBody rep) where
  rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
    BodyDec rep
dec' <- BodyDec rep -> RenameM (BodyDec rep)
forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
    Stms rep
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall rep a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms ((Stms rep -> RenameM (KernelBody rep))
 -> RenameM (KernelBody rep))
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' ([KernelResult] -> KernelBody rep)
-> RenameM [KernelResult] -> RenameM (KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res

instance Rename KernelResult where
  rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename

-- | Perform alias analysis on a 'KernelBody'.
aliasAnalyseKernelBody ::
  (Alias.AliasableRep rep) =>
  AliasTable ->
  KernelBody rep ->
  KernelBody (Aliases rep)
aliasAnalyseKernelBody :: forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases (Body rep -> Body (Aliases rep)) -> Body rep -> Body (Aliases rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
   in BodyDec (Aliases rep)
-> Stms (Aliases rep) -> [KernelResult] -> KernelBody (Aliases rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res

-- | The variables consumed in the kernel body.
consumedInKernelBody ::
  (Aliased rep) =>
  KernelBody rep ->
  Names
consumedInKernelBody :: forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
  where
    consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ ShapeBase SubExp
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
    consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty

checkKernelBody ::
  (TC.Checkable rep) =>
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkKernelBody :: forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec rep
dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
  BodyDec rep -> TypeM rep ()
forall rep. Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
  -- We consume the kernel results (when applicable) before
  -- type-checking the stms, so we will get an error if a statement
  -- uses an array that is written to in a result.
  (KernelResult -> TypeM rep ()) -> [KernelResult] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelResult -> TypeM rep ()
forall {rep}. Checkable rep => KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
  Stms (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text
"Kernel return type is "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but body returns "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText ([KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" values."
    (KernelResult -> Type -> TypeM rep ())
-> [KernelResult] -> [Type] -> TypeM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM rep ()
forall {rep}. Checkable rep => KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
  where
    consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
    consumeKernelResult KernelResult
_ =
      () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
    checkKernelResult (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
      Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
        (SubExp -> TypeM rep ()) -> Slice SubExp -> TypeM rep ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
e
        Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
            Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
              Text
"WriteReturns returning "
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SubExp -> Text
forall a. Pretty a => a -> Text
prettyText SubExp
e
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" of type "
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
t
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", shape="
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp -> Text
forall a. Pretty a => a -> Text
prettyText ShapeBase SubExp
shape
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but destination array has type "
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
arr_t
    checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      [(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
      Type
vt <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            Text
"Invalid type for TileReturns " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
v
    checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles

      -- assert that arr is of element type t and shape (rev outer_tiles ++ reg_tiles)
      Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
expected) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text
"Invalid type for TileReturns. Expected:\n  "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
expected
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
",\ngot:\n  "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
arr_t
      where
        ([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
        expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
reg_tiles)

kernelBodyMetrics :: (OpMetrics (Op rep)) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics = (Stm rep -> MetricsM ()) -> Seq (Stm rep) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics (Seq (Stm rep) -> MetricsM ())
-> (KernelBody rep -> Seq (Stm rep))
-> KernelBody rep
-> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms

instance (PrettyRep rep) => Pretty (KernelBody rep) where
  pretty :: forall ann. KernelBody rep -> Doc ann
pretty (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.stack ((Stm rep -> Doc ann) -> [Stm rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Stm rep -> Doc ann
pretty (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"return"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commastack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc ann) -> [KernelResult] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelResult -> Doc ann
pretty [KernelResult]
res)

certAnnots :: Certs -> [Doc ann]
certAnnots :: forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
  | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = []
  | Bool
otherwise = [Certs -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs]

instance Pretty KernelResult where
  pretty :: forall ann. KernelResult -> Doc ann
pretty (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (manifest)" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
  pretty (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (private)" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
  pretty (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
  pretty (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$
      Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
        [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [ VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
arr
               Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
forall ann. Doc ann
PP.colon
               Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
shape
               Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"with"
               Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.apply (((Slice SubExp, SubExp) -> Doc ann)
-> [(Slice SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc ann
forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
ppRes [(Slice SubExp, SubExp)]
res)
           ]
    where
      ppRes :: (a, a) -> Doc ann
ppRes (a
slice, a
e) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
slice Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"=" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
e
  pretty (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"tile" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply (((SubExp, SubExp) -> Doc ann) -> [(SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc ann
forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
onDim [(SubExp, SubExp)]
dims) Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
v]
    where
      onDim :: (a, a) -> Doc ann
onDim (a
dim, a
tile) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
dim Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
tile
  pretty (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"blkreg_tile" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply (((SubExp, SubExp, SubExp) -> Doc ann)
-> [(SubExp, SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp, SubExp) -> Doc ann
forall {a} {a} {a} {ann}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc ann
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
v]
    where
      onDim :: (a, a, a) -> Doc ann
onDim (a
dim, a
blk_tile, a
reg_tile) =
        a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
dim Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
blk_tile Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"*" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
reg_tile)

-- | Index space of a 'SegOp'.
data SegSpace = SegSpace
  { -- | Flat physical index corresponding to the
    -- dimensions (at code generation used for a
    -- thread ID or similar).
    SegSpace -> VName
segFlat :: VName,
    SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
  }
  deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
/= :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SegSpace -> SegSpace -> Ordering
compare :: SegSpace -> SegSpace -> Ordering
$c< :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
>= :: SegSpace -> SegSpace -> Bool
$cmax :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
min :: SegSpace -> SegSpace -> SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SegSpace -> ShowS
showsPrec :: Int -> SegSpace -> ShowS
$cshow :: SegSpace -> String
show :: SegSpace -> String
$cshowList :: [SegSpace] -> ShowS
showList :: [SegSpace] -> ShowS
Show)

-- | The sizes spanned by the indexes of the 'SegSpace'.
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space

-- | A 'Scope' containing all the identifiers brought into scope by
-- this 'SegSpace'.
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
  [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Map VName (NameInfo rep))
-> [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, NameInfo rep))
-> [VName] -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (,IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space)

checkSegSpace :: (TC.Checkable rep) => SegSpace -> TC.TypeM rep ()
checkSegSpace :: forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
  ((VName, SubExp) -> TypeM rep ())
-> [(VName, SubExp)] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims

-- | A 'SegOp' is semantically a perfectly nested stack of maps, on
-- top of some bottommost computation (scalar computation, reduction,
-- scan, or histogram).  The 'SegSpace' encodes the original map
-- structure.
--
-- All 'SegOp's are parameterised by the representation of their body,
-- as well as a *level*.  The *level* is a representation-specific bit
-- of information.  For example, in GPU backends, it is used to
-- indicate whether the 'SegOp' is expected to run at the thread-level
-- or the group-level.
data SegOp lvl rep
  = SegMap lvl SegSpace [Type] (KernelBody rep)
  | -- | The KernelSpace must always have at least two dimensions,
    -- implying that the result of a SegRed is always an array.
    SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
  deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
(SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool) -> Eq (SegOp lvl rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, Eq (SegOp lvl rep)
Eq (SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> Ordering)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> Ord (SegOp lvl rep)
SegOp lvl rep -> SegOp lvl rep -> Bool
SegOp lvl rep -> SegOp lvl rep -> Ordering
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {lvl} {rep}. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$ccompare :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$c< :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$cmax :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
Ord, Int -> SegOp lvl rep -> ShowS
[SegOp lvl rep] -> ShowS
SegOp lvl rep -> String
(Int -> SegOp lvl rep -> ShowS)
-> (SegOp lvl rep -> String)
-> ([SegOp lvl rep] -> ShowS)
-> Show (SegOp lvl rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
$cshowsPrec :: forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshow :: forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
show :: SegOp lvl rep -> String
$cshowList :: forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
showList :: [SegOp lvl rep] -> ShowS
Show)

-- | The level of a 'SegOp'.
segLevel :: SegOp lvl rep -> lvl
segLevel :: forall lvl rep. SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl

-- | The space of a 'SegOp'.
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: forall lvl rep. SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl

-- | The body of a 'SegOp'.
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
  case SegOp lvl rep
segop of
    SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body

segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns Certs
_ ShapeBase SubExp
shape VName
_ [(Slice SubExp, SubExp)]
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape
segResultShape SegSpace
space Type
t Returns {} =
  (SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp, SubExp) -> SubExp)
-> [(SubExp, SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)

-- | The return type of a 'SegOp'.
segOpType :: SegOp lvl rep -> [Type]
segOpType :: forall lvl rep. SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
  [Type]
red_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
reds
      let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
  [Type]
scan_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
    scan_ts :: [Type]
scan_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
scans
      let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
  HistOp rep
op <- [HistOp rep]
ops
  let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
  where
    dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
dims

instance TypedOp (SegOp lvl rep) where
  opType :: forall t (m :: * -> *).
HasScope t m =>
SegOp lvl rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl rep -> [ExtType]) -> SegOp lvl rep -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType

instance (ASTConstraints lvl, Aliased rep) => AliasedOp (SegOp lvl rep) where
  opAliases :: SegOp lvl rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType

  consumedInOp :: SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
    [VName] -> Names
namesFromList ((HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody

-- | Type check a 'SegOp', given a checker for its level.
typeCheckSegOp ::
  (TC.Checkable rep) =>
  (lvl -> TC.TypeM rep ()) ->
  SegOp lvl (Aliases rep) ->
  TC.TypeM rep ()
typeCheckSegOp :: forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases rep)
body
  where
    reds' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' =
      [Lambda (Aliases rep)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
        ((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
        ((SegBinOp (Aliases rep) -> ShapeBase SubExp)
-> [SegBinOp (Aliases rep)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases rep)
body
  where
    scans' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' =
      [Lambda (Aliases rep)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
        ((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
        ((SegBinOp (Aliases rep) -> ShapeBase SubExp)
-> [SegBinOp (Aliases rep)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  (Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts

  Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
nes_ts <- [HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]])
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp ShapeBase SubExp
dest_shape SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases rep)
op) -> do
      (SubExp -> TypeM rep ()) -> ShapeBase SubExp -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ShapeBase SubExp
dest_shape
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
      [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> Type) -> Arg -> Arg
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            Text
"SegHist operator has return type "
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t

      -- Arrays must have proper type.
      let dest_shape' :: ShapeBase SubExp
dest_shape' = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
dest_shape ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
      [(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
        [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape'] VName
dest
        Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

      [Type] -> TypeM rep [Type]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

    -- Return type of bucket function must be an index for each
    -- operation followed by the values to write.
    let bucket_ret_t :: [Type]
bucket_ret_t =
          (HistOp (Aliases rep) -> [Type])
-> [HistOp (Aliases rep)] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> Type -> [Type]
forall a. Int -> a -> [a]
`replicate` PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) (Int -> [Type])
-> (HistOp (Aliases rep) -> Int) -> HistOp (Aliases rep) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp (Aliases rep) -> ShapeBase SubExp)
-> HistOp (Aliases rep)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Aliases rep)]
ops
            [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          Text
"SegHist body has return type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
  where
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

checkScanRed ::
  (TC.Checkable rep) =>
  SegSpace ->
  [(Lambda (Aliases rep), [SubExp], Shape)] ->
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkScanRed :: forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
  SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  (Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts

  Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
ne_ts <- [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> ((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
    -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops (((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
  -> TypeM rep [Type])
 -> TypeM rep [[Type]])
-> ((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
    -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
      [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes

      -- Operator type must match the type of neutral elements.
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'

      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"wrong type for operator or neutral elements."

      [Type] -> TypeM rep [Type]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
        got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          Text
"Wrong return for body (does not match neutral elements; expected "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => a -> Text
prettyText [Type]
expecting
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"; found "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => a -> Text
prettyText [Type]
got
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"

    [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

-- | Like 'Mapper', but just for 'SegOp's.
data SegOpMapper lvl frep trep m = SegOpMapper
  { forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
  }

-- | A mapper that simply returns the 'SegOp' verbatim.
identitySegOpMapper :: (Monad m) => SegOpMapper lvl rep rep m
identitySegOpMapper :: forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper =
  SegOpMapper
    { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }

mapOnSegSpace ::
  (Monad f) => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
  VName -> [(VName, SubExp)] -> SegSpace
SegSpace
    (VName -> [(VName, SubExp)] -> SegSpace)
-> f VName -> f ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
    f ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((VName -> f VName)
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> (a, b) -> f (c, d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (SegOpMapper lvl frep trep f -> SubExp -> f SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims

mapSegBinOp ::
  (Monad m) =>
  SegOpMapper lvl frep trep m ->
  SegBinOp frep ->
  m (SegBinOp trep)
mapSegBinOp :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes ShapeBase SubExp
shape) =
  Commutativity
-> Lambda trep -> [SubExp] -> ShapeBase SubExp -> SegBinOp trep
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm
    (Lambda trep -> [SubExp] -> ShapeBase SubExp -> SegBinOp trep)
-> m (Lambda trep)
-> m ([SubExp] -> ShapeBase SubExp -> SegBinOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
    m ([SubExp] -> ShapeBase SubExp -> SegBinOp trep)
-> m [SubExp] -> m (ShapeBase SubExp -> SegBinOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
    m (ShapeBase SubExp -> SegBinOp trep)
-> m (ShapeBase SubExp) -> m (SegBinOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))

-- | Apply a 'SegOpMapper' to the given 'SegOp'.
mapSegOpM ::
  (Monad m) =>
  SegOpMapper lvl frep trep m ->
  SegOp lvl frep ->
  m (SegOp lvl trep)
mapSegOpM :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
  lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
    (lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace -> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> Type -> m Type
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
  lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
    (lvl
 -> SegSpace
 -> [SegBinOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
      -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
  lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
    (lvl
 -> SegSpace
 -> [SegBinOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
      -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
  lvl
-> SegSpace
-> [HistOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
    (lvl
 -> SegSpace
 -> [HistOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [HistOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
  where
    onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda frep
op) =
      ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda trep
-> HistOp trep
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp
        (ShapeBase SubExp
 -> SubExp
 -> [VName]
 -> [SubExp]
 -> ShapeBase SubExp
 -> Lambda trep
 -> HistOp trep)
-> m (ShapeBase SubExp)
-> m (SubExp
      -> [VName]
      -> [SubExp]
      -> ShapeBase SubExp
      -> Lambda trep
      -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ShapeBase a -> m (ShapeBase b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
w
        m (SubExp
   -> [VName]
   -> [SubExp]
   -> ShapeBase SubExp
   -> Lambda trep
   -> HistOp trep)
-> m SubExp
-> m ([VName]
      -> [SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
        m ([VName]
   -> [SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m [VName]
-> m ([SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
        m ([SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
        m (ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m (ShapeBase SubExp) -> m (Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
        m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op

mapOnSegOpType ::
  (Monad m) =>
  SegOpMapper lvl frep trep m ->
  Type ->
  m Type
mapOnSegOpType :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = Type -> m Type
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts NoUniqueness
u) =
  VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc
    (VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m VName
-> m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
    m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m ([Type] -> NoUniqueness -> Type)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
ispace
    m ([Type] -> NoUniqueness -> Type)
-> m [Type] -> m (NoUniqueness -> Type)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((ShapeBase SubExp -> m (ShapeBase SubExp))
-> (NoUniqueness -> m NoUniqueness) -> Type -> m Type
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> TypeBase a b -> f (TypeBase c d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse ((SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
    m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
u) =
  PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (ShapeBase SubExp -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = Type -> m Type
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s

rephraseBinOp ::
  (Monad f) =>
  Rephraser f from rep ->
  SegBinOp from ->
  f (SegBinOp rep)
rephraseBinOp :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser f from rep
r (SegBinOp Commutativity
comm Lambda from
lam [SubExp]
nes ShapeBase SubExp
shape) =
  Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm (Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep)
-> f (Lambda rep)
-> f ([SubExp] -> ShapeBase SubExp -> SegBinOp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser f from rep -> Lambda from -> f (Lambda rep)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser f from rep
r Lambda from
lam f ([SubExp] -> ShapeBase SubExp -> SegBinOp rep)
-> f [SubExp] -> f (ShapeBase SubExp -> SegBinOp rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes f (ShapeBase SubExp -> SegBinOp rep)
-> f (ShapeBase SubExp) -> f (SegBinOp rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ShapeBase SubExp -> f (ShapeBase SubExp)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ShapeBase SubExp
shape

rephraseKernelBody ::
  (Monad f) =>
  Rephraser f from rep ->
  KernelBody from ->
  f (KernelBody rep)
rephraseKernelBody :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser f from rep
r (KernelBody BodyDec from
dec Stms from
stms [KernelResult]
res) =
  BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody (BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep)
-> f (BodyDec rep)
-> f (Stms rep -> [KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser f from rep -> BodyDec from -> f (BodyDec rep)
forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec Rephraser f from rep
r BodyDec from
dec f (Stms rep -> [KernelResult] -> KernelBody rep)
-> f (Stms rep) -> f ([KernelResult] -> KernelBody rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Stm from -> f (Stm rep)) -> Stms from -> f (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Seq a -> f (Seq b)
traverse (Rephraser f from rep -> Stm from -> f (Stm rep)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser f from rep
r) Stms from
stms f ([KernelResult] -> KernelBody rep)
-> f [KernelResult] -> f (KernelBody rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> f [KernelResult]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

instance RephraseOp (SegOp lvl) where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SegOp lvl from -> m (SegOp lvl to)
rephraseInOp Rephraser m from to
r (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody from
body) =
    lvl -> SegSpace -> [Type] -> KernelBody to -> SegOp lvl to
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegRed lvl
lvl SegSpace
space [SegBinOp from]
reds [Type]
ts KernelBody from
body) =
    lvl
-> SegSpace
-> [SegBinOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space
      ([SegBinOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [SegBinOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp from -> m (SegBinOp to))
-> [SegBinOp from] -> m [SegBinOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> SegBinOp from -> m (SegBinOp to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
reds
      m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegScan lvl
lvl SegSpace
space [SegBinOp from]
scans [Type]
ts KernelBody from
body) =
    lvl
-> SegSpace
-> [SegBinOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space
      ([SegBinOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [SegBinOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp from -> m (SegBinOp to))
-> [SegBinOp from] -> m [SegBinOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> SegBinOp from -> m (SegBinOp to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
scans
      m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegHist lvl
lvl SegSpace
space [HistOp from]
hists [Type]
ts KernelBody from
body) =
    lvl
-> SegSpace
-> [HistOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space
      ([HistOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [HistOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp from -> m (HistOp to)) -> [HistOp from] -> m [HistOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
hists
      m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
    where
      onOp :: HistOp from -> m (HistOp to)
onOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda from
op) =
        ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda to
-> HistOp to
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape (Lambda to -> HistOp to) -> m (Lambda to) -> m (HistOp to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op

-- | A helper for defining 'TraverseOpStms'.
traverseSegOpStms :: (Monad m) => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: forall (m :: * -> *) lvl rep.
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = SegOpMapper lvl rep rep m -> SegOp lvl rep -> m (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
  where
    seg_scope :: Scope rep
seg_scope = SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
    f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
    mapper :: SegOpMapper lvl rep rep m
mapper =
      SegOpMapper lvl Any Any m
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = OpStmsTraverser m (Lambda rep) rep
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f',
          mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
onBody
        }
    onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec (Stms rep -> [KernelResult] -> KernelBody rep)
-> m (Stms rep) -> m ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms m ([KernelResult] -> KernelBody rep)
-> m [KernelResult] -> m (KernelBody rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> m [KernelResult]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

instance
  (ASTRep rep, Substitute lvl) =>
  Substitute (SegOp lvl rep)
  where
  substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl rep -> Identity (SegOp lvl rep))
-> SegOp lvl rep
-> SegOp lvl rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep rep Identity
-> SegOp lvl rep -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
    where
      substitute :: SegOpMapper lvl rep rep Identity
substitute =
        SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = Lambda rep -> Identity (Lambda rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> Identity (KernelBody rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody rep -> KernelBody rep)
-> KernelBody rep
-> Identity (KernelBody rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> KernelBody rep -> KernelBody rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
  rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
    [VName] -> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a. [VName] -> RenameM a -> RenameM a
renameBound (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) (RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep))
-> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl rep rep RenameM
-> SegOp lvl rep -> RenameM (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
    where
      renamer :: SegOpMapper lvl rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (KernelBody rep -> RenameM (KernelBody rep))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl rep rep RenameM
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename KernelBody rep -> RenameM (KernelBody rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename

instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where
  freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
    Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$
      (State FV (SegOp lvl rep) -> FV -> FV)
-> FV -> State FV (SegOp lvl rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl rep) -> FV) -> State FV (SegOp lvl rep) -> FV
forall a b. (a -> b) -> a -> b
$
        SegOpMapper lvl rep rep (StateT FV Identity)
-> SegOp lvl rep -> State FV (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
      free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
        SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = (KernelBody rep -> FV)
-> KernelBody rep -> StateT FV Identity (KernelBody rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
          }

instance (OpMetrics (Op rep)) => OpMetrics (SegOp lvl rep) where
  opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body

instance Pretty SegSpace where
  pretty :: forall ann. SegSpace -> Doc ann
pretty (SegSpace VName
phys [(VName, SubExp)]
dims) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply
      ( do
          (VName
i, SubExp
d) <- [(VName, SubExp)]
dims
          Doc ann -> [Doc ann]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Doc ann -> [Doc ann]) -> Doc ann -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
i Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"<" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
d
      )
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann
"~" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
phys)

instance (PrettyRep rep) => Pretty (SegBinOp rep) where
  pretty :: forall ann. SegBinOp rep -> Doc ann
pretty (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes ShapeBase SubExp
shape) =
    Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
shape
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
comm'
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
    where
      comm' :: Doc ann
comm' = case Commutativity
comm of
        Commutativity
Commutative -> Doc ann
"commutative "
        Commutativity
Noncommutative -> Doc ann
forall a. Monoid a => a
mempty

instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
  pretty :: forall ann. SegOp lvl rep -> Doc ann
pretty (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
    Doc ann
"segmap"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
forall ann. Doc ann
PP.colon
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
  pretty (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
    Doc ann
"segred"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc ann) -> [SegBinOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegBinOp rep -> Doc ann
pretty [SegBinOp rep]
reds)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
  pretty (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
    Doc ann
"segscan"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc ann) -> [SegBinOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegBinOp rep -> Doc ann
pretty [SegBinOp rep]
scans)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
  pretty (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
    Doc ann
"seghist"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc ann) -> [HistOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc ann
forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
    where
      ppOp :: HistOp rep -> Doc ann
ppOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda rep
op) =
        ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
w
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
rf
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
dests)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
shape
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
          Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
op

instance CanBeAliased (SegOp lvl) where
  addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SegOp lvl rep -> SegOp lvl (Aliases rep)
addOpAliases AliasTable
aliases = Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Aliases rep)))
-> SegOp lvl rep
-> SegOp lvl (Aliases rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Aliases rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Aliases rep))
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
    where
      alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
        (SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Aliases rep)))
-> (KernelBody rep -> Identity (KernelBody (Aliases rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Aliases rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep)))
-> (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep
-> Identity (Lambda (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
          (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep)))
-> (KernelBody rep -> KernelBody (Aliases rep))
-> KernelBody rep
-> Identity (KernelBody (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
          VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

informKernelBody :: (Informing rep) => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (Stms rep -> Stms (Wise rep)
forall rep. Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res

instance CanBeWise (SegOp lvl) where
  addOpWisdom :: forall rep. Informing rep => SegOp lvl rep -> SegOp lvl (Wise rep)
addOpWisdom = Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Wise rep)))
-> SegOp lvl rep
-> SegOp lvl (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Wise rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Wise rep))
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Wise rep) Identity
forall {lvl}. SegOpMapper lvl rep (Wise rep) Identity
add
    where
      add :: SegOpMapper lvl rep (Wise rep) Identity
add =
        (SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Wise rep)))
-> (KernelBody rep -> Identity (KernelBody (Wise rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Wise rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda (Wise rep) -> Identity (Lambda (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep)))
-> (Lambda rep -> Lambda (Wise rep))
-> Lambda rep
-> Identity (Lambda (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Lambda (Wise rep)
forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda)
          (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep)))
-> (KernelBody rep -> KernelBody (Wise rep))
-> KernelBody rep
-> Identity (KernelBody (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> KernelBody (Wise rep)
forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
          VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (ASTRep rep) => ST.IndexOp (SegOp lvl rep) where
  indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
    Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
    let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> Indexed)
-> [TPrimExp Int64 VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed Certs
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Indexed
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
        idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm rep -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm rep) -> Map VName Indexed
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm rep) -> Map VName Indexed)
-> Seq (Stm rep) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
    case SubExp
se of
      Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
    where
      ([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      -- Indexes in excess of what is used to index through the
      -- segment dimensions.
      excess_is :: [TPrimExp Int64 VName]
excess_is = Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is

      expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
        | [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
            VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
        | [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          BasicOp (Index VName
arr Slice SubExp
slice) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
          [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
          VName
arr VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
          Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
            let idx :: Indexed
idx =
                  Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
                    (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs)
                    VName
arr
                    (Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((PrimExp VName -> TPrimExp Int64 VName)
-> Slice (PrimExp VName) -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
             in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
        | Bool
otherwise =
            Map VName Indexed
table

      asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
        WriterT Certs Maybe (Slice (PrimExp VName))
-> Maybe (Slice (PrimExp VName), Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (Slice (PrimExp VName))
 -> Maybe (Slice (PrimExp VName), Certs))
-> (Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName)))
-> Slice SubExp
-> Maybe (Slice (PrimExp VName), Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SubExp -> WriterT Certs Maybe (PrimExp VName))
-> Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Slice a -> f (Slice b)
traverse ((VName -> WriterT Certs Maybe (PrimExp VName))
-> SubExp -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))

      asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
        | Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall a b.
WriterT Certs Maybe a
-> WriterT Certs Maybe b -> WriterT Certs Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
        | Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
            PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => m a -> WriterT Certs m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance
  (ASTRep rep, ASTConstraints lvl) =>
  IsOp (SegOp lvl rep)
  where
  cheapOp :: SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
  safeOp :: SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True

--- Simplification

instance Engine.Simplifiable SegSpace where
  simplify :: forall rep. SimplifiableRep rep => SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
    VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM rep [(VName, SubExp)] -> SimpleM rep SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM rep (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM rep [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> SimpleM rep SubExp)
-> (VName, SubExp) -> SimpleM rep (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> (VName, a) -> f (VName, b)
traverse SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims

instance Engine.Simplifiable KernelResult where
  simplify :: forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Certs -> SubExp -> KernelResult)
-> SimpleM rep Certs -> SimpleM rep (SubExp -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep (SubExp -> KernelResult)
-> SimpleM rep SubExp -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
  simplify (WriteReturns Certs
cs ShapeBase SubExp
ws VName
a [(Slice SubExp, SubExp)]
res) =
    Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
      (Certs
 -> ShapeBase SubExp
 -> VName
 -> [(Slice SubExp, SubExp)]
 -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
     rep
     (ShapeBase SubExp
      -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      SimpleM
  rep
  (ShapeBase SubExp
   -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep (ShapeBase SubExp)
-> SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
ws
      SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep VName
-> SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
      SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep [(Slice SubExp, SubExp)] -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
  simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Certs -> [(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
  simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
     rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      SimpleM rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
      SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what

mkWiseKernelBody ::
  (Informing rep) =>
  BodyDec rep ->
  Stms (Wise rep) ->
  [KernelResult] ->
  KernelBody (Wise rep)
mkWiseKernelBody :: forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
  let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms (Result -> Body (Wise rep)) -> Result -> Body (Wise rep)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
   in BodyDec (Wise rep)
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
  where
    res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res

mkKernelBodyM ::
  (MonadBuilder m) =>
  Stms (Rep m) ->
  [KernelResult] ->
  m (KernelBody (Rep m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
  Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms (Result -> m (Body (Rep m))) -> Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
  KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Rep m) -> m (KernelBody (Rep m)))
-> KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Rep m)
-> Stms (Rep m) -> [KernelResult] -> KernelBody (Rep m)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
  where
    res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres

simplifyKernelBody ::
  (Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
  SegSpace ->
  KernelBody (Wise rep) ->
  Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
  BlockPred (Wise rep)
par_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv ((Env rep -> BlockPred (Wise rep))
 -> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Env rep -> HoistBlockers rep
forall {k} (rep :: k). Env rep -> HoistBlockers rep
Engine.envHoistBlockers

  let blocker :: BlockPred (Wise rep)
blocker =
        Names -> BlockPred (Wise rep)
forall rep. ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isOp
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isConsumed
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. Aliased rep => BlockPred rep
Engine.isConsuming
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated

  -- Ensure we do not try to use anything that is consumed in the result.
  ([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable ((SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ((KernelResult -> [VName]) -> [KernelResult] -> [VName]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
      (SimpleM rep ([KernelResult], UsageTable)
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
        [KernelResult]
res' <-
          (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo (Wise rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (Wise rep)) -> [VName])
-> Map VName (NameInfo (Wise rep)) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Wise rep) -> Map VName (NameInfo (Wise rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) (SimpleM rep [KernelResult] -> SimpleM rep [KernelResult])
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall a b. (a -> b) -> a -> b
$
            (KernelResult -> SimpleM rep KernelResult)
-> [KernelResult] -> SimpleM rep [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM KernelResult -> SimpleM rep KernelResult
forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
        ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], UsageTable)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')

  (KernelBody (Wise rep), Stms (Wise rep))
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
  where
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = SegSpace -> SymbolTable (Wise rep)
forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
    bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space

    consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      [VName
arr]
    consumedInResult KernelResult
_ =
      []

simplifyLambda ::
  (Engine.SimplifiableRep rep) =>
  Names ->
  Lambda (Wise rep) ->
  Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Names
bound = SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> (Lambda (Wise rep)
    -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
bound

segSpaceSymbolTable :: (ASTRep rep) => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
  (SymbolTable rep -> (VName, SubExp) -> SymbolTable rep)
-> SymbolTable rep -> [(VName, SubExp)] -> SymbolTable rep
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
forall {rep}.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope rep -> SymbolTable rep) -> Scope rep -> SymbolTable rep
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo rep -> Scope rep
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo rep -> Scope rep) -> NameInfo rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
  where
    f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable

simplifySegBinOp ::
  (Engine.SimplifiableRep rep) =>
  VName ->
  SegBinOp (Wise rep) ->
  Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp VName
phys_id (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes ShapeBase SubExp
shape) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
      Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName VName
phys_id) Lambda (Wise rep)
lam
  ShapeBase SubExp
shape' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
shape
  [SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  (SegBinOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Commutativity
-> Lambda (Wise rep)
-> [SubExp]
-> ShapeBase SubExp
-> SegBinOp (Wise rep)
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise rep)
hoisted)

-- | Simplify the given 'SegOp'.
simplifySegOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ (),
    Engine.Simplifiable lvl
  ) =>
  SegOp lvl (Wise rep) ->
  Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
      Stms (Wise rep)
body_hoisted
    )
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
 -> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
      (SegBinOp (Wise rep)
 -> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
reds
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
 -> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
      (SegBinOp (Wise rep)
 -> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
scans
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)

  ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
 -> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [HistOp (Wise rep)]
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops ((HistOp (Wise rep)
  -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$
      \(HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda (Wise rep)
lam) -> do
        ShapeBase SubExp
w' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
w
        SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
        [VName]
arrs' <- [VName] -> SimpleM rep [VName]
forall rep. SimplifiableRep rep => [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
        [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
        ShapeBase SubExp
dims' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
dims
        (Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
          (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
            (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
              Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName (SegSpace -> VName
segFlat SegSpace
space)) Lambda (Wise rep)
lam
        (HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise rep)
lam',
            Stms (Wise rep)
op_hoisted
          )

  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [HistOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope

-- | Does this rep contain 'SegOp's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSegOp rep where
  type SegOpLevel rep
  asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
  segOp :: SegOp (SegOpLevel rep) rep -> Op rep

-- | Simplification rules for simplifying 'SegOp's.
segOpRules ::
  (HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
  RuleBook rep
segOpRules :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules =
  [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (TopDown rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [RuleOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (BottomUp rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]

segOpRuleTopDown ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  TopDownRuleOp rep
segOpRuleTopDown :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      TopDown rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      Rule rep
forall rep. Rule rep
Skip

segOpRuleBottomUp ::
  (HasSegOp rep, BuilderOps rep, Aliased rep) =>
  BottomUpRuleOp rep
segOpRuleBottomUp :: forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      BottomUp rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      Rule rep
forall rep. Rule rep
Skip

topDownSegOp ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  ST.SymbolTable rep ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- If a SegOp produces something invariant to the SegOp, turn it
-- into a replicate.
topDownSegOp :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  ([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
    [(Type, PatElem (LetDec rep), KernelResult)]
-> ([Type], [PatElem (LetDec rep)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElem (LetDec rep), KernelResult)]
 -> ([Type], [PatElem (LetDec rep)], [KernelResult]))
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep ([Type], [PatElem (LetDec rep)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool)
-> [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult ([Type]
-> [PatElem (LetDec rep)]
-> [KernelResult]
-> [(Type, PatElem (LetDec rep), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)

  -- Check if we did anything at all.
  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody <- Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms [KernelResult]
kres'
  Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
    Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
      Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
        SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$
          SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
  where
    isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
    isInvariant (Var VName
v) = Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable

    checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
      | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty,
        ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
        SubExp -> Bool
isInvariant SubExp
se = do
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
              ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
          Bool -> RuleM rep Bool
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
      Bool -> RuleM rep Bool
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

-- If a SegRed contains two reduction operations that have the same
-- vector shape, merge them together.  This saves on communication
-- overhead, but can in principle lead to more local memory usage.
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
  | [SegBinOp rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
    [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
      ((SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
 -> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
 -> Bool)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> Bool
forall {rep} {b} {rep} {b}.
(SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> [[(SegBinOp rep,
       [(PatElem (LetDec rep), Type, KernelResult)])]])
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$
        [SegBinOp rep]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops ([[(PatElem (LetDec rep), Type, KernelResult)]]
 -> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])])
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) ([(PatElem (LetDec rep), Type, KernelResult)]
 -> [[(PatElem (LetDec rep), Type, KernelResult)]])
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
            [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
    ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> Bool)
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
    -> Int)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> ([SegBinOp rep],
     [[(PatElem (LetDec rep), Type, KernelResult)]]))
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> Maybe
      (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)]))
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Maybe
     (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
forall {rep} {a}.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
          ([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (LetDec rep), Type, KernelResult)]
 -> ([PatElem (LetDec rep)], [Type], [KernelResult]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
          pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
          ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
          kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
      Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
  where
    ([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
    ([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
    ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody

    sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) = SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op2

    combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = Maybe (SegBinOp rep, [a])
forall a. Maybe a
Nothing
    combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a. a -> Maybe a
Just ((SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a]))
-> (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a]))
-> (SegBinOp rep, [a])
-> [(SegBinOp rep, [a])]
-> (SegBinOp rep, [a])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
forall {rep} {a}.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs

    combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
      let lam1 :: Lambda rep
lam1 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
          lam2 :: Lambda rep
lam2 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
          ([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
            Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
          ([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
            Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
          lam :: Lambda rep
lam =
            Lambda
              { lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
                  [Param (LParamInfo rep)]
op1_xparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
                lambdaReturnType :: [Type]
lambdaReturnType = Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
                lambdaBody :: Body rep
lambdaBody =
                  Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$
                    Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
              }
       in ( SegBinOp
              { segBinOpComm :: Commutativity
segBinOpComm = SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
                segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
                segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
                segBinOpShape :: ShapeBase SubExp
segBinOpShape = SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 -- Same as shape of op2 due to the grouping.
              },
            [a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
          )
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = Rule rep
forall rep. Rule rep
Skip

-- A convenient way of operating on the type and body of a SegOp,
-- without worrying about exactly what kind it is.
segOpGuts ::
  SegOp (SegOpLevel rep) rep ->
  ( [Type],
    KernelBody rep,
    Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
  )
segOpGuts :: forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, Int
0, SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)

bottomUpSegOp ::
  (Aliased rep, HasSegOp rep, BuilderOps rep) =>
  (ST.SymbolTable rep, UT.UsageTable) ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- Some SegOp results can be moved outside the SegOp, which can
-- simplify further analysis.
bottomUpSegOp :: forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
_used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  -- Iterate through the bindings.  For each, we check whether it is
  -- in kres and can be moved outside.  If so, we remove it from kres
  -- and kpes and make it a binding outside.  We have to be careful
  -- not to remove anything that is passed on to a scan/map/histogram
  -- operation.  Fortunately, these are always first in the result
  -- list.
  ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
    Scope rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM
   rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
 -> RuleM
      rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a b. (a -> b) -> a -> b
$
      (([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
 -> Stm rep
 -> RuleM
      rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stms rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, Stms rep
forall a. Monoid a => a
mempty) Stms rep
kstms

  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PatElem (LetDec rep)]
kpes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody' <-
    Scope rep
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep))
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms' [KernelResult]
kres'

  Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
  where
    ([Type]
kts, KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres, Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
      SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
    free_in_kstms :: Names
free_in_kstms = (Stm rep -> Names) -> Stms rep -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
    space :: SegSpace
space = SegOp (SegOpLevel rep) rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop

    sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
      | Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
aux (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
        [DimIndex SubExp]
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex SubExp])
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
        [DimIndex SubExp]
space_slice [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
        Slice SubExp
remaining_slice <- [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool)
-> (VName -> Maybe (Entry rep)) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> SymbolTable rep -> Maybe (Entry rep))
-> SymbolTable rep -> VName -> Maybe (Entry rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
          Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
            VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Certs -> Names
forall a. FreeIn a => a -> Names
freeIn (StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) =
          (Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
      | Bool
otherwise =
          Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing

    distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
      | Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
        Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
        Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
          let outer_slice :: [DimIndex SubExp]
outer_slice =
                (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map
                  ( \SubExp
d ->
                      SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
d (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
                  )
                  ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
              index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
                [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] (Exp rep -> RuleM rep ())
-> (Slice SubExp -> Exp rep) -> Slice SubExp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM rep ()) -> Slice SubExp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
                  [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
                    [DimIndex SubExp]
outer_slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
          VName
precopy <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM rep VName) -> String -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
          PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName :: VName
patElemName = VName
precopy}
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
precopy
          ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( [PatElem (LetDec rep)]
kpes'',
              [Type]
kts'',
              [KernelResult]
kres'',
              if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
                then Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
                else Stms rep
kstms'
            )
    distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
      ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm)

    isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
      case ((PatElem (LetDec rep), Type, KernelResult) -> Bool)
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
    [(PatElem (LetDec rep), Type, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches ([(PatElem (LetDec rep), Type, KernelResult)]
 -> ([(PatElem (LetDec rep), Type, KernelResult)],
     [(PatElem (LetDec rep), Type, KernelResult)]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
    [(PatElem (LetDec rep), Type, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
        ([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
          | Just Int
i <- PatElem (LetDec rep) -> [PatElem (LetDec rep)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
            Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
            ([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
              (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
 [KernelResult])
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
        ([(PatElem (LetDec rep), Type, KernelResult)],
 [(PatElem (LetDec rep), Type, KernelResult)])
_ -> Maybe
  (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
   [KernelResult])
forall a. Maybe a
Nothing
      where
        matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
        matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False

--- Memory

kernelBodyReturns ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  KernelBody somerep ->
  [ExpReturns] ->
  m [ExpReturns]
kernelBodyReturns :: forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall {rep} {inner :: * -> *} {m :: * -> *}.
(RetType rep ~ RetTypeMem, FParamInfo rep ~ FParamMem,
 LParamInfo rep ~ LParamMem, BranchType rep ~ BranchTypeMem,
 OpC rep ~ MemOp inner, Monad m, HasScope rep m,
 HasLetDecMem (LetDec rep), ASTRep rep, OpReturns (inner rep),
 RephraseOp inner) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody somerep -> [KernelResult])
-> KernelBody somerep
-> [ExpReturns]
-> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody somerep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult
  where
    correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
    correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret

-- | Like 'segOpType', but for memory representations.
segOpReturns ::
  (Mem rep inner, Monad m, HasScope rep m) =>
  SegOp lvl somerep ->
  m [ExpReturns]
segOpReturns :: forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns k :: SegOp lvl somerep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody somerep
kbody) =
  KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
forall t (m :: * -> *).
HasScope t m =>
SegOp lvl somerep -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
  KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
forall t (m :: * -> *).
HasScope t m =>
SegOp lvl somerep -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
  KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
forall t (m :: * -> *).
HasScope t m =>
SegOp lvl somerep -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp somerep]
ops [Type]
_ KernelBody somerep
_) =
  [[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp somerep -> m [ExpReturns])
-> [HistOp somerep] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp somerep -> [VName]) -> HistOp somerep -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp somerep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp somerep]
ops