{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module Language.Halide.Expr
( Expr (..)
, Var
, RVar
, VarOrRVar
, Int32
, mkExpr
, mkVar
, mkRVar
, cast
, eq
, neq
, lt
, lte
, gt
, gte
, bool
, undef
, evaluate
, printed
, toIntImm
, exprToForeignPtr
, cxxConstructExpr
, wrapCxxRVar
, wrapCxxVarOrRVar
, wrapCxxParameter
, asExpr
, asVar
, asRVar
, asVarOrRVar
, asScalarParam
, asVectorOf
, mkScalarParameter
, withMany
, binaryOp
, unaryOp
, checkType
)
where
import Control.Exception (bracket)
import Control.Monad (unless)
import Data.IORef
import Data.Int (Int32)
import Data.Proxy
import Data.Ratio (denominator, numerator)
import Data.Text (Text, unpack)
import Data.Text.Encoding qualified as T
import Data.Vector.Storable.Mutable qualified as SM
import Foreign.ForeignPtr
import Foreign.Marshal (alloca, allocaArray, peekArray, toBool, with)
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import Foreign.Storable (peek)
import GHC.Stack (HasCallStack)
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (min)
importHalide
instanceCxxConstructible "Halide::Expr"
instanceCxxConstructible "Halide::Var"
instanceCxxConstructible "Halide::RVar"
instanceCxxConstructible "Halide::VarOrRVar"
Storable Double
Storable Float
Storable Int8
Storable Int16
Storable Int32
Storable Int64
Storable Word8
Storable Word16
Storable Word32
Storable Word64
Storable CFloat
Storable CDouble
Double -> IO (ForeignPtr CxxExpr)
Float -> IO (ForeignPtr CxxExpr)
Int8 -> IO (ForeignPtr CxxExpr)
Int16 -> IO (ForeignPtr CxxExpr)
Int32 -> IO (ForeignPtr CxxExpr)
Int64 -> IO (ForeignPtr CxxExpr)
Word8 -> IO (ForeignPtr CxxExpr)
Word16 -> IO (ForeignPtr CxxExpr)
Word32 -> IO (ForeignPtr CxxExpr)
Word64 -> IO (ForeignPtr CxxExpr)
CFloat -> IO (ForeignPtr CxxExpr)
CDouble -> IO (ForeignPtr CxxExpr)
forall a.
Storable a
-> (forall (proxy :: * -> *). proxy a -> HalideType)
-> (a -> IO (ForeignPtr CxxExpr))
-> IsHalideType a
forall (proxy :: * -> *). proxy Double -> HalideType
forall (proxy :: * -> *). proxy Float -> HalideType
forall (proxy :: * -> *). proxy Int8 -> HalideType
forall (proxy :: * -> *). proxy Int16 -> HalideType
forall (proxy :: * -> *). proxy Int32 -> HalideType
forall (proxy :: * -> *). proxy Int64 -> HalideType
forall (proxy :: * -> *). proxy Word8 -> HalideType
forall (proxy :: * -> *). proxy Word16 -> HalideType
forall (proxy :: * -> *). proxy Word32 -> HalideType
forall (proxy :: * -> *). proxy Word64 -> HalideType
forall (proxy :: * -> *). proxy CFloat -> HalideType
forall (proxy :: * -> *). proxy CDouble -> HalideType
toCxxExpr :: Float -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Float -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Float -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Float -> HalideType
toCxxExpr :: CFloat -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: CFloat -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy CFloat -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy CFloat -> HalideType
toCxxExpr :: Double -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Double -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Double -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Double -> HalideType
toCxxExpr :: CDouble -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: CDouble -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy CDouble -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy CDouble -> HalideType
toCxxExpr :: Int8 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int8 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int8 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int8 -> HalideType
toCxxExpr :: Int16 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int16 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int16 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int16 -> HalideType
toCxxExpr :: Int32 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int32 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int32 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int32 -> HalideType
toCxxExpr :: Int64 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Int64 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Int64 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Int64 -> HalideType
toCxxExpr :: Word8 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word8 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word8 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word8 -> HalideType
toCxxExpr :: Word16 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word16 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word16 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word16 -> HalideType
toCxxExpr :: Word32 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word32 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word32 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word32 -> HalideType
toCxxExpr :: Word64 -> IO (ForeignPtr CxxExpr)
$ctoCxxExpr :: Word64 -> IO (ForeignPtr CxxExpr)
halideTypeFor :: forall (proxy :: * -> *). proxy Word64 -> HalideType
$chalideTypeFor :: forall (proxy :: * -> *). proxy Word64 -> HalideType
defineIsHalideTypeInstances
instanceHasCxxVector "Halide::Expr"
instanceHasCxxVector "Halide::Var"
instanceHasCxxVector "Halide::RVar"
instanceHasCxxVector "Halide::VarOrRVar"
instance IsHalideType Bool where
halideTypeFor :: forall (proxy :: * -> *). proxy Bool -> HalideType
halideTypeFor proxy Bool
_ = HalideTypeCode -> Word8 -> Word16 -> HalideType
HalideType HalideTypeCode
HalideTypeUInt Word8
1 Word16
1
toCxxExpr :: Bool -> IO (ForeignPtr CxxExpr)
toCxxExpr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
x) =
forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{cast(Halide::UInt(1), Halide::Expr{$(int x)})} } |]
type instance FromTuple (Expr a) = Arguments '[Expr a]
data Expr a
=
Expr (ForeignPtr CxxExpr)
|
Var (ForeignPtr CxxVar)
|
RVar (ForeignPtr CxxRVar)
|
ScalarParam (IORef (Maybe (ForeignPtr CxxParameter)))
type Var = Expr Int32
type RVar = Expr Int32
type VarOrRVar = Expr Int32
mkExpr :: IsHalideType a => a -> Expr a
mkExpr :: forall a. IsHalideType a => a -> Expr a
mkExpr a
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$! forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IsHalideType a => a -> IO (ForeignPtr CxxExpr)
toCxxExpr a
x
mkVar :: Text -> IO (Expr Int32)
mkVar :: Text -> IO (Expr Int32)
mkVar (Text -> ByteString
T.encodeUtf8 -> ByteString
s) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
[CU.exp| void {
new ($(Halide::Var* ptr)) Halide::Var{std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}} } |]
mkRVar
:: Text
-> Expr Int32
-> Expr Int32
-> IO (Expr Int32)
mkRVar :: Text -> Expr Int32 -> Expr Int32 -> IO (Expr Int32)
mkRVar Text
name Expr Int32
min Expr Int32
extent =
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
min forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
min' ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extent' ->
Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::RVar* {
new Halide::RVar{static_cast<Halide::RVar>(Halide::RDom{
*$(const Halide::Expr* min'),
*$(const Halide::Expr* extent'),
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}
})}
} |]
where
s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
undef :: forall a. IsHalideType a => Expr a
undef :: forall a. IsHalideType a => Expr a
undef = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
tp ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void {
new ($(Halide::Expr* ptr))
Halide::Expr{Halide::undef(Halide::Type{*$(const halide_type_t* tp)})} } |]
{-# NOINLINE undef #-}
cast :: forall to from. (IsHalideType to, IsHalideType from) => Expr from -> Expr to
cast :: forall to from.
(IsHalideType to, IsHalideType from) =>
Expr from -> Expr to
cast Expr from
expr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr from
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e ->
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @to)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
Halide::cast(Halide::Type{*$(halide_type_t* t)}, *$(Halide::Expr* e))} } |]
printed :: IsHalideType a => Expr a -> Expr a
printed :: forall a. IsHalideType a => Expr a -> Expr a
printed = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{print(*$(Halide::Expr* e))} } |]
infix 4 `eq`, `neq`, `lt`, `lte`, `gt`, `gte`
eq :: IsHalideType a => Expr a -> Expr a -> Expr Bool
eq :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
eq = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
(*$(Halide::Expr* a)) == (*$(Halide::Expr* b))} } |]
neq :: IsHalideType a => Expr a -> Expr a -> Expr Bool
neq :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
neq = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
(*$(Halide::Expr* a)) != (*$(Halide::Expr* b))} } |]
lt :: IsHalideType a => Expr a -> Expr a -> Expr Bool
lt :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
lt = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
(*$(Halide::Expr* a)) < (*$(Halide::Expr* b))} } |]
lte :: IsHalideType a => Expr a -> Expr a -> Expr Bool
lte :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
lte = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
(*$(Halide::Expr* a)) <= (*$(Halide::Expr* b))} } |]
gt :: IsHalideType a => Expr a -> Expr a -> Expr Bool
gt :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
gt = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
(*$(Halide::Expr* a)) > (*$(Halide::Expr* b))} } |]
gte :: IsHalideType a => Expr a -> Expr a -> Expr Bool
gte :: forall a. IsHalideType a => Expr a -> Expr a -> Expr Bool
gte = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
(*$(Halide::Expr* a)) >= (*$(Halide::Expr* b))} } |]
bool :: IsHalideType a => Expr Bool -> Expr a -> Expr a -> Expr a
bool :: forall a. IsHalideType a => Expr Bool -> Expr a -> Expr a -> Expr a
bool Expr Bool
condExpr Expr a
trueExpr Expr a
falseExpr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Bool
condExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
p ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
trueExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
t ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
falseExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
f ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void {
new ($(Halide::Expr* ptr)) Halide::Expr{
Halide::select(*$(Halide::Expr* p),
*$(Halide::Expr* t), *$(Halide::Expr* f))} } |]
evaluate :: forall a. IsHalideType a => Expr a -> IO a
evaluate :: forall a. IsHalideType a => Expr a -> IO a
evaluate Expr a
expr =
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e -> do
MVector RealWorld a
out <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
SM.new Int
1
forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer MVector RealWorld a
out forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer 1 a)
buffer -> do
let b :: Ptr RawHalideBuffer
b = forall a b. Ptr a -> Ptr b
castPtr (Ptr (HalideBuffer 1 a)
buffer :: Ptr (HalideBuffer 1 a))
[C.throwBlock| void {
handle_halide_exceptions([=]() {
Halide::Func f;
Halide::Var i;
f(i) = *$(Halide::Expr* e);
f.realize(Halide::Pipeline::RealizationArg{$(halide_buffer_t* b)});
});
} |]
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m a
SM.read MVector RealWorld a
out Int
0
toIntImm :: IsHalideType a => Expr a -> Maybe Int
toIntImm :: forall a. IsHalideType a => Expr a -> Maybe Int
toIntImm Expr a
expr = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
expr' -> do
Ptr Int64
intPtr <-
[CU.block| const int64_t* {
auto expr = *$(const Halide::Expr* expr');
Halide::Internal::IntImm const* node = expr.as<Halide::Internal::IntImm>();
if (node == nullptr) return nullptr;
return &node->value;
} |]
if Ptr Int64
intPtr forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
else forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr Int64
intPtr
instance IsTuple (Arguments '[Expr a]) (Expr a) where
toTuple :: Arguments '[Expr a] -> Expr a
toTuple (t
x ::: Arguments ts
Nil) = t
x
fromTuple :: Expr a -> Arguments '[Expr a]
fromTuple Expr a
x = Expr a
x forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments '[]
Nil
instance IsHalideType a => Show (Expr a) where
show :: Expr a -> String
show (Expr ForeignPtr CxxExpr
expr) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
x ->
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* { to_string_via_iostream(*$(const Halide::Expr* x)) } |]
show (Var ForeignPtr CxxVar
var) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
x ->
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* { to_string_via_iostream(*$(const Halide::Var* x)) } |]
show (RVar ForeignPtr CxxRVar
rvar) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
rvar forall a b. (a -> b) -> a -> b
$ \Ptr CxxRVar
x ->
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* { to_string_via_iostream(*$(const Halide::RVar* x)) } |]
show (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) = Text -> String
unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Maybe (ForeignPtr CxxParameter)
maybeParam <- forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r
case Maybe (ForeignPtr CxxParameter)
maybeParam of
Just ForeignPtr CxxParameter
fp ->
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxParameter
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
x ->
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* {
new std::string{$(const Halide::Internal::Parameter* x)->name()} } |]
Maybe (ForeignPtr CxxParameter)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
"ScalarParam"
instance (IsHalideType a, Num a) => Num (Expr a) where
fromInteger :: Integer -> Expr a
fromInteger :: Integer -> Expr a
fromInteger Integer
x = forall a. IsHalideType a => a -> Expr a
mkExpr (forall a. Num a => Integer -> a
fromInteger Integer
x :: a)
(+) :: Expr a -> Expr a -> Expr a
+ :: Expr a -> Expr a -> Expr a
(+) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) + *$(Halide::Expr* b)} } |]
(-) :: Expr a -> Expr a -> Expr a
(-) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) - *$(Halide::Expr* b)} } |]
(*) :: Expr a -> Expr a -> Expr a
* :: Expr a -> Expr a -> Expr a
(*) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) * *$(Halide::Expr* b)} } |]
abs :: Expr a -> Expr a
abs :: Expr a -> Expr a
abs = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr ->
[CU.block| void {
if ($(Halide::Expr* a)->type().is_uint()) {
new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a)};
}
else {
new ($(Halide::Expr* ptr)) Halide::Expr{
Halide::cast($(Halide::Expr* a)->type(), Halide::abs(*$(Halide::Expr* a)))};
}
} |]
negate :: Expr a -> Expr a
negate :: Expr a -> Expr a
negate = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{ -(*$(Halide::Expr* a))} } |]
signum :: Expr a -> Expr a
signum :: Expr a -> Expr a
signum = forall a. HasCallStack => String -> a
error String
"Num instance of (Expr a) does not implement signum"
instance (IsHalideType a, Fractional a) => Fractional (Expr a) where
(/) :: Expr a -> Expr a -> Expr a
/ :: Expr a -> Expr a -> Expr a
(/) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{*$(Halide::Expr* a) / *$(Halide::Expr* b)} } |]
fromRational :: Rational -> Expr a
fromRational :: Rational -> Expr a
fromRational Rational
r = forall a. Num a => Integer -> a
fromInteger (forall a. Ratio a -> a
numerator Rational
r) forall a. Fractional a => a -> a -> a
/ forall a. Num a => Integer -> a
fromInteger (forall a. Ratio a -> a
denominator Rational
r)
instance (IsHalideType a, Floating a) => Floating (Expr a) where
pi :: Expr a
pi :: Expr a
pi = forall to from.
(IsHalideType to, IsHalideType from) =>
Expr from -> Expr to
cast @a @Double forall a b. (a -> b) -> a -> b
$! forall a. IsHalideType a => a -> Expr a
mkExpr (forall a. Floating a => a
pi :: Double)
exp :: Expr a -> Expr a
exp :: Expr a -> Expr a
exp = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::exp(*$(Halide::Expr* a))} } |]
log :: Expr a -> Expr a
log :: Expr a -> Expr a
log = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::log(*$(Halide::Expr* a))} } |]
sqrt :: Expr a -> Expr a
sqrt :: Expr a -> Expr a
sqrt = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::sqrt(*$(Halide::Expr* a))} } |]
(**) :: Expr a -> Expr a -> Expr a
** :: Expr a -> Expr a -> Expr a
(**) = forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
b Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::pow(*$(Halide::Expr* a), *$(Halide::Expr* b))} } |]
sin :: Expr a -> Expr a
sin :: Expr a -> Expr a
sin = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::sin(*$(Halide::Expr* a))} } |]
cos :: Expr a -> Expr a
cos :: Expr a -> Expr a
cos = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::cos(*$(Halide::Expr* a))} } |]
tan :: Expr a -> Expr a
tan :: Expr a -> Expr a
tan = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::tan(*$(Halide::Expr* a))} } |]
asin :: Expr a -> Expr a
asin :: Expr a -> Expr a
asin = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::asin(*$(Halide::Expr* a))} } |]
acos :: Expr a -> Expr a
acos :: Expr a -> Expr a
acos = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::acos(*$(Halide::Expr* a))} } |]
atan :: Expr a -> Expr a
atan :: Expr a -> Expr a
atan = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::atan(*$(Halide::Expr* a))} } |]
sinh :: Expr a -> Expr a
sinh :: Expr a -> Expr a
sinh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::sinh(*$(Halide::Expr* a))} } |]
cosh :: Expr a -> Expr a
cosh :: Expr a -> Expr a
cosh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::cosh(*$(Halide::Expr* a))} } |]
tanh :: Expr a -> Expr a
tanh :: Expr a -> Expr a
tanh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::tanh(*$(Halide::Expr* a))} } |]
asinh :: Expr a -> Expr a
asinh :: Expr a -> Expr a
asinh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::asinh(*$(Halide::Expr* a))} } |]
acosh :: Expr a -> Expr a
acosh :: Expr a -> Expr a
acosh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::acosh(*$(Halide::Expr* a))} } |]
atanh :: Expr a -> Expr a
atanh :: Expr a -> Expr a
atanh = forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
a Ptr CxxExpr
ptr -> [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{Halide::atanh(*$(Halide::Expr* a))} } |]
cxxConstructExpr :: forall a. (HasCallStack, IsHalideType a) => (Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr :: forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr Ptr CxxExpr -> IO ()
construct = do
ForeignPtr CxxExpr
fp <- forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct Ptr CxxExpr -> IO ()
construct
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
fp (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
fp)
wrapCxxRVar :: Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar :: Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxRVar -> Expr a
RVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxRVar -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxRVar -> IO ())
deleter = [C.funPtr| void deleteExpr(Halide::RVar *p) { delete p; } |]
wrapCxxVarOrRVar :: Ptr CxxVarOrRVar -> IO (Expr Int32)
wrapCxxVarOrRVar :: Ptr CxxVarOrRVar -> IO (Expr Int32)
wrapCxxVarOrRVar Ptr CxxVarOrRVar
p = do
Bool
isRVar <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::VarOrRVar* p)->is_rvar } |]
Expr Int32
expr <-
if Bool
isRVar
then Ptr CxxRVar -> IO (Expr Int32)
wrapCxxRVar forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::RVar* { new Halide::RVar{$(const Halide::VarOrRVar* p)->rvar} } |]
else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
[CU.exp| void { new ($(Halide::Var* ptr)) Halide::Var{$(const Halide::VarOrRVar* p)->var} } |]
[CU.exp| void { delete $(const Halide::VarOrRVar* p) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Int32
expr
class HasHalideType a where
getHalideType :: a -> IO HalideType
instance HasHalideType (Expr a) where
getHalideType :: Expr a -> IO HalideType
getHalideType (Expr ForeignPtr CxxExpr
fp) =
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e -> forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
[CU.block| void {
*$(halide_type_t* t) = static_cast<halide_type_t>(
$(Halide::Expr* e)->type()); } |]
forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
getHalideType (Var ForeignPtr CxxVar
fp) =
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
e -> forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
[CU.block| void {
*$(halide_type_t* t) = static_cast<halide_type_t>(
static_cast<Halide::Expr>(*$(Halide::Var* e)).type()); } |]
forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
getHalideType (RVar ForeignPtr CxxRVar
fp) =
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxRVar
e -> forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
[CU.block| void {
*$(halide_type_t* t) = static_cast<halide_type_t>(
static_cast<Halide::Expr>(*$(Halide::RVar* e)).type()); } |]
forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
getHalideType Expr a
_ = forall a. HasCallStack => String -> a
error String
"not implemented"
instance HasHalideType (Ptr CxxExpr) where
getHalideType :: Ptr CxxExpr -> IO HalideType
getHalideType Ptr CxxExpr
e =
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
[CU.block| void {
*$(halide_type_t* t) = static_cast<halide_type_t>($(Halide::Expr* e)->type()); } |]
forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
instance HasHalideType (Ptr CxxVar) where
getHalideType :: Ptr CxxVar -> IO HalideType
getHalideType Ptr CxxVar
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @Int32)
instance HasHalideType (Ptr CxxRVar) where
getHalideType :: Ptr CxxRVar -> IO HalideType
getHalideType Ptr CxxRVar
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @Int32)
instance HasHalideType (Ptr CxxParameter) where
getHalideType :: Ptr CxxParameter -> IO HalideType
getHalideType Ptr CxxParameter
p =
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
[CU.block| void {
*$(halide_type_t* t) = static_cast<halide_type_t>($(Halide::Internal::Parameter* p)->type()); } |]
forall a. Storable a => Ptr a -> IO a
peek Ptr HalideType
t
wrapCxxParameter :: Ptr CxxParameter -> IO (ForeignPtr CxxParameter)
wrapCxxParameter :: Ptr CxxParameter -> IO (ForeignPtr CxxParameter)
wrapCxxParameter = forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxParameter -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxParameter -> IO ())
deleter = [C.funPtr| void deleteParameter(Halide::Internal::Parameter *p) { delete p; } |]
checkType :: forall a t. (HasCallStack, IsHalideType a, HasHalideType t) => t -> IO ()
checkType :: forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType t
x = do
let hsType :: HalideType
hsType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)
HalideType
cxxType <- forall a. HasHalideType a => a -> IO HalideType
getHalideType t
x
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (HalideType
cxxType forall a. Eq a => a -> a -> Bool
== HalideType
hsType) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"Type mismatch: C++ Expr has type "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show HalideType
cxxType
forall a. Semigroup a => a -> a -> a
<> String
", but its Haskell counterpart has type "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show HalideType
hsType
mkScalarParameter :: forall a. IsHalideType a => Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter :: forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter Maybe Text
maybeName = do
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
let createWithoutName :: IO (Ptr CxxParameter)
createWithoutName =
[CU.exp| Halide::Internal::Parameter* {
new Halide::Internal::Parameter{Halide::Type{*$(halide_type_t* t)}, false, 0} } |]
createWithName :: Text -> IO (Ptr CxxParameter)
createWithName Text
name =
let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
in [CU.exp| Halide::Internal::Parameter* {
new Halide::Internal::Parameter{
Halide::Type{*$(halide_type_t* t)},
false,
0,
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}}
} |]
Ptr CxxParameter
p <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO (Ptr CxxParameter)
createWithoutName Text -> IO (Ptr CxxParameter)
createWithName Maybe Text
maybeName
forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a Ptr CxxParameter
p
Ptr CxxParameter -> IO (ForeignPtr CxxParameter)
wrapCxxParameter Ptr CxxParameter
p
getScalarParameter
:: forall a
. IsHalideType a
=> Maybe Text
-> IORef (Maybe (ForeignPtr CxxParameter))
-> IO (ForeignPtr CxxParameter)
getScalarParameter :: forall a.
IsHalideType a =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxParameter))
-> IO (ForeignPtr CxxParameter)
getScalarParameter Maybe Text
name IORef (Maybe (ForeignPtr CxxParameter))
r = do
forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just ForeignPtr CxxParameter
fp -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxParameter
fp
Maybe (ForeignPtr CxxParameter)
Nothing -> do
ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter @a Maybe Text
name
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxParameter))
r (forall a. a -> Maybe a
Just ForeignPtr CxxParameter
fp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxParameter
fp
forceExpr :: forall a. (HasCallStack, IsHalideType a) => Expr a -> IO (Expr a)
forceExpr :: forall a. (HasCallStack, IsHalideType a) => Expr a -> IO (Expr a)
forceExpr x :: Expr a
x@(Expr ForeignPtr CxxExpr
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr a
x
forceExpr (Var ForeignPtr CxxVar
fp) =
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
varPtr ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
static_cast<Halide::Expr>(*$(Halide::Var* varPtr))} } |]
forceExpr (RVar ForeignPtr CxxRVar
fp) =
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxRVar
rvarPtr ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
static_cast<Halide::Expr>(*$(Halide::RVar* rvarPtr))} } |]
forceExpr (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) =
forall a.
IsHalideType a =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxParameter))
-> IO (ForeignPtr CxxParameter)
getScalarParameter @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ForeignPtr CxxParameter
fp -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxParameter
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
paramPtr ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
Halide::Internal::Variable::make(
$(Halide::Internal::Parameter* paramPtr)->type(),
$(Halide::Internal::Parameter* paramPtr)->name(),
*$(Halide::Internal::Parameter* paramPtr))} } |]
asExpr :: IsHalideType a => Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr :: forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
x = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a
x)
asVectorOf
:: forall c k ts a
. (All c ts, HasCxxVector k)
=> (forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts
-> (Ptr (CxxVector k) -> IO a)
-> IO a
asVectorOf :: forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf forall t b. c t => t -> (Ptr k -> IO b) -> IO b
asPtr Arguments ts
args Ptr (CxxVector k) -> IO a
action =
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector (forall (ts' :: [*]).
All c ts' =>
Arguments ts' -> Ptr (CxxVector k) -> IO a
go Arguments ts
args)
where
go
:: All c ts'
=> Arguments ts'
-> Ptr (CxxVector k)
-> IO a
go :: forall (ts' :: [*]).
All c ts' =>
Arguments ts' -> Ptr (CxxVector k) -> IO a
go Arguments ts'
Nil Ptr (CxxVector k)
v = Ptr (CxxVector k) -> IO a
action Ptr (CxxVector k)
v
go (t
x ::: Arguments ts
xs) Ptr (CxxVector k)
v = forall t b. c t => t -> (Ptr k -> IO b) -> IO b
asPtr t
x forall a b. (a -> b) -> a -> b
$ \Ptr k
p -> forall a. HasCxxVector a => Ptr (CxxVector a) -> Ptr a -> IO ()
cxxVectorPushBack Ptr (CxxVector k)
v Ptr k
p forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (ts' :: [*]).
All c ts' =>
Arguments ts' -> Ptr (CxxVector k) -> IO a
go Arguments ts
xs Ptr (CxxVector k)
v
withMany
:: forall k t a
. (HasCxxVector k)
=> (t -> (Ptr k -> IO a) -> IO a)
-> [t]
-> (Ptr (CxxVector k) -> IO a)
-> IO a
withMany :: forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany t -> (Ptr k -> IO a) -> IO a
asPtr [t]
args Ptr (CxxVector k) -> IO a
action =
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector ([t] -> Ptr (CxxVector k) -> IO a
go [t]
args)
where
go :: [t] -> Ptr (CxxVector k) -> IO a
go [] Ptr (CxxVector k)
v = Ptr (CxxVector k) -> IO a
action Ptr (CxxVector k)
v
go (t
x : [t]
xs) Ptr (CxxVector k)
v = t -> (Ptr k -> IO a) -> IO a
asPtr t
x forall a b. (a -> b) -> a -> b
$ \Ptr k
p -> forall a. HasCxxVector a => Ptr (CxxVector a) -> Ptr a -> IO ()
cxxVectorPushBack Ptr (CxxVector k)
v Ptr k
p forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [t] -> Ptr (CxxVector k) -> IO a
go [t]
xs Ptr (CxxVector k)
v
asVar :: HasCallStack => Expr Int32 -> (Ptr CxxVar -> IO b) -> IO b
asVar :: forall b.
HasCallStack =>
Expr Int32 -> (Ptr CxxVar -> IO b) -> IO b
asVar (Var ForeignPtr CxxVar
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp
asVar Expr Int32
_ = forall a. HasCallStack => String -> a
error String
"the expression is not a Var"
asRVar :: HasCallStack => Expr Int32 -> (Ptr CxxRVar -> IO b) -> IO b
asRVar :: forall b.
HasCallStack =>
Expr Int32 -> (Ptr CxxRVar -> IO b) -> IO b
asRVar (RVar ForeignPtr CxxRVar
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp
asRVar Expr Int32
_ = forall a. HasCallStack => String -> a
error String
"the expression is not an RVar"
asVarOrRVar :: HasCallStack => VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar :: forall b.
HasCallStack =>
Expr Int32 -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar Expr Int32
x Ptr CxxVarOrRVar -> IO b
action = case Expr Int32
x of
Var ForeignPtr CxxVar
fp ->
let allocate :: Ptr CxxVar -> IO (Ptr CxxVarOrRVar)
allocate Ptr CxxVar
p = [CU.exp| Halide::VarOrRVar* { new Halide::VarOrRVar{*$(Halide::Var* p)} } |]
in forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxVar
fp (IO (Ptr CxxVarOrRVar) -> IO b
run forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CxxVar -> IO (Ptr CxxVarOrRVar)
allocate)
RVar ForeignPtr CxxRVar
fp ->
let allocate :: Ptr CxxRVar -> IO (Ptr CxxVarOrRVar)
allocate Ptr CxxRVar
p = [CU.exp| Halide::VarOrRVar* { new Halide::VarOrRVar{*$(Halide::RVar* p)} } |]
in forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxRVar
fp (IO (Ptr CxxVarOrRVar) -> IO b
run forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CxxRVar -> IO (Ptr CxxVarOrRVar)
allocate)
Expr Int32
_ -> forall a. HasCallStack => String -> a
error String
"the expression is not a Var or an RVar"
where
destroy :: Ptr CxxVarOrRVar -> IO ()
destroy Ptr CxxVarOrRVar
p = [CU.exp| void { delete $(Halide::VarOrRVar* p) } |]
run :: IO (Ptr CxxVarOrRVar) -> IO b
run IO (Ptr CxxVarOrRVar)
allocate = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr CxxVarOrRVar)
allocate Ptr CxxVarOrRVar -> IO ()
destroy Ptr CxxVarOrRVar -> IO b
action
asScalarParam :: forall a b. (HasCallStack, IsHalideType a) => Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam :: forall a b.
(HasCallStack, IsHalideType a) =>
Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) Ptr CxxParameter -> IO b
action = do
ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxParameter))
-> IO (ForeignPtr CxxParameter)
getScalarParameter @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxParameter))
r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxParameter
fp Ptr CxxParameter -> IO b
action
asScalarParam Expr a
_ Ptr CxxParameter -> IO b
_ = forall a. HasCallStack => String -> a
error String
"the expression is not a ScalarParam"
exprToForeignPtr :: IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr :: forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a
x =
forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$!
forall a. (HasCallStack, IsHalideType a) => Expr a -> IO (Expr a)
forceExpr Expr a
x forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
(Expr ForeignPtr CxxExpr
fp) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxExpr
fp
Expr a
_ -> forall a. HasCallStack => String -> a
error String
"this cannot happen"
unaryOp :: IsHalideType a => (Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp :: forall a.
IsHalideType a =>
(Ptr CxxExpr -> Ptr CxxExpr -> IO ()) -> Expr a -> Expr a
unaryOp Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Expr a
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
a forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
aPtr ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
destPtr ->
Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Ptr CxxExpr
aPtr Ptr CxxExpr
destPtr
binaryOp
:: (IsHalideType a, IsHalideType b, IsHalideType c)
=> (Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a
-> Expr b
-> Expr c
binaryOp :: forall a b c.
(IsHalideType a, IsHalideType b, IsHalideType c) =>
(Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ())
-> Expr a -> Expr b -> Expr c
binaryOp Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Expr a
a Expr b
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr a
a forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
aPtr -> forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr b
b forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
bPtr ->
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
destPtr ->
Ptr CxxExpr -> Ptr CxxExpr -> Ptr CxxExpr -> IO ()
f Ptr CxxExpr
aPtr Ptr CxxExpr
bPtr Ptr CxxExpr
destPtr