{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Futhark.IR.Kernels.Kernel
  ( -- * Size operations
    SizeOp (..),

    -- * Host operations
    HostOp (..),
    typeCheckHostOp,

    -- * SegOp refinements
    SegLevel (..),

    -- * Reexports
    module Futhark.IR.Kernels.Sizes,
    module Futhark.IR.SegOp,
  )
where

import Futhark.Analysis.Metrics
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.Kernels.Sizes
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util.Pretty
  ( commasep,
    parens,
    ppr,
    text,
    (<+>),
  )
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))

-- | At which level the *body* of a t'SegOp' executes.
data SegLevel
  = SegThread
      { SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp,
        SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp,
        SegLevel -> SegVirt
segVirt :: SegVirt
      }
  | SegGroup
      { segNumGroups :: Count NumGroups SubExp,
        segGroupSize :: Count GroupSize SubExp,
        segVirt :: SegVirt
      }
  deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel
-> (SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)

instance PP.Pretty SegLevel where
  ppr :: SegLevel -> Doc
ppr SegLevel
lvl =
    Doc -> Doc
PP.parens
      ( Doc
lvl' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi
          Doc -> Doc -> Doc
<+> String -> Doc
text String
"#groups=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count NumGroups SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi
          Doc -> Doc -> Doc
<+> String -> Doc
text String
"groupsize=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
virt
      )
    where
      lvl' :: Doc
lvl' = case SegLevel
lvl of
        SegThread {} -> Doc
"thread"
        SegGroup {} -> Doc
"group"
      virt :: Doc
virt = case SegLevel -> SegVirt
segVirt SegLevel
lvl of
        SegVirt
SegNoVirt -> Doc
forall a. Monoid a => a
mempty
        SegVirt
SegNoVirtFull -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"full"
        SegVirt
SegVirt -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"virtualise"

instance Engine.Simplifiable SegLevel where
  simplify :: forall lore.
