{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-orphans #-}

-- |
-- Module      : Language.Halide.Expr
-- Description : Scalar expressions
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Expr
  ( Expr (..)
  , Var
  , RVar
  , VarOrRVar
  , Int32
  , mkExpr
  , mkVar
  , mkRVar
  , cast
  , eq
  , neq
  , lt
  , lte
  , gt
  , gte
  , bool
  , undef
    -- | For debugging, it's often useful to observe the value of an expression when it's evaluated. If you
    -- have a complex expression that does not depend on any buffers or indices, you can 'evaluate' it.
  , evaluate
    -- | However, often an expression is only used within a definition of a pipeline, and it's impossible to
    -- call 'evaluate' on it. In such cases, it can be wrapped with 'printed' to indicate to Halide that the
    -- value of the expression should be dumped to screen when it's computed.
  , printed
  , toIntImm

    -- * Internal
  , exprToForeignPtr
  , cxxConstructExpr
  -- , wrapCxxExpr
  , wrapCxxRVar
  , wrapCxxVarOrRVar
  , wrapCxxParameter
  , asExpr
  , asVar
  , asRVar
  , asVarOrRVar
  , asScalarParam
  , asVectorOf
  , mkScalarParameter
  , withMany
  , binaryOp
  , unaryOp
  , checkType
  )
where

import Control.Exception (bracket)
import Control.Monad (unless)
import Data.IORef
import Data.Int (Int32)
import Data.Proxy
import Data.Ratio (denominator, numerator)
import Data.Text (Text, unpack)
import Data.Text.Encoding qualified as T
import Data.Vector.Storable.Mutable qualified as SM
import Foreign.ForeignPtr
import Foreign.Marshal (alloca, allocaArray, peekArray, toBool, with)
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import Foreign.Storable (peek)
import GHC.Stack (HasCallStack)
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (min)

importHalide

instanceCxxConstructible "Halide::Expr"
instanceCxxConstructible "Halide::Var"
instanceCxxConstructible "Halide::RVar"
instanceCxxConstructible "Halide::VarOrRVar"

Storable Double
Storable Float
Storable Int8
Storable Int16
Storable Int32
Storable Int64
Storable Word8
Storable Word16
Storable Word32
Storable Word64
Storable CFloat
Storable CDouble
Double -> IO (ForeignPtr CxxExpr)
Float -> IO (ForeignPtr CxxExpr)
Int8 -> IO (ForeignPtr CxxExpr)
Int16 -> IO (ForeignPtr CxxExpr)
Int32 -> IO (ForeignPtr CxxExpr)
Int64 -> IO (ForeignPtr CxxExpr)
Word8 -> IO (ForeignPtr CxxExpr)
Word16 -> IO (ForeignPtr CxxExpr)
Word32 -> IO (ForeignPtr CxxExpr)
Word64 -> IO (ForeignPtr CxxExpr)
CFloat -> IO (ForeignPtr CxxExpr)
CDouble -> IO (ForeignPtr CxxExpr)
forall a.
Storable a
-> (forall (proxy :: * -> *). proxy a -> HalideType)
-> (a -> IO (ForeignPtr CxxExpr))
-> IsHalideType a
forall (proxy :: * -> *). proxy Double -> HalideType
forall (proxy :: * -> *). proxy Float -> HalideType
forall (proxy :: * -> *). proxy Int8 -> HalideType
forall (proxy :: * -> *). proxy Int16 -> HalideType
forall (proxy :: * -> *). proxy Int32 -> HalideType
forall (proxy :: * -> *). proxy Int64 -> HalideType
forall (proxy :: * -> *). proxy Word8 -> HalideType
forall (proxy :: * -> *). proxy Word16 -> HalideType
forall (proxy :: * -> *). proxy Word32 -> HalideType
forall (proxy :: * -> *). proxy Word64 -> HalideType
forall (proxy :: * -> *). proxy CFloat -> HalideType
forall (proxy :: * -> *). proxy CDouble -> HalideType
toCxxExpr :: Float -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Float -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Float -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Float -> HalideType
toCxxExpr :: CFloat -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: CFloat -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy CFloat -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy CFloat -> HalideType
toCxxExpr :: Double -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Double -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Double -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Double -> HalideType
toCxxExpr :: CDouble -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: CDouble -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy CDouble -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy CDouble -> HalideType
toCxxExpr :: Int8 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int8 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int8 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int8 -> HalideType
toCxxExpr :: Int16 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int16 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int16 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int16 -> HalideType
toCxxExpr :: Int32 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int32 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int32 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int32 -> HalideType
toCxxExpr :: Int64 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int64 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int64 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int64 -> HalideType
toCxxExpr :: Word8 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word8 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word8 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word8 -> HalideType
toCxxExpr :: Word16 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word16 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word16 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word16 -> HalideType
toCxxExpr :: Word32 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word32 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word32 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word32 -> HalideType
toCxxExpr :: Word64 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word64 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word64 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word64 -> HalideType
defineIsHalideTypeInstances

instanceHasCxxVector "Halide::Expr"
instanceHasCxxVector "Halide::Var"
instanceHasCxxVector "Halide::RVar"
instanceHasCxxVector "Halide::VarOrRVar"

-- instanceCxxConstructible "Halide::Var"
-- instanceCxxConstructible "Halide::RVar"
-- instanceCxxConstructible "Halide::VarOrRVar"

instance IsHalideType Bool where
  halideTypeFor :: forall (proxy :: * -> *). proxy Bool -> HalideType
halideTypeFor proxy Bool
_ = HalideTypeCode -> Word8 -> Word16 -> HalideType
HalideType HalideTypeCode
HalideTypeUInt Word8
1 Word16
1
  toCxxExpr :: Bool -> IO (ForeignPtr CxxExpr)
toCxxExpr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
x) =
    forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
      [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{cast(Halide::UInt(1), Halide::Expr{$(int x)})} } |]

type instance FromTuple (Expr a) = Arguments '[Expr a]

-- | A scalar expression in Halide.
--
-- To have a nice experience writing arithmetic expressions in terms of @Expr@s, we want to derive 'Num',
-- 'Floating' etc. instances for @Expr@. Unfortunately, that means that we encode v'Expr', v'Var', v'RVar',
-- and v'ScalarParam' by the same type, and passing an @Expr@ to a function that expects a @Var@ will produce
-- a runtime error.
data Expr a
  = -- | Scalar expression.
    Expr (ForeignPtr CxxExpr)
  | -- | Index variable.
    Var (ForeignPtr CxxVar)
  | -- | Reduction variable.
    RVar (ForeignPtr CxxRVar)
  | -- | Scalar parameter.
    --
    -- The 'IORef' is initialized with 'Nothing' and filled in on the first
    -- call to 'asExpr'.
    ScalarParam (IORef (Maybe (ForeignPtr CxxParameter)))

-- | A v'Var'.
type Var = Expr Int32

-- | An v'RVar'.
type RVar = Expr Int32

-- | Either v'Var' or v'RVar'.
type VarOrRVar = Expr Int32

-- | Create a scalar expression from a Haskell value.
mkExpr :: IsHalideType a => a -> Expr a
mkExpr :: forall a. IsHalideType a => a -> Expr a
mkExpr a
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$! forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IsHalideType a => a -> IO (ForeignPtr CxxExpr)
toCxxExpr a
x

-- | Create a named index variable.
mkVar :: Text -> IO (Expr Int32)
mkVar :: Text -> IO (Expr Int32)
mkVar (Text -> ByteString
T.encodeUtf8 -> ByteString
s) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
  [CU.exp| void {
    new ($(Halide::Var* ptr)) Halide::Var{std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}} } |]

-- | Create a named reduction variable.
--
-- For more information about reduction variables, see [@Halide::RDom@](https://halide-lang.org/docs/class_halide_1_1_r_dom.html).
mkRVar
  :: Text
  -- ^ name
  -> Expr Int32
  -- ^ min index
  -> Expr Int32
  -- ^ extent
  -> IO (Expr Int32)
mkRVar :: Text -> Expr Int32 -> Expr Int32 -> IO (Expr Int32)
mkRVar Text
name Expr Int32
min Expr Int32
extent =
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
min forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
min' ->
    forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extent' ->
      Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::RVar* {
              new Halide::RVar{static_cast<Halide::RVar>(Halide::RDom{
                *$(const Halide::Expr* min'),
                *$(const Halide::Expr* extent'),
                std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}
                })}
            } |]
  where
    s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name

-- | Return an undef value of the given type.
--
-- For more information, see [@Halide::undef@](https://halide-lang.org/docs/namespace_halide.html#a9389bcacbed602df70eae94826312e03).
undef :: forall a. IsHalideType a => Expr a
undef :: forall a. IsHalideType a => Expr a
undef = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
tp ->
    forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
      [CU.exp| void {
        new ($(Halide::Expr* ptr))
          Halide::Expr{Halide::undef(Halide::Type{*$(const halide_type_t* tp)})} } |]
{-# NOINLINE undef #-}

-- | Cast a scalar expression to a different type.
--
-- Use TypeApplications with this function, e.g. @cast \@Float x@.
cast :: forall to from. (IsHalideType to, IsHalideType from) => Expr from -> Expr to
cast :: forall to from.
(IsHalideType to, IsHalideType from) =>
Expr from -> Expr to
cast Expr from
expr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr from
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e ->
    forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @to)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t ->
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          Halide::cast(Halide::Type{*$(halide_type_t* t)}, *$(Halide::Expr* e))} } |]

-- | Print the expression to stdout when it's evaluated.
--
-- This is useful for debugging Halide pipelines.
printed :: IsHalideType a => Expr a -> Expr a
printed :: forall a. IsHalideType a => Expr a -> Expr a
printed = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e Ptr CxxExpr
ptr ->
  [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{print(*$(Halide::Expr* e))} } |]

infix 4 `eq`, `neq`, `lt`, `lte`, `gt`, `gte`

-- | '==' but lifted to return an 'Expr'.
eq :: IsHalideType a => Expr a -> Expr a -> Expr Bool
eq :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
eq = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
  [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
    (*$(Halide::Expr* a)) == (*$(Halide::Expr* b))} } |]

-- | '/=' but lifted to return an 'Expr'.
neq :: IsHalideType a => Expr a -> Expr a -> Expr Bool
neq :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
neq = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
  [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
    (*$(Halide::Expr* a)) != (*$(Halide::Expr* b))} } |]

-- | '<' but lifted to return an 'Expr'.
lt :: IsHalideType a => Expr a -> Expr a -> Expr Bool
lt :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
lt = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
  [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
    (*$(Halide::Expr* a)) < (*$(Halide::Expr* b))} } |]

-- | '<=' but lifted to return an 'Expr'.
lte :: IsHalideType a => Expr a -> Expr a -> Expr Bool
lte :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
lte = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
  [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
    (*$(Halide::Expr* a)) <= (*$(Halide::Expr* b))} } |]

-- | '>' but lifted to return an 'Expr'.
gt :: IsHalideType a => Expr a -> Expr a -> Expr Bool
gt :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
gt = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
  [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
    (*$(Halide::Expr* a)) > (*$(Halide::Expr* b))} } |]

-- | '>=' but lifted to return an 'Expr'.
gte :: IsHalideType a => Expr a -> Expr a -> Expr Bool
gte :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
gte = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
  [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
    (*$(Halide::Expr* a)) >= (*$(Halide::Expr* b))} } |]

-- | Similar to the standard 'Prelude.bool' function from Prelude except that it's
-- lifted to work with 'Expr' types.
bool :: IsHalideType a => Expr Bool -> Expr a -> Expr a -> Expr a
bool :: forall a. IsHalideType a => Expr Bool -> Expr a -> Expr a -> Expr a
bool Expr Bool
condExpr Expr a
trueExpr Expr a
falseExpr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Bool
condExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
p ->
    forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
trueExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
t ->
      forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
falseExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
f ->
        forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
          [CU.exp| void {
            new ($(Halide::Expr* ptr)) Halide::Expr{
              Halide::select(*$(Halide::Expr* p),
                *$(Halide::Expr* t), *$(Halide::Expr* f))} } |]

-- | Evaluate a scalar expression.
--
-- It should contain no parameters. If it does contain parameters, an exception will be thrown.
evaluate :: forall a. IsHalideType a => Expr a -> IO a
evaluate :: forall a. IsHalideType a => Expr a -> IO a
evaluate Expr a
expr =
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e -> do
    MVector RealWorld a
out <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
SM.new Int
1
    forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer MVector RealWorld a
out forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer 1 a)
buffer -> do
      let b :: Ptr RawHalideBuffer
b = forall a b. Ptr a -> Ptr b
castPtr (Ptr (HalideBuffer 1 a)
buffer :: Ptr (HalideBuffer 1 a))
      [C.throwBlock| void {
        handle_halide_exceptions([=]() {
          Halide::Func f;
          Halide::Var i;
          f(i) = *$(Halide::Expr* e);
          f.realize(Halide::Pipeline::RealizationArg{$(halide_buffer_t* b)});
        });
      } |]
    forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m a
SM.read MVector RealWorld a
out Int
0

-- | Convert expression to integer immediate.
--
-- Tries to extract the value of an expression if it is a compile-time constant. If the expression
-- isn't known at compile-time of the Halide pipeline, returns 'Nothing'.
toIntImm :: IsHalideType a => Expr a -> Maybe Int
toIntImm :: forall a. IsHalideType a => Expr a -> Maybe Int
toIntImm Expr a
expr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
expr' -> do
    Ptr Int64
intPtr <-
      [CU.block| const int64_t* {
        auto expr = *$(const Halide::Expr* expr');
        Halide::Internal::IntImm const* node = expr.as<Halide::Internal::IntImm>();
        if (node == nullptr) return nullptr;
        return &node->value;
      } |]
    if Ptr Int64
intPtr forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
      then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
      else forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr Int64
intPtr

instance IsTuple (Arguments '[Expr a]) (Expr a) where
  toTuple :: Arguments '[Expr a] -> Expr a
toTuple (t
x ::: Arguments ts
Nil) = t
x
  fromTuple :: Expr a -> Arguments '[Expr a]
fromTuple Expr a
x = Expr a
x forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments '[]
Nil

instance IsHalideType a => Show (Expr a) where
  show :: Expr a -> String
show (Expr ForeignPtr CxxExpr
expr) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
x ->
      Ptr CxxString -> IO Text
peekAndDeleteCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* { to_string_via_iostream(*$(const Halide::Expr* x)) } |]
  show (Var ForeignPtr CxxVar
var) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
x ->
      Ptr CxxString -> IO Text
peekAndDeleteCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* { to_string_via_iostream(*$(const Halide::Var* x)) } |]
  show (RVar ForeignPtr CxxRVar
rvar) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
rvar forall a b. (a -> b) -> a -> b
$ \Ptr CxxRVar
x ->
      Ptr CxxString -> IO Text
peekAndDeleteCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* { to_string_via_iostream(*$(const Halide::RVar* x)) } |]
  show (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    Maybe (ForeignPtr CxxParameter)
maybeParam <- forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r
    case Maybe (ForeignPtr CxxParameter)
maybeParam of
      Just ForeignPtr CxxParameter
fp ->
        forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxParameter
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
x ->
          Ptr CxxString -> IO Text
peekAndDeleteCxxString
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* {
                  new std::string{$(const Halide::Internal::Parameter* x)->name()} } |]
      Maybe (ForeignPtr CxxParameter)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
"ScalarParam"

instance (IsHalideType a, Num a) => Num (Expr a) where
  fromInteger :: Integer -> Expr a
  fromInteger :: Integer -> Expr a
fromInteger Integer
x = forall a. IsHalideType a => a -> Expr a
mkExpr (forall a. Num a => Integer -> a
fromInteger Integer
x :: a)
  (+) :: Expr a -> Expr a -> Expr a
  + :: Expr a -> Expr a -> Expr a
(+) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
    [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) + *$(Halide::Expr* b)} } |]
  (-) :: Expr a -> Expr a -> Expr a
  (-) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
    [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) - *$(Halide::Expr* b)} } |]
  (*) :: Expr a -> Expr a -> Expr a
  * :: Expr a -> Expr a -> Expr a
(*) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
    [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) * *$(Halide::Expr* b)} } |]

  abs :: Expr a -> Expr a
  abs :: Expr a -> Expr a
abs = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr ->
    -- If the type is unsigned, then abs does nothing Also note that for signed
    -- integers, in Halide abs returns the unsigned version, so we manually
    -- cast it back.
    [CU.block| void {
      if ($(Halide::Expr* a)->type().is_uint()) {
        new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a)};
      }
      else {
        new ($(Halide::Expr* ptr)) Halide::Expr{
          Halide::cast($(Halide::Expr* a)->type(), Halide::abs(*$(Halide::Expr* a)))};
      }
    } |]
  negate :: Expr a -> Expr a
  negate :: Expr a -> Expr a
negate = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr ->
    [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{ -(*$(Halide::Expr* a))} } |]
  signum :: Expr a -> Expr a
  signum :: Expr a -> Expr a
signum = forall a. HasCallStack => String -> a
error String
"Num instance of (Expr a) does not implement signum"

instance (IsHalideType a, Fractional a) => Fractional (Expr a) where
  (/) :: Expr a -> Expr a -> Expr a
  / :: Expr a -> Expr a -> Expr a
(/) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
    [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) / *$(Halide::Expr* b)} } |]
  fromRational :: Rational -> Expr a
  fromRational :: Rational -> Expr a
fromRational Rational
r = forall a. Num a => Integer -> a
fromInteger (forall a. Ratio a -> a
numerator Rational
r) forall a. Fractional a => a -> a -> a
/ forall a. Num a => Integer -> a
fromInteger (forall a. Ratio a -> a
denominator Rational
r)

instance (IsHalideType a, Floating a) => Floating (Expr a) where
  pi :: Expr a
  pi :: Expr a
pi = forall to from.
(IsHalideType to, IsHalideType from) =>
Expr from -> Expr to
cast @a @Double forall a b. (a -> b) -> a -> b
$! forall a. IsHalideType a => a -> Expr a
mkExpr (forall a. Floating a => a
pi :: Double)
  exp :: Expr a -> Expr a
  exp :: Expr a -> Expr a
exp = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::exp(*$(Halide::Expr* a))} } |]
  log :: Expr a -> Expr a
  log :: Expr a -> Expr a
log = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::log(*$(Halide::Expr* a))} } |]
  sqrt :: Expr a -> Expr a
  sqrt :: Expr a -> Expr a
sqrt = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::sqrt(*$(Halide::Expr* a))} } |]
  (**) :: Expr a -> Expr a -> Expr a
  ** :: Expr a -> Expr a -> Expr a
(**) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
    [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::pow(*$(Halide::Expr* a), *$(Halide::Expr* b))} } |]
  sin :: Expr a -> Expr a
  sin :: Expr a -> Expr a
sin = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::sin(*$(Halide::Expr* a))} } |]
  cos :: Expr a -> Expr a
  cos :: Expr a -> Expr a
cos = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::cos(*$(Halide::Expr* a))} } |]
  tan :: Expr a -> Expr a
  tan :: Expr a -> Expr a
tan = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::tan(*$(Halide::Expr* a))} } |]
  asin :: Expr a -> Expr a
  asin :: Expr a -> Expr a
asin = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::asin(*$(Halide::Expr* a))} } |]
  acos :: Expr a -> Expr a
  acos :: Expr a -> Expr a
acos = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::acos(*$(Halide::Expr* a))} } |]
  atan :: Expr a -> Expr a
  atan :: Expr a -> Expr a
atan = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::atan(*$(Halide::Expr* a))} } |]
  sinh :: Expr a -> Expr a
  sinh :: Expr a -> Expr a
sinh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::sinh(*$(Halide::Expr* a))} } |]
  cosh :: Expr a -> Expr a
  cosh :: Expr a -> Expr a
cosh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::cosh(*$(Halide::Expr* a))} } |]
  tanh :: Expr a -> Expr a
  tanh :: Expr a -> Expr a
tanh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::tanh(*$(Halide::Expr* a))} } |]
  asinh :: Expr a -> Expr a
  asinh :: Expr a -> Expr a
asinh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::asinh(*$(Halide::Expr* a))} } |]
  acosh :: Expr a -> Expr a
  acosh :: Expr a -> Expr a
acosh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::acosh(*$(Halide::Expr* a))} } |]
  atanh :: Expr a -> Expr a
  atanh :: Expr a -> Expr a
atanh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::atanh(*$(Halide::Expr* a))} } |]

-- | Wrap a raw @Halide::Expr@ pointer in a Haskell value.
--
-- __Note:__ This function checks the runtime type of the expression.
-- wrapCxxExpr :: forall a. (HasCallStack, IsHalideType a) => Ptr CxxExpr -> IO (Expr a)
-- wrapCxxExpr p = do
--   checkType @a p
--   Expr <$> newForeignPtr deleter p
--   where
--     deleter = [C.funPtr| void deleteExpr(Halide::Expr *p) { delete p; } |]
cxxConstructExpr :: forall a. (HasCallStack, IsHalideType a) => (Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr :: forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr Ptr CxxExpr -> IO ()
construct = do
  ForeignPtr CxxExpr
fp <- forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct Ptr CxxExpr -> IO ()
construct
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
fp (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
fp)

-- | Wrap a raw @Halide::RVar@ pointer in a Haskell value.
--
-- __Note:__ v'RVar' objects correspond to expressions of type 'Int32'.
wrapCxxRVar :: Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar :: Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxRVar -> Expr a
RVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxRVar -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxRVar -> IO ())
deleter = [C.funPtr| void deleteExpr(Halide::RVar *p) { delete p; } |]

wrapCxxVarOrRVar :: Ptr CxxVarOrRVar -> IO (Expr Int32)
wrapCxxVarOrRVar :: Ptr CxxVarOrRVar -> IO (Expr Int32)
wrapCxxVarOrRVar Ptr CxxVarOrRVar
p = do
  Bool
isRVar <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::VarOrRVar* p)->is_rvar } |]
  Expr Int32
expr <-
    if Bool
isRVar
      then Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::RVar* { new Halide::RVar{$(const Halide::VarOrRVar* p)->rvar} } |]
      else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
        [CU.exp| void { new ($(Halide::Var* ptr)) Halide::Var{$(const Halide::VarOrRVar* p)->var} } |]
  [CU.exp| void { delete $(const Halide::VarOrRVar* p) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Int32
expr

class HasHalideType a where
  getHalideType :: a -> IO HalideType

instance HasHalideType (Expr a) where
  getHalideType :: Expr a -> IO HalideType
getHalideType (Expr ForeignPtr CxxExpr
fp) =
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e -> forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
      [CU.block| void {
        *$(halide_type_t* t) = static_cast<halide_type_t>(
          $(Halide::Expr* e)->type()); } |]
      forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
  getHalideType (Var ForeignPtr CxxVar
fp) =
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
e -> forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
      [CU.block| void {
        *$(halide_type_t* t) = static_cast<halide_type_t>(
          static_cast<Halide::Expr>(*$(Halide::Var* e)).type()); } |]
      forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
  getHalideType (RVar ForeignPtr CxxRVar
fp) =
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxRVar
e -> forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
      [CU.block| void {
        *$(halide_type_t* t) = static_cast<halide_type_t>(
          static_cast<Halide::Expr>(*$(Halide::RVar* e)).type()); } |]
      forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
  getHalideType Expr a
_ = forall a. HasCallStack => String -> a
error String
"not implemented"

instance HasHalideType (Ptr CxxExpr) where
  getHalideType :: Ptr CxxExpr -> IO HalideType
getHalideType Ptr CxxExpr
e =
    forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
      [CU.block| void {
        *$(halide_type_t* t) = static_cast<halide_type_t>($(Halide::Expr* e)->type()); } |]
      forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t

instance HasHalideType (Ptr CxxVar) where
  getHalideType :: Ptr CxxVar -> IO HalideType
getHalideType Ptr CxxVar
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @Int32)

instance HasHalideType (Ptr CxxRVar) where
  getHalideType :: Ptr CxxRVar -> IO HalideType
getHalideType Ptr CxxRVar
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @Int32)

instance HasHalideType (Ptr CxxParameter) where
  getHalideType :: Ptr CxxParameter -> IO HalideType
getHalideType Ptr CxxParameter
p =
    forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
      [CU.block| void {
        *$(halide_type_t* t) = static_cast<halide_type_t>($(Halide::Internal::Parameter* p)->type()); } |]
      forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t

-- | Wrap a raw @Halide::Internal::Parameter@ pointer in a Haskell value.
--
-- __Note:__ v'Var' objects correspond to expressions of type 'Int32'.
wrapCxxParameter :: Ptr CxxParameter -> IO (ForeignPtr CxxParameter)
wrapCxxParameter :: Ptr CxxParameter -> IO (ForeignPtr CxxParameter)
wrapCxxParameter = forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxParameter -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxParameter -> IO ())
deleter = [C.funPtr| void deleteParameter(Halide::Internal::Parameter *p) { delete p; } |]

-- | Helper function to assert that the runtime type of the expression matches it's
-- compile-time type.
--
-- Essentially, given an @(x :: 'Expr' a)@, we check that @x.type()@ in C++ is equal to
-- @'halideTypeFor' (Proxy \@a)@ in Haskell.
checkType :: forall a t. (HasCallStack, IsHalideType a, HasHalideType t) => t -> IO ()
checkType :: forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType t
x = do
  let hsType :: HalideType
hsType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)
  HalideType
cxxType <- forall a. HasHalideType a => a -> IO HalideType
getHalideType t
x
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (HalideType
cxxType forall a. Eq a => a -> a -> Bool
== HalideType
hsType) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"Type mismatch: C++ Expr has type "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show HalideType
cxxType
      forall a. Semigroup a => a -> a -> a
<> String
", but its Haskell counterpart has type "
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show HalideType
hsType

mkScalarParameter :: forall a. IsHalideType a => Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter :: forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter Maybe Text
maybeName = do
  forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
    let createWithoutName :: IO (Ptr CxxParameter)
createWithoutName =
          [CU.exp| Halide::Internal::Parameter* {
            new Halide::Internal::Parameter{Halide::Type{*$(halide_type_t* t)}, false, 0} } |]
        createWithName :: Text -> IO (Ptr CxxParameter)
createWithName Text
name =
          let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
           in [CU.exp| Halide::Internal::Parameter* {
                new Halide::Internal::Parameter{
                  Halide::Type{*$(halide_type_t* t)},
                  false,
                  0,
                  std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}}
              } |]
    Ptr CxxParameter
p <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO (Ptr CxxParameter)
createWithoutName Text -> IO (Ptr CxxParameter)
createWithName Maybe Text
maybeName
    forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a Ptr CxxParameter
p
    Ptr CxxParameter -> IO (ForeignPtr CxxParameter)
wrapCxxParameter Ptr CxxParameter
p

getScalarParameter
  :: forall a
   . IsHalideType a
  => Maybe Text
  -> IORef (Maybe (ForeignPtr CxxParameter))
  -> IO (ForeignPtr CxxParameter)
getScalarParameter :: forall a.
IsHalideType a =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxParameter))
-> IO (ForeignPtr CxxParameter)
getScalarParameter Maybe Text
name IORef (Maybe (ForeignPtr CxxParameter))
r = do
  forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just ForeignPtr CxxParameter
fp -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxParameter
fp
    Maybe (ForeignPtr CxxParameter)
Nothing -> do
      ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter @a Maybe Text
name
      forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxParameter))
r (forall a. a -> Maybe a
Just ForeignPtr CxxParameter
fp)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxParameter
fp

-- | Make sure that the expression is fully constructed. That means that if we
-- are dealing with a 'ScalarParam' rather than an 'Expr', we force the
-- construction of the underlying @Halide::Internal::Parameter@ and convert it
-- to an 'Expr'.
forceExpr :: forall a. (HasCallStack, IsHalideType a) => Expr a -> IO (Expr a)
forceExpr :: forall a. (HasCallStack, IsHalideType a) => Expr a -> IO (Expr a)
forceExpr x :: Expr a
x@(Expr ForeignPtr CxxExpr
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr a
x
forceExpr (Var ForeignPtr CxxVar
fp) =
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
varPtr ->
    forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
      [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
        static_cast<Halide::Expr>(*$(Halide::Var* varPtr))} } |]
forceExpr (RVar ForeignPtr CxxRVar
fp) =
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxRVar
rvarPtr ->
    forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
      [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
        static_cast<Halide::Expr>(*$(Halide::RVar* rvarPtr))} } |]
forceExpr (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) =
  forall a.
IsHalideType a =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxParameter))
-> IO (ForeignPtr CxxParameter)
getScalarParameter @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ForeignPtr CxxParameter
fp -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxParameter
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
paramPtr ->
    forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
      [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
        Halide::Internal::Variable::make(
          $(Halide::Internal::Parameter* paramPtr)->type(),
          $(Halide::Internal::Parameter* paramPtr)->name(),
          *$(Halide::Internal::Parameter* paramPtr))} } |]

-- | Use the underlying @Halide::Expr@ in an 'IO' action.
asExpr :: IsHalideType a => Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr :: forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
x = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a
x)

-- | Allows applying 'asExpr', 'asVar', 'asRVar', and 'asVarOrRVar' to multiple arguments.
--
-- Example usage:
--
-- > asVectorOf @((~) (Expr Int32)) asVarOrRVar (fromTuple args) $ \v -> do
-- >   withFunc func $ \f ->
-- >     [C.throwBlock| void { $(Halide::Func* f)->reorder(
-- >                             *$(std::vector<Halide::VarOrRVar>* v)); } |]
asVectorOf
  :: forall c k ts a
   . (All c ts, HasCxxVector k)
  => (forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
  -> Arguments ts
  -> (Ptr (CxxVector k) -> IO a)
  -> IO a
asVectorOf :: forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf forall t b. c t => t -> (Ptr k -> IO b) -> IO b
asPtr Arguments ts
args Ptr (CxxVector k) -> IO a
action =
  forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector (forall (ts' :: [*]).
All c ts' =>
Arguments ts' -> Ptr (CxxVector k) -> IO a
go Arguments ts
args)
  where
    go
      :: All c ts'
      => Arguments ts'
      -> Ptr (CxxVector k)
      -> IO a
    go :: forall (ts' :: [*]).
All c ts' =>
Arguments ts' -> Ptr (CxxVector k) -> IO a
go Arguments ts'
Nil Ptr (CxxVector k)
v = Ptr (CxxVector k) -> IO a
action Ptr (CxxVector k)
v
    go (t
x ::: Arguments ts
xs) Ptr (CxxVector k)
v = forall t b. c t => t -> (Ptr k -> IO b) -> IO b
asPtr t
x forall a b. (a -> b) -> a -> b
$ \Ptr k
p -> forall a. HasCxxVector a => Ptr (CxxVector a) -> Ptr a -> IO ()
cxxVectorPushBack Ptr (CxxVector k)
v Ptr k
p forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (ts' :: [*]).
All c ts' =>
Arguments ts' -> Ptr (CxxVector k) -> IO a
go Arguments ts
xs Ptr (CxxVector k)
v

withMany
  :: forall k t a
   . (HasCxxVector k)
  => (t -> (Ptr k -> IO a) -> IO a)
  -> [t]
  -> (Ptr (CxxVector k) -> IO a)
  -> IO a
withMany :: forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany t -> (Ptr k -> IO a) -> IO a
asPtr [t]
args Ptr (CxxVector k) -> IO a
action =
  forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector ([t] -> Ptr (CxxVector k) -> IO a
go [t]
args)
  where
    go :: [t] -> Ptr (CxxVector k) -> IO a
go [] Ptr (CxxVector k)
v = Ptr (CxxVector k) -> IO a
action Ptr (CxxVector k)
v
    go (t
x : [t]
xs) Ptr (CxxVector k)
v = t -> (Ptr k -> IO a) -> IO a
asPtr t
x forall a b. (a -> b) -> a -> b
$ \Ptr k
p -> forall a. HasCxxVector a => Ptr (CxxVector a) -> Ptr a -> IO ()
cxxVectorPushBack Ptr (CxxVector k)
v Ptr k
p forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [t] -> Ptr (CxxVector k) -> IO a
go [t]
xs Ptr (CxxVector k)
v

-- | Use the underlying @Halide::Var@ in an 'IO' action.
asVar :: HasCallStack => Expr Int32 -> (Ptr CxxVar -> IO b) -> IO b
asVar :: forall b.
HasCallStack =>
Expr Int32 -> (Ptr CxxVar -> IO b) -> IO b
asVar (Var ForeignPtr CxxVar
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp
asVar Expr Int32
_ = forall a. HasCallStack => String -> a
error String
"the expression is not a Var"

-- | Use the underlying @Halide::RVar@ in an 'IO' action.
asRVar :: HasCallStack => Expr Int32 -> (Ptr CxxRVar -> IO b) -> IO b
asRVar :: forall b.
HasCallStack =>
Expr Int32 -> (Ptr CxxRVar -> IO b) -> IO b
asRVar (RVar ForeignPtr CxxRVar
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp
asRVar Expr Int32
_ = forall a. HasCallStack => String -> a
error String
"the expression is not an RVar"

-- | Use the underlying v'Var' or v'RVar' as @Halide::VarOrRVar@ in an 'IO' action.
asVarOrRVar :: HasCallStack => VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar :: forall b.
HasCallStack =>
Expr Int32 -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar Expr Int32
x Ptr CxxVarOrRVar -> IO b
action = case Expr Int32
x of
  Var ForeignPtr CxxVar
fp ->
    let allocate :: Ptr CxxVar -> IO (Ptr CxxVarOrRVar)
allocate Ptr CxxVar
p = [CU.exp| Halide::VarOrRVar* { new Halide::VarOrRVar{*$(Halide::Var* p)} } |]
     in forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp (IO (Ptr CxxVarOrRVar) -> IO b
run forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CxxVar -> IO (Ptr CxxVarOrRVar)
allocate)
  RVar ForeignPtr CxxRVar
fp ->
    let allocate :: Ptr CxxRVar -> IO (Ptr CxxVarOrRVar)
allocate Ptr CxxRVar
p = [CU.exp| Halide::VarOrRVar* { new Halide::VarOrRVar{*$(Halide::RVar* p)} } |]
     in forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp (IO (Ptr CxxVarOrRVar) -> IO b
run forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CxxRVar -> IO (Ptr CxxVarOrRVar)
allocate)
  Expr Int32
_ -> forall a. HasCallStack => String -> a
error String
"the expression is not a Var or an RVar"
  where
    destroy :: Ptr CxxVarOrRVar -> IO ()
destroy Ptr CxxVarOrRVar
p = [CU.exp| void { delete $(Halide::VarOrRVar* p) } |]
    run :: IO (Ptr CxxVarOrRVar) -> IO b
run IO (Ptr CxxVarOrRVar)
allocate = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr CxxVarOrRVar)
allocate Ptr CxxVarOrRVar -> IO ()
destroy Ptr CxxVarOrRVar -> IO b
action

-- | Use the underlying @Halide::RVar@ in an 'IO' action.
asScalarParam :: forall a b. (HasCallStack, IsHalideType a) => Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam :: forall a b.
(HasCallStack, IsHalideType a) =>
Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) Ptr CxxParameter -> IO b
action = do
  ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxParameter))
-> IO (ForeignPtr CxxParameter)
getScalarParameter @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxParameter))
r
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxParameter
fp Ptr CxxParameter -> IO b
action
asScalarParam Expr a
_ Ptr CxxParameter -> IO b
_ = forall a. HasCallStack => String -> a
error String
"the expression is not a ScalarParam"

-- | Get the underlying 'ForeignPtr CxxExpr'.
exprToForeignPtr :: IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr :: forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a
x =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$!
    forall a. (HasCallStack, IsHalideType a) => Expr a -> IO (Expr a)
forceExpr Expr a
x forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      (Expr ForeignPtr CxxExpr
fp) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxExpr
fp
      Expr a
_ -> forall a. HasCallStack => String -> a
error String
"this cannot happen"

-- | Lift a unary function working with @Halide::Expr@ to work with 'Expr'.
unaryOp :: IsHalideType a => (Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp :: forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Expr a
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
a forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
aPtr ->
    forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
destPtr ->
      Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Ptr CxxExpr
aPtr Ptr CxxExpr
destPtr

-- | Lift a binary function working with @Halide::Expr@ to work with 'Expr'.
binaryOp
  :: (IsHalideType a, IsHalideType b, IsHalideType c)
  => (Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
  -> Expr a
  -> Expr b
  -> Expr c
binaryOp :: forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Expr a
a Expr b
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
a forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
aPtr -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr b
b forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
bPtr ->
    forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
destPtr ->
      Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Ptr CxxExpr
aPtr Ptr CxxExpr
bPtr Ptr CxxExpr
destPtr