{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Segmented operations.  These correspond to perfect @map@ nests on
-- top of /something/, except that the @map@s are conceptually only
-- over @iota@s (so there will be explicit indexing inside them).
module Futhark.IR.SegOp
  ( SegOp (..),
    SegVirt (..),
    segLevel,
    segSpace,
    typeCheckSegOp,
    SegSpace (..),
    scopeOfSegSpace,
    segSpaceDims,

    -- * Details
    HistOp (..),
    histType,
    SegBinOp (..),
    segBinOpResults,
    segBinOpChunks,
    KernelBody (..),
    aliasAnalyseKernelBody,
    consumedInKernelBody,
    ResultManifest (..),
    KernelResult (..),
    kernelResultSubExp,
    SplitOrdering (..),

    -- ** Generic traversal
    SegOpMapper (..),
    identitySegOpMapper,
    mapSegOpM,

    -- * Simplification
    simplifySegOp,
    HasSegOp (..),
    segOpRules,

    -- * Memory
    segOpReturns,
  )
where

import Control.Category
import Control.Monad.Identity hiding (mapM_)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Data.Bifunctor (first)
import Data.List
  ( elemIndex,
    foldl',
    groupBy,
    intersperse,
    isPrefixOf,
    partition,
  )
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR
import Futhark.IR.Aliases
  ( Aliases,
    removeLambdaAliases,
    removeStmAliases,
  )
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
  ( Pretty,
    commasep,
    parens,
    ppr,
    text,
    (<+>),
    (</>),
  )
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))

-- | How an array is split into chunks.
data SplitOrdering
  = SplitContiguous
  | SplitStrided SubExp
  deriving (SplitOrdering -> SplitOrdering -> Bool
(SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool) -> Eq SplitOrdering
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SplitOrdering -> SplitOrdering -> Bool
$c/= :: SplitOrdering -> SplitOrdering -> Bool
== :: SplitOrdering -> SplitOrdering -> Bool
$c== :: SplitOrdering -> SplitOrdering -> Bool
Eq, Eq SplitOrdering
Eq SplitOrdering
-> (SplitOrdering -> SplitOrdering -> Ordering)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> Ord SplitOrdering
SplitOrdering -> SplitOrdering -> Bool
SplitOrdering -> SplitOrdering -> Ordering
SplitOrdering -> SplitOrdering -> SplitOrdering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmin :: SplitOrdering -> SplitOrdering -> SplitOrdering
max :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmax :: SplitOrdering -> SplitOrdering -> SplitOrdering
>= :: SplitOrdering -> SplitOrdering -> Bool
$c>= :: SplitOrdering -> SplitOrdering -> Bool
> :: SplitOrdering -> SplitOrdering -> Bool
$c> :: SplitOrdering -> SplitOrdering -> Bool
<= :: SplitOrdering -> SplitOrdering -> Bool
$c<= :: SplitOrdering -> SplitOrdering -> Bool
< :: SplitOrdering -> SplitOrdering -> Bool
$c< :: SplitOrdering -> SplitOrdering -> Bool
compare :: SplitOrdering -> SplitOrdering -> Ordering
$ccompare :: SplitOrdering -> SplitOrdering -> Ordering
Ord, Int -> SplitOrdering -> ShowS
[SplitOrdering] -> ShowS
SplitOrdering -> String
(Int -> SplitOrdering -> ShowS)
-> (SplitOrdering -> String)
-> ([SplitOrdering] -> ShowS)
-> Show SplitOrdering
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SplitOrdering] -> ShowS
$cshowList :: [SplitOrdering] -> ShowS
show :: SplitOrdering -> String
$cshow :: SplitOrdering -> String
showsPrec :: Int -> SplitOrdering -> ShowS
$cshowsPrec :: Int -> SplitOrdering -> ShowS
Show)

instance FreeIn SplitOrdering where
  freeIn' :: SplitOrdering -> FV
freeIn' SplitOrdering
SplitContiguous = FV
forall a. Monoid a => a
mempty
  freeIn' (SplitStrided SubExp
stride) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
stride

instance Substitute SplitOrdering where
  substituteNames :: Map VName VName -> SplitOrdering -> SplitOrdering
substituteNames Map VName VName
_ SplitOrdering
SplitContiguous =
    SplitOrdering
SplitContiguous
  substituteNames Map VName VName
subst (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
stride

instance Rename SplitOrdering where
  rename :: SplitOrdering -> RenameM SplitOrdering
rename SplitOrdering
SplitContiguous =
    SplitOrdering -> RenameM SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
  rename (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> RenameM SubExp -> RenameM SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
stride

-- | An operator for 'SegHist'.
data HistOp lore = HistOp
  { forall lore. HistOp lore -> SubExp
histWidth :: SubExp,
    forall lore. HistOp lore -> SubExp
histRaceFactor :: SubExp,
    forall lore. HistOp lore -> [VName]
histDest :: [VName],
    forall lore. HistOp lore -> [SubExp]
histNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall lore. HistOp lore -> ShapeBase SubExp
histShape :: Shape,
    forall lore. HistOp lore -> Lambda lore
histOp :: Lambda lore
  }
  deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Decorations lore => Int -> HistOp lore -> ShowS
forall lore. Decorations lore => [HistOp lore] -> ShowS
forall lore. Decorations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Decorations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> HistOp lore -> ShowS
Show)

-- | The type of a histogram produced by a 'HistOp'.  This can be
-- different from the type of the 'histDest's in case we are
-- dealing with a segmented histogram.
histType :: HistOp lore -> [Type]
histType :: forall lore. HistOp lore -> [Type]
histType HistOp lore
op =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map
    ( (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op)
        (Type -> Type) -> (Type -> Type) -> Type -> Type
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> ShapeBase SubExp -> Type
`arrayOfShape` HistOp lore -> ShapeBase SubExp
forall lore. HistOp lore -> ShapeBase SubExp
histShape HistOp lore
op)
    )
    ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op

-- | An operator for 'SegScan' and 'SegRed'.
data SegBinOp lore = SegBinOp
  { forall lore. SegBinOp lore -> Commutativity
segBinOpComm :: Commutativity,
    forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda :: Lambda lore,
    forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape :: Shape
  }
  deriving (SegBinOp lore -> SegBinOp lore -> Bool
(SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool) -> Eq (SegBinOp lore)
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegBinOp lore -> SegBinOp lore -> Bool
$c/= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
== :: SegBinOp lore -> SegBinOp lore -> Bool
$c== :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
Eq, Eq (SegBinOp lore)
Eq (SegBinOp lore)
-> (SegBinOp lore -> SegBinOp lore -> Ordering)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> SegBinOp lore)
-> (SegBinOp lore -> SegBinOp lore -> SegBinOp lore)
-> Ord (SegBinOp lore)
SegBinOp lore -> SegBinOp lore -> Bool
SegBinOp lore -> SegBinOp lore -> Ordering
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (SegBinOp lore)
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Ordering
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
min :: SegBinOp lore -> SegBinOp lore -> SegBinOp lore
$cmin :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
max :: SegBinOp lore -> SegBinOp lore -> SegBinOp lore
$cmax :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
>= :: SegBinOp lore -> SegBinOp lore -> Bool
$c>= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
> :: SegBinOp lore -> SegBinOp lore -> Bool
$c> :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
<= :: SegBinOp lore -> SegBinOp lore -> Bool
$c<= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
< :: SegBinOp lore -> SegBinOp lore -> Bool
$c< :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
compare :: SegBinOp lore -> SegBinOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Ordering
Ord, Int -> SegBinOp lore -> ShowS
[SegBinOp lore] -> ShowS
SegBinOp lore -> String
(Int -> SegBinOp lore -> ShowS)
-> (SegBinOp lore -> String)
-> ([SegBinOp lore] -> ShowS)
-> Show (SegBinOp lore)
forall lore. Decorations lore => Int -> SegBinOp lore -> ShowS
forall lore. Decorations lore => [SegBinOp lore] -> ShowS
forall lore. Decorations lore => SegBinOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegBinOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [SegBinOp lore] -> ShowS
show :: SegBinOp lore -> String
$cshow :: forall lore. Decorations lore => SegBinOp lore -> String
showsPrec :: Int -> SegBinOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> SegBinOp lore -> ShowS
Show)

-- | How many reduction results are produced by these 'SegBinOp's?
segBinOpResults :: [SegBinOp lore] -> Int
segBinOpResults :: forall lore. [SegBinOp lore] -> Int
segBinOpResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp lore] -> [Int]) -> [SegBinOp lore] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral)

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segBinOpChunks :: [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks :: forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp lore] -> [Int]) -> [SegBinOp lore] -> [a] -> [[a]]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral)

-- | The body of a 'SegOp'.
data KernelBody lore = KernelBody
  { forall lore. KernelBody lore -> BodyDec lore
kernelBodyLore :: BodyDec lore,
    forall lore. KernelBody lore -> Stms lore
kernelBodyStms :: Stms lore,
    forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult :: [KernelResult]
  }

deriving instance Decorations lore => Ord (KernelBody lore)

deriving instance Decorations lore => Show (KernelBody lore)

deriving instance Decorations lore => Eq (KernelBody lore)

-- | Metadata about whether there is a subtle point to this
-- 'KernelResult'.  This is used to protect things like tiling, which
-- might otherwise be removed by the simplifier because they're
-- semantically redundant.  This has no semantic effect and can be
-- ignored at code generation.
data ResultManifest
  = -- | Don't simplify this one!
    ResultNoSimplify
  | -- | Go nuts.
    ResultMaySimplify
  | -- | The results produced are only used within the
    -- same physical thread later on, and can thus be
    -- kept in registers.
    ResultPrivate
  deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
Ord)

