{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-- {-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- |
-- Module      : Language.Halide.Func
-- Description : Functions / Arrays
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Func
  ( -- * Defining pipelines
    Func (..)
  , FuncTy (..)
  , Stage (..)
  , buffer
  , scalar
  , define
  , (!)
  , realize

    -- * Scheduling
  , Schedulable (..)
  , TailStrategy (..)

    -- ** 'Func'-specific
  , computeRoot
  , getStage
  , getLoopLevel
  , getLoopLevelAtStage
  , asUsed
  , asUsedBy
  , copyToDevice
  , copyToHost
  , storeAt
  , computeAt
  , dim
  , estimate
  , bound
  , getArgs
  -- , deepCopy

    -- * Update definitions
  , update
  , hasUpdateDefinitions
  , getUpdateStage

    -- * Debugging
  , prettyLoopNest

    -- * Internal
  , IndexTuple
  , asBufferParam
  , withFunc
  , withBufferParam
  , wrapCxxFunc
  , CxxStage
  , wrapCxxStage
  , withCxxStage
  )
where

import Control.Exception (bracket)
import Control.Monad (forM)
import Data.IORef
import Data.Kind (Type)
import Data.Proxy
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Foreign.ForeignPtr
import Foreign.Marshal (toBool, with)
import Foreign.Ptr (Ptr, castPtr)
import GHC.Stack (HasCallStack)
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Dimension
import Language.Halide.Expr
import Language.Halide.LoopLevel
import Language.Halide.Target
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (min, tail)

-- | Haskell counterpart of [Halide::Stage](https://halide-lang.org/docs/class_halide_1_1_stage.html).
data CxxStage

importHalide

-- | A function in Halide. Conceptually, it can be thought of as a lazy
-- @n@-dimensional buffer of type @a@.
--
-- This is a wrapper around the [@Halide::Func@](https://halide-lang.org/docs/class_halide_1_1_func.html)
-- C++ type.
data Func (t :: FuncTy) (n :: Nat) (a :: Type) where
  Func :: {-# UNPACK #-} !(ForeignPtr CxxFunc) -> Func 'FuncTy n a
  Param :: {-# UNPACK #-} !(IORef (Maybe (ForeignPtr CxxImageParam))) -> Func 'ParamTy n a

-- | Function type. It can either be 'FuncTy' which means that we have defined the function ourselves,
-- or 'ParamTy' which means that it's a parameter to our pipeline.
data FuncTy = FuncTy | ParamTy
  deriving stock (Int -> FuncTy -> ShowS
[FuncTy] -> ShowS
FuncTy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FuncTy] -> ShowS
$cshowList :: [FuncTy] -> ShowS
show :: FuncTy -> String
$cshow :: FuncTy -> String
showsPrec :: Int -> FuncTy -> ShowS
$cshowsPrec :: Int -> FuncTy -> ShowS
Show, FuncTy -> FuncTy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuncTy -> FuncTy -> Bool
$c/= :: FuncTy -> FuncTy -> Bool
== :: FuncTy -> FuncTy -> Bool
$c== :: FuncTy -> FuncTy -> Bool
Eq, Eq FuncTy
FuncTy -> FuncTy -> Bool
FuncTy -> FuncTy -> Ordering
FuncTy -> FuncTy -> FuncTy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FuncTy -> FuncTy -> FuncTy
$cmin :: FuncTy -> FuncTy -> FuncTy
max :: FuncTy -> FuncTy -> FuncTy
$cmax :: FuncTy -> FuncTy -> FuncTy
>= :: FuncTy -> FuncTy -> Bool
$c>= :: FuncTy -> FuncTy -> Bool
> :: FuncTy -> FuncTy -> Bool
$c> :: FuncTy -> FuncTy -> Bool
<= :: FuncTy -> FuncTy -> Bool
$c<= :: FuncTy -> FuncTy -> Bool
< :: FuncTy -> FuncTy -> Bool
$c< :: FuncTy -> FuncTy -> Bool
compare :: FuncTy -> FuncTy -> Ordering
$ccompare :: FuncTy -> FuncTy -> Ordering
Ord)

-- | A single definition of a t'Func'.
newtype Stage (n :: Nat) (a :: Type) = Stage (ForeignPtr CxxStage)

-- | Different ways to handle a tail case in a split when the split factor does
-- not provably divide the extent.
--
-- This is the Haskell counterpart of [@Halide::TailStrategy@](https://halide-lang.org/docs/namespace_halide.html#a6c6557df562bd7850664e70fdb8fea0f).
data TailStrategy
  = -- | Round up the extent to be a multiple of the split factor.
    --
    -- Not legal for RVars, as it would change the meaning of the algorithm.
    --
    -- * Pros: generates the simplest, fastest code.
    -- * Cons: if used on a stage that reads from the input or writes to the
    -- output, constrains the input or output size to be a multiple of the
    -- split factor.
    TailRoundUp
  | -- | Guard the inner loop with an if statement that prevents evaluation
    -- beyond the original extent.
    --
    -- Always legal. The if statement is treated like a boundary condition, and
    -- factored out into a loop epilogue if possible.
    --
    -- * Pros: no redundant re-evaluation; does not constrain input our output sizes.
    -- * Cons: increases code size due to separate tail-case handling;
    -- vectorization will scalarize in the tail case to handle the if
    -- statement.
    TailGuardWithIf
  | -- | Guard the loads and stores in the loop with an if statement that
    -- prevents evaluation beyond the original extent.
    --
    -- Always legal. The if statement is treated like a boundary condition, and
    -- factored out into a loop epilogue if possible.
    -- * Pros: no redundant re-evaluation; does not constrain input or output
    -- sizes.
    -- * Cons: increases code size due to separate tail-case handling.
    TailPredicate
  | -- | Guard the loads in the loop with an if statement that prevents
    -- evaluation beyond the original extent.
    --
    -- Only legal for innermost splits. Not legal for RVars, as it would change
    -- the meaning of the algorithm. The if statement is treated like a
    -- boundary condition, and factored out into a loop epilogue if possible.
    -- * Pros: does not constrain input sizes, output size constraints are
    -- simpler than full predication.
    -- * Cons: increases code size due to separate tail-case handling,
    -- constrains the output size to be a multiple of the split factor.
    TailPredicateLoads
  | -- | Guard the stores in the loop with an if statement that prevents
    -- evaluation beyond the original extent.
    --
    -- Only legal for innermost splits. Not legal for RVars, as it would change
    -- the meaning of the algorithm. The if statement is treated like a
    -- boundary condition, and factored out into a loop epilogue if possible.
    -- * Pros: does not constrain output sizes, input size constraints are
    -- simpler than full predication.
    -- * Cons: increases code size due to separate tail-case handling,
    -- constraints the input size to be a multiple of the split factor.
    TailPredicateStores
  | -- | Prevent evaluation beyond the original extent by shifting the tail
    -- case inwards, re-evaluating some points near the end.
    --
    -- Only legal for pure variables in pure definitions. If the inner loop is
    -- very simple, the tail case is treated like a boundary condition and
    -- factored out into an epilogue.
    --
    -- This is a good trade-off between several factors. Like 'TailRoundUp', it
    -- supports vectorization well, because the inner loop is always a fixed
    -- size with no data-dependent branching. It increases code size slightly
    -- for inner loops due to the epilogue handling, but not for outer loops
    -- (e.g. loops over tiles). If used on a stage that reads from an input or
    -- writes to an output, this stategy only requires that the input/output
    -- extent be at least the split factor, instead of a multiple of the split
    -- factor as with 'TailRoundUp'.
    TailShiftInwards
  | -- | For pure definitions use 'TailShiftInwards'.
    --
    -- For pure vars in update definitions use 'TailRoundUp'. For RVars in update
    -- definitions use 'TailGuardWithIf'.
    TailAuto
  deriving stock (TailStrategy -> TailStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TailStrategy -> TailStrategy -> Bool
$c/= :: TailStrategy -> TailStrategy -> Bool
== :: TailStrategy -> TailStrategy -> Bool
$c== :: TailStrategy -> TailStrategy -> Bool
Eq, Eq TailStrategy
TailStrategy -> TailStrategy -> Bool
TailStrategy -> TailStrategy -> Ordering
TailStrategy -> TailStrategy -> TailStrategy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TailStrategy -> TailStrategy -> TailStrategy
$cmin :: TailStrategy -> TailStrategy -> TailStrategy
max :: TailStrategy -> TailStrategy -> TailStrategy
$cmax :: TailStrategy -> TailStrategy -> TailStrategy
>= :: TailStrategy -> TailStrategy -> Bool
$c>= :: TailStrategy -> TailStrategy -> Bool
> :: TailStrategy -> TailStrategy -> Bool
$c> :: TailStrategy -> TailStrategy -> Bool
<= :: TailStrategy -> TailStrategy -> Bool
$c<= :: TailStrategy -> TailStrategy -> Bool
< :: TailStrategy -> TailStrategy -> Bool
$c< :: TailStrategy -> TailStrategy -> Bool
compare :: TailStrategy -> TailStrategy -> Ordering
$ccompare :: TailStrategy -> TailStrategy -> Ordering
Ord, Int -> TailStrategy -> ShowS
[TailStrategy] -> ShowS
TailStrategy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TailStrategy] -> ShowS
$cshowList :: [TailStrategy] -> ShowS
show :: TailStrategy -> String
$cshow :: TailStrategy -> String
showsPrec :: Int -> TailStrategy -> ShowS
$cshowsPrec :: Int -> TailStrategy -> ShowS
Show)

-- | Specifies that @i@ is a tuple of @'Expr' Int32@.
--
-- @ts@ are deduced from @i@, so you don't have to specify them explicitly.
type IndexTuple i ts = (IsTuple (Arguments ts) i, All ((~) (Expr Int32)) ts)

-- | Common scheduling functions
class (KnownNat n, IsHalideType a) => Schedulable f n a where
  -- | Vectorize the dimension.
  vectorize :: VarOrRVar -> f n a -> IO (f n a)

  -- | Unroll the dimension.
  unroll :: VarOrRVar -> f n a -> IO (f n a)

  -- | Reorder variables to have the given nesting order, from innermost out.
  reorder :: [VarOrRVar] -> f n a -> IO (f n a)

  -- | Split a dimension into inner and outer subdimensions with the given names, where the inner dimension
  -- iterates from @0@ to @factor-1@.
  --
  -- The inner and outer subdimensions can then be dealt with using the other scheduling calls. It's okay
  -- to reuse the old variable name as either the inner or outer variable. The first argument specifies
  -- how the tail should be handled if the split factor does not provably divide the extent.
  split :: TailStrategy -> VarOrRVar -> (VarOrRVar, VarOrRVar) -> Expr Int32 -> f n a -> IO (f n a)

  -- | Join two dimensions into a single fused dimenion.
  --
  -- The fused dimension covers the product of the extents of the inner and outer dimensions given.
  fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)

  -- | Mark the dimension to be traversed serially
  serial :: VarOrRVar -> f n a -> IO (f n a)

  -- | Mark the dimension to be traversed in parallel
  parallel :: VarOrRVar -> f n a -> IO (f n a)

  specialize :: Expr Bool -> f n a -> IO (Stage n a)
  specializeFail :: Text -> f n a -> IO ()
  gpuBlocks :: (IndexTuple i ts, 1 <= Length ts, Length ts <= 3) => DeviceAPI -> i -> f n a -> IO (f n a)
  gpuThreads :: (IndexTuple i ts, 1 <= Length ts, Length ts <= 3) => DeviceAPI -> i -> f n a -> IO (f n a)
  gpuLanes :: DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)

  -- | Schedule the iteration over this stage to be fused with another stage from outermost loop to a
  -- given LoopLevel.
  --
  -- For more info, see [Halide::Stage::compute_with](https://halide-lang.org/docs/class_halide_1_1_stage.html#a82a2ae25a009d6a2d52cb407a25f0a5b).
  computeWith :: LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()

instance (KnownNat n, IsHalideType a) => Schedulable Stage n a where
  vectorize :: VarOrRVar -> Stage n a -> IO (Stage n a)
vectorize VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->vectorize(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  unroll :: VarOrRVar -> Stage n a -> IO (Stage n a)
unroll VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->unroll(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  reorder :: [VarOrRVar] -> Stage n a -> IO (Stage n a)
reorder [VarOrRVar]
args Stage n a
stage = do
    forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar [VarOrRVar]
args forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
args' -> do
      forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=]() {
            $(Halide::Stage* stage')->reorder(
              *$(const std::vector<Halide::VarOrRVar>* args')); 
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Stage n a
-> IO (Stage n a)
split TailStrategy
tail VarOrRVar
old (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
factor Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
old forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
old' ->
        forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
          forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
            forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
factor forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
factor' ->
              [C.throwBlock| void {
                handle_halide_exceptions([=](){
                  $(Halide::Stage* stage')->split(
                    *$(const Halide::VarOrRVar* old'),
                    *$(const Halide::VarOrRVar* outer'),
                    *$(const Halide::VarOrRVar* inner'),
                    *$(const Halide::Expr* factor'),
                    static_cast<Halide::TailStrategy>($(int t)));
                });
              } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
    where
      t :: CInt
t = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ TailStrategy
tail
  fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> Stage n a -> IO (Stage n a)
fuse (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
fused Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
        forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
          forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
fused forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
fused' ->
            [C.throwBlock| void {
              handle_halide_exceptions([=](){
                $(Halide::Stage* stage')->fuse(
                  *$(const Halide::VarOrRVar* outer'),
                  *$(const Halide::VarOrRVar* inner'),
                  *$(const Halide::VarOrRVar* fused'));
              });
            } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  serial :: VarOrRVar -> Stage n a -> IO (Stage n a)
serial VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->serial(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  parallel :: VarOrRVar -> Stage n a -> IO (Stage n a)
parallel VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->parallel(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  specialize :: Expr Bool -> Stage n a -> IO (Stage n a)
specialize Expr Bool
cond Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Bool
cond forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
cond' ->
        forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Stage* {
                return handle_halide_exceptions([=](){
                  return new Halide::Stage{$(Halide::Stage* stage')->specialize(
                    *$(const Halide::Expr* cond'))};
                });
              } |]
  specializeFail :: Text -> Stage n a -> IO ()
specializeFail (Text -> ByteString
T.encodeUtf8 -> ByteString
s) Stage n a
stage =
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      [C.throwBlock| void {
        return handle_halide_exceptions([=](){
          $(Halide::Stage* stage')->specialize_fail(
            std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)});
        });
      } |]
  gpuBlocks :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Stage n a -> IO (Stage n a)
gpuBlocks (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) i
vars Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall a t. IsTuple a t => t -> a
fromTuple i
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
            auto& stage = *$(Halide::Stage* stage');
            auto const device = static_cast<Halide::DeviceAPI>($(int api));
            switch (vars.size()) {
              case 1: stage.gpu_blocks(vars.at(0), device);
                      break;
              case 2: stage.gpu_blocks(vars.at(0), vars.at(1), device);
                      break;
              case 3: stage.gpu_blocks(vars.at(0), vars.at(1), vars.at(2), device);
                      break;
              default: throw std::runtime_error{"unexpected number of arguments in gpuBlocks"};
            }
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  gpuThreads :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Stage n a -> IO (Stage n a)
gpuThreads (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) i
vars Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall a t. IsTuple a t => t -> a
fromTuple i
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
            auto& stage = *$(Halide::Stage* stage');
            auto const device = static_cast<Halide::DeviceAPI>($(int api));
            switch (vars.size()) {
              case 1: stage.gpu_threads(vars.at(0), device);
                      break;
              case 2: stage.gpu_threads(vars.at(0), vars.at(1), device);
                      break;
              case 3: stage.gpu_threads(vars.at(0), vars.at(1), vars.at(2), device);
                      break;
              default: throw std::runtime_error{"unexpected number of arguments in gpuThreads"};
            }
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  gpuLanes :: DeviceAPI -> VarOrRVar -> Stage n a -> IO (Stage n a)
gpuLanes (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->gpu_lanes(
              *$(const Halide::VarOrRVar* var'),
              static_cast<Halide::DeviceAPI>($(int api)));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Stage n a -> LoopLevel t -> IO ()
computeWith (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
align) Stage n a
stage LoopLevel t
level = do
    forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=]() {
            $(Halide::Stage* stage')->compute_with(
              *$(const Halide::LoopLevel* level'),
              static_cast<Halide::LoopAlignStrategy>($(int align)));
          });
        } |]

viaStage1
  :: (KnownNat n, IsHalideType b)
  => (a -> Stage n b -> IO (Stage n b))
  -> a
  -> Func t n b
  -> IO (Func t n b)
viaStage1 :: forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 a -> Stage n b -> IO (Stage n b)
f a
a1 Func t n b
func = do
  Stage n b
_ <- a -> Stage n b -> IO (Stage n b)
f a
a1 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func

viaStage2
  :: (KnownNat n, IsHalideType b)
  => (a1 -> a2 -> Stage n b -> IO (Stage n b))
  -> a1
  -> a2
  -> Func t n b
  -> IO (Func t n b)
viaStage2 :: forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 Func t n b
func = do
  Stage n b
_ <- a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func

{-
viaStage3
  :: (KnownNat n, IsHalideType b)
  => (a1 -> a2 -> a3 -> Stage n b -> IO (Stage n b))
  -> a1
  -> a2
  -> a3
  -> Func t n b
  -> IO (Func t n b)
viaStage3 f a1 a2 a3 func = do
  _ <- f a1 a2 a3 =<< getStage func
  pure func
-}

viaStage4
  :: (KnownNat n, IsHalideType b)
  => (a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
  -> a1
  -> a2
  -> a3
  -> a4
  -> Func t n b
  -> IO (Func t n b)
viaStage4 :: forall (n :: Nat) b a1 a2 a3 a4 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 Func t n b
func = do
  Stage n b
_ <- a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func

instance (KnownNat n, IsHalideType a) => Schedulable (Func t) n a where
  vectorize :: VarOrRVar -> Func t n a -> IO (Func t n a)
vectorize = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
vectorize
  unroll :: VarOrRVar -> Func t n a -> IO (Func t n a)
unroll = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
unroll
  reorder :: [VarOrRVar] -> Func t n a -> IO (Func t n a)
reorder = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
[VarOrRVar] -> f n a -> IO (f n a)
reorder
  split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Func t n a
-> IO (Func t n a)
split = forall (n :: Nat) b a1 a2 a3 a4 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> f n a
-> IO (f n a)
split
  fuse :: (VarOrRVar, VarOrRVar)
-> VarOrRVar -> Func t n a -> IO (Func t n a)
fuse = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
(VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)
fuse
  serial :: VarOrRVar -> Func t n a -> IO (Func t n a)
serial = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
serial
  parallel :: VarOrRVar -> Func t n a -> IO (Func t n a)
parallel = forall (n :: Nat) b a (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
parallel
  specialize :: Expr Bool -> Func t n a -> IO (Stage n a)
specialize Expr Bool
cond Func t n a
func = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Bool -> f n a -> IO (Stage n a)
specialize Expr Bool
cond
  specializeFail :: Text -> Func t n a -> IO ()
specializeFail Text
msg Func t n a
func = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Text -> f n a -> IO ()
specializeFail Text
msg
  gpuBlocks :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Func t n a -> IO (Func t n a)
gpuBlocks = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a i (ts :: [*]).
(Schedulable f n a, IndexTuple i ts, 1 <= Length ts,
 Length ts <= 3) =>
DeviceAPI -> i -> f n a -> IO (f n a)
gpuBlocks
  gpuThreads :: forall i (ts :: [*]).
(IndexTuple i ts, 1 <= Length ts, Length ts <= 3) =>
DeviceAPI -> i -> Func t n a -> IO (Func t n a)
gpuThreads = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a i (ts :: [*]).
(Schedulable f n a, IndexTuple i ts, 1 <= Length ts,
 Length ts <= 3) =>
DeviceAPI -> i -> f n a -> IO (f n a)
gpuThreads
  gpuLanes :: DeviceAPI -> VarOrRVar -> Func t n a -> IO (Func t n a)
gpuLanes = forall (n :: Nat) b a1 a2 (t :: FuncTy).
(KnownNat n, IsHalideType b) =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)
gpuLanes
  computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Func t n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Func t n a
f LoopLevel t
l = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
f forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Stage n a
f' -> forall (f :: Nat -> * -> *) (n :: Nat) a (t :: LoopLevelTy).
Schedulable f n a =>
LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Stage n a
f' LoopLevel t
l

instance Enum TailStrategy where
  fromEnum :: TailStrategy -> Int
fromEnum =
    forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      TailStrategy
TailRoundUp -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |]
      TailStrategy
TailGuardWithIf -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |]
      TailStrategy
TailPredicate -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |]
      TailStrategy
TailPredicateLoads -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |]
      TailStrategy
TailPredicateStores -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |]
      TailStrategy
TailShiftInwards -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |]
      TailStrategy
TailAuto -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |]
  toEnum :: Int -> TailStrategy
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |] = TailStrategy
TailRoundUp
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |] = TailStrategy
TailGuardWithIf
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |] = TailStrategy
TailPredicate
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |] = TailStrategy
TailPredicateLoads
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |] = TailStrategy
TailPredicateStores
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |] = TailStrategy
TailShiftInwards
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |] = TailStrategy
TailAuto
    | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid TailStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k

-- vectorize
--   :: (KnownNat n, IsHalideType a)
--   => TailStrategy
--   -> Func t n a
--   -> Expr Int32
--   -- ^ Variable to vectorize
--   -> Expr Int32
--   -- ^ Split factor
--   -> IO ()
-- vectorize strategy func var factor =
--   withFunc func $ \f ->
--     asVarOrRVar var $ \x ->
--       asExpr factor $ \n ->
--         [C.throwBlock| void {
--           $(Halide::Func* f)->vectorize(*$(Halide::VarOrRVar* x), *$(Halide::Expr* n),
--                                         static_cast<Halide::TailStrategy>($(int tail)));
--         } |]
--   where
--     tail = fromIntegral (fromEnum strategy)

-- | Split a dimension by the given factor, then unroll the inner dimension.
--
-- This is how you unroll a loop of unknown size by some constant factor. After
-- this call, @var@ refers to the outer dimension of the split.
-- unroll
--   :: (KnownNat n, IsHalideType a)
--   => TailStrategy
--   -> Func t n a
--   -> Expr Int32
--   -- ^ Variable @var@ to vectorize
--   -> Expr Int32
--   -- ^ Split factor
--   -> IO ()
-- unroll strategy func var factor =
--   withFunc func $ \f ->
--     asVarOrRVar var $ \x ->
--       asExpr factor $ \n ->
--         [C.throwBlock| void {
--           $(Halide::Func* f)->unroll(*$(Halide::VarOrRVar* x), *$(Halide::Expr* n),
--                                      static_cast<Halide::TailStrategy>($(int tail)));
--         } |]
--   where
--     tail = fromIntegral (fromEnum strategy)

-- | Reorder variables to have the given nesting order, from innermost out.
-- reorder
--   :: forall t n a i ts
--    . ( IsTuple (Arguments ts) i
--      , All ((~) (Expr Int32)) ts
--      , Length ts ~ n
--      , KnownNat n
--      , IsHalideType a
--      )
--   => Func t n a
--   -> i
--   -> IO ()
-- reorder func args =
--   asVectorOf @((~) (Expr Int32)) asVarOrRVar (fromTuple args) $ \v -> do
--     withFunc func $ \f ->
--       [C.throwBlock| void { $(Halide::Func* f)->reorder(*$(std::vector<Halide::VarOrRVar>* v)); } |]

-- | Statically declare the range over which the function will be evaluated in the general case.
--
-- This provides a basis for the auto scheduler to make trade-offs and scheduling decisions.
-- The auto generated schedules might break when the sizes of the dimensions are very different from the
-- estimates specified. These estimates are used only by the auto scheduler if the function is a pipeline output.
estimate
  :: (KnownNat n, IsHalideType a)
  => Expr Int32
  -- ^ index variable
  -> Expr Int32
  -- ^ @min@ estimate
  -> Expr Int32
  -- ^ @extent@ estimate
  -> Func t n a
  -> IO ()
estimate :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
estimate VarOrRVar
var VarOrRVar
min VarOrRVar
extent Func t n a
func =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
min forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
    [CU.exp| void {
      $(Halide::Func* f)->set_estimate(
        *$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]

-- | Statically declare the range over which a function should be evaluated.
--
-- This can let Halide perform some optimizations. E.g. if you know there are going to be 4 color channels,
-- you can completely vectorize the color channel dimension without the overhead of splitting it up.
-- If bounds inference decides that it requires more of this function than the bounds you have stated,
-- a runtime error will occur when you try to run your pipeline.
bound
  :: (KnownNat n, IsHalideType a)
  => Expr Int32
  -- ^ index variable
  -> Expr Int32
  -- ^ @min@ estimate
  -> Expr Int32
  -- ^ @extent@ estimate
  -> Func t n a
  -> IO ()
bound :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
bound VarOrRVar
var VarOrRVar
min VarOrRVar
extent Func t n a
func =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
min forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
    [CU.exp| void {
      $(Halide::Func* f)->bound(
        *$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]

-- | Get the index arguments of the function.
--
-- The returned list contains exactly @n@ elements.
getArgs :: (KnownNat n, IsHalideType a) => Func t n a -> IO [Var]
getArgs :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO [VarOrRVar]
getArgs Func t n a
func =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' -> do
    let allocate :: IO (Ptr (CxxVector CxxVar))
allocate =
          [CU.exp| std::vector<Halide::Var>* { 
            new std::vector<Halide::Var>{$(const Halide::Func* func')->args()} } |]
        destroy :: Ptr (CxxVector CxxVar) -> IO ()
destroy Ptr (CxxVector CxxVar)
v = [CU.exp| void { delete $(std::vector<Halide::Var>* v) } |]
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxVar))
allocate Ptr (CxxVector CxxVar) -> IO ()
destroy forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
v -> do
      CSize
n <- [CU.exp| size_t { $(const std::vector<Halide::Var>* v)->size() } |]
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CSize
0 .. CSize
n forall a. Num a => a -> a -> a
- CSize
1] forall a b. (a -> b) -> a -> b
$ \CSize
i ->
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
          [CU.exp| void {
            new ($(Halide::Var* ptr)) Halide::Var{$(const std::vector<Halide::Var>* v)->at($(size_t i))} } |]

-- | Compute all of this function once ahead of time.
--
-- See [Halide::Func::compute_root](https://halide-lang.org/docs/class_halide_1_1_func.html#a29df45a4a16a63eb81407261a9783060) for more info.
computeRoot :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func t n a)
computeRoot :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func t n a)
computeRoot Func t n a
func = do
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->compute_root(); }); } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func

-- | Creates and returns a new identity Func that wraps this Func.
--
-- During compilation, Halide replaces all calls to this Func done by 'f' with calls to the wrapper.
-- If this Func is already wrapped for use in 'f', will return the existing wrapper.
--
-- For more info, see [Halide::Func::in](https://halide-lang.org/docs/class_halide_1_1_func.html#a9d619f2d0111ea5bf640781d1324d050).
asUsedBy
  :: (KnownNat n, KnownNat m, IsHalideType a, IsHalideType b)
  => Func t1 n a
  -> Func 'FuncTy m b
  -> IO (Func 'FuncTy n a)
asUsedBy :: forall (n :: Nat) (m :: Nat) a b (t1 :: FuncTy).
(KnownNat n, KnownNat m, IsHalideType a, IsHalideType b) =>
Func t1 n a -> Func 'FuncTy m b -> IO (Func 'FuncTy n a)
asUsedBy Func t1 n a
g Func 'FuncTy m b
f =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t1 n a
g forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
gPtr -> forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy m b
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
    forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
            new Halide::Func{$(Halide::Func* gPtr)->in(*$(Halide::Func* fPtr))} } |]

-- | Create and return a global identity wrapper, which wraps all calls to this Func by any other Func.
--
-- If a global wrapper already exists, returns it. The global identity wrapper is only used by callers
-- for which no custom wrapper has been specified.
asUsed :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func 'FuncTy n a)
asUsed :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func 'FuncTy n a)
asUsed Func t n a
f =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
    forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* { new Halide::Func{$(Halide::Func* fPtr)->in()} } |]

-- | Declare that this function should be implemented by a call to @halide_buffer_copy@ with the given
-- target device API.
--
-- Asserts that the @Func@ has a pure definition which is a simple call to a single input, and no update
-- definitions. The wrapper @Func@s returned by 'asUsed' are suitable candidates. Consumes all pure variables,
-- and rewrites the @Func@ to have an extern definition that calls @halide_buffer_copy@.
copyToDevice :: (KnownNat n, IsHalideType a) => DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
deviceApi Func t n a
func = do
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    [C.throwBlock| void {
      handle_halide_exceptions([=](){
        $(Halide::Func* f)->copy_to_device(static_cast<Halide::DeviceAPI>($(int api)));
      });
    } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func
  where
    api :: CInt
api = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ DeviceAPI
deviceApi

-- | Same as @'copyToDevice' 'DeviceHost'@
copyToHost :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Func t n a)
copyToHost :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func t n a)
copyToHost = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
DeviceHost

-- | Split a dimension into inner and outer subdimensions with the given names, where the inner dimension
-- iterates from @0@ to @factor-1@.
--
-- The inner and outer subdimensions can then be dealt with using the other scheduling calls. It's okay
-- to reuse the old variable name as either the inner or outer variable. The first argument specifies
-- how the tail should be handled if the split factor does not provably divide the extent.
-- split
--   :: (KnownNat n, IsHalideType a)
--   => TailStrategy
--   -- ^ how to treat the remainder
--   -> Func t n a
--   -> Expr Int32
--   -- ^ loop variable to split
--   -> Expr Int32
--   -- ^ new outer loop variable
--   -> Expr Int32
--   -- ^ new inner loop variable
--   -> Expr Int32
--   -- ^ split factor
--   -> IO (Func t n a)
-- split tail func old outer inner factor = do
--   withFunc func $ \f ->
--     asVarOrRVar old $ \old' ->
--       asVarOrRVar outer $ \outer' ->
--         asVarOrRVar inner $ \inner' ->
--           asExpr factor $ \factor' ->
--             [C.throwBlock| void {
--               handle_halide_exceptions([=](){
--                 $(Halide::Func* f)->split(
--                   *$(const Halide::VarOrRVar* old'),
--                   *$(const Halide::VarOrRVar* outer'),
--                   *$(const Halide::VarOrRVar* inner'),
--                   *$(const Halide::Expr* factor'),
--                   static_cast<Halide::TailStrategy>($(int t)));
--               }); } |]
--   pure func
--   where
--     t = fromIntegral . fromEnum $ tail

-- | Join two dimensions into a single fused dimenion.
--
-- The fused dimension covers the product of the extents of the inner and outer dimensions given.
-- fuse
--   :: (KnownNat n, IsHalideType a)
--   => Func t n a
--   -> Expr Int32
--   -- ^ inner loop variable
--   -> Expr Int32
--   -- ^ outer loop variable
--   -> Expr Int32
--   -- ^ new fused loop variable
--   -> IO (Func t n a)
-- fuse func outer inner fused = do
--   withFunc func $ \f ->
--     asVarOrRVar outer $ \outer' ->
--       asVarOrRVar inner $ \inner' ->
--         asVarOrRVar fused $ \fused' ->
--           [CU.exp| void {
--                 $(Halide::Func* f)->fuse(
--                   *$(const Halide::VarOrRVar* outer'),
--                   *$(const Halide::VarOrRVar* inner'),
--                   *$(const Halide::VarOrRVar* fused')) } |]
--   pure func

-- withVarOrRVarMany :: [Expr Int32] -> (Int -> Ptr (CxxVector CxxVarOrRVar) -> IO a) -> IO a
-- withVarOrRVarMany xs f =
--   bracket allocate destroy $ \v -> do
--     let go !k [] = f k v
--         go !k (y : ys) = withVarOrRVarMany y $ \p -> do
--           [CU.exp| void { $(std::vector<Halide::Expr>* v)->push_back(*$(Halide::VarOrRVar* p)) } |]
--           go (k + 1) ys
--     go 0 xs
--   where
--     count = fromIntegral (length xs)

--   withFunc func $ \f ->
--     withVarOrRVarMany vars $ \count v -> do
--       unless natVal (Proxy @n)
--       handleHalideExceptionsM
--         [C.tryBlock| void {
--           $(Halide::Func* f)->reorder(*$(std::vector<Halide::VarOrRVar>* v));
--         } |]
--
-- class Curry (args :: [Type]) (r :: Type) (f :: Type) | args r -> f where
--   curryG :: (Arguments args -> r) -> f

mkBufferParameter
  :: forall n a. (KnownNat n, IsHalideType a) => Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter Maybe Text
maybeName = do
  forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
    let d :: CInt
d = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)
        createWithoutName :: IO (Ptr CxxImageParam)
createWithoutName =
          [CU.exp| Halide::ImageParam* {
            new Halide::ImageParam{Halide::Type{*$(halide_type_t* t)}, $(int d)} } |]
        deleter :: FunPtr (Ptr CxxImageParam -> IO ())
deleter = [C.funPtr| void deleteImageParam(Halide::ImageParam* p) { delete p; } |]
        createWithName :: Text -> IO (Ptr CxxImageParam)
createWithName Text
name =
          let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
           in [CU.exp| Halide::ImageParam* {
                new Halide::ImageParam{
                      Halide::Type{*$(halide_type_t* t)},
                      $(int d),
                      std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}} } |]
    forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxImageParam -> IO ())
deleter forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO (Ptr CxxImageParam)
createWithoutName Text -> IO (Ptr CxxImageParam)
createWithName Maybe Text
maybeName

getBufferParameter
  :: forall n a
   . (KnownNat n, IsHalideType a)
  => Maybe Text
  -> IORef (Maybe (ForeignPtr CxxImageParam))
  -> IO (ForeignPtr CxxImageParam)
getBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter Maybe Text
name IORef (Maybe (ForeignPtr CxxImageParam))
r =
  forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just ForeignPtr CxxImageParam
fp -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp
    Maybe (ForeignPtr CxxImageParam)
Nothing -> do
      ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a Maybe Text
name
      forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxImageParam))
r (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
fp)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp

-- | Same as 'withFunc', but ensures that we're dealing with 'Param' instead of a 'Func'.
withBufferParam
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => Func 'ParamTy n a
  -> (Ptr CxxImageParam -> IO b)
  -> IO b
withBufferParam :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n a -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam (Param IORef (Maybe (ForeignPtr CxxImageParam))
r) Ptr CxxImageParam -> IO b
action =
  forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr Ptr CxxImageParam -> IO b
action

-- instance (KnownNat n, IsHalideType a) => Named (Func 'ParamTy n a) where
--   setName :: Func 'ParamTy n a -> Text -> IO ()
--   setName (Param r) name = do
--     readIORef r >>= \case
--       Just _ -> error "the name of this Func has already been set"
--       Nothing -> do
--         fp <- mkBufferParameter @n @a (Just name)
--         writeIORef r (Just fp)

-- | Get the underlying pointer to @Halide::Func@ and invoke an 'IO' action with it.
withFunc :: (KnownNat n, IsHalideType a) => Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> ForeignPtr CxxFunc
funcToForeignPtr Func t n a
f)

wrapCxxFunc :: Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc :: forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxFunc -> Func 'FuncTy n a
Func forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxFunc -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxFunc -> IO ())
deleter = [C.funPtr| void deleteFunc(Halide::Func *x) { delete x; } |]

forceFunc :: forall t n a. (KnownNat n, IsHalideType a) => Func t n a -> IO (Func 'FuncTy n a)
forceFunc :: forall (t :: FuncTy) (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func 'FuncTy n a)
forceFunc x :: Func t n a
x@(Func ForeignPtr CxxFunc
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
x
forceFunc (Param IORef (Maybe (ForeignPtr CxxImageParam))
r) = do
  ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
    forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
            new Halide::Func{static_cast<Halide::Func>(*$(Halide::ImageParam* p))} } |]

funcToForeignPtr :: (KnownNat n, IsHalideType a) => Func t n a -> ForeignPtr CxxFunc
funcToForeignPtr :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> ForeignPtr CxxFunc
funcToForeignPtr Func t n a
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$! forall (t :: FuncTy) (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Func 'FuncTy n a)
forceFunc Func t n a
x forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Func ForeignPtr CxxFunc
fp) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxFunc
fp

-- | Define a Halide function.
--
-- @define "f" i e@ defines a Halide function called "f" such that @f[i] = e@.
--
-- Here, @i@ is an @n@-element tuple of t'Var', i.e. the following are all valid:
--
-- >>> [x, y, z] <- mapM mkVar ["x", "y", "z"]
-- >>> f1 <- define "f1" x (0 :: Expr Float)
-- >>> f2 <- define "f2" (x, y) (0 :: Expr Float)
-- >>> f3 <- define "f3" (x, y, z) (0 :: Expr Float)
define
  :: ( IsTuple (Arguments ts) i
     , All ((~) Var) ts
     , Length ts ~ n
     , KnownNat n
     , IsHalideType a
     )
  => Text
  -> i
  -> Expr a
  -> IO (Func 'FuncTy n a)
define :: forall (ts :: [*]) i (n :: Nat) a.
(IsTuple (Arguments ts) i, All ((~) VarOrRVar) ts, Length ts ~ n,
 KnownNat n, IsHalideType a) =>
Text -> i -> Expr a -> IO (Func 'FuncTy n a)
define Text
name i
args Expr a
expr =
  forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar (forall a t. IsTuple a t => t -> a
fromTuple i
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
x -> do
    let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
    forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
y ->
      forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.block| Halide::Func* {
              Halide::Func f{std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}};
              f(*$(std::vector<Halide::Var>* x)) = *$(Halide::Expr* y);
              return new Halide::Func{f};
            } |]

-- | Create an update definition for a Halide function.
--
-- @update f i e@ creates an update definition for @f@ that performs @f[i] = e@.
update
  :: ( IsTuple (Arguments ts) i
     , All ((~) (Expr Int32)) ts
     , Length ts ~ n
     , KnownNat n
     , IsHalideType a
     )
  => Func 'FuncTy n a
  -> i
  -> Expr a
  -> IO ()
update :: forall (ts :: [*]) i (n :: Nat) a.
(IsTuple (Arguments ts) i, All ((~) VarOrRVar) ts, Length ts ~ n,
 KnownNat n, IsHalideType a) =>
Func 'FuncTy n a -> i -> Expr a -> IO ()
update Func 'FuncTy n a
func i
args Expr a
expr =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall a t. IsTuple a t => t -> a
fromTuple i
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
      forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
y ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Func* f)->operator()(*$(std::vector<Halide::Expr>* x)) = *$(Halide::Expr* y);
          });
        } |]

infix 9 !

-- | Apply a Halide function. Conceptually, @f ! i@ is equivalent to @f[i]@, i.e.
-- indexing into a lazy array.
(!)
  :: ( IsTuple (Arguments ts) i
     , All ((~) (Expr Int32)) ts
     , Length ts ~ n
     , KnownNat n
     , IsHalideType a
     )
  => Func t n a
  -> i
  -> Expr a
! :: forall (ts :: [*]) i (n :: Nat) a (t :: FuncTy).
(IsTuple (Arguments ts) i, All ((~) VarOrRVar) ts, Length ts ~ n,
 KnownNat n, IsHalideType a) =>
Func t n a -> i -> Expr a
(!) Func t n a
func i
args =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
      forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall a t. IsTuple a t => t -> a
fromTuple i
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
        forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
          [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
            $(Halide::Func* f)->operator()(*$(std::vector<Halide::Expr>* x))} } |]

-- | Get a particular dimension of a pipeline parameter.
dim
  :: forall n a
   . (HasCallStack, KnownNat n, IsHalideType a)
  => Int
  -> Func 'ParamTy n a
  -> IO Dimension
dim :: forall (n :: Nat) a.
(HasCallStack, KnownNat n, IsHalideType a) =>
Int -> Func 'ParamTy n a -> IO Dimension
dim Int
k Func 'ParamTy n a
func
  | Int
0 forall a. Ord a => a -> a -> Bool
<= Int
k Bool -> Bool -> Bool
&& Int
k forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) =
      let n :: CInt
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
       in forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n a -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func 'ParamTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
f ->
            Ptr CxxDimension -> IO Dimension
wrapCxxDimension
              forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Internal::Dimension* {
                    new Halide::Internal::Dimension{$(Halide::ImageParam* f)->dim($(int n))} } |]
  | Bool
otherwise =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"invalid dimension index: "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
          forall a. Semigroup a => a -> a -> a
<> String
"; Func is "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
          forall a. Semigroup a => a -> a -> a
<> String
"-dimensional"

-- | Write out the loop nests specified by the schedule for this function.
--
-- Helpful for understanding what a schedule is doing.
--
-- For more info, see
-- [@Halide::Func::print_loop_nest@](https://halide-lang.org/docs/class_halide_1_1_func.html#a03f839d9e13cae4b87a540aa618589ae)
-- printLoopNest :: (KnownNat n, IsHalideType r) => Func n r -> IO ()
-- printLoopNest func = withFunc func $ \f ->
--   [C.exp| void { $(Halide::Func* f)->print_loop_nest() } |]

-- | Get the loop nests specified by the schedule for this function.
--
-- Helpful for understanding what a schedule is doing.
--
-- For more info, see
-- [@Halide::Func::print_loop_nest@](https://halide-lang.org/docs/class_halide_1_1_func.html#a03f839d9e13cae4b87a540aa618589ae)
prettyLoopNest :: (KnownNat n, IsHalideType r) => Func t n r -> IO Text
prettyLoopNest :: forall (n :: Nat) r (t :: FuncTy).
(KnownNat n, IsHalideType r) =>
Func t n r -> IO Text
prettyLoopNest Func t n r
func = forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n r
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
  Ptr CxxString -> IO Text
peekAndDeleteCxxString
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| std::string* {
          return handle_halide_exceptions([=]() {
            return new std::string{Halide::Internal::print_loop_nest(
              std::vector<Halide::Internal::Function>{$(Halide::Func* f)->function()})};
          });
        } |]

-- | Evaluate this function over a rectangular domain.
realize
  :: forall n a t b
   . (KnownNat n, IsHalideType a)
  => Func t n a
  -- ^ Function to evaluate
  -> [Int]
  -- ^ Domain over which to evaluate
  -> (Ptr (HalideBuffer n a) -> IO b)
  -- ^ What to do with the buffer afterwards. Note that the buffer is allocated only temporary,
  -- so do not return it directly.
  -> IO b
realize :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
realize Func t n a
func [Int]
shape Ptr (HalideBuffer n a) -> IO b
action =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
[Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaCpuBuffer [Int]
shape forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
buf -> do
      let raw :: Ptr RawHalideBuffer
raw = forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
buf
      [C.throwBlock| void {
        handle_halide_exceptions([=](){
          $(Halide::Func* f)->realize(
            Halide::Pipeline::RealizationArg{$(halide_buffer_t* raw)});
        });
      } |]
      Ptr (HalideBuffer n a) -> IO b
action Ptr (HalideBuffer n a)
buf

-- \| Evaluate this function over a one-dimensional domain and return the
-- resulting buffer or buffers.
-- realize1D
--   :: forall a t
--    . IsHalideType a
--   => Int
--   -- ^ @size@ of the domain. The function will be evaluated on @[0, ..., size -1]@
--   -> Func t 1 a
--   -- ^ Function to evaluate
--   -> IO (Vector a)
-- realize1D size func = do
--   buf <- SM.new size
--   withHalideBuffer @1 @a buf $ \x -> do
--     let b = castPtr x
--     withFunc func $ \f ->
--       [CU.exp| void {
--         $(Halide::Func* f)->realize(
--           Halide::Pipeline::RealizationArg{$(halide_buffer_t* b)}) } |]
--   S.unsafeFreeze buf

-- | A view pattern to specify the name of a buffer argument.
--
-- Example usage:
--
-- >>> :{
-- _ <- compile $ \(buffer "src" -> src) -> do
--   i <- mkVar "i"
--   define "dest" i $ (src ! i :: Expr Float)
-- :}
--
-- or if we want to specify the dimension and type, we can use type applications:
--
-- >>> :{
-- _ <- compile $ \(buffer @1 @Float "src" -> src) -> do
--   i <- mkVar "i"
--   define "dest" i $ src ! i
-- :}
buffer :: forall n a. (KnownNat n, IsHalideType a) => Text -> Func 'ParamTy n a -> Func 'ParamTy n a
buffer :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Text -> Func 'ParamTy n a -> Func 'ParamTy n a
buffer Text
name p :: Func 'ParamTy n a
p@(Param IORef (Maybe (ForeignPtr CxxImageParam))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  ForeignPtr CxxImageParam
_ <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a (forall a. a -> Maybe a
Just Text
name) IORef (Maybe (ForeignPtr CxxImageParam))
r
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'ParamTy n a
p

-- | Similar to 'buffer', but for scalar parameters.
--
-- Example usage:
--
-- >>> :{
-- _ <- compile $ \(scalar @Float "a" -> a) -> do
--   i <- mkVar "i"
--   define "dest" i $ a
-- :}
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar Text
name (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just ForeignPtr CxxParameter
_ -> forall a. HasCallStack => String -> a
error String
"the name of this Expr has already been set"
    Maybe (ForeignPtr CxxParameter)
Nothing -> do
      ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter @a (forall a. a -> Maybe a
Just Text
name)
      forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxParameter))
r (forall a. a -> Maybe a
Just ForeignPtr CxxParameter
fp)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r)
scalar Text
_ Expr a
_ = forall a. HasCallStack => String -> a
error String
"cannot set the name of an expression that is not a parameter"

wrapCxxStage :: (KnownNat n, IsHalideType a) => Ptr CxxStage -> IO (Stage n a)
wrapCxxStage :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxStage -> Stage n a
Stage forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxStage -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxStage -> IO ())
deleter = [C.funPtr| void deleteStage(Halide::Stage* p) { delete p; } |]

withCxxStage :: (KnownNat n, IsHalideType a) => Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage :: forall (n :: Nat) a b.
(KnownNat n, IsHalideType a) =>
Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage (Stage ForeignPtr CxxStage
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxStage
fp

-- | Get the pure stage of a 'Func' for the purposes of scheduling it.
getStage :: (KnownNat n, IsHalideType a) => Func t n a -> IO (Stage n a)
getStage :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
    [CU.exp| Halide::Stage* { new Halide::Stage{static_cast<Halide::Stage>(*$(Halide::Func* func'))} } |]
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage

-- | Return 'True' when the function has update definitions, 'False' otherwise.
hasUpdateDefinitions :: (KnownNat n, IsHalideType a) => Func t n a -> IO Bool
hasUpdateDefinitions :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> IO Bool
hasUpdateDefinitions Func t n a
func =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
    forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Func* func')->has_update_definition() } |]

-- | Get a handle to an update step for the purposes of scheduling it.
getUpdateStage :: (KnownNat n, IsHalideType a) => Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage Int
k Func 'FuncTy n a
func =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
    let k' :: CInt
k' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
     in [CU.exp| Halide::Stage* { new Halide::Stage{$(Halide::Func* func')->update($(int k'))} } |]
          forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Ptr CxxStage -> IO (Stage n a)
wrapCxxStage

-- | Identify the loop nest corresponding to some dimension of some function.
getLoopLevelAtStage
  :: (KnownNat n, IsHalideType a)
  => Func t n a
  -> Expr Int32
  -> Int
  -- ^ update index
  -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
func VarOrRVar
var Int
stageIndex =
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
i -> do
    (SomeLoopLevel LoopLevel t
level) <-
      Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::LoopLevel* {
              return handle_halide_exceptions([=](){
                return new Halide::LoopLevel{*$(const Halide::Func* f),
                                             *$(const Halide::VarOrRVar* i),
                                             $(int k)};
              });
            } |]
    case LoopLevel t
level of
      LoopLevel ForeignPtr CxxLoopLevel
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopLevel t
level
      LoopLevel t
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getLoopLevelAtStage: got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LoopLevel t
level forall a. Semigroup a => a -> a -> a
<> String
", but expected a LoopLevel 'LockedTy"
  where
    k :: CInt
k = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stageIndex

-- | Same as 'getLoopLevelAtStage' except that the stage is @-1@.
getLoopLevel :: (KnownNat n, IsHalideType a) => Func t n a -> Expr Int32 -> IO (LoopLevel 'LockedTy)
getLoopLevel :: forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> VarOrRVar -> IO (LoopLevel 'LockedTy)
getLoopLevel Func t n a
f VarOrRVar
i = forall (n :: Nat) a (t :: FuncTy).
(KnownNat n, IsHalideType a) =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
f VarOrRVar
i (-Int
1)

-- | Allocate storage for this function within a particular loop level.
--
-- Scheduling storage is optional, and can be used to separate the loop level at which storage is allocated
-- from the loop level at which computation occurs to trade off between locality and redundant work.
--
-- For more info, see [Halide::Func::store_at](https://halide-lang.org/docs/class_halide_1_1_func.html#a417c08f8aa3a5cdf9146fba948b65193).
storeAt :: (KnownNat n, IsHalideType a) => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
(KnownNat n, IsHalideType a) =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt Func 'FuncTy n a
func LoopLevel t
level = do
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
      [CU.exp| void { $(Halide::Func* f)->store_at(*$(const Halide::LoopLevel* l)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func

-- | Schedule a function to be computed within the iteration over a given loop level.
--
-- For more info, see [Halide::Func::compute_at](https://halide-lang.org/docs/class_halide_1_1_func.html#a800cbcc3ca5e3d3fa1707f6e1990ec83).
computeAt :: (KnownNat n, IsHalideType a) => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
(KnownNat n, IsHalideType a) =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt Func 'FuncTy n a
func LoopLevel t
level = do
  forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
      [CU.exp| void { $(Halide::Func* f)->compute_at(*$(const Halide::LoopLevel* l)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func

-- | Wrap a buffer into a t'Func'.
--
-- Suppose, we are defining a pipeline that adds together two vectors, and we'd like to call 'realize' to
-- evaluate it directly, how do we pass the vectors to the t'Func'? 'asBufferParam' allows to do exactly this.
--
-- > asBuffer [1, 2, 3] $ \a ->
-- >   asBuffer [4, 5, 6] $ \b -> do
-- >     i <- mkVar "i"
-- >     f <- define "vectorAdd" i $ a ! i + b ! i
-- >     realize f [3] $ \result ->
-- >       print =<< peekToList f
asBufferParam
  :: forall n a t b
   . IsHalideBuffer t n a
  => t
  -- ^ Object to treat as a buffer
  -> (Func 'ParamTy n a -> IO b)
  -- ^ What to do with the __temporary__ buffer
  -> IO b
asBufferParam :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Func 'ParamTy n a -> IO b) -> IO b
asBufferParam t
arr Func 'ParamTy n a -> IO b
action =
  forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer @n @a t
arr forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
arr' -> do
    ForeignPtr CxxImageParam
param <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a forall a. Maybe a
Nothing
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
param' ->
      let buf :: Ptr RawHalideBuffer
buf = (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
arr' :: Ptr RawHalideBuffer)
       in [CU.block| void {
            $(Halide::ImageParam* param')->set(Halide::Buffer<>{*$(const halide_buffer_t* buf)});
          } |]
    Func 'ParamTy n a -> IO b
action forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) a.
IORef (Maybe (ForeignPtr CxxImageParam)) -> Func 'ParamTy n a
Param forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. a -> IO (IORef a)
newIORef (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
param)