SimplifiableLore lore =>
SegLevel -> SimpleM lore SegLevel
simplify (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups
      SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size
      SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
  simplify (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups
      SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size
      SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt

instance Substitute SegLevel where
  substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
      (Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
      (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
      SegVirt
virt
  substituteNames Map VName VName
substs (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
      (Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
      (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
      SegVirt
virt

instance Rename SegLevel where
  rename :: SegLevel -> RenameM SegLevel
rename = SegLevel -> RenameM SegLevel
forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn SegLevel where
  freeIn' :: SegLevel -> FV
freeIn' (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
  freeIn' (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size

-- | A simple size-level query or computation.
data SizeOp
  = -- | @SplitSpace o w i elems_per_thread@.
    --
    -- Computes how to divide array elements to
    -- threads in a kernel.  Returns the number of
    -- elements in the chunk that the current thread
    -- should take.
    --
    -- @w@ is the length of the outer dimension in
    -- the array. @i@ is the current thread
    -- index. Each thread takes at most
    -- @elems_per_thread@ elements.
    --
    -- If the order @o@ is 'SplitContiguous', thread with index @i@
    -- should receive elements
    -- @i*elems_per_tread, i*elems_per_thread + 1,
    -- ..., i*elems_per_thread + (elems_per_thread-1)@.
    --
    -- If the order @o@ is @'SplitStrided' stride@,
    -- the thread will receive elements @i,
    -- i+stride, i+2*stride, ...,
    -- i+(elems_per_thread-1)*stride@.
    SplitSpace SplitOrdering SubExp SubExp SubExp
  | -- | Produce some runtime-configurable size.
    GetSize Name SizeClass
  | -- | The maximum size of some class.
    GetSizeMax SizeClass
  | -- | Compare size (likely a threshold) with some integer value.
    CmpSizeLe Name SizeClass SubExp
  | -- | @CalcNumGroups w max_num_groups group_size@ calculates the
    -- number of GPU workgroups to use for an input of the given size.
    -- The @Name@ is a size name.  Note that @w@ is an i64 to avoid
    -- overflow issues.
    CalcNumGroups SubExp Name SubExp
  deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp
-> (SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)

instance Substitute SizeOp where
  substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
subst (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
      (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
i)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
elems_per_thread)
  substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
  substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
      Name
max_num_groups
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
  substituteNames Map VName VName
_ SizeOp
op = SizeOp
op

instance Rename SizeOp where
  rename :: SizeOp -> RenameM SizeOp
rename (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
      (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SplitOrdering
-> RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> RenameM SplitOrdering
forall a. Rename a => a -> RenameM a
rename SplitOrdering
o
      RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w
      RenameM (SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
i
      RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
elems_per_thread
  rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
  rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
group_size
  rename SizeOp
x = SizeOp -> RenameM SizeOp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x

instance IsOp SizeOp where
  safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
  cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True

instance TypedOp SizeOp where
  opType :: forall t (m :: * -> *). HasScope t m => SizeOp -> m [ExtType]
opType SplitSpace {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType (GetSize Name
_ SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType (GetSizeMax SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType CmpSizeLe {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
  opType CalcNumGroups {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]

instance AliasedOp SizeOp where
  opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [Names
forall a. Monoid a => a
mempty]
  consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = Names
forall a. Monoid a => a
mempty

instance FreeIn SizeOp where
  freeIn' :: SizeOp -> FV
freeIn' (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp
w, SubExp
i, SubExp
elems_per_thread]
  freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
  freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
  freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty

instance PP.Pretty SizeOp where
  ppr :: SizeOp -> Doc
ppr (SplitSpace SplitOrdering
SplitContiguous SubExp
w SubExp
i SubExp
elems_per_thread) =
    String -> Doc
text String
"split_space"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
  ppr (SplitSpace (SplitStrided SubExp
stride) SubExp
w SubExp
i SubExp
elems_per_thread) =
    String -> Doc
text String
"split_space_strided"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
  ppr (GetSize Name
name SizeClass
size_class) =
    String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
  ppr (GetSizeMax SizeClass
size_class) =
    String -> Doc
text String
"get_size_max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
  ppr (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
    String -> Doc
text String
"cmp_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
      Doc -> Doc -> Doc
<+> String -> Doc
text String
"<="
      Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
x
  ppr (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    String -> Doc
text String
"calc_num_groups" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
max_num_groups, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
group_size])

instance OpMetrics SizeOp where
  opMetrics :: SizeOp -> MetricsM ()
opMetrics SplitSpace {} = Text -> MetricsM ()
seen Text
"SplitSpace"
  opMetrics GetSize {} = Text -> MetricsM ()
seen Text
"GetSize"
  opMetrics GetSizeMax {} = Text -> MetricsM ()
seen Text
"GetSizeMax"
  opMetrics CmpSizeLe {} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
  opMetrics CalcNumGroups {} = Text -> MetricsM ()
seen Text
"CalcNumGroups"

typeCheckSizeOp :: TC.Checkable lore => SizeOp -> TC.TypeM lore ()
typeCheckSizeOp :: forall lore. Checkable lore => SizeOp -> TypeM lore ()
typeCheckSizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) = do
  case SplitOrdering
o of
    SplitOrdering
SplitContiguous -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
stride
  (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp
w, SubExp
i, SubExp
elems_per_thread]
typeCheckSizeOp GetSize {} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp GetSizeMax {} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
group_size

-- | A host-level operation; parameterised by what else it can do.
data HostOp lore op
  = -- | A segmented operation.
    SegOp (SegOp SegLevel lore)
  | SizeOp SizeOp
  | OtherOp op
  deriving (HostOp lore op -> HostOp lore op -> Bool
(HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> Eq (HostOp lore op)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
/= :: HostOp lore op -> HostOp lore op -> Bool
$c/= :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
== :: HostOp lore op -> HostOp lore op -> Bool
$c== :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
Eq, Eq (HostOp lore op)
Eq (HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> Ordering)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> Ord (HostOp lore op)
HostOp lore op -> HostOp lore op -> Bool
HostOp lore op -> HostOp lore op -> Ordering
HostOp lore op -> HostOp lore op -> HostOp lore op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {lore} {op}.
(Decorations lore, Ord op) =>
Eq (HostOp lore op)
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
min :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmin :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
max :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmax :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
>= :: HostOp lore op -> HostOp lore op -> Bool
$c>= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
> :: HostOp lore op -> HostOp lore op -> Bool
$c> :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
<= :: HostOp lore op -> HostOp lore op -> Bool
$c<= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
< :: HostOp lore op -> HostOp lore op -> Bool
$c< :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
compare :: HostOp lore op -> HostOp lore op -> Ordering
$ccompare :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
Ord, Int -> HostOp lore op -> ShowS
[HostOp lore op] -> ShowS
HostOp lore op -> String
(Int -> HostOp lore op -> ShowS)
-> (HostOp lore op -> String)
-> ([HostOp lore op] -> ShowS)
-> Show (HostOp lore op)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showList :: [HostOp lore op] -> ShowS
$cshowList :: forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
show :: HostOp lore op -> String
$cshow :: forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showsPrec :: Int -> HostOp lore op -> ShowS
$cshowsPrec :: forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
Show)

instance (ASTLore lore, Substitute op) => Substitute (HostOp lore op) where
  substituteNames :: Map VName VName -> HostOp lore op -> HostOp lore op
substituteNames Map VName VName
substs (SegOp SegOp SegLevel lore
op) =
    SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp SegLevel lore -> SegOp SegLevel lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel lore
op
  substituteNames Map VName VName
substs (OtherOp op
op) =
    op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op -> op
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
  substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
    SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op) -> SizeOp -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op

instance (ASTLore lore, Rename op) => Rename (HostOp lore op) where
  rename :: HostOp lore op -> RenameM (HostOp lore op)
rename (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> RenameM (SegOp SegLevel lore) -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel lore -> RenameM (SegOp SegLevel lore)
forall a. Rename a => a -> RenameM a
rename SegOp SegLevel lore
op
  rename (OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> RenameM op -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> RenameM op
forall a. Rename a => a -> RenameM a
rename op
op
  rename (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op)
-> RenameM SizeOp -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op

instance (ASTLore lore, IsOp op) => IsOp (HostOp lore op) where
  safeOp :: HostOp lore op -> Bool
safeOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
safeOp SegOp SegLevel lore
op
  safeOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
safeOp op
op
  safeOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
safeOp SizeOp
op

  cheapOp :: HostOp lore op -> Bool
cheapOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp SegOp SegLevel lore
op
  cheapOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
cheapOp op
op
  cheapOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
cheapOp SizeOp
op

instance TypedOp op => TypedOp (HostOp lore op) where
  opType :: forall t (m :: * -> *).
HasScope t m =>
HostOp lore op -> m [ExtType]
opType (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp SegLevel lore
op
  opType (OtherOp op
op) = op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
  opType (SizeOp SizeOp
op) = SizeOp -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op

instance (Aliased lore, AliasedOp op, ASTLore lore) => AliasedOp (HostOp lore op) where
  opAliases :: HostOp lore op -> [Names]
opAliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SegOp SegLevel lore
op
  opAliases (OtherOp op
op) = op -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases op
op
  opAliases (SizeOp SizeOp
op) = SizeOp -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op

  consumedInOp :: HostOp lore op -> Names
consumedInOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SegOp SegLevel lore
op
  consumedInOp (OtherOp op
op) = op -> Names
forall op. AliasedOp op => op -> Names
consumedInOp op
op
  consumedInOp (SizeOp SizeOp
op) = SizeOp -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op

instance (ASTLore lore, FreeIn op) => FreeIn (HostOp lore op) where
  freeIn' :: HostOp lore op -> FV
freeIn' (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel lore
op
  freeIn' (OtherOp op
op) = op -> FV
forall a. FreeIn a => a -> FV
freeIn' op
op
  freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op

instance (CanBeAliased (Op lore), CanBeAliased op, ASTLore lore) => CanBeAliased (HostOp lore op) where
  type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op)

  addOpAliases :: AliasTable -> HostOp lore op -> OpWithAliases (HostOp lore op)
addOpAliases AliasTable
aliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases lore)
 -> HostOp (Aliases lore) (OpWithAliases op))
-> SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> SegOp SegLevel lore -> OpWithAliases (SegOp SegLevel lore)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases SegOp SegLevel lore
op
  addOpAliases AliasTable
aliases (OtherOp op
op) = OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op))
-> OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ AliasTable -> op -> OpWithAliases op
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases op
op
  addOpAliases AliasTable
_ (SizeOp SizeOp
op) = SizeOp -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

  removeOpAliases :: OpWithAliases (HostOp lore op) -> HostOp lore op
removeOpAliases (SegOp SegOp SegLevel (Aliases lore)
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SegOp SegLevel lore)
SegOp SegLevel (Aliases lore)
op
  removeOpAliases (OtherOp OpWithAliases op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases op -> op
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
  removeOpAliases (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

instance (CanBeWise (Op lore), CanBeWise op, ASTLore lore) => CanBeWise (HostOp lore op) where
  type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op)

  removeOpWisdom :: OpWithWisdom (HostOp lore op) -> HostOp lore op
removeOpWisdom (SegOp SegOp SegLevel (Wise lore)
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom (SegOp SegLevel lore)
SegOp SegLevel (Wise lore)
op
  removeOpWisdom (OtherOp OpWithWisdom op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom op -> op
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
  removeOpWisdom (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

instance (ASTLore lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where
  indexOp :: forall lore.
(ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> HostOp lore op -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegOp SegOp SegLevel lore
op) [TPrimExp Int64 VName]
is = SymbolTable lore
-> Int
-> SegOp SegLevel lore
-> [TPrimExp Int64 VName]
-> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k SegOp SegLevel lore
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable lore
vtable Int
k (OtherOp op
op) [TPrimExp Int64 VName]
is = SymbolTable lore
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k op
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable lore
_ Int
_ HostOp lore op
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where
  ppr :: HostOp lore op -> Doc
ppr (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Doc
forall a. Pretty a => a -> Doc
ppr SegOp SegLevel lore
op
  ppr (OtherOp op
op) = op -> Doc
forall a. Pretty a => a -> Doc
ppr op
op
  ppr (SizeOp SizeOp
op) = SizeOp -> Doc
forall a. Pretty a => a -> Doc
ppr SizeOp
op

instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where
  opMetrics :: HostOp lore op -> MetricsM ()
opMetrics (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel lore
op
  opMetrics (OtherOp op
op) = op -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
  opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op

checkSegLevel ::
  TC.Checkable lore =>
  Maybe SegLevel ->
  SegLevel ->
  TC.TypeM lore ()
checkSegLevel :: forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
lvl = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
checkSegLevel (Just SegThread {}) SegLevel
_ =
  ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
  | SegLevel
x SegLevel -> SegLevel -> Bool
forall a. Eq a => a -> a -> Bool
== SegLevel
y = ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Already at at level " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty SegLevel
x
  | SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x Count NumGroups SubExp -> Count NumGroups SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x Count GroupSize SubExp -> Count GroupSize SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Physical layout for SegLevel does not match parent SegLevel."
  | Bool
otherwise =
    () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

typeCheckHostOp ::
  TC.Checkable lore =>
  (SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ()) ->
  Maybe SegLevel ->
  (op -> TC.TypeM lore ()) ->
  HostOp (Aliases lore) op ->
  TC.TypeM lore ()
typeCheckHostOp :: forall lore op.
Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker Maybe SegLevel
lvl op -> TypeM lore ()
_ (SegOp SegOp SegLevel (Aliases lore)
op) =
  (OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore () -> TypeM lore ()
forall lore a.
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore a -> TypeM lore a
TC.checkOpWith (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases lore) -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    (SegLevel -> TypeM lore ())
-> SegOp SegLevel (Aliases lore) -> TypeM lore ()
forall lore lvl.
Checkable lore =>
(lvl -> TypeM lore ()) -> SegOp lvl (Aliases lore) -> TypeM lore ()
typeCheckSegOp (Maybe SegLevel -> SegLevel -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases lore)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
f (OtherOp op
op) = op -> TypeM lore ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM lore ()
forall lore. Checkable lore => SizeOp -> TypeM lore ()
typeCheckSizeOp SizeOp
op