-- | A 'KernelBody' does not return an ordinary 'Result'.  Instead, it
-- returns a list of these.
data KernelResult
  = -- | Each "worker" in the kernel returns this.
    -- Whether this is a result-per-thread or a
    -- result-per-group depends on where the 'SegOp' occurs.
    Returns ResultManifest SubExp
  | WriteReturns
      Shape -- Size of array.  Must match number of dims.
      VName -- Which array
      [(Slice SubExp, SubExp)]
  | -- Arbitrary number of index/value pairs.
    ConcatReturns
      SplitOrdering -- Permuted?
      SubExp -- The final size.
      SubExp -- Per-thread/group (max) chunk size.
      VName -- Chunk by this worker.
  | TileReturns
      [(SubExp, SubExp)] -- Total/tile for each dimension
      VName -- Tile written by this worker.
      -- The TileReturns must not expect more than one
      -- result to be written per physical thread.
  | RegTileReturns
      -- For each dim of result:
      [ ( SubExp, -- size of this dim.
          SubExp, -- block tile size for this dim.
          SubExp -- reg tile size for this dim.
        )
      ]
      VName -- Tile returned by this worker/group.
  deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
Ord)

-- | Get the root t'SubExp' corresponding values for a 'KernelResult'.
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (ConcatReturns SplitOrdering
_ SubExp
_ SubExp
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (TileReturns [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v

instance FreeIn KernelResult where
  freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ SubExp
what) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
  freeIn' (WriteReturns ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) = ShapeBase SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ShapeBase SubExp
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
  freeIn' (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
per_thread_elems FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    [(SubExp, SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v

instance ASTLore lore => FreeIn (KernelBody lore) where
  freeIn' :: KernelBody lore -> FV
freeIn' (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
    Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec lore
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms lore -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms lore
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
    where
      bound_in_stms :: Names
bound_in_stms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall lore. Stm lore -> Names
boundByStm Stms lore
stms

instance ASTLore lore => Substitute (KernelBody lore) where
  substituteNames :: Map VName VName -> KernelBody lore -> KernelBody lore
substituteNames Map VName VName
subst (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
    BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody
      (Map VName VName -> BodyDec lore -> BodyDec lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec lore
dec)
      (Map VName VName -> Stms lore -> Stms lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms lore
stms)
      (Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)

instance Substitute KernelResult where
  substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest SubExp
se) =
    ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
  substituteNames Map VName VName
subst (WriteReturns ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) =
    ShapeBase SubExp
-> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
      (Map VName VName -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase SubExp
rws)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
      (Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
  substituteNames Map VName VName
subst (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
      (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
per_thread_elems)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (Map VName VName
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)

instance ASTLore lore => Rename (KernelBody lore) where
  rename :: KernelBody lore -> RenameM (KernelBody lore)
rename (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) = do
    BodyDec lore
dec' <- BodyDec lore -> RenameM (BodyDec lore)
forall a. Rename a => a -> RenameM a
rename BodyDec lore
dec
    Stms lore
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall lore a.
Renameable lore =>
Stms lore -> (Stms lore -> RenameM a) -> RenameM a
renamingStms Stms lore
stms ((Stms lore -> RenameM (KernelBody lore))
 -> RenameM (KernelBody lore))
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ \Stms lore
stms' ->
      BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec' Stms lore
stms' ([KernelResult] -> KernelBody lore)
-> RenameM [KernelResult] -> RenameM (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res

instance Rename KernelResult where
  rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename

-- | Perform alias analysis on a 'KernelBody'.
aliasAnalyseKernelBody ::
  ( ASTLore lore,
    CanBeAliased (Op lore)
  ) =>
  AliasTable ->
  KernelBody lore ->
  KernelBody (Aliases lore)
aliasAnalyseKernelBody :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
  let Body BodyDec (Aliases lore)
dec' Stms (Aliases lore)
stms' [SubExp]
_ = AliasTable -> Body lore -> BodyT (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
Alias.analyseBody AliasTable
aliases (Body lore -> BodyT (Aliases lore))
-> Body lore -> BodyT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
dec Stms lore
stms []
   in BodyDec (Aliases lore)
-> Stms (Aliases lore)
-> [KernelResult]
-> KernelBody (Aliases lore)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Aliases lore)
dec' Stms (Aliases lore)
stms' [KernelResult]
res

removeKernelBodyAliases ::
  CanBeAliased (Op lore) =>
  KernelBody (Aliases lore) ->
  KernelBody lore
removeKernelBodyAliases :: forall lore.
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases (KernelBody (BodyAliasing
_, BodyDec lore
dec) Stms (Aliases lore)
stms [KernelResult]
res) =
  BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec ((Stm (Aliases lore) -> Stm lore)
-> Stms (Aliases lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases Stms (Aliases lore)
stms) [KernelResult]
res

removeKernelBodyWisdom ::
  CanBeWise (Op lore) =>
  KernelBody (Wise lore) ->
  KernelBody lore
removeKernelBodyWisdom :: forall lore.
CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom (KernelBody BodyDec (Wise lore)
dec Stms (Wise lore)
stms [KernelResult]
res) =
  let Body BodyDec lore
dec' Stms lore
stms' [SubExp]
_ = Body (Wise lore) -> BodyT lore
forall lore. CanBeWise (Op lore) => Body (Wise lore) -> Body lore
removeBodyWisdom (Body (Wise lore) -> BodyT lore) -> Body (Wise lore) -> BodyT lore
forall a b. (a -> b) -> a -> b
$ BodyDec (Wise lore)
-> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec (Wise lore)
dec Stms (Wise lore)
stms []
   in BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec' Stms lore
stms' [KernelResult]
res

-- | The variables consumed in the kernel body.
consumedInKernelBody ::
  Aliased lore =>
  KernelBody lore ->
  Names
consumedInKernelBody :: forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
  Body lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody (BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
dec Stms lore
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
  where
    consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns ShapeBase SubExp
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
    consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty

checkKernelBody ::
  TC.Checkable lore =>
  [Type] ->
  KernelBody (Aliases lore) ->
  TC.TypeM lore ()
checkKernelBody :: forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec lore
dec) Stms (Aliases lore)
stms [KernelResult]
kres) = do
  BodyDec lore -> TypeM lore ()
forall lore. Checkable lore => BodyDec lore -> TypeM lore ()
TC.checkBodyLore BodyDec lore
dec
  -- We consume the kernel results (when applicable) before
  -- type-checking the stms, so we will get an error if a statement
  -- uses an array that is written to in a result.
  (KernelResult -> TypeM lore ()) -> [KernelResult] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelResult -> TypeM lore ()
forall {lore}. Checkable lore => KernelResult -> TypeM lore ()
consumeKernelResult [KernelResult]
kres
  Stms (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Stms (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.checkStms Stms (Aliases lore)
stms (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
          String
"Kernel return type is " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but body returns "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" values."
    (KernelResult -> Type -> TypeM lore ())
-> [KernelResult] -> [Type] -> TypeM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM lore ()
forall {lore}.
Checkable lore =>
KernelResult -> Type -> TypeM lore ()
checkKernelResult [KernelResult]
kres [Type]
ts
  where
    consumeKernelResult :: KernelResult -> TypeM lore ()
consumeKernelResult (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
arr
    consumeKernelResult KernelResult
_ =
      () -> TypeM lore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    checkKernelResult :: KernelResult -> Type -> TypeM lore ()
checkKernelResult (Returns ResultManifest
_ SubExp
what) Type
t =
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
what
    checkKernelResult (WriteReturns ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
      Type
arr_t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
        (DimIndex SubExp -> TypeM lore (DimIndex ()))
-> Slice SubExp -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((SubExp -> TypeM lore ())
-> DimIndex SubExp -> TypeM lore (DimIndex ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> TypeM lore ())
 -> DimIndex SubExp -> TypeM lore (DimIndex ()))
-> (SubExp -> TypeM lore ())
-> DimIndex SubExp
-> TypeM lore (DimIndex ())
forall a b. (a -> b) -> a -> b
$ [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
        [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
e
        Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
          ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
            String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
              String
"WriteReturns returning "
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
e
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" of type "
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", shape="
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShapeBase SubExp -> String
forall a. Pretty a => a -> String
pretty ShapeBase SubExp
shape
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but destination array has type "
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
    checkKernelResult (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) Type
t = do
      case SplitOrdering
o of
        SplitOrdering
SplitContiguous -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
stride
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
per_thread_elems
      Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` Int -> Type -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 Type
vt) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for ConcatReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
    checkKernelResult (TileReturns [(SubExp, SubExp)]
dims VName
v) Type
t = do
      [(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
        [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
        [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
      Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for TileReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
    checkKernelResult (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles

      -- assert that arr is of element type t and shape (rev outer_tiles ++ reg_tiles)
      Type
arr_t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
expected) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> (String -> ErrorCase lore) -> String -> TypeM lore ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> TypeM lore ()) -> String -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
          String
"Invalid type for TileReturns. Expected:\n  "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
expected
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
",\ngot:\n  "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
      where
        ([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
        expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
reg_tiles)

kernelBodyMetrics :: OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics :: forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics = (Stm lore -> MetricsM ()) -> Seq (Stm lore) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Stm lore -> MetricsM ()
stmMetrics (Seq (Stm lore) -> MetricsM ())
-> (KernelBody lore -> Seq (Stm lore))
-> KernelBody lore
-> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms

instance PrettyLore lore => Pretty (KernelBody lore) where
  ppr :: KernelBody lore -> Doc
ppr (KernelBody BodyDec lore
_ Stms lore
stms [KernelResult]
res) =
    [Doc] -> Doc
PP.stack ((Stm lore -> Doc) -> [Stm lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> Doc
forall a. Pretty a => a -> Doc
ppr (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms))
      Doc -> Doc -> Doc
</> String -> Doc
text String
"return" Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc) -> [KernelResult] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc
forall a. Pretty a => a -> Doc
ppr [KernelResult]
res)

instance Pretty KernelResult where
  ppr :: KernelResult -> Doc
ppr (Returns ResultManifest
ResultNoSimplify SubExp
what) =
    String -> Doc
text String
"returns (manifest)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (Returns ResultManifest
ResultPrivate SubExp
what) =
    String -> Doc
text String
"returns (private)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (Returns ResultManifest
ResultMaySimplify SubExp
what) =
    String -> Doc
text String
"returns" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (WriteReturns ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) =
    VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> ShapeBase SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeBase SubExp
shape
      Doc -> Doc -> Doc
</> String -> Doc
text String
"with" Doc -> Doc -> Doc
<+> [Doc] -> Doc
PP.apply (((Slice SubExp, SubExp) -> Doc)
-> [(Slice SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc
forall {a} {a}. (Pretty a, Pretty a) => ([a], a) -> Doc
ppRes [(Slice SubExp, SubExp)]
res)
    where
      ppRes :: ([a], a) -> Doc
ppRes ([a]
slice, a
e) =
        Doc -> Doc
PP.brackets ([Doc] -> Doc
commasep ((a -> Doc) -> [a] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map a -> Doc
forall a. Pretty a => a -> Doc
ppr [a]
slice)) Doc -> Doc -> Doc
<+> String -> Doc
text String
"=" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
e
  ppr (ConcatReturns SplitOrdering
SplitContiguous SubExp
w SubExp
per_thread_elems VName
v) =
    String -> Doc
text String
"concat"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems]) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
  ppr (ConcatReturns (SplitStrided SubExp
stride) SubExp
w SubExp
per_thread_elems VName
v) =
    String -> Doc
text String
"concat_strided"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems]) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
  ppr (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    Doc
"tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Doc) -> [(SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc
forall {a} {a}. (Pretty a, Pretty a) => (a, a) -> Doc
onDim [(SubExp, SubExp)]
dims) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
    where
      onDim :: (a, a) -> Doc
onDim (a
dim, a
tile) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
tile
  ppr (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Doc
"blkreg_tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp, SubExp) -> Doc)
-> [(SubExp, SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp, SubExp) -> Doc
forall {a} {a} {a}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
    where
      onDim :: (a, a, a) -> Doc
onDim (a
dim, a
blk_tile, a
reg_tile) =
        a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> Doc -> Doc
parens (a -> Doc
forall a. Pretty a => a -> Doc
ppr a
blk_tile Doc -> Doc -> Doc
<+> Doc
"*" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
reg_tile)

-- | Do we need group-virtualisation when generating code for the
-- segmented operation?  In most cases, we do, but for some simple
-- kernels, we compute the full number of groups in advance, and then
-- virtualisation is an unnecessary (but generally very small)
-- overhead.  This only really matters for fairly trivial but very
-- wide @map@ kernels where each thread performs constant-time work on
-- scalars.
data SegVirt
  = SegVirt
  | SegNoVirt
  | -- | Not only do we not need virtualisation, but we _guarantee_
    -- that all physical threads participate in the work.  This can
    -- save some checks in code generation.
    SegNoVirtFull
  deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt
-> (SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)

-- | Index space of a 'SegOp'.
data SegSpace = SegSpace
  { -- | Flat physical index corresponding to the
    -- dimensions (at code generation used for a
    -- thread ID or similar).
    SegSpace -> VName
segFlat :: VName,
    SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
  }
  deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)

-- | The sizes spanned by the indexes of the 'SegSpace'.
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space

-- | A 'Scope' containing all the identifiers brought into scope by
-- this 'SegSpace'.
scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace :: forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
  [(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Map VName (NameInfo lore))
-> [(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) ([NameInfo lore] -> [(VName, NameInfo lore)])
-> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> [NameInfo lore]
forall a. a -> [a]
repeat (NameInfo lore -> [NameInfo lore])
-> NameInfo lore -> [NameInfo lore]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64

checkSegSpace :: TC.Checkable lore => SegSpace -> TC.TypeM lore ()
checkSegSpace :: forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
  ((VName, SubExp) -> TypeM lore ())
-> [(VName, SubExp)] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM lore ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM lore ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims

-- | A 'SegOp' is semantically a perfectly nested stack of maps, on
-- top of some bottommost computation (scalar computation, reduction,
-- scan, or histogram).  The 'SegSpace' encodes the original map
-- structure.
--
-- All 'SegOp's are parameterised by the representation of their body,
-- as well as a *level*.  The *level* is a representation-specific bit
-- of information.  For example, in GPU backends, it is used to
-- indicate whether the 'SegOp' is expected to run at the thread-level
-- or the group-level.
data SegOp lvl lore
  = SegMap lvl SegSpace [Type] (KernelBody lore)
  | -- | The KernelSpace must always have at least two dimensions,
    -- implying that the result of a SegRed is always an array.
    SegRed lvl SegSpace [SegBinOp lore] [Type] (KernelBody lore)
  | SegScan lvl SegSpace [SegBinOp lore] [Type] (KernelBody lore)
  | SegHist lvl SegSpace [HistOp lore] [Type] (KernelBody lore)
  deriving (SegOp lvl lore -> SegOp lvl lore -> Bool
(SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> Eq (SegOp lvl lore)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
/= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c/= :: forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
== :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c== :: forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
Eq, Eq (SegOp lvl lore)
Eq (SegOp lvl lore)
-> (SegOp lvl lore -> SegOp lvl lore -> Ordering)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore)
-> (SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore)
-> Ord (SegOp lvl lore)
SegOp lvl lore -> SegOp lvl lore -> Bool
SegOp lvl lore -> SegOp lvl lore -> Ordering
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {lvl} {lore}.
(Decorations lore, Ord lvl) =>
Eq (SegOp lvl lore)
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Ordering
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
min :: SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
$cmin :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
max :: SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
$cmax :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
>= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c>= :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
> :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c> :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
<= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c<= :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
< :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c< :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
compare :: SegOp lvl lore -> SegOp lvl lore -> Ordering
$ccompare :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Ordering
Ord, Int -> SegOp lvl lore -> ShowS
[SegOp lvl lore] -> ShowS
SegOp lvl lore -> String
(Int -> SegOp lvl lore -> ShowS)
-> (SegOp lvl lore -> String)
-> ([SegOp lvl lore] -> ShowS)
-> Show (SegOp lvl lore)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl lore.
(Decorations lore, Show lvl) =>
Int -> SegOp lvl lore -> ShowS
forall lvl lore.
(Decorations lore, Show lvl) =>
[SegOp lvl lore] -> ShowS
forall lvl lore.
(Decorations lore, Show lvl) =>
SegOp lvl lore -> String
showList :: [SegOp lvl lore] -> ShowS
$cshowList :: forall lvl lore.
(Decorations lore, Show lvl) =>
[SegOp lvl lore] -> ShowS
show :: SegOp lvl lore -> String
$cshow :: forall lvl lore.
(Decorations lore, Show lvl) =>
SegOp lvl lore -> String
showsPrec :: Int -> SegOp lvl lore -> ShowS
$cshowsPrec :: forall lvl lore.
(Decorations lore, Show lvl) =>
Int -> SegOp lvl lore -> ShowS
Show)

-- | The level of a 'SegOp'.
segLevel :: SegOp lvl lore -> lvl
segLevel :: forall lvl lore. SegOp lvl lore -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl

-- | The space of a 'SegOp'.
segSpace :: SegOp lvl lore -> SegSpace
segSpace :: forall lvl lore. SegOp lvl lore -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl

segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns ShapeBase SubExp
shape VName
_ [(Slice SubExp, SubExp)]
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape
segResultShape SegSpace
space Type
t (Returns ResultManifest
_ SubExp
_) =
  (SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (ConcatReturns SplitOrdering
_ SubExp
w SubExp
_ VName
_) =
  Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
segResultShape SegSpace
_ Type
t (TileReturns [(SubExp, SubExp)]
dims VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp, SubExp) -> SubExp)
-> [(SubExp, SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)

-- | The return type of a 'SegOp'.
segOpType :: SegOp lvl lore -> [Type]
segOpType :: forall lvl lore. SegOp lvl lore -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody lore
kbody) =
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
kbody) =
  [Type]
red_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp lore
op <- [SegBinOp lore]
reds
      let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
kbody) =
  [Type]
scan_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
    scan_ts :: [Type]
scan_ts = do
      SegBinOp lore
op <- [SegBinOp lore]
scans
      let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp lore]
ops [Type]
_ KernelBody lore
_) = do
  HistOp lore
op <- [HistOp lore]
ops
  let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op]) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp lore -> ShapeBase SubExp
forall lore. HistOp lore -> ShapeBase SubExp
histShape HistOp lore
op
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
  where
    dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims

instance TypedOp (SegOp lvl lore) where
  opType :: forall t (m :: * -> *).
HasScope t m =>
SegOp lvl lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl lore -> [ExtType]) -> SegOp lvl lore -> m [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl lore -> [Type]) -> SegOp lvl lore -> [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl lore -> [Type]
forall lvl lore. SegOp lvl lore -> [Type]
segOpType

instance
  (ASTLore lore, Aliased lore, ASTConstraints lvl) =>
  AliasedOp (SegOp lvl lore)
  where
  opAliases :: SegOp lvl lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl lore -> [Type]) -> SegOp lvl lore -> [Names]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl lore -> [Type]
forall lvl lore. SegOp lvl lore -> [Type]
segOpType

  consumedInOp :: SegOp lvl lore -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
kbody) =
    [VName] -> Names
namesFromList ((HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody

-- | Type check a 'SegOp', given a checker for its level.
typeCheckSegOp ::
  TC.Checkable lore =>
  (lvl -> TC.TypeM lore ()) ->
  SegOp lvl (Aliases lore) ->
  TC.TypeM lore ()
typeCheckSegOp :: forall lore lvl.
Checkable lore =>
(lvl -> TypeM lore ()) -> SegOp lvl (Aliases lore) -> TypeM lore ()
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases lore)
kbody) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases lore)
kbody
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases lore)]
reds [Type]
ts KernelBody (Aliases lore)
body) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases lore)
body
  where
    reds' :: [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
reds' =
      [Lambda (Aliases lore)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegBinOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp (Aliases lore)]
reds)
        ((SegBinOp (Aliases lore) -> [SubExp])
-> [SegBinOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases lore)]
reds)
        ((SegBinOp (Aliases lore) -> ShapeBase SubExp)
-> [SegBinOp (Aliases lore)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases lore)]
reds)
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases lore)]
scans [Type]
ts KernelBody (Aliases lore)
body) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases lore)
body
  where
    scans' :: [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
scans' =
      [Lambda (Aliases lore)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegBinOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp (Aliases lore)]
scans)
        ((SegBinOp (Aliases lore) -> [SubExp])
-> [SegBinOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases lore)]
scans)
        ((SegBinOp (Aliases lore) -> ShapeBase SubExp)
-> [SegBinOp (Aliases lore)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases lore)]
scans)
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases lore)]
ops [Type]
ts KernelBody (Aliases lore)
kbody) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
  (Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u.
Checkable lore =>
TypeBase (ShapeBase SubExp) u -> TypeM lore ()
TC.checkType [Type]
ts

  Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
nes_ts <- [HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore [Type])
 -> TypeM lore [[Type]])
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases lore)
op) -> do
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dest_w
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
      [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> Type) -> Arg -> Arg
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
            String
"SegHist operator has return type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t

      -- Arrays must have proper type.
      let dest_shape :: ShapeBase SubExp
dest_shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp
dest_w]) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
      [(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
        [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape] VName
dest
        Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest

      [Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody

    -- Return type of bucket function must be an index for each
    -- operation followed by the values to write.
    let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
          String
"SegHist body has return type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
  where
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

checkScanRed ::
  TC.Checkable lore =>
  SegSpace ->
  [(Lambda (Aliases lore), [SubExp], Shape)] ->
  [Type] ->
  KernelBody (Aliases lore) ->
  TC.TypeM lore ()
checkScanRed :: forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases lore)
kbody = do
  SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
  (Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u.
Checkable lore =>
TypeBase (ShapeBase SubExp) u -> TypeM lore ()
TC.checkType [Type]
ts

  Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
ne_ts <- [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> ((Lambda (Aliases lore), [SubExp], ShapeBase SubExp)
    -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
ops (((Lambda (Aliases lore), [SubExp], ShapeBase SubExp)
  -> TypeM lore [Type])
 -> TypeM lore [[Type]])
-> ((Lambda (Aliases lore), [SubExp], ShapeBase SubExp)
    -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases lore)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
      [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes

      -- Operator type must match the type of neutral elements.
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'

      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"wrong type for operator or neutral elements."

      [Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
        got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
          String
"Wrong return for body (does not match neutral elements; expected "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
expecting
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; found "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
got
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

    [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody

-- | Like 'Mapper', but just for 'SegOp's.
data SegOpMapper lvl flore tlore m = SegOpMapper
  { forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
    forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda :: Lambda flore -> m (Lambda tlore),
    forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody :: KernelBody flore -> m (KernelBody tlore),
    forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
    forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
  }

-- | A mapper that simply returns the 'SegOp' verbatim.
identitySegOpMapper :: Monad m => SegOpMapper lvl lore lore m
identitySegOpMapper :: forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper =
  SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
    { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSegOpLambda :: Lambda lore -> m (Lambda lore)
mapOnSegOpLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSegOpBody :: KernelBody lore -> m (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> m (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall (m :: * -> *) a. Monad m => a -> m a
return
    }

mapOnSegSpace ::
  Monad f =>
  SegOpMapper lvl flore tlore f ->
  SegSpace ->
  f SegSpace
mapOnSegSpace :: forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
  VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp))
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore f -> SubExp -> f SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore f
tv) [(VName, SubExp)]
dims

mapSegBinOp ::
  Monad m =>
  SegOpMapper lvl flore tlore m ->
  SegBinOp flore ->
  m (SegBinOp tlore)
mapSegBinOp :: forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv (SegBinOp Commutativity
comm Lambda flore
red_op [SubExp]
nes ShapeBase SubExp
shape) =
  Commutativity
-> Lambda tlore -> [SubExp] -> ShapeBase SubExp -> SegBinOp tlore
forall lore.
Commutativity
-> Lambda lore -> [SubExp] -> ShapeBase SubExp -> SegBinOp lore
SegBinOp Commutativity
comm
    (Lambda tlore -> [SubExp] -> ShapeBase SubExp -> SegBinOp tlore)
-> m (Lambda tlore)
-> m ([SubExp] -> ShapeBase SubExp -> SegBinOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper lvl flore tlore m
tv Lambda flore
red_op
    m ([SubExp] -> ShapeBase SubExp -> SegBinOp tlore)
-> m [SubExp] -> m (ShapeBase SubExp -> SegBinOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
nes
    m (ShapeBase SubExp -> SegBinOp tlore)
-> m (ShapeBase SubExp) -> m (SegBinOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))

-- | Apply a 'SegOpMapper' to the given 'SegOp'.
mapSegOpM ::
  (Applicative m, Monad m) =>
  SegOpMapper lvl flore tlore m ->
  SegOp lvl flore ->
  m (SegOp lvl tlore)
mapSegOpM :: forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody flore
body) =
  lvl -> SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap
    (lvl -> SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
    m (SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m SegSpace -> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
    m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> Type -> m Type
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl flore tlore m
tv) [Type]
ts
    m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp flore]
reds [Type]
ts KernelBody flore
lam) =
  lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed
    (lvl
 -> SegSpace
 -> [SegBinOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
      -> [SegBinOp tlore]
      -> [Type]
      -> KernelBody tlore
      -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp tlore]
   -> [Type]
   -> KernelBody tlore
   -> SegOp lvl tlore)
-> m SegSpace
-> m ([SegBinOp tlore]
      -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
    m ([SegBinOp tlore]
   -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [SegBinOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp flore -> m (SegBinOp tlore))
-> [SegBinOp flore] -> m [SegBinOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv) [SegBinOp flore]
reds
    m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
    m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
lam
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp flore]
scans [Type]
ts KernelBody flore
body) =
  lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan
    (lvl
 -> SegSpace
 -> [SegBinOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
      -> [SegBinOp tlore]
      -> [Type]
      -> KernelBody tlore
      -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp tlore]
   -> [Type]
   -> KernelBody tlore
   -> SegOp lvl tlore)
-> m SegSpace
-> m ([SegBinOp tlore]
      -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
    m ([SegBinOp tlore]
   -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [SegBinOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp flore -> m (SegBinOp tlore))
-> [SegBinOp flore] -> m [SegBinOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv) [SegBinOp flore]
scans
    m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
    m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegHist lvl
lvl SegSpace
space [HistOp flore]
ops [Type]
ts KernelBody flore
body) =
  lvl
-> SegSpace
-> [HistOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist
    (lvl
 -> SegSpace
 -> [HistOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
      -> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
    m (SegSpace
   -> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m SegSpace
-> m ([HistOp tlore]
      -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
    m ([HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [HistOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp flore -> m (HistOp tlore)
onHistOp [HistOp flore]
ops
    m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
    m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
  where
    onHistOp :: HistOp flore -> m (HistOp tlore)
onHistOp (HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda flore
op) =
      SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda tlore
-> HistOp tlore
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda lore
-> HistOp lore
HistOp (SubExp
 -> SubExp
 -> [VName]
 -> [SubExp]
 -> ShapeBase SubExp
 -> Lambda tlore
 -> HistOp tlore)
-> m SubExp
-> m (SubExp
      -> [VName]
      -> [SubExp]
      -> ShapeBase SubExp
      -> Lambda tlore
      -> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv SubExp
w
        m (SubExp
   -> [VName]
   -> [SubExp]
   -> ShapeBase SubExp
   -> Lambda tlore
   -> HistOp tlore)
-> m SubExp
-> m ([VName]
      -> [SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv SubExp
rf
        m ([VName]
   -> [SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
-> m [VName]
-> m ([SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> VName -> m VName
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl flore tlore m
tv) [VName]
arrs
        m ([SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
-> m [SubExp]
-> m (ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
nes
        m (ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
-> m (ShapeBase SubExp) -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
        m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper lvl flore tlore m
tv Lambda flore
op

mapOnSegOpType ::
  Monad m =>
  SegOpMapper lvl flore tlore m ->
  Type ->
  m Type
mapOnSegOpType :: forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl flore tlore m
_tv t :: Type
t@Prim {} = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl flore tlore m
tv (Array PrimType
pt ShapeBase SubExp
shape NoUniqueness
u) = PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt (ShapeBase SubExp -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase SubExp -> m (ShapeBase SubExp)
f ShapeBase SubExp
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
  where
    f :: ShapeBase SubExp -> m (ShapeBase SubExp)
f (Shape [SubExp]
dims) = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
dims
mapOnSegOpType SegOpMapper lvl flore tlore m
_tv (Mem Space
s) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s

instance
  (ASTLore lore, Substitute lvl) =>
  Substitute (SegOp lvl lore)
  where
  substituteNames :: Map VName VName -> SegOp lvl lore -> SegOp lvl lore
substituteNames Map VName VName
subst = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl lore -> Identity (SegOp lvl lore))
-> SegOp lvl lore
-> SegOp lvl lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl lore lore Identity
-> SegOp lvl lore -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore Identity
substitute
    where
      substitute :: SegOpMapper lvl lore lore Identity
substitute =
        SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSegOpLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpBody :: KernelBody lore -> Identity (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody lore -> KernelBody lore)
-> KernelBody lore
-> Identity (KernelBody lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> KernelBody lore -> KernelBody lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance
  (ASTLore lore, ASTConstraints lvl) =>
  Rename (SegOp lvl lore)
  where
  rename :: SegOp lvl lore -> RenameM (SegOp lvl lore)
rename = SegOpMapper lvl lore lore RenameM
-> SegOp lvl lore -> RenameM (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore RenameM
renamer
    where
      renamer :: SegOpMapper lvl lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (KernelBody lore -> RenameM (KernelBody lore))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl lore lore RenameM
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename KernelBody lore -> RenameM (KernelBody lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename

instance
  (ASTLore lore, FreeIn (LParamInfo lore), FreeIn lvl) =>
  FreeIn (SegOp lvl lore)
  where
  freeIn' :: SegOp lvl lore -> FV
freeIn' SegOp lvl lore
e = (State FV (SegOp lvl lore) -> FV -> FV)
-> FV -> State FV (SegOp lvl lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl lore) -> FV)
-> State FV (SegOp lvl lore) -> FV
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl lore lore (StateT FV Identity)
-> SegOp lvl lore -> State FV (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore (StateT FV Identity)
free SegOp lvl lore
e
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
      free :: SegOpMapper lvl lore lore (StateT FV Identity)
free =
        SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSegOpLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpBody :: KernelBody lore -> StateT FV Identity (KernelBody lore)
mapOnSegOpBody = (KernelBody lore -> FV)
-> KernelBody lore -> StateT FV Identity (KernelBody lore)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody lore -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
          }

instance OpMetrics (Op lore) => OpMetrics (SegOp lvl lore) where
  opMetrics :: SegOp lvl lore -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp lore]
reds [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp lore -> MetricsM ()) -> [SegBinOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegBinOp lore -> Lambda lore) -> SegBinOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp lore]
reds
      KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp lore]
scans [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp lore -> MetricsM ()) -> [SegBinOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegBinOp lore -> Lambda lore) -> SegBinOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp lore]
scans
      KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops
      KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body

instance Pretty SegSpace where
  ppr :: SegSpace -> Doc
ppr (SegSpace VName
phys [(VName, SubExp)]
dims) =
    Doc -> Doc
parens
      ( [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ do
          (VName
i, SubExp
d) <- [(VName, SubExp)]
dims
          Doc -> [Doc]
forall (m :: * -> *) a. Monad m => a -> m a
return (Doc -> [Doc]) -> Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
i Doc -> Doc -> Doc
<+> Doc
"<" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
d
      )
      Doc -> Doc -> Doc
<+> Doc -> Doc
parens (String -> Doc
text String
"~" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
phys)

instance PrettyLore lore => Pretty (SegBinOp lore) where
  ppr :: SegBinOp lore -> Doc
ppr (SegBinOp Commutativity
comm Lambda lore
lam [SubExp]
nes ShapeBase SubExp
shape) =
    Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
      Doc -> Doc -> Doc
</> ShapeBase SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeBase SubExp
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
      Doc -> Doc -> Doc
</> Doc
comm' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
    where
      comm' :: Doc
comm' = case Commutativity
comm of
        Commutativity
Commutative -> String -> Doc
text String
"commutative "
        Commutativity
Noncommutative -> Doc
forall a. Monoid a => a
mempty

instance (PrettyLore lore, PP.Pretty lvl) => PP.Pretty (SegOp lvl lore) where
  ppr :: SegOp lvl lore -> Doc
ppr (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segmap" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
<+> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
  ppr (SegRed lvl
lvl SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segred" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp lore -> Doc) -> [SegBinOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp lore]
reds)
      Doc -> Doc -> Doc
</> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
  ppr (SegScan lvl
lvl SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segscan" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp lore -> Doc) -> [SegBinOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp lore]
scans)
      Doc -> Doc -> Doc
</> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
  ppr (SegHist lvl
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"seghist" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall {lore}. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops)
      Doc -> Doc -> Doc
</> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
    where
      ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda lore
op) =
        SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> ShapeBase SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeBase SubExp
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op

instance
  ( ASTLore lore,
    ASTLore (Aliases lore),
    CanBeAliased (Op lore),
    ASTConstraints lvl
  ) =>
  CanBeAliased (SegOp lvl lore)
  where
  type OpWithAliases (SegOp lvl lore) = SegOp lvl (Aliases lore)

  addOpAliases :: AliasTable -> SegOp lvl lore -> OpWithAliases (SegOp lvl lore)
addOpAliases AliasTable
aliases = Identity (SegOp lvl (Aliases lore)) -> SegOp lvl (Aliases lore)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases lore)) -> SegOp lvl (Aliases lore))
-> (SegOp lvl lore -> Identity (SegOp lvl (Aliases lore)))
-> SegOp lvl lore
-> SegOp lvl (Aliases lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl lore (Aliases lore) Identity
-> SegOp lvl lore -> Identity (SegOp lvl (Aliases lore))
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore (Aliases lore) Identity
alias
    where
      alias :: SegOpMapper lvl lore (Aliases lore) Identity
alias =
        (SubExp -> Identity SubExp)
-> (Lambda lore -> Identity (Lambda (Aliases lore)))
-> (KernelBody lore -> Identity (KernelBody (Aliases lore)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl lore (Aliases lore) Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
          SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
          (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore)))
-> (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore
-> Identity (Lambda (Aliases lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases)
          (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore)))
-> (KernelBody lore -> KernelBody (Aliases lore))
-> KernelBody lore
-> Identity (KernelBody (Aliases lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> KernelBody lore -> KernelBody (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody AliasTable
aliases)
          VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
          lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return

  removeOpAliases :: OpWithAliases (SegOp lvl lore) -> SegOp lvl lore
removeOpAliases = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl (Aliases lore) -> Identity (SegOp lvl lore))
-> SegOp lvl (Aliases lore)
-> SegOp lvl lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Aliases lore) lore Identity
-> SegOp lvl (Aliases lore) -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl (Aliases lore) lore Identity
forall {lvl}. SegOpMapper lvl (Aliases lore) lore Identity
remove
    where
      remove :: SegOpMapper lvl (Aliases lore) lore Identity
remove =
        (SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (KernelBody (Aliases lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Aliases lore) lore Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
          SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
          (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases)
          (KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Aliases lore) -> KernelBody lore)
-> KernelBody (Aliases lore)
-> Identity (KernelBody lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Aliases lore) -> KernelBody lore
forall lore.
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases)
          VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
          lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return

instance
  (CanBeWise (Op lore), ASTLore lore, ASTConstraints lvl) =>
  CanBeWise (SegOp lvl lore)
  where
  type OpWithWisdom (SegOp lvl lore) = SegOp lvl (Wise lore)

  removeOpWisdom :: OpWithWisdom (SegOp lvl lore) -> SegOp lvl lore
removeOpWisdom = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl (Wise lore) -> Identity (SegOp lvl lore))
-> SegOp lvl (Wise lore)
-> SegOp lvl lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Wise lore) lore Identity
-> SegOp lvl (Wise lore) -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl (Wise lore) lore Identity
forall {lvl}. SegOpMapper lvl (Wise lore) lore Identity
remove
    where
      remove :: SegOpMapper lvl (Wise lore) lore Identity
remove =
        (SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (KernelBody (Wise lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Wise lore) lore Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
          SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
          (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom)
          (KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Wise lore) -> KernelBody lore)
-> KernelBody (Wise lore)
-> Identity (KernelBody lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Wise lore) -> KernelBody lore
forall lore.
CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom)
          VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
          lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return

instance ASTLore lore => ST.IndexOp (SegOp lvl lore) where
  indexOp :: forall lore.
(ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> SegOp lvl lore -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody lore
kbody) [TPrimExp Int64 VName]
is = do
    Returns ResultManifest
ResultMaySimplify SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
    let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> Indexed)
-> [TPrimExp Int64 VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> PrimExp VName -> Indexed
ST.Indexed Certificates
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Indexed
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
        idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm lore -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm lore) -> Map VName Indexed
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm lore) -> Map VName Indexed)
-> Seq (Stm lore) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody lore
kbody
    case SubExp
se of
      Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
    where
      ([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      -- Indexes in excess of what is used to index through the
      -- segment dimensions.
      excess_is :: [TPrimExp Int64 VName]
excess_is = Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is

      expandIndexedTable :: Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm lore
stm
        | [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
          Just (PrimExp VName
pe, Certificates
cs) <-
            WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
          VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certificates -> PrimExp VName -> Indexed
ST.Indexed (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) PrimExp VName
pe) Map VName Indexed
table
        | [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
          BasicOp (Index VName
arr Slice SubExp
slice) <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
          [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
          VName
arr VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable,
          Just ([DimIndex (PrimExp VName)]
slice', Certificates
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
          let idx :: Indexed
idx =
                Certificates -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
                  (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs)
                  VName
arr
                  (Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex (PrimExp VName) -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex (PrimExp VName)] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((PrimExp VName -> TPrimExp Int64 VName)
-> DimIndex (PrimExp VName) -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64) [DimIndex (PrimExp VName)]
slice') [TPrimExp Int64 VName]
excess_is)
           in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
        | Bool
otherwise =
          Map VName Indexed
table

      asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table =
        WriterT Certificates Maybe [DimIndex (PrimExp VName)]
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe [DimIndex (PrimExp VName)]
 -> Maybe ([DimIndex (PrimExp VName)], Certificates))
-> (Slice SubExp
    -> WriterT Certificates Maybe [DimIndex (PrimExp VName)])
-> Slice SubExp
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (DimIndex SubExp
 -> WriterT Certificates Maybe (DimIndex (PrimExp VName)))
-> Slice SubExp
-> WriterT Certificates Maybe [DimIndex (PrimExp VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> WriterT Certificates Maybe (PrimExp VName))
-> DimIndex SubExp
-> WriterT Certificates Maybe (DimIndex (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> WriterT Certificates Maybe (PrimExp VName))
-> SubExp -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table)))

      asPrimExp :: Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
        | Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
        | Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
          PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable lore
_ Int
_ SegOp lvl lore
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance
  (ASTLore lore, ASTConstraints lvl) =>
  IsOp (SegOp lvl lore)
  where
  cheapOp :: SegOp lvl lore -> Bool
cheapOp SegOp lvl lore
_ = Bool
False
  safeOp :: SegOp lvl lore -> Bool
safeOp SegOp lvl lore
_ = Bool
True

--- Simplification

instance Engine.Simplifiable SplitOrdering where
  simplify :: forall lore.
SimplifiableLore lore =>
SplitOrdering -> SimpleM lore SplitOrdering
simplify SplitOrdering
SplitContiguous =
    SplitOrdering -> SimpleM lore SplitOrdering
forall (m :: * -> *) a. Monad m => a -> m a
return SplitOrdering
SplitContiguous
  simplify (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> SimpleM lore SubExp -> SimpleM lore SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
stride

instance Engine.Simplifiable SegSpace where
  simplify :: forall lore.
SimplifiableLore lore =>
SegSpace -> SimpleM lore SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
    VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM lore [(VName, SubExp)] -> SimpleM lore SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM lore (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM lore [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> SimpleM lore SubExp)
-> (VName, SubExp) -> SimpleM lore (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify) [(VName, SubExp)]
dims

instance Engine.Simplifiable KernelResult where
  simplify :: forall lore.
SimplifiableLore lore =>
KernelResult -> SimpleM lore KernelResult
simplify (Returns ResultManifest
manifest SubExp
what) =
    ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (SubExp -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
what
  simplify (WriteReturns ShapeBase SubExp
ws VName
a [(Slice SubExp, SubExp)]
res) =
    ShapeBase SubExp
-> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns (ShapeBase SubExp
 -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore (ShapeBase SubExp)
-> SimpleM lore (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase SubExp -> SimpleM lore (ShapeBase SubExp)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ShapeBase SubExp
ws SimpleM lore (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore VName
-> SimpleM lore ([(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
a SimpleM lore ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore [(Slice SubExp, SubExp)]
-> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM lore [(Slice SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(Slice SubExp, SubExp)]
res
  simplify (ConcatReturns SplitOrdering
o SubExp
w SubExp
pte VName
what) =
    SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
      (SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SplitOrdering
-> SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> SimpleM lore SplitOrdering
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SplitOrdering
o
      SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp
-> SimpleM lore (SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
      SimpleM lore (SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
pte
      SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what
  simplify (TileReturns [(SubExp, SubExp)]
dims VName
what) =
    [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM lore [(SubExp, SubExp)]
-> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(SubExp, SubExp)] -> SimpleM lore [(SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what
  simplify (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
    [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM lore [(SubExp, SubExp, SubExp)]
-> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(SubExp, SubExp, SubExp)]
-> SimpleM lore [(SubExp, SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
      SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what

mkWiseKernelBody ::
  (ASTLore lore, CanBeWise (Op lore)) =>
  BodyDec lore ->
  Stms (Wise lore) ->
  [KernelResult] ->
  KernelBody (Wise lore)
mkWiseKernelBody :: forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody BodyDec lore
dec Stms (Wise lore)
bnds [KernelResult]
res =
  let Body BodyDec (Wise lore)
dec' Stms (Wise lore)
_ [SubExp]
_ = BodyDec lore -> Stms (Wise lore) -> [SubExp] -> BodyT (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
mkWiseBody BodyDec lore
dec Stms (Wise lore)
bnds [SubExp]
res_vs
   in BodyDec (Wise lore)
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Wise lore)
dec' Stms (Wise lore)
bnds [KernelResult]
res
  where
    res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res

mkKernelBodyM ::
  MonadBinder m =>
  Stms (Lore m) ->
  [KernelResult] ->
  m (KernelBody (Lore m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms (Lore m)
stms [KernelResult]
kres = do
  Body BodyDec (Lore m)
dec' Stms (Lore m)
_ [SubExp]
_ <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
stms [SubExp]
res_ses
  KernelBody (Lore m) -> m (KernelBody (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Lore m) -> m (KernelBody (Lore m)))
-> KernelBody (Lore m) -> m (KernelBody (Lore m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Lore m)
-> Stms (Lore m) -> [KernelResult] -> KernelBody (Lore m)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Lore m)
dec' Stms (Lore m)
stms [KernelResult]
kres
  where
    res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres

simplifyKernelBody ::
  (Engine.SimplifiableLore lore, BodyDec lore ~ ()) =>
  SegSpace ->
  KernelBody lore ->
  Engine.SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody :: forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space (KernelBody BodyDec lore
_ Stms lore
stms [KernelResult]
res) = do
  BlockPred (Wise lore)
par_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
Engine.asksEngineEnv ((Env lore -> BlockPred (Wise lore))
 -> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
Engine.blockHoistPar (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
Engine.envHoistBlockers

  -- Ensure we do not try to use anything that is consumed in the result.
  ((Stms (Wise lore)
body_stms, [KernelResult]
body_res), Stms (Wise lore)
hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable ((SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. VName -> SymbolTable lore -> SymbolTable lore
ST.consume)) ((KernelResult -> [VName]) -> [KernelResult] -> [VName]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
      (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> (SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
    -> SimpleM
         lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable)
      (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> (SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
    -> SimpleM
         lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
      (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> (SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
    -> SimpleM
         lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SimpleM lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
Engine.enterLoop
      (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
Engine.blockIf
        ( Names -> BlockPred (Wise lore)
forall lore. ASTLore lore => Names -> BlockPred lore
Engine.hasFree Names
bound_here
            BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isOp
            BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
par_blocker
            BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isConsumed
        )
        (SimpleM lore (SimplifiedBody lore [KernelResult])
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ Stms lore
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
Engine.simplifyStms Stms lore
stms (SimpleM lore (SimplifiedBody lore [KernelResult])
 -> SimpleM lore (SimplifiedBody lore [KernelResult]))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall a b. (a -> b) -> a -> b
$ do
          [KernelResult]
res' <-
            (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. Names -> SymbolTable lore -> SymbolTable lore
ST.hideCertified (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo lore) -> [VName])
-> Map VName (NameInfo lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms) (SimpleM lore [KernelResult] -> SimpleM lore [KernelResult])
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall a b. (a -> b) -> a -> b
$
              (KernelResult -> SimpleM lore KernelResult)
-> [KernelResult] -> SimpleM lore [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> SimpleM lore KernelResult
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [KernelResult]
res
          SimplifiedBody lore [KernelResult]
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall (m :: * -> *) a. Monad m => a -> m a
return (([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res'), Stms (Wise lore)
forall a. Monoid a => a
mempty)

  (KernelBody (Wise lore), Stms (Wise lore))
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody () Stms (Wise lore)
body_stms [KernelResult]
body_res, Stms (Wise lore)
hoisted)
  where
    scope_vtable :: SymbolTable (Wise lore)
scope_vtable = SegSpace -> SymbolTable (Wise lore)
forall lore. ASTLore lore => SegSpace -> SymbolTable lore
segSpaceSymbolTable SegSpace
space
    bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space

    consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      [VName
arr]
    consumedInResult KernelResult
_ =
      []

segSpaceSymbolTable :: ASTLore lore => SegSpace -> ST.SymbolTable lore
segSpaceSymbolTable :: forall lore. ASTLore lore => SegSpace -> SymbolTable lore
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
  (SymbolTable lore -> (VName, SubExp) -> SymbolTable lore)
-> SymbolTable lore -> [(VName, SubExp)] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
forall {lore}.
ASTLore lore =>
SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
f (Scope lore -> SymbolTable lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope lore -> SymbolTable lore) -> Scope lore -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo lore -> Scope lore
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo lore -> Scope lore) -> NameInfo lore -> Scope lore
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
  where
    f :: SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
f SymbolTable lore
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable lore
vtable

simplifySegBinOp ::
  Engine.SimplifiableLore lore =>
  SegBinOp lore ->
  Engine.SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp :: forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp (SegBinOp Commutativity
comm Lambda lore
lam [SubExp]
nes ShapeBase SubExp
shape) = do
  (Lambda (Wise lore)
lam', Stms (Wise lore)
hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
      Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
  ShapeBase SubExp
shape' <- ShapeBase SubExp -> SimpleM lore (ShapeBase SubExp)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ShapeBase SubExp
shape
  [SubExp]
nes' <- (SubExp -> SimpleM lore SubExp)
-> [SubExp] -> SimpleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
  (SegBinOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Commutativity
-> Lambda (Wise lore)
-> [SubExp]
-> ShapeBase SubExp
-> SegBinOp (Wise lore)
forall lore.
Commutativity
-> Lambda lore -> [SubExp] -> ShapeBase SubExp -> SegBinOp lore
SegBinOp Commutativity
comm Lambda (Wise lore)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise lore)
hoisted)

-- | Simplify the given 'SegOp'.
simplifySegOp ::
  ( Engine.SimplifiableLore lore,
    BodyDec lore ~ (),
    Engine.Simplifiable lvl
  ) =>
  SegOp lvl lore ->
  Engine.SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
simplifySegOp :: forall lore lvl.
(SimplifiableLore lore, BodyDec lore ~ (), Simplifiable lvl) =>
SegOp lvl lore
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody
  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise lore)
kbody',
      Stms (Wise lore)
body_hoisted
    )
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise lore)]
reds', [Stms (Wise lore)]
reds_hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable) (SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
 -> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
      [(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise lore), Stms (Wise lore))]
 -> ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp lore
 -> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore)))
-> [SegBinOp lore]
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp [SegBinOp lore]
reds
  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( lvl
-> SegSpace
-> [SegBinOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise lore)]
reds' [Type]
ts' KernelBody (Wise lore)
kbody',
      [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
reds_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted
    )
  where
    scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise lore)]
scans', [Stms (Wise lore)]
scans_hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable) (SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
 -> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
      [(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise lore), Stms (Wise lore))]
 -> ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp lore
 -> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore)))
-> [SegBinOp lore]
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp [SegBinOp lore]
scans
  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( lvl
-> SegSpace
-> [SegBinOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise lore)]
scans' [Type]
ts' KernelBody (Wise lore)
kbody',
      [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
scans_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted
    )
  where
    scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)

  ([HistOp (Wise lore)]
ops', [Stms (Wise lore)]
ops_hoisted) <- ([(HistOp (Wise lore), Stms (Wise lore))]
 -> ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise lore), Stms (Wise lore))]
-> ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
 -> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
    [HistOp lore]
-> (HistOp lore
    -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp lore]
ops ((HistOp lore
  -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
 -> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))])
-> (HistOp lore
    -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$
      \(HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda lore
lam) -> do
        SubExp
w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
        SubExp
rf' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
rf
        [VName]
arrs' <- [VName] -> SimpleM lore [VName]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
arrs
        [SubExp]
nes' <- [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
        ShapeBase SubExp
dims' <- ShapeBase SubExp -> SimpleM lore (ShapeBase SubExp)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ShapeBase SubExp
dims
        (Lambda (Wise lore)
lam', Stms (Wise lore)
op_hoisted) <-
          (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
            (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
              Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
        (HistOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda (Wise lore)
-> HistOp (Wise lore)
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda lore
-> HistOp lore
HistOp SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise lore)
lam',
            Stms (Wise lore)
op_hoisted
          )

  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( lvl
-> SegSpace
-> [HistOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise lore)]
ops' [Type]
ts' KernelBody (Wise lore)
kbody',
      [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
ops_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted
    )
  where
    scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope

-- | Does this lore contain 'SegOp's in its t'Op's?  A lore must be an
-- instance of this class for the simplification rules to work.
class HasSegOp lore where
  type SegOpLevel lore
  asSegOp :: Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
  segOp :: SegOp (SegOpLevel lore) lore -> Op lore

-- | Simplification rules for simplifying 'SegOp's.
segOpRules ::
  (HasSegOp lore, BinderOps lore, Bindable lore) =>
  RuleBook lore
segOpRules :: forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
RuleBook lore
segOpRules =
  [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp lore (TopDown lore)
forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
TopDownRuleOp lore
segOpRuleTopDown] [RuleOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp lore (BottomUp lore)
forall lore. (HasSegOp lore, BinderOps lore) => BottomUpRuleOp lore
segOpRuleBottomUp]

segOpRuleTopDown ::
  (HasSegOp lore, BinderOps lore, Bindable lore) =>
  TopDownRuleOp lore
segOpRuleTopDown :: forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
TopDownRuleOp lore
segOpRuleTopDown TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec Op lore
op
  | Just SegOp (SegOpLevel lore) lore
op' <- Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
forall lore.
HasSegOp lore =>
Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
asSegOp Op lore
op =
    TopDown lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
SymbolTable lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
topDownSegOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
op'
  | Bool
otherwise =
    Rule lore
forall lore. Rule lore
Skip

segOpRuleBottomUp ::
  (HasSegOp lore, BinderOps lore) =>
  BottomUpRuleOp lore
segOpRuleBottomUp :: forall lore. (HasSegOp lore, BinderOps lore) => BottomUpRuleOp lore
segOpRuleBottomUp BottomUp lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec Op lore
op
  | Just SegOp (SegOpLevel lore) lore
op' <- Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
forall lore.
HasSegOp lore =>
Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
asSegOp Op lore
op =
    BottomUp lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
forall lore.
(HasSegOp lore, BinderOps lore) =>
(SymbolTable lore, UsageTable)
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
bottomUpSegOp BottomUp lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
op'
  | Bool
otherwise =
    Rule lore
forall lore. Rule lore
Skip

topDownSegOp ::
  (HasSegOp lore, BinderOps lore, Bindable lore) =>
  ST.SymbolTable lore ->
  Pattern lore ->
  StmAux (ExpDec lore) ->
  SegOp (SegOpLevel lore) lore ->
  Rule lore
-- If a SegOp produces something invariant to the SegOp, turn it
-- into a replicate.
topDownSegOp :: forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
SymbolTable lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
topDownSegOp SymbolTable lore
vtable (Pattern [] [PatElemT (LetDec lore)]
kpes) StmAux (ExpDec lore)
dec (SegMap SegOpLevel lore
lvl SegSpace
space [Type]
ts (KernelBody BodyDec lore
_ Stms lore
kstms [KernelResult]
kres)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  ([Type]
ts', [PatElemT (LetDec lore)]
kpes', [KernelResult]
kres') <-
    [(Type, PatElemT (LetDec lore), KernelResult)]
-> ([Type], [PatElemT (LetDec lore)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElemT (LetDec lore), KernelResult)]
 -> ([Type], [PatElemT (LetDec lore)], [KernelResult]))
-> RuleM lore [(Type, PatElemT (LetDec lore), KernelResult)]
-> RuleM lore ([Type], [PatElemT (LetDec lore)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool)
-> [(Type, PatElemT (LetDec lore), KernelResult)]
-> RuleM lore [(Type, PatElemT (LetDec lore), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool
checkForInvarianceResult ([Type]
-> [PatElemT (LetDec lore)]
-> [KernelResult]
-> [(Type, PatElemT (LetDec lore), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElemT (LetDec lore)]
kpes [KernelResult]
kres)

  -- Check if we did anything at all.
  Bool -> RuleM lore () -> RuleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres')
    RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify

  KernelBody lore
kbody <- Stms (Lore (RuleM lore))
-> [KernelResult] -> RuleM lore (KernelBody (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms lore
Stms (Lore (RuleM lore))
kstms [KernelResult]
kres'
  Stm (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM lore)) -> RuleM lore ())
-> Stm (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
    PatternT (LetDec lore)
-> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> PatternT (LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
kpes') StmAux (ExpDec lore)
dec (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
      Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$
        SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$
          SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
space [Type]
ts' KernelBody lore
kbody
  where
    isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
    isInvariant (Var VName
v) = Maybe (Entry lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry lore) -> Bool) -> Maybe (Entry lore) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v SymbolTable lore
vtable

    checkForInvarianceResult :: (Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool
checkForInvarianceResult (Type
_, PatElemT (LetDec lore)
pe, Returns ResultManifest
rm SubExp
se)
      | ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
        SubExp -> Bool
isInvariant SubExp
se = do
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
        Bool -> RuleM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    checkForInvarianceResult (Type, PatElemT (LetDec lore), KernelResult)
_ =
      Bool -> RuleM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

-- If a SegRed contains two reduction operations that have the same
-- vector shape, merge them together.  This saves on communication
-- overhead, but can in principle lead to more local memory usage.
topDownSegOp SymbolTable lore
_ (Pattern [] [PatElemT (LetDec lore)]
pes) StmAux (ExpDec lore)
_ (SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops [Type]
ts KernelBody lore
kbody)
  | [SegBinOp lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp lore]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
    [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings <-
      ((SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
 -> (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
 -> Bool)
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> Bool
forall {lore} {b} {lore} {b}.
(SegBinOp lore, b) -> (SegBinOp lore, b) -> Bool
sameShape ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> [[(SegBinOp lore,
       [(PatElemT (LetDec lore), Type, KernelResult)])]])
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$
        [SegBinOp lore]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp lore]
ops ([[(PatElemT (LetDec lore), Type, KernelResult)]]
 -> [(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])])
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp lore]
ops) ([(PatElemT (LetDec lore), Type, KernelResult)]
 -> [[(PatElemT (LetDec lore), Type, KernelResult)]])
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
            [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (LetDec lore)]
red_pes [Type]
red_ts [KernelResult]
red_res,
    ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> Bool)
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ([(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]
    -> Int)
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let ([SegBinOp lore]
ops', [[(PatElemT (LetDec lore), Type, KernelResult)]]
aux) = [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> ([SegBinOp lore],
    [[(PatElemT (LetDec lore), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> ([SegBinOp lore],
     [[(PatElemT (LetDec lore), Type, KernelResult)]]))
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> ([SegBinOp lore],
    [[(PatElemT (LetDec lore), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> Maybe
      (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)]))
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Maybe
     (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
forall {lore} {a}.
Bindable lore =>
[(SegBinOp lore, [a])] -> Maybe (SegBinOp lore, [a])
combineOps [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings
        ([PatElemT (LetDec lore)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (LetDec lore), Type, KernelResult)]
 -> ([PatElemT (LetDec lore)], [Type], [KernelResult]))
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElemT (LetDec lore), Type, KernelResult)]]
aux
        pes' :: [PatElemT (LetDec lore)]
pes' = [PatElemT (LetDec lore)]
red_pes' [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (LetDec lore)]
map_pes
        ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
        kbody' :: KernelBody lore
kbody' = KernelBody lore
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
    Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> PatternT (LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pes') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$ SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops' [Type]
ts' KernelBody lore
kbody'
  where
    ([PatElemT (LetDec lore)]
red_pes, [PatElemT (LetDec lore)]
map_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) [PatElemT (LetDec lore)]
pes
    ([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) [Type]
ts
    ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody

    sameShape :: (SegBinOp lore, b) -> (SegBinOp lore, b) -> Bool
sameShape (SegBinOp lore
op1, b
_) (SegBinOp lore
op2, b
_) = SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op1 ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op2

    combineOps :: [(SegBinOp lore, [a])] -> Maybe (SegBinOp lore, [a])
combineOps [] = Maybe (SegBinOp lore, [a])
forall a. Maybe a
Nothing
    combineOps ((SegBinOp lore, [a])
x : [(SegBinOp lore, [a])]
xs) = (SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a])
forall a. a -> Maybe a
Just ((SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a]))
-> (SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp lore, [a])
 -> (SegBinOp lore, [a]) -> (SegBinOp lore, [a]))
-> (SegBinOp lore, [a])
-> [(SegBinOp lore, [a])]
-> (SegBinOp lore, [a])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
forall {lore} {a}.
Bindable lore =>
(SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
combine (SegBinOp lore, [a])
x [(SegBinOp lore, [a])]
xs

    combine :: (SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
combine (SegBinOp lore
op1, [a]
op1_aux) (SegBinOp lore
op2, [a]
op2_aux) =
      let lam1 :: Lambda lore
lam1 = SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op1
          lam2 :: Lambda lore
lam2 = SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op2
          ([Param (LParamInfo lore)]
op1_xparams, [Param (LParamInfo lore)]
op1_yparams) =
            Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op1)) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam1
          ([Param (LParamInfo lore)]
op2_xparams, [Param (LParamInfo lore)]
op2_yparams) =
            Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op2)) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam2
          lam :: Lambda lore
lam =
            Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
              { lambdaParams :: [Param (LParamInfo lore)]
lambdaParams =
                  [Param (LParamInfo lore)]
op1_xparams [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op2_xparams
                    [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op1_yparams
                    [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op2_yparams,
                lambdaReturnType :: [Type]
lambdaReturnType = Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam2,
                lambdaBody :: BodyT lore
lambdaBody =
                  Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)) ([SubExp] -> BodyT lore) -> [SubExp] -> BodyT lore
forall a b. (a -> b) -> a -> b
$
                    BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)
              }
       in ( SegBinOp :: forall lore.
Commutativity
-> Lambda lore -> [SubExp] -> ShapeBase SubExp -> SegBinOp lore
SegBinOp
              { segBinOpComm :: Commutativity
segBinOpComm = SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm SegBinOp lore
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm SegBinOp lore
op2,
                segBinOpLambda :: Lambda lore
segBinOpLambda = Lambda lore
lam,
                segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op2,
                segBinOpShape :: ShapeBase SubExp
segBinOpShape = SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op1 -- Same as shape of op2 due to the grouping.
              },
            [a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
          )
topDownSegOp SymbolTable lore
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ SegOp (SegOpLevel lore) lore
_ = Rule lore
forall lore. Rule lore
Skip

-- A convenient way of operating on the type and body of a SegOp,
-- without worrying about exactly what kind it is.
segOpGuts ::
  SegOp (SegOpLevel lore) lore ->
  ( [Type],
    KernelBody lore,
    Int,
    [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore
  )
segOpGuts :: forall lore.
SegOp (SegOpLevel lore) lore
-> ([Type], KernelBody lore, Int,
    [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore)
segOpGuts (SegMap SegOpLevel lore
lvl SegSpace
space [Type]
kts KernelBody lore
body) =
  ([Type]
kts, KernelBody lore
body, Int
0, SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops [Type]
kts KernelBody lore
body) =
  ([Type]
kts, KernelBody lore
body, [SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops, SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops)
segOpGuts (SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops [Type]
kts KernelBody lore
body) =
  ([Type]
kts, KernelBody lore
body, [SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops, SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops)
segOpGuts (SegHist SegOpLevel lore
lvl SegSpace
space [HistOp lore]
ops [Type]
kts KernelBody lore
body) =
  ([Type]
kts, KernelBody lore
body, [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Int) -> [HistOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp lore -> [VName]) -> HistOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp lore]
ops, SegOpLevel lore
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist SegOpLevel lore
lvl SegSpace
space [HistOp lore]
ops)

bottomUpSegOp ::
  (HasSegOp lore, BinderOps lore) =>
  (ST.SymbolTable lore, UT.UsageTable) ->
  Pattern lore ->
  StmAux (ExpDec lore) ->
  SegOp (SegOpLevel lore) lore ->
  Rule lore
-- Some SegOp results can be moved outside the SegOp, which can
-- simplify further analysis.
bottomUpSegOp :: forall lore.
(HasSegOp lore, BinderOps lore) =>
(SymbolTable lore, UsageTable)
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
bottomUpSegOp (SymbolTable lore
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec lore)]
kpes) StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
segop = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  -- Iterate through the bindings.  For each, we check whether it is
  -- in kres and can be moved outside.  If so, we remove it from kres
  -- and kpes and make it a binding outside.  We have to be careful
  -- not to remove anything that is passed on to a scan/map/histogram
  -- operation.  Fortunately, these are always first in the result
  -- list.
  ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') <-
    Scope lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (RuleM
   lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
 -> RuleM
      lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore))
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall a b. (a -> b) -> a -> b
$
      (([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
 -> Stm lore
 -> RuleM
      lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore))
-> ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stms lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stm lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
distribute ([PatElemT (LetDec lore)]
kpes, [Type]
kts, [KernelResult]
kres, Stms lore
forall a. Monoid a => a
mempty) Stms lore
kstms

  Bool -> RuleM lore () -> RuleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    ([PatElemT (LetDec lore)]
kpes' [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec lore)]
kpes)
    RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify

  KernelBody lore
kbody <-
    Scope lore
-> RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore))
-> RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$
      Stms (Lore (RuleM lore))
-> [KernelResult] -> RuleM lore (KernelBody (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms lore
Stms (Lore (RuleM lore))
kstms' [KernelResult]
kres'

  Stm (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM lore)) -> RuleM lore ())
-> Stm (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec lore)
-> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> PatternT (LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
kpes') StmAux (ExpDec lore)
dec (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore
mk_segop [Type]
kts' KernelBody lore
kbody
  where
    ([Type]
kts, KernelBody BodyDec lore
_ Stms lore
kstms [KernelResult]
kres, Int
num_nonmap_results, [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore
mk_segop) =
      SegOp (SegOpLevel lore) lore
-> ([Type], KernelBody lore, Int,
    [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore)
forall lore.
SegOp (SegOpLevel lore) lore
-> ([Type], KernelBody lore, Int,
    [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore)
segOpGuts SegOp (SegOpLevel lore) lore
segop
    free_in_kstms :: Names
free_in_kstms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms lore
kstms
    space :: SegSpace
space = SegOp (SegOpLevel lore) lore -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp (SegOpLevel lore) lore
segop

    sliceWithGtidsFixed :: Stm lore -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm lore
stm
      | Let PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm lore
stm,
        Slice SubExp
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> Slice SubExp)
-> [(VName, SubExp)] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
        Slice SubExp
space_slice Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp
slice,
        Slice SubExp
remaining_slice <- Int -> Slice SubExp -> Slice SubExp
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
space_slice) Slice SubExp
slice,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry lore) -> Bool)
-> (VName -> Maybe (Entry lore)) -> VName -> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> SymbolTable lore -> Maybe (Entry lore))
-> SymbolTable lore -> VName -> Maybe (Entry lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup SymbolTable lore
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
          Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
            VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
        (Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
      | Bool
otherwise =
        Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing

    distribute :: ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stm lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
distribute ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') Stm lore
stm
      | Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ Exp lore
_ <- Stm lore
stm,
        Just (Slice SubExp
remaining_slice, VName
arr) <- Stm lore -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm lore
stm,
        Just (PatElemT (LetDec lore)
kpe, [PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> PatElemT (LetDec lore)
-> Maybe
     (PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
      [KernelResult])
isResult [PatElemT (LetDec lore)]
kpes' [Type]
kts' [KernelResult]
kres' PatElemT (LetDec lore)
pe = do
        let outer_slice :: Slice SubExp
outer_slice =
              (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map
                ( \SubExp
d ->
                    SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice
                      (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
                      SubExp
d
                      (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
                )
                ([SubExp] -> Slice SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
            index :: PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe' =
              [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe'] (Exp lore -> RuleM lore ())
-> (Slice SubExp -> Exp lore) -> Slice SubExp -> RuleM lore ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM lore ()) -> Slice SubExp -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
                Slice SubExp
outer_slice Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. Semigroup a => a -> a -> a
<> Slice SubExp
remaining_slice
        if PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
          then do
            VName
precopy <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM lore VName) -> String -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
            PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe {patElemName :: VName
patElemName = VName
precopy}
            [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
          else PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe
        ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( [PatElemT (LetDec lore)]
kpes'',
            [Type]
kts'',
            [KernelResult]
kres'',
            if PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
              then Stms lore
kstms' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
stm
              else Stms lore
kstms'
          )
    distribute ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') Stm lore
stm =
      ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
stm)

    isResult :: [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> PatElemT (LetDec lore)
-> Maybe
     (PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
      [KernelResult])
isResult [PatElemT (LetDec lore)]
kpes' [Type]
kts' [KernelResult]
kres' PatElemT (LetDec lore)
pe =
      case ((PatElemT (LetDec lore), Type, KernelResult) -> Bool)
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([(PatElemT (LetDec lore), Type, KernelResult)],
    [(PatElemT (LetDec lore), Type, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), Type, KernelResult) -> Bool
matches ([(PatElemT (LetDec lore), Type, KernelResult)]
 -> ([(PatElemT (LetDec lore), Type, KernelResult)],
     [(PatElemT (LetDec lore), Type, KernelResult)]))
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([(PatElemT (LetDec lore), Type, KernelResult)],
    [(PatElemT (LetDec lore), Type, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (LetDec lore)]
kpes' [Type]
kts' [KernelResult]
kres' of
        ([(PatElemT (LetDec lore)
kpe, Type
_, KernelResult
_)], [(PatElemT (LetDec lore), Type, KernelResult)]
kpes_and_kres)
          | Just Int
i <- PatElemT (LetDec lore) -> [PatElemT (LetDec lore)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElemT (LetDec lore)
kpe [PatElemT (LetDec lore)]
kpes,
            Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
            ([PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElemT (LetDec lore), Type, KernelResult)]
kpes_and_kres ->
            (PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
 [KernelResult])
-> Maybe
     (PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
      [KernelResult])
forall a. a -> Maybe a
Just (PatElemT (LetDec lore)
kpe, [PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
        ([(PatElemT (LetDec lore), Type, KernelResult)],
 [(PatElemT (LetDec lore), Type, KernelResult)])
_ -> Maybe
  (PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
   [KernelResult])
forall a. Maybe a
Nothing
      where
        matches :: (PatElemT (LetDec lore), Type, KernelResult) -> Bool
matches (PatElemT (LetDec lore)
_, Type
_, Returns ResultManifest
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe
        matches (PatElemT (LetDec lore), Type, KernelResult)
_ = Bool
False
bottomUpSegOp (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ SegOp (SegOpLevel lore) lore
_ = Rule lore
forall lore. Rule lore
Skip

--- Memory

kernelBodyReturns ::
  (Mem lore, HasScope lore m, Monad m) =>
  KernelBody lore ->
  [ExpReturns] ->
  m [ExpReturns]
kernelBodyReturns :: forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall {m :: * -> *} {lore}.
(Monad m, HasScope lore m, AllocOp (Op lore), ASTLore lore,
 OpReturns lore, LetDec lore ~ LetDecMem,
 LParamInfo lore ~ LetDecMem, RetType lore ~ RetTypeMem,
 FParamInfo lore ~ FParamMem, BranchType lore ~ BranchTypeMem) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody lore -> [KernelResult])
-> KernelBody lore
-> [ExpReturns]
-> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult
  where
    correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns VName
arr
    correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return ExpReturns
ret

-- | Like 'segOpType', but for memory representations.
segOpReturns ::
  (Mem lore, Monad m, HasScope lore m) =>
  SegOp lvl lore ->
  m [ExpReturns]
segOpReturns :: forall lore (m :: * -> *) lvl.
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore -> m [ExpReturns]
segOpReturns k :: SegOp lvl lore
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
  KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k
segOpReturns k :: SegOp lvl lore
k@(SegRed lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
  KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k
segOpReturns k :: SegOp lvl lore
k@(SegScan lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
  KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
_) =
  [[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp lore -> m [ExpReturns])
-> [HistOp lore] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp lore -> [VName]) -> HistOp lore -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp lore]
